Skip to content

Commit

Permalink
ML: Add write type and have scripts return writes instead of calling …
Browse files Browse the repository at this point in the history
…bridge (#119)

The tests work now, but we can hold off on merging this until we've
checked that these changes are compatible with the orchestrator.

---------

Co-authored-by: Rory Tyler Hayford <[email protected]>
  • Loading branch information
siddharth-krishna and ngua authored May 9, 2024
1 parent bb24f11 commit 89470e8
Show file tree
Hide file tree
Showing 19 changed files with 305 additions and 238 deletions.
2 changes: 1 addition & 1 deletion inferno-ml-server-types/inferno-ml-server-types.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ source-repository head

common common
ghc-options:
-Wall -Werror -Wincomplete-uni-patterns -Wincomplete-record-updates
-Wall -Wincomplete-uni-patterns -Wincomplete-record-updates
-Wmissing-deriving-strategies

default-language: Haskell2010
Expand Down
8 changes: 6 additions & 2 deletions inferno-ml-server-types/src/Inferno/ML/Server/Client.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

module Inferno.ML.Server.Client
( statusC,
Expand All @@ -21,7 +22,10 @@ import Servant.Client.Streaming (ClientM, client)
statusC :: ClientM (Maybe ())

-- | Run an inference parameter
inferenceC :: Id (InferenceParam uid gid p s) -> Maybe Int64 -> ClientM ()
inferenceC ::
Id (InferenceParam uid gid p s) ->
Maybe Int64 ->
ClientM (WriteStream IO)

-- | Cancel the existing inference job, if it exists
cancelC :: ClientM ()
Expand All @@ -39,5 +43,5 @@ statusC
:<|> checkBridgeC =
client api

api :: Proxy (InfernoMlServerAPI uid gid p s)
api :: Proxy (InfernoMlServerAPI uid gid p s t)
api = Proxy
21 changes: 7 additions & 14 deletions inferno-ml-server-types/src/Inferno/ML/Server/Client/Bridge.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@ import Servant ((:<|>) (..))
import Servant.Client.Streaming (ClientM, client)
import Web.HttpApiData (ToHttpApiData)

-- | Write a stream of @(t, Double)@ pairs to the bridge server, where @t@ will
-- typically represent some time value (e.g. @EpochTime@)
writePairsC ::
( ToJSON t,
ToHttpApiData p,
ToHttpApiData t
) =>
p ->
PairStream t IO ->
ClientM ()

-- | Get the value at the given time via the bridge, for the given entity @p@
valueAtC ::
( ToJSON t,
Expand All @@ -43,9 +32,13 @@ latestValueAndTimeBeforeC ::
t ->
p ->
ClientM IValue
writePairsC
:<|> valueAtC
:<|> latestValueAndTimeBeforeC =

-- | Get an array of values falling between the two times
valuesBetweenC ::
(ToHttpApiData p, ToHttpApiData t) => Int64 -> p -> t -> t -> ClientM IValue
valueAtC
:<|> latestValueAndTimeBeforeC
:<|> valuesBetweenC =
client api

api :: Proxy (BridgeAPI p t)
Expand Down
56 changes: 37 additions & 19 deletions inferno-ml-server-types/src/Inferno/ML/Server/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ import Servant
QueryParam',
ReqBody,
Required,
StreamBody,
StreamPost,
(:<|>),
(:>),
)
Expand All @@ -96,7 +96,7 @@ import Web.HttpApiData
)

-- API type for `inferno-ml-server`
type InfernoMlServerAPI uid gid p s =
type InfernoMlServerAPI uid gid p s t =
-- Check if the server is up and if any job is currently running:
--
-- * `Nothing` -> The server is evaluating a script
Expand All @@ -108,7 +108,7 @@ type InfernoMlServerAPI uid gid p s =
:<|> "inference"
:> Capture "id" (Id (InferenceParam uid gid p s))
:> QueryParam "res" Int64
:> Post '[JSON] ()
:> StreamPost NewlineFraming JSON (WriteStream IO)
:<|> "inference" :> "cancel" :> Put '[JSON] ()
-- Register the bridge. This is an `inferno-ml-server` endpoint, not a
-- bridge endpoint
Expand All @@ -120,24 +120,31 @@ type InfernoMlServerAPI uid gid p s =
-- by a bridge server connected to a data source, not by `inferno-ml-server`
type BridgeAPI p t =
"bridge"
:> "write"
:> "pairs"
:> Capture "p" p
:> StreamBody NewlineFraming JSON (PairStream t IO)
:> Post '[JSON] ()
:> "value-at"
:> QueryParam' '[Required] "res" Int64
:> QueryParam' '[Required] "p" p
:> QueryParam' '[Required] "time" t
:> Get '[JSON] IValue
:<|> "bridge"
:> "value-at"
:> QueryParam' '[Required] "res" Int64
:> QueryParam' '[Required] "p" p
:> "latest-value-and-time-before"
:> QueryParam' '[Required] "time" t
:> QueryParam' '[Required] "p" p
:> Get '[JSON] IValue
:<|> "bridge"
:> "latest-value-and-time-before"
:> QueryParam' '[Required] "time" t
:> "values-between"
:> QueryParam' '[Required] "res" Int64
:> QueryParam' '[Required] "p" p
:> QueryParam' '[Required] "t1" t
:> QueryParam' '[Required] "t2" t
:> Get '[JSON] IValue

type PairStream t m = ConduitT () (t, Double) m ()
-- | Stream of writes that an ML parameter script results in. Each element
-- in the stream is a chunk (sub-list) of the original values that the
-- inference script evaluates to. For example, given the following output:
-- @[ (1, [ (100, 5.0) .. (10000, 5000.0) ]) ]@; the stream items will be:
-- @(1, [ (100, 5.0) .. (500, 2500.0) ]), (1, [ (501, 2501.0) .. (10000, 5000.0) ])@.
-- This means the same output may appear more than once in the stream
type WriteStream m = ConduitT () (Int, [(EpochTime, IValue)]) m ()

-- | Information for contacting a bridge server that implements the 'BridgeAPI'
data BridgeInfo = BridgeInfo
Expand Down Expand Up @@ -697,6 +704,7 @@ data IValue
| ITuple (IValue, IValue)
| ITime EpochTime
| IEmpty
| IArray (Vector IValue)
deriving stock (Show, Eq, Generic)
deriving anyclass (NFData)

Expand All @@ -709,10 +717,19 @@ instance FromJSON IValue where
Object o -> ITime <$> o .: "time"
Array a
| [x, y] <- Vector.toList a ->
fmap ITuple $
(,) <$> parseJSON x <*> parseJSON y
| Vector.null a -> pure IEmpty
_ -> fail "Expected one of: string, double, time, tuple, unit (empty array)"
(,) <$> parseJSON x <*> parseJSON y <&> \case
-- We don't want to confuse a two-element array of tuples with
-- a tuple itself. For example, `"[[10.0, {\"time\": 10}], [10.0, {\"time\": 10}]]"`
-- should parse as a two-element array of `(double, time)` tuples,
-- not as a `((double, time), (double, time))`. I can't think of
-- any reason to support the latter. An alternative would be to
-- change tuple encoding to an object, but then we would be transmitting
-- a much larger amount of data on most requests
(f@(ITuple _), s@(ITuple _)) -> IArray $ Vector.fromList [f, s]
t -> ITuple t
| otherwise -> IArray <$> traverse (parseJSON @IValue) a
Null -> pure IEmpty
_ -> fail "Expected one of: string, double, time, tuple, unit (empty array), array"

instance ToJSON IValue where
toJSON = \case
Expand All @@ -721,7 +738,8 @@ instance ToJSON IValue where
ITuple t -> toJSON t
-- See `FromJSON` instance above
ITime t -> object ["time" .= t]
IEmpty -> toJSON ()
IEmpty -> toJSON Null
IArray is -> toJSON is

-- | Used to represent inputs to the script. 'Many' allows for an array input
data SingleOrMany a
Expand Down
60 changes: 54 additions & 6 deletions inferno-ml-server/exe/Client.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}

-- NOTE
-- This executable is only intended for testing the inference endpoint with the
-- `nixosTest` (see `../../../tests/server.nix`)

module Client (main) where

import Control.Monad (void)
import Conduit
import Control.Monad (unless, void)
import Data.Coerce (coerce)
import qualified Data.Conduit.List as Conduit.List
import Data.Function (on)
import Data.Int (Int64)
import qualified Data.Map as Map
import Inferno.ML.Server.Client (inferenceC, registerBridgeC)
import Inferno.ML.Server.Types (BridgeInfo (BridgeInfo), Id (Id), toIPv4)
import Inferno.ML.Server.Types
( BridgeInfo (BridgeInfo),
IValue (IDouble),
Id (Id),
WriteStream,
toIPv4,
)
import Network.HTTP.Client (defaultManagerSettings, newManager)
import Servant.Client.Streaming
( mkClientEnv,
Expand All @@ -15,6 +30,7 @@ import Servant.Client.Streaming
withClientM,
)
import System.Exit (die)
import System.Posix.Types (EpochTime)
import Text.Read (readMaybe)
import UnliftIO (throwString)
import UnliftIO.Environment (getArgs)
Expand All @@ -35,8 +51,40 @@ main =
. registerBridgeC
. flip BridgeInfo 9999
$ toIPv4 (127, 0, 0, 1)
-- Run the given inference param. The test scripts should use `writePairs`.
-- The dummy bridge implementation will write this to a file for later
-- inspection
withClientM (inferenceC ipid Nothing) env . either throwIO . const $ pure ()
withClientM (inferenceC ipid Nothing) env . either throwIO $
verifyWrites (coerce ipid)
_ -> die "Usage: test-client <inference-parameter-id>"

-- Check that the returned write stream matches the expected value
verifyWrites ::
Int64 ->
WriteStream IO ->
IO ()
verifyWrites ipid c = do
expected <- getExpected
result <- rebuildWrites
unless (result == expected) . throwString . unwords $
[ "Expected: ",
show expected,
"but got:",
show result,
"for param",
show ipid
]
where
rebuildWrites :: IO [(Int, [(EpochTime, IValue)])]
rebuildWrites =
runConduit $
c
.| Conduit.List.groupBy ((==) `on` fst)
.| Conduit.List.concat
.| sinkList

getExpected :: IO [(Int, [(EpochTime, IValue)])]
getExpected =
maybe (throwString "Missing PID") pure . Map.lookup ipid $
Map.fromList
[ (1, [(1, [(151, IDouble 2.5), (251, IDouble 3.5)])]),
(2, [(2, [(300, IDouble 25.0)])]),
(3, [(3, [(100, IDouble 7.0)])])
]
19 changes: 5 additions & 14 deletions inferno-ml-server/exe/Dummy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@

module Dummy where

import Conduit (runConduit, sinkList, (.|))
import Control.Monad.Except (ExceptT (ExceptT))
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Reader (ReaderT (runReaderT))
import Data.Aeson (encodeFile)
import Data.Int (Int64)
import Data.Map (Map)
import qualified Data.Map as Map
Expand All @@ -21,7 +18,6 @@ import Inferno.ML.Server.Module.Types (PID (PID))
import "inferno-ml-server-types" Inferno.ML.Server.Types
( BridgeAPI,
IValue (IDouble, IEmpty),
PairStream,
)
import Lens.Micro.Platform
import Network.HTTP.Types (Status)
Expand All @@ -35,7 +31,6 @@ import Network.Wai.Handler.Warp
)
import Network.Wai.Logger (withStdoutLogger)
import Servant
import System.FilePath ((<.>), (</>))
import UnliftIO.Exception (throwIO, try)

main :: IO ()
Expand Down Expand Up @@ -83,14 +78,7 @@ newtype DummyEnv = DummyEnv
deriving stock (Generic)

server :: ServerT (BridgeAPI PID Int) DummyM
server = writePairs :<|> valueAt :<|> latestValueAndTimeBefore

-- Write the pairs to a JSON file for later inspection
writePairs :: PID -> PairStream Int IO -> DummyM ()
writePairs (PID p) c = liftIO $ encodeFile path =<< runConduit (c .| sinkList)
where
path :: FilePath
path = "./" </> show p <.> "json"
server = valueAt :<|> latestValueAndTimeBefore :<|> valuesBetween

-- Dummy implementation of `valueAt`, ignoring resolution for now
valueAt :: Int64 -> PID -> Int -> DummyM IValue
Expand All @@ -99,4 +87,7 @@ valueAt _ p t =
<&> maybe IEmpty IDouble . preview (at p . _Just . at t . _Just)

latestValueAndTimeBefore :: Int -> PID -> DummyM IValue
latestValueAndTimeBefore = const . const . throwIO $ userError "Unsupported"
latestValueAndTimeBefore _ _ = throwIO $ userError "Unsupported"

valuesBetween :: Int64 -> PID -> Int -> Int -> ReaderT DummyEnv IO IValue
valuesBetween _ _ _ _ = throwIO $ userError "Unsupported"
16 changes: 9 additions & 7 deletions inferno-ml-server/inferno-ml-server.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@ name: inferno-ml-server
version: 2023.4.3
synopsis: Server for Inferno ML
description: Server for Inferno ML
homepage:
https://github.com/plow-technologies/inferno.git#readme

bug-reports:
https://github.com/plow-technologies/inferno.git/issues

homepage: https://github.com/plow-technologies/inferno.git#readme
bug-reports: https://github.com/plow-technologies/inferno.git/issues
copyright: Plow-Technologies LLC
license: MIT
author: Rory Tyler hayford
Expand All @@ -22,7 +18,7 @@ source-repository head

common common
ghc-options:
-Wall -Werror -Wincomplete-uni-patterns -Wincomplete-record-updates
-Wall -Wincomplete-uni-patterns -Wincomplete-record-updates
-Wmissing-deriving-strategies -Wno-unticked-promoted-constructors

default-language: Haskell2010
Expand Down Expand Up @@ -126,7 +122,13 @@ executable test-client
hs-source-dirs: exe
ghc-options: -threaded -rtsopts -main-is Client
build-depends:
, aeson
, base
, bytestring
, conduit
, containers
, extra
, filepath
, http-client
, inferno-ml-server-types
, iproute
Expand Down
Loading

0 comments on commit 89470e8

Please sign in to comment.