diff --git a/inferno-ml-server-types/inferno-ml-server-types.cabal b/inferno-ml-server-types/inferno-ml-server-types.cabal index db094348..b9a08557 100644 --- a/inferno-ml-server-types/inferno-ml-server-types.cabal +++ b/inferno-ml-server-types/inferno-ml-server-types.cabal @@ -17,7 +17,7 @@ source-repository head common common ghc-options: - -Wall -Werror -Wincomplete-uni-patterns -Wincomplete-record-updates + -Wall -Wincomplete-uni-patterns -Wincomplete-record-updates -Wmissing-deriving-strategies default-language: Haskell2010 diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs index 39ed5abc..532fcda1 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NoMonomorphismRestriction #-} module Inferno.ML.Server.Client ( statusC, @@ -21,7 +22,10 @@ import Servant.Client.Streaming (ClientM, client) statusC :: ClientM (Maybe ()) -- | Run an inference parameter -inferenceC :: Id (InferenceParam uid gid p s) -> Maybe Int64 -> ClientM () +inferenceC :: + Id (InferenceParam uid gid p s) -> + Maybe Int64 -> + ClientM (WriteStream IO) -- | Cancel the existing inference job, if it exists cancelC :: ClientM () @@ -39,5 +43,5 @@ statusC :<|> checkBridgeC = client api -api :: Proxy (InfernoMlServerAPI uid gid p s) +api :: Proxy (InfernoMlServerAPI uid gid p s t) api = Proxy diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Client/Bridge.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Client/Bridge.hs index e7490594..e0f565a2 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Client/Bridge.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Client/Bridge.hs @@ -12,17 +12,6 @@ import Servant ((:<|>) (..)) import Servant.Client.Streaming (ClientM, client) import Web.HttpApiData (ToHttpApiData) --- | Write a stream of @(t, Double)@ pairs to the bridge server, where @t@ will --- typically represent some time value (e.g. @EpochTime@) -writePairsC :: - ( ToJSON t, - ToHttpApiData p, - ToHttpApiData t - ) => - p -> - PairStream t IO -> - ClientM () - -- | Get the value at the given time via the bridge, for the given entity @p@ valueAtC :: ( ToJSON t, @@ -43,9 +32,13 @@ latestValueAndTimeBeforeC :: t -> p -> ClientM IValue -writePairsC - :<|> valueAtC - :<|> latestValueAndTimeBeforeC = + +-- | Get an array of values falling between the two times +valuesBetweenC :: + (ToHttpApiData p, ToHttpApiData t) => Int64 -> p -> t -> t -> ClientM IValue +valueAtC + :<|> latestValueAndTimeBeforeC + :<|> valuesBetweenC = client api api :: Proxy (BridgeAPI p t) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 258da6f7..5aa1cff3 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -81,7 +81,7 @@ import Servant QueryParam', ReqBody, Required, - StreamBody, + StreamPost, (:<|>), (:>), ) @@ -96,7 +96,7 @@ import Web.HttpApiData ) -- API type for `inferno-ml-server` -type InfernoMlServerAPI uid gid p s = +type InfernoMlServerAPI uid gid p s t = -- Check if the server is up and if any job is currently running: -- -- * `Nothing` -> The server is evaluating a script @@ -108,7 +108,7 @@ type InfernoMlServerAPI uid gid p s = :<|> "inference" :> Capture "id" (Id (InferenceParam uid gid p s)) :> QueryParam "res" Int64 - :> Post '[JSON] () + :> StreamPost NewlineFraming JSON (WriteStream IO) :<|> "inference" :> "cancel" :> Put '[JSON] () -- Register the bridge. This is an `inferno-ml-server` endpoint, not a -- bridge endpoint @@ -120,24 +120,31 @@ type InfernoMlServerAPI uid gid p s = -- by a bridge server connected to a data source, not by `inferno-ml-server` type BridgeAPI p t = "bridge" - :> "write" - :> "pairs" - :> Capture "p" p - :> StreamBody NewlineFraming JSON (PairStream t IO) - :> Post '[JSON] () + :> "value-at" + :> QueryParam' '[Required] "res" Int64 + :> QueryParam' '[Required] "p" p + :> QueryParam' '[Required] "time" t + :> Get '[JSON] IValue :<|> "bridge" - :> "value-at" - :> QueryParam' '[Required] "res" Int64 - :> QueryParam' '[Required] "p" p + :> "latest-value-and-time-before" :> QueryParam' '[Required] "time" t + :> QueryParam' '[Required] "p" p :> Get '[JSON] IValue :<|> "bridge" - :> "latest-value-and-time-before" - :> QueryParam' '[Required] "time" t + :> "values-between" + :> QueryParam' '[Required] "res" Int64 :> QueryParam' '[Required] "p" p + :> QueryParam' '[Required] "t1" t + :> QueryParam' '[Required] "t2" t :> Get '[JSON] IValue -type PairStream t m = ConduitT () (t, Double) m () +-- | Stream of writes that an ML parameter script results in. Each element +-- in the stream is a chunk (sub-list) of the original values that the +-- inference script evaluates to. For example, given the following output: +-- @[ (1, [ (100, 5.0) .. (10000, 5000.0) ]) ]@; the stream items will be: +-- @(1, [ (100, 5.0) .. (500, 2500.0) ]), (1, [ (501, 2501.0) .. (10000, 5000.0) ])@. +-- This means the same output may appear more than once in the stream +type WriteStream m = ConduitT () (Int, [(EpochTime, IValue)]) m () -- | Information for contacting a bridge server that implements the 'BridgeAPI' data BridgeInfo = BridgeInfo @@ -697,6 +704,7 @@ data IValue | ITuple (IValue, IValue) | ITime EpochTime | IEmpty + | IArray (Vector IValue) deriving stock (Show, Eq, Generic) deriving anyclass (NFData) @@ -709,10 +717,19 @@ instance FromJSON IValue where Object o -> ITime <$> o .: "time" Array a | [x, y] <- Vector.toList a -> - fmap ITuple $ - (,) <$> parseJSON x <*> parseJSON y - | Vector.null a -> pure IEmpty - _ -> fail "Expected one of: string, double, time, tuple, unit (empty array)" + (,) <$> parseJSON x <*> parseJSON y <&> \case + -- We don't want to confuse a two-element array of tuples with + -- a tuple itself. For example, `"[[10.0, {\"time\": 10}], [10.0, {\"time\": 10}]]"` + -- should parse as a two-element array of `(double, time)` tuples, + -- not as a `((double, time), (double, time))`. I can't think of + -- any reason to support the latter. An alternative would be to + -- change tuple encoding to an object, but then we would be transmitting + -- a much larger amount of data on most requests + (f@(ITuple _), s@(ITuple _)) -> IArray $ Vector.fromList [f, s] + t -> ITuple t + | otherwise -> IArray <$> traverse (parseJSON @IValue) a + Null -> pure IEmpty + _ -> fail "Expected one of: string, double, time, tuple, unit (empty array), array" instance ToJSON IValue where toJSON = \case @@ -721,7 +738,8 @@ instance ToJSON IValue where ITuple t -> toJSON t -- See `FromJSON` instance above ITime t -> object ["time" .= t] - IEmpty -> toJSON () + IEmpty -> toJSON Null + IArray is -> toJSON is -- | Used to represent inputs to the script. 'Many' allows for an array input data SingleOrMany a diff --git a/inferno-ml-server/exe/Client.hs b/inferno-ml-server/exe/Client.hs index 70f742d6..10b3c35e 100644 --- a/inferno-ml-server/exe/Client.hs +++ b/inferno-ml-server/exe/Client.hs @@ -1,12 +1,27 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeApplications #-} + -- NOTE -- This executable is only intended for testing the inference endpoint with the -- `nixosTest` (see `../../../tests/server.nix`) module Client (main) where -import Control.Monad (void) +import Conduit +import Control.Monad (unless, void) +import Data.Coerce (coerce) +import qualified Data.Conduit.List as Conduit.List +import Data.Function (on) +import Data.Int (Int64) +import qualified Data.Map as Map import Inferno.ML.Server.Client (inferenceC, registerBridgeC) -import Inferno.ML.Server.Types (BridgeInfo (BridgeInfo), Id (Id), toIPv4) +import Inferno.ML.Server.Types + ( BridgeInfo (BridgeInfo), + IValue (IDouble), + Id (Id), + WriteStream, + toIPv4, + ) import Network.HTTP.Client (defaultManagerSettings, newManager) import Servant.Client.Streaming ( mkClientEnv, @@ -15,6 +30,7 @@ import Servant.Client.Streaming withClientM, ) import System.Exit (die) +import System.Posix.Types (EpochTime) import Text.Read (readMaybe) import UnliftIO (throwString) import UnliftIO.Environment (getArgs) @@ -35,8 +51,40 @@ main = . registerBridgeC . flip BridgeInfo 9999 $ toIPv4 (127, 0, 0, 1) - -- Run the given inference param. The test scripts should use `writePairs`. - -- The dummy bridge implementation will write this to a file for later - -- inspection - withClientM (inferenceC ipid Nothing) env . either throwIO . const $ pure () + withClientM (inferenceC ipid Nothing) env . either throwIO $ + verifyWrites (coerce ipid) _ -> die "Usage: test-client " + +-- Check that the returned write stream matches the expected value +verifyWrites :: + Int64 -> + WriteStream IO -> + IO () +verifyWrites ipid c = do + expected <- getExpected + result <- rebuildWrites + unless (result == expected) . throwString . unwords $ + [ "Expected: ", + show expected, + "but got:", + show result, + "for param", + show ipid + ] + where + rebuildWrites :: IO [(Int, [(EpochTime, IValue)])] + rebuildWrites = + runConduit $ + c + .| Conduit.List.groupBy ((==) `on` fst) + .| Conduit.List.concat + .| sinkList + + getExpected :: IO [(Int, [(EpochTime, IValue)])] + getExpected = + maybe (throwString "Missing PID") pure . Map.lookup ipid $ + Map.fromList + [ (1, [(1, [(151, IDouble 2.5), (251, IDouble 3.5)])]), + (2, [(2, [(300, IDouble 25.0)])]), + (3, [(3, [(100, IDouble 7.0)])]) + ] diff --git a/inferno-ml-server/exe/Dummy.hs b/inferno-ml-server/exe/Dummy.hs index 41ebbede..e5ecd68a 100644 --- a/inferno-ml-server/exe/Dummy.hs +++ b/inferno-ml-server/exe/Dummy.hs @@ -7,11 +7,8 @@ module Dummy where -import Conduit (runConduit, sinkList, (.|)) import Control.Monad.Except (ExceptT (ExceptT)) -import Control.Monad.IO.Class (liftIO) import Control.Monad.Reader (ReaderT (runReaderT)) -import Data.Aeson (encodeFile) import Data.Int (Int64) import Data.Map (Map) import qualified Data.Map as Map @@ -21,7 +18,6 @@ import Inferno.ML.Server.Module.Types (PID (PID)) import "inferno-ml-server-types" Inferno.ML.Server.Types ( BridgeAPI, IValue (IDouble, IEmpty), - PairStream, ) import Lens.Micro.Platform import Network.HTTP.Types (Status) @@ -35,7 +31,6 @@ import Network.Wai.Handler.Warp ) import Network.Wai.Logger (withStdoutLogger) import Servant -import System.FilePath ((<.>), ()) import UnliftIO.Exception (throwIO, try) main :: IO () @@ -83,14 +78,7 @@ newtype DummyEnv = DummyEnv deriving stock (Generic) server :: ServerT (BridgeAPI PID Int) DummyM -server = writePairs :<|> valueAt :<|> latestValueAndTimeBefore - --- Write the pairs to a JSON file for later inspection -writePairs :: PID -> PairStream Int IO -> DummyM () -writePairs (PID p) c = liftIO $ encodeFile path =<< runConduit (c .| sinkList) - where - path :: FilePath - path = "./" show p <.> "json" +server = valueAt :<|> latestValueAndTimeBefore :<|> valuesBetween -- Dummy implementation of `valueAt`, ignoring resolution for now valueAt :: Int64 -> PID -> Int -> DummyM IValue @@ -99,4 +87,7 @@ valueAt _ p t = <&> maybe IEmpty IDouble . preview (at p . _Just . at t . _Just) latestValueAndTimeBefore :: Int -> PID -> DummyM IValue -latestValueAndTimeBefore = const . const . throwIO $ userError "Unsupported" +latestValueAndTimeBefore _ _ = throwIO $ userError "Unsupported" + +valuesBetween :: Int64 -> PID -> Int -> Int -> ReaderT DummyEnv IO IValue +valuesBetween _ _ _ _ = throwIO $ userError "Unsupported" diff --git a/inferno-ml-server/inferno-ml-server.cabal b/inferno-ml-server/inferno-ml-server.cabal index 8582d08f..9ec63155 100644 --- a/inferno-ml-server/inferno-ml-server.cabal +++ b/inferno-ml-server/inferno-ml-server.cabal @@ -3,12 +3,8 @@ name: inferno-ml-server version: 2023.4.3 synopsis: Server for Inferno ML description: Server for Inferno ML -homepage: - https://github.com/plow-technologies/inferno.git#readme - -bug-reports: - https://github.com/plow-technologies/inferno.git/issues - +homepage: https://github.com/plow-technologies/inferno.git#readme +bug-reports: https://github.com/plow-technologies/inferno.git/issues copyright: Plow-Technologies LLC license: MIT author: Rory Tyler hayford @@ -22,7 +18,7 @@ source-repository head common common ghc-options: - -Wall -Werror -Wincomplete-uni-patterns -Wincomplete-record-updates + -Wall -Wincomplete-uni-patterns -Wincomplete-record-updates -Wmissing-deriving-strategies -Wno-unticked-promoted-constructors default-language: Haskell2010 @@ -126,7 +122,13 @@ executable test-client hs-source-dirs: exe ghc-options: -threaded -rtsopts -main-is Client build-depends: + , aeson , base + , bytestring + , conduit + , containers + , extra + , filepath , http-client , inferno-ml-server-types , iproute diff --git a/inferno-ml-server/src/Inferno/ML/Server.hs b/inferno-ml-server/src/Inferno/ML/Server.hs index 32e17f6f..9f368cab 100644 --- a/inferno-ml-server/src/Inferno/ML/Server.hs +++ b/inferno-ml-server/src/Inferno/ML/Server.hs @@ -17,7 +17,6 @@ import qualified Data.ByteString.Lazy.Char8 as ByteString.Lazy.Char8 import Data.Proxy (Proxy (Proxy)) import Database.PostgreSQL.Simple (withConnect) import Inferno.ML.Server.Bridge -import qualified Inferno.ML.Server.Client.Bridge as Bridge import Inferno.ML.Server.Inference import Inferno.ML.Server.Log import Inferno.ML.Server.Types @@ -93,7 +92,7 @@ runInEnv cfg f = withRemoteTracer $ \tracer -> do <*> newIORef Nothing where mkBridge :: IO Bridge - mkBridge = Bridge defaultBridgeClient <$> (newIORef =<< maybeDecodeBridge) + mkBridge = fmap Bridge $ newIORef =<< maybeDecodeBridge maybeDecodeBridge :: IO (Maybe BridgeInfo) maybeDecodeBridge = @@ -101,13 +100,6 @@ runInEnv cfg f = withRemoteTracer $ \tracer -> do False -> pure Nothing True -> decodeFileStrict bridgeCache - defaultBridgeClient :: BridgeClient - defaultBridgeClient = - BridgeClient - Bridge.valueAtC - Bridge.latestValueAndTimeBeforeC - Bridge.writePairsC - infernoMlRemote :: Env -> Application infernoMlRemote env = serve api $ hoistServer api (`toHandler` env) server where @@ -128,6 +120,7 @@ infernoMlRemote env = serve api $ hoistServer api (`toHandler` env) server e@NoSuchParameter {} -> errWith err404 e e@NoSuchScript {} -> errWith err404 e e@InvalidScript {} -> errWith err400 e + e@InvalidOutput {} -> errWith err400 e e@InfernoError {} -> errWith err500 e e@BridgeNotRegistered {} -> errWith err500 e e@ScriptTimeout {} -> errWith err500 e @@ -165,5 +158,5 @@ server = =<< tryTakeMVar =<< view #job where - logAndCancel :: (Id InferenceParam, Async (Maybe ())) -> RemoteM () + logAndCancel :: (Id InferenceParam, Async (Maybe (WriteStream IO))) -> RemoteM () logAndCancel (i, j) = logTrace (CancelingInference i) *> cancel j diff --git a/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs b/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs index d7afe3d0..af79dc94 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs @@ -6,7 +6,6 @@ module Inferno.ML.Server.Bridge ) where -import Conduit (mapMC, yieldMany, (.|)) import Control.DeepSeq (NFData) import Control.Monad.Catch (throwM) import Control.Monad.IO.Class (liftIO) @@ -14,6 +13,7 @@ import Control.Monad.Reader (asks) import Data.Aeson (encodeFile) import Data.Int (Int64) import Inferno.Core (mkInferno) +import qualified Inferno.ML.Server.Client.Bridge as Bridge import Inferno.ML.Server.Module.Bridge (mkBridgeFuns) import Inferno.ML.Server.Module.Prelude (mkBridgePrelude) import Inferno.ML.Server.Types @@ -29,8 +29,6 @@ import Servant.Client.Streaming runClientM, ) import System.Posix.Types (EpochTime) -import Torch (Tensor, asValue) -import UnliftIO.Exception (throwIO) import UnliftIO.IORef (atomicWriteIORef, readIORef) -- | Save the provided 'BridgeInfo' and update the Inferno interpreter to use @@ -46,35 +44,16 @@ registerBridgeInfo bi = do =<< view #interpreter where funs :: BridgeFuns RemoteM - funs = mkBridgeFuns valueAt latestValueAndTimeBefore writePairs + funs = mkBridgeFuns valueAt latestValueAndTimeBefore valuesBetween valueAt :: Int64 -> PID -> EpochTime -> RemoteM IValue - valueAt res pid t = callBridge =<< getBridgeRoute #valueAt ?? res ?? pid ?? t + valueAt res pid = callBridge . Bridge.valueAtC res pid latestValueAndTimeBefore :: EpochTime -> PID -> RemoteM IValue - latestValueAndTimeBefore t pid = - callBridge =<< getBridgeRoute #latestValueAndTimeBefore ?? t ?? pid + latestValueAndTimeBefore t = callBridge . Bridge.latestValueAndTimeBeforeC t - -- FIXME `writePairs` will be removed soon - writePairs :: PID -> Tensor -> RemoteM () - writePairs pid t = callBridge =<< getBridgeRoute #writePairs ?? pid ?? yieldTensor - where - -- Convert the (assumed two-dimensional) tensor into a list of pairs - -- for streaming to the bridge endpoint - yieldTensor :: PairStream Int IO - yieldTensor = - yieldMany (Torch.asValue @[[Double]] t) - .| mapMC mkPair - where - mkPair :: [Double] -> IO (Int, Double) - mkPair = \case - -- Since the input elements need to be homogeneous, the time value - -- needs to be stored as a double, which then needs to be converted - -- to an integer - [time, val] -> pure (round time, val) - _ -> - throwIO $ - InvalidScript "Expecting two-dimensional tensor of time/value pairs" + valuesBetween :: Int64 -> PID -> EpochTime -> EpochTime -> RemoteM IValue + valuesBetween res pid t1 = callBridge . Bridge.valuesBetweenC res pid t1 -- | Get the previously saved 'BridgeInfo', if any getBridgeInfo :: RemoteM (Maybe BridgeInfo) @@ -101,6 +80,3 @@ callBridge c = (view (#host . to show) bi) (view (#port . to fromIntegral) bi) mempty - -getBridgeRoute :: Lens' BridgeClient (a -> b) -> RemoteM (a -> b) -getBridgeRoute l = view $ #bridge . #client . l diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index a67afa7c..c27ca6d0 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -11,12 +11,14 @@ module Inferno.ML.Server.Inference ) where +import Conduit (ConduitT, awaitForever, mapC, yieldMany, (.|)) import Control.Monad (void, when, (<=<)) import Control.Monad.Catch (throwM) import Control.Monad.Extra (loopM, unlessM, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.ListM (sortByM) import Data.Bifoldable (bitraverse_) +import Data.Conduit.List (chunksOf, sourceList) import Data.Foldable (foldl') import Data.Generics.Wrapped (wrappedTo) import Data.Int (Int64) @@ -25,6 +27,7 @@ import qualified Data.Map as Map import Data.Maybe (fromMaybe) import qualified Data.Text as Text import Data.Time.Clock.POSIX (getPOSIXTime) +import Data.Traversable (for) import qualified Data.Vector as Vector import Database.PostgreSQL.Simple ( Only (Only), @@ -46,15 +49,18 @@ import Inferno.Types.Syntax Scoped (LocalScope), ) import Inferno.Types.Value - ( Value (VArray, VCustom, VEpochTime), + ( ImplEnvM, + Value (VArray, VCustom, VEpochTime), runImplEnvM, ) import Inferno.Types.VersionControl (VCObjectHash) +import Inferno.Utils.Prettyprinter (renderPretty) import Inferno.VersionControl.Types ( VCObject (VCFunction), ) import Lens.Micro.Platform import System.FilePath (dropExtensions, (<.>)) +import System.Posix.Types (EpochTime) import UnliftIO.Async (wait, withAsync) import UnliftIO.Directory ( createFileLink, @@ -80,7 +86,7 @@ import UnliftIO.Timeout (timeout) runInferenceParam :: Id InferenceParam -> Maybe Int64 -> - RemoteM () + RemoteM (WriteStream IO) -- FIXME / TODO Deal with default resolution, probably shouldn't need to be -- passed on all requests runInferenceParam ipid (fromMaybe 128 -> res) = @@ -93,7 +99,7 @@ runInferenceParam ipid (fromMaybe 128 -> res) = -- Runs the inference parameter in a separate `Async` thread. The `Async` -- is stored in the server environment so it can be canceled at any point -- before the script finishes evaluating - run :: Int -> RemoteM (Maybe ()) + run :: Int -> RemoteM (Maybe (WriteStream IO)) run tmo = withAsync (runInference tmo) $ \a -> bracket_ ((`putMVar` (ipid, a)) =<< view #job) @@ -101,7 +107,7 @@ runInferenceParam ipid (fromMaybe 128 -> res) = $ wait a -- Actually runs the inference evaluation, within the configured timeout - runInference :: Int -> RemoteM (Maybe ()) + runInference :: Int -> RemoteM (Maybe (WriteStream IO)) runInference tmo = timeout tmo $ do view #interpreter >>= readIORef >>= \case Nothing -> throwM BridgeNotRegistered @@ -124,7 +130,7 @@ runInferenceParam ipid (fromMaybe 128 -> res) = InferenceParam -> CTime -> VCMeta VCObject -> - RemoteM () + RemoteM (WriteStream IO) runEval Interpreter {evalExpr, mkEnvFromClosure} param t vcm = vcm ^. #obj & \case VCFunction {} -> do @@ -180,20 +186,32 @@ runInferenceParam ipid (fromMaybe 128 -> res) = doEval :: Expr (Maybe VCObjectHash) () -> BridgeTermEnv RemoteM -> - RemoteM () + RemoteM (WriteStream IO) doEval x env = - -- FIXME / TODO Do we want to just void any results here? - -- Users should use the bridge API to write to a PID, so - -- the script shouldn't evaluate to anything. Should we - -- check that the script evaluates to `unit`? Or we might - -- want to retain the return value of the script (provided - -- it evaluates to a limited subset of Inferno types) and - -- display it in the UI? either (throwInfernoError . Left . SomeInfernoError) - (const (pure ())) + yieldPairs =<< evalExpr env implEnv x + yieldPairs :: + Value BridgeMlValue (ImplEnvM RemoteM BridgeMlValue) -> + RemoteM (WriteStream IO) + yieldPairs = \case + VArray vs -> + fmap ((.| mkChunks) . yieldMany) . for vs $ \case + VCustom (VExtended (VWrite vw)) -> pure vw + v -> throwM . InvalidOutput $ renderPretty v + v -> throwM . InvalidOutput $ renderPretty v + where + mkChunks :: + ConduitT + (PID, [(EpochTime, IValue)]) + (Int, [(EpochTime, IValue)]) + IO + () + mkChunks = awaitForever $ \(p, ws) -> + sourceList ws .| chunksOf 500 .| mapC (wrappedTo p,) + implEnv :: Map ExtIdent (Value BridgeMlValue m) implEnv = Map.fromList diff --git a/inferno-ml-server/src/Inferno/ML/Server/Module/Bridge.hs b/inferno-ml-server/src/Inferno/ML/Server/Module/Bridge.hs index d72e5392..e0364cc0 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Module/Bridge.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Module/Bridge.hs @@ -6,19 +6,15 @@ module Inferno.ML.Server.Module.Bridge where import Control.Category ((>>>)) -import Control.Monad.Catch (MonadThrow (throwM), handle) import Data.Int (Int64) -import qualified Data.Text as Text -import Inferno.Eval.Error (EvalError (RuntimeError)) import Inferno.ML.Server.Types import Inferno.Module.Cast (ToValue (toValue)) import Inferno.Types.Value ( ImplicitCast (ImplicitCast), - Value (VDouble, VOne, VTuple), + Value (..), liftImplEnvM, ) import System.Posix.Types (EpochTime) -import Torch (Tensor) -- | Create the functions that will be used for the Inferno primitives related -- to the data source. Effects defined in @RemoteM@ are wrapped in @ImplEnvM m ...@ @@ -27,17 +23,15 @@ mkBridgeFuns :: (Int64 -> PID -> EpochTime -> RemoteM IValue) -> -- | @latestValueAndTimeBefore@ (EpochTime -> PID -> RemoteM IValue) -> - -- | @writePairs@ - -- - -- FIXME `writePairs` will be removed soon - (PID -> Tensor -> RemoteM ()) -> + -- | @valuesBetween@ + (Int64 -> PID -> EpochTime -> EpochTime -> RemoteM IValue) -> BridgeFuns RemoteM -mkBridgeFuns valueAt latestValueAndTimeBefore writePairs = +mkBridgeFuns valueAt latestValueAndTimeBefore valuesBetween = BridgeFuns valueAtFun latestValueAndTimeBeforeFun latestValueAndTimeFun - writePairsFun + valuesBetweenFun where valueAtFun :: BridgeV RemoteM valueAtFun = toValue $ ImplicitCast @"resolution" inputFunction @@ -86,21 +80,18 @@ mkBridgeFuns valueAt latestValueAndTimeBefore writePairs = t@VTuple {} -> VOne t v -> v - -- FIXME `writePairs` will be removed soon - writePairsFun :: BridgeV RemoteM - writePairsFun = toValue writePairsFunction + valuesBetweenFun :: BridgeV RemoteM + valuesBetweenFun = + toValue $ + ImplicitCast @"resolution" valuesBetweenFunction where - writePairsFunction :: PID -> Tensor -> BridgeImplM RemoteM - writePairsFunction p = - handle raiseRuntime - . liftImplEnvM - . fmap (const (VTuple mempty)) - . writePairs p - - raiseRuntime :: RemoteError -> BridgeImplM RemoteM - raiseRuntime = \case - -- If the tensor provided as an argument isn't the correct shape, this - -- will be raised, so the user should be informed clearly (it will be - -- thrown as an `InvalidScript` internally) - InvalidScript t -> throwM . RuntimeError $ Text.unpack t - e -> throwM e + valuesBetweenFunction :: + InverseResolution -> + PID -> + EpochTime -> + EpochTime -> + BridgeImplM RemoteM + valuesBetweenFunction r pid t1 = + liftImplEnvM + . fmap fromIValue + . valuesBetween (resolutionToInt r) pid t1 diff --git a/inferno-ml-server/src/Inferno/ML/Server/Module/Prelude.hs b/inferno-ml-server/src/Inferno/ML/Server/Module/Prelude.hs index e6a97b81..d6a9c013 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Module/Prelude.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Module/Prelude.hs @@ -11,6 +11,7 @@ where import Control.Monad.Catch (MonadCatch, MonadThrow (throwM)) import Control.Monad.IO.Class (MonadIO) +import Data.Foldable (foldrM) import Data.Int (Int64) import qualified Data.IntMap as IntMap import qualified Data.Map as Map @@ -19,16 +20,18 @@ import Foreign.C (CTime (CTime)) import Inferno.Eval.Error (EvalError (RuntimeError)) import Inferno.ML.Module.Prelude (mlPrelude) import Inferno.ML.Server.Module.Types +import Inferno.ML.Server.Types (IValue) import Inferno.ML.Types.Value (MlValue (VExtended), mlQuoter) import Inferno.Module.Cast import Inferno.Module.Prelude (ModuleMap) import Inferno.Types.Syntax (ExtIdent (ExtIdent)) import Inferno.Types.Value ( ImplEnvM, - Value (VCustom, VEmpty, VEpochTime, VFun), + Value (VArray, VCustom, VEmpty, VEpochTime, VFun, VTuple), ) import Inferno.Types.VersionControl (VCObjectHash) import Lens.Micro.Platform +import System.Posix.Types (EpochTime) -- | Contains primitives for use in bridge prelude, including those to read\/write -- data @@ -37,8 +40,6 @@ import Lens.Micro.Platform -- open-source users will presumably not find these useful. Unfortunately, the -- Inferno interpreter used by the server needs to be initialized with these -- primitives --- --- FIXME `writePairs` will be removed soon bridgeModules :: forall m. (MonadThrow m, MonadIO m) => @@ -49,16 +50,14 @@ bridgeModules valueAt latestValueAndTimeBefore latestValueAndTime - writePairsFun + valuesBetween ) = [mlQuoter| module DataSource - @doc Write the value of a tensor to the parameter `p`. Note that the tensor - MUST be two-dimensional, assumed to contain a series of pairs representing - times (the first element) and values (the second element). A runtime error - will be raised if this condition is not satisfied. The input tensor must - must be of type `Double`; - writePairs : forall 'a. series of 'a -> tensor -> () := ###!writePairsFun###; + @doc Create a `write` object encapsulating an array of `(time, 'a)` values to be + written to a given parameter. All ML scripts must return an array of such `write` + objects, potentially empty, and this is the only way for them to write values to parameters.; + makeWrites : forall 'a. series of 'a -> array of (time, 'a) -> write := ###!makeWriteFun###; toResolution : int -> resolution := ###toResolution###; @@ -138,6 +137,11 @@ module DataSource -> time -> option of 'a := ###!valueAtOrAdjacent###; + @doc Returns all values between two times, using the implicit resolution. + + If the resolution is set to 1, this returns all the events (actual values, not approximations) in the given time window.; + valuesBetween : forall 'a. { implicit resolution : resolution } + => series of 'a -> time -> time -> array of ('a, time) := ###!valuesBetween###; |] where valueAtOrAdjacent :: BridgeV m @@ -179,6 +183,28 @@ module DataSource e :: EvalError e = RuntimeError "valueAt: expected a function" + makeWriteFun :: BridgeV m + makeWriteFun = + VFun $ \case + VCustom (VExtended (VSeries pid)) -> + pure . VFun $ \case + VArray vs -> + VCustom . VExtended . VWrite . (pid,) <$> extractPairs vs + _ -> throwM $ RuntimeError "makeWrite: expecting an array" + _ -> throwM $ RuntimeError "makeWrite: expecting a pid" + where + extractPairs :: + [Value c n] -> + ImplEnvM m BridgeMlValue [(EpochTime, IValue)] + extractPairs = flip foldrM mempty $ \v acc -> (: acc) <$> extractPair v + + extractPair :: + Value c n -> + ImplEnvM m BridgeMlValue (EpochTime, IValue) + extractPair = \case + VTuple [VEpochTime t, x] -> (t,) <$> toIValue x + _ -> throwM $ RuntimeError "extractPair: expected a tuple (time, 'a)" + mkBridgePrelude :: forall m. ( MonadIO m, diff --git a/inferno-ml-server/src/Inferno/ML/Server/Module/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Module/Types.hs index 6561df4f..3409e987 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Module/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Module/Types.hs @@ -7,17 +7,20 @@ module Inferno.ML.Server.Module.Types where import Control.DeepSeq (NFData) +import Control.Monad.Catch (MonadThrow (throwM)) import Data.Aeson (FromJSON, ToJSON) import Data.Bits (countLeadingZeros) import Data.Generics.Labels () import Data.Int (Int64) +import qualified Data.Vector as Vector import Data.Word (Word8) import Database.PostgreSQL.Simple.FromField (FromField) import Database.PostgreSQL.Simple.ToField (ToField) import GHC.Generics (Generic) import Inferno.Eval (TermEnv) +import Inferno.Eval.Error (EvalError (RuntimeError)) import "inferno-ml-server-types" Inferno.ML.Server.Types - ( IValue (IDouble, IEmpty, IText, ITime, ITuple), + ( IValue (IArray, IDouble, IEmpty, IText, ITime, ITuple), ) import Inferno.ML.Types.Value (MlValue (VExtended)) import Inferno.Module.Cast @@ -27,27 +30,31 @@ import Inferno.Module.Cast ) import Inferno.Types.Value ( ImplEnvM, - Value (VCustom, VDouble, VEmpty, VEpochTime, VText, VTuple), + Value (VArray, VCustom, VDouble, VEmpty, VEpochTime, VText, VTuple), ) import Inferno.Types.VersionControl (VCObjectHash) import Prettyprinter (Pretty (pretty), cat, (<+>)) +import System.Posix.Types (EpochTime) import Web.HttpApiData (FromHttpApiData, ToHttpApiData) -- | Custom type for bridge prelude data BridgeValue = VResolution InverseResolution | VSeries PID + | VWrite (PID, [(EpochTime, IValue)]) deriving stock (Generic) instance Eq BridgeValue where VResolution r1 == VResolution r2 = r1 == r2 VSeries v1 == VSeries v2 = v1 == v2 + VWrite w1 == VWrite w2 = w1 == w2 _ == _ = False instance Pretty BridgeValue where pretty = \case VSeries p -> cat ["<<", "series" <+> pretty p, ">>"] VResolution e -> pretty @Int $ 2 ^ e + VWrite (p, vs) -> "Write" <+> pretty p <> ":" <+> pretty (show vs) -- | Unique ID for pollable data point (for the data source that can be -- queried using the bridge) @@ -98,8 +105,7 @@ data BridgeFuns m = BridgeFuns { valueAt :: BridgeV m, latestValueAndTimeBefore :: BridgeV m, latestValueAndTime :: BridgeV m, - -- FIXME `writePairs` will be removed soon - writePairsFun :: BridgeV m + valuesBetween :: BridgeV m } deriving stock (Generic) @@ -110,6 +116,17 @@ fromIValue = \case ITime t -> VEpochTime t ITuple (x, y) -> VTuple [fromIValue x, fromIValue y] IEmpty -> VEmpty + IArray v -> VArray $ Vector.toList $ fromIValue <$> v + +toIValue :: MonadThrow f => Value custom m -> f IValue +toIValue = \case + VText t -> pure $ IText t + VDouble d -> pure $ IDouble d + VEpochTime t -> pure $ ITime t + VTuple [x, y] -> curry ITuple <$> toIValue x <*> toIValue y + VEmpty -> pure IEmpty + VArray vs -> IArray . Vector.fromList <$> traverse toIValue vs + _ -> throwM $ RuntimeError "toIValue: got an unsupported value type" toResolution :: Int64 -> InverseResolution toResolution = diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index c8869a60..9eb7ef5f 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -87,7 +87,7 @@ import Network.HTTP.Client (Manager) import Numeric (readHex) import qualified Options.Applicative as Options import Plow.Logging (IOTracer, traceWith) -import Servant.Client.Streaming (ClientError, ClientM) +import Servant.Client.Streaming (ClientError) import System.Posix.Types (EpochTime) import Text.Read (readMaybe) import UnliftIO (Async) @@ -109,7 +109,7 @@ data Env = Env ( -- ID for the inference param Id InferenceParam, -- The actual job itself. This is stored so it can be canceled later - Async (Maybe ()) + Async (Maybe (Types.WriteStream IO)) ), bridge :: Bridge, manager :: Manager, @@ -122,19 +122,8 @@ data Env = Env -- | A bridge (host and port) of a server that can proxy a data source. This -- can be set using @POST /bridge@. It's not included directly into the NixOS -- image because then it would be difficult to change --- --- The client are @ClientM@ effects to perform specific queries, e.g. --- @valueAt@, which requires connecting to a data source -data Bridge = Bridge - { client :: BridgeClient, - info :: IORef (Maybe BridgeInfo) - } - deriving stock (Generic) - -data BridgeClient = BridgeClient - { valueAt :: Int64 -> PID -> EpochTime -> ClientM IValue, - latestValueAndTimeBefore :: EpochTime -> PID -> ClientM IValue, - writePairs :: PID -> PairStream Int IO -> ClientM () +newtype Bridge = Bridge + { info :: IORef (Maybe BridgeInfo) } deriving stock (Generic) @@ -278,6 +267,7 @@ data RemoteError | NoSuchScript VCObjectHash | NoSuchParameter Int64 | InvalidScript Text + | InvalidOutput Text | -- | Any error condition returned by Inferno script evaluation InfernoError SomeInfernoError | BridgeNotRegistered @@ -304,6 +294,11 @@ instance Exception RemoteError where NoSuchParameter iid -> unwords ["Parameter:", "'" <> show iid <> "'", "does not exist"] InvalidScript t -> Text.unpack t + InvalidOutput t -> + unwords + [ "Script output should be an array of `write` but was", + Text.unpack t + ] InfernoError (SomeInfernoError x) -> unwords [ "Inferno evaluation failed with:", @@ -395,7 +390,7 @@ pattern VCMeta t a g n d p v o = Inferno.VersionControl.Types.VCMeta t a g n d p v o type InfernoMlServerAPI = - Types.InfernoMlServerAPI (EntityId UId) (EntityId GId) PID VCObjectHash + Types.InfernoMlServerAPI (EntityId UId) (EntityId GId) PID VCObjectHash EpochTime -- Orphans diff --git a/inferno-ml/src/Inferno/ML/Types/Value.hs b/inferno-ml/src/Inferno/ML/Types/Value.hs index cbde1179..a8554ca7 100644 --- a/inferno-ml/src/Inferno/ML/Types/Value.hs +++ b/inferno-ml/src/Inferno/ML/Types/Value.hs @@ -40,7 +40,7 @@ instance Pretty x => FromValue (MlValue x) m T.ScriptModule where fromValue v = couldNotCast v customTypes :: [CustomType] -customTypes = ["tensor", "model"] +customTypes = ["tensor", "model", "write"] mlQuoter :: QuasiQuoter mlQuoter = moduleQuoter customTypes diff --git a/nix/inferno-ml/tests/scripts/contrived.inferno b/nix/inferno-ml/tests/scripts/contrived.inferno index 954e9e6e..d7b2d2e8 100644 --- a/nix/inferno-ml/tests/scripts/contrived.inferno +++ b/nix/inferno-ml/tests/scripts/contrived.inferno @@ -1,5 +1,5 @@ fun input0 -> -let t = Time.toTime (Time.seconds 200) -in let ?resolution = (toResolution 128) -in let v = valueAt input0 t ? 0.0 -in writePairs input0 (ML.asTensor2 ML.#double [[300.0, v + 5.0]]) + let t = Time.toTime (Time.seconds 200) in + let ?resolution = (toResolution 128) in + let v = valueAt input0 t ? 0.0 in + [makeWrites input0 [(Time.toTime (Time.seconds 300), v + 5.0)]] diff --git a/nix/inferno-ml/tests/scripts/mnist.inferno b/nix/inferno-ml/tests/scripts/mnist.inferno index 28a4f64f..787ba93d 100644 --- a/nix/inferno-ml/tests/scripts/mnist.inferno +++ b/nix/inferno-ml/tests/scripts/mnist.inferno @@ -1,39 +1,40 @@ fun input0 -> -let input = ML.asTensor4 ML.#float [[[ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -]]] in -let model = ML.loadModel "mnist.ts.pt" in -match ML.forward model [input] with { - | [scores] -> - let m = ML.toType ML.#double (ML.argmax 1 #false scores) - in let t = ML.asTensor1 ML.#double [100] - in writePairs input0 (ML.stack 1 [t, m]) - | _ -> writePairs input0 (ML.asTensor2 ML.#double [[100,0, 0.0]]) -} + let input = ML.asTensor4 ML.#float [[[ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ]]] in + let t = Time.toTime (Time.seconds 100) in + let model = ML.loadModel "mnist.ts.pt" in + match ML.forward model [input] with { + | [scores] -> + let m = ML.toType ML.#double (ML.argmax 1 #false scores) in + [makeWrites input0 [(t, ML.asDouble m)]] + | _ -> + [makeWrites input0 [(t, -1.0)]] + } diff --git a/nix/inferno-ml/tests/scripts/ones.inferno b/nix/inferno-ml/tests/scripts/ones.inferno index 9816eb6d..8bc22800 100644 --- a/nix/inferno-ml/tests/scripts/ones.inferno +++ b/nix/inferno-ml/tests/scripts/ones.inferno @@ -1,7 +1,8 @@ fun input0 -> -let mkV = fun t -> valueAt input0 (Time.toTime (Time.seconds t)) ? 0.0 -in let v1 = mkV 150 -in let v2 = mkV 250 -in let xs = ML.ones ML.#double [1, 2] -in let vs = ML.asTensor2 ML.#double [[150.0, v1], [250.0, v2]] -in writePairs input0 (ML.add xs vs) + let mkV = fun t -> valueAt input0 (Time.toTime (Time.seconds t)) ? 0.0 in + let ts = [150, 250] in + let vs = Array.map mkV ts in + let xs = ML.ones ML.#double [2] in + let ts1 = Array.map (fun t -> Time.toTime (Time.seconds (t + 1))) ts in + let vs1 = ML.asArray1 (ML.add xs (ML.asTensor1 ML.#double vs)) in + [makeWrites input0 (zip ts1 vs1)] diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index 4bd892dc..955af67d 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -194,19 +194,12 @@ pkgs.nixosTest { import json import time - def runtest(param, ex): - node.succeed(f'run-inference-client-test {param}') - # Load the JSON file written to by the dummy server (as its implementation - # of the `writePairs` bridge endpoint) and confirm that the results are - # correct + def runtest(param): + # Runs an test for an individual param using the client executable, + # which confirms that the results are correct # # Note: The inference param DB ID and the associated PID are the same number - - # Give the dummy bridge a second to write the file, just to be sure - time.sleep(1) - res = json.loads(node.succeed(f'cat /tmp/dummy/{param}.json')) - print(f'Inference param {param} should write values {ex}') - assert res == ex + node.succeed(f'run-inference-client-test {param}') node.wait_for_unit("multi-user.target") node.wait_for_unit("postgresql.service") @@ -228,12 +221,12 @@ pkgs.nixosTest { node.succeed('register-bridge') # `tests/scripts/ones.inferno` - runtest(1, [[151, 2.5], [251, 3.5]]) + runtest(1) # `tests/scripts/contrived.inferno` - runtest(2, [[300, 25.0]]) + runtest(2) # `tests/scripts/mnist.inferno` - runtest(3, [[100, 7.0]]) + runtest(3) ''; }