From 52f28b63b03cf8224f183b1a4ca38c957b41a1e7 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford <52039264+ngua@users.noreply.github.com> Date: Thu, 30 May 2024 15:01:24 +0700 Subject: [PATCH] [inferno-ml] Change input/output representation (#122) Changes the internal representation of `InferenceParam` inputs/outputs. Instead of having two separate `inputs` and `outputs` as vectors, they are now both in a single field and annotated with a `ScriptInputType`, which controls readability/writability of the input. Also, instead of being a `Vector`, they are now explicitly mapped to Inferno `Ident`s The rationale for the merging `inputs`/`outputs`: - We need to provide outputs as well as inputs as arguments during script evaluation. If we don't do this, the identifiers used in `makeWrites` will not resolve to anything - We can't just concatenate the old `inputs` and `outputs` fields, because users may wish to have identically named inputs and outputs; for example, they may want to both read from and write to `input0`. This would lead to the following during script eval: ``` fun input0 input0 -> ... ``` - Accordingly, we need some way of allowing identically named identifiers to be both inputs and outputs. Combining the two argument types and storing if they are readable, writable, or both solves this The rationale for storing the Inferno identifiers: - When scripts are created (e.g. via an Inferno LSP server), the only thing that is provided is the list of Inferno `Ident`s - These `[Ident]` need to correspond exactly to the `inputs` that the `InferenceParam` contains - Even if we sort this list of identifiers, we need to make sure that the order of the `InferenceParam`'s inputs is exactly the same as the order of the original `[Ident]`. This could lead to bug-prone assumptions or workarounds, since we would't have the original `[Ident]` in the param - To make sure that we always know which order is correct, we can store the `Ident`s along with the actual inputs in a `Map` and then use `Map.toAscList` to get the correct order for script arguments Also I fixed the JSON encoding/decoding so that NaNs can be transmitted to/from `inferno-ml-server`. Previously, `[null, 0.0, 0.0]` would parse to `[IEmpty, IDouble 0.0, IDouble 0.0]`, which is of course incorrect. Now it correctly parses to `[IDouble NaN, IDouble 0.0, IDouble 0.0]` --- .github/workflows/build.yml | 1 + inferno-ml-server-types/CHANGELOG.md | 3 + .../inferno-ml-server-types.cabal | 2 +- .../src/Inferno/ML/Server/Types.hs | 112 ++++++++++++++---- inferno-ml-server/CHANGELOG.md | 3 + inferno-ml-server/exe/Dummy.hs | 6 +- inferno-ml-server/exe/ParseAndSave.hs | 43 +++---- inferno-ml-server/inferno-ml-server.cabal | 6 +- .../src/Inferno/ML/Server/Inference.hs | 42 +++---- .../src/Inferno/ML/Server/Types.hs | 10 +- inferno-ml-server/{exe => test}/Client.hs | 37 +++--- .../migrations/v1-create-tables.sql | 4 +- nix/inferno-ml/tests/scripts/mnist.inferno | 6 +- nix/inferno-ml/tests/server.nix | 22 ++-- 14 files changed, 190 insertions(+), 107 deletions(-) rename inferno-ml-server/{exe => test}/Client.hs (77%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0d702f2d..a877a153 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,6 +36,7 @@ jobs: with: install_url: https://releases.nixos.org/nix/nix-2.13.3/install extra_nix_config: | + fallback = true substituters = https://cache.nixos.org https://cache.iog.io trusted-public-keys = cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY= hydra.iohk.io:f/Ea+s+dFdN+3Y/G+FDgSq+a5NEWhJGzdjvKNGv0/EQ= narinfo-cache-negative-ttl = 60 diff --git a/inferno-ml-server-types/CHANGELOG.md b/inferno-ml-server-types/CHANGELOG.md index d7993081..6450fca8 100644 --- a/inferno-ml-server-types/CHANGELOG.md +++ b/inferno-ml-server-types/CHANGELOG.md @@ -1,6 +1,9 @@ # Revision History for inferno-ml-server-types *Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH) +## 0.4.0 +* Change representation of script inputs/outputs + ## 0.3.0 * Add support for tracking evaluation info diff --git a/inferno-ml-server-types/inferno-ml-server-types.cabal b/inferno-ml-server-types/inferno-ml-server-types.cabal index 6817b0dc..76888089 100644 --- a/inferno-ml-server-types/inferno-ml-server-types.cabal +++ b/inferno-ml-server-types/inferno-ml-server-types.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server-types -version: 0.3.0 +version: 0.4.0 synopsis: Types for Inferno ML server description: Types for Inferno ML server homepage: https://github.com/plow-technologies/inferno.git#readme diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index e098c65a..447057ad 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -17,7 +17,7 @@ import Conduit (ConduitT) import Control.Applicative (asum, optional) import Control.Category ((>>>)) import Control.DeepSeq (NFData (rnf), rwhnf) -import Control.Monad (void) +import Control.Monad (void, (<=<)) import Data.Aeson import Data.Aeson.Types (Parser) import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec @@ -64,6 +64,7 @@ import Database.PostgreSQL.Simple.Types ) import Foreign.C (CUInt (CUInt)) import GHC.Generics (Generic) +import Inferno.Types.Syntax (Ident) import Inferno.Types.VersionControl ( VCObjectHash, byteStringToVCObjectHash, @@ -172,8 +173,9 @@ newtype Id a = Id Int64 -- | Row for the table containing inference script closures data InferenceScript uid gid = InferenceScript - { -- NOTE: This is the ID for each row, stored as a `bytea` (bytes of the hash) + { -- | This is the ID for each row, stored as a @bytea@ (bytes of the hash) hash :: VCObjectHash, + -- | Script closure obj :: VCMeta uid gid VCObject } deriving stock (Show, Eq, Generic) @@ -237,6 +239,8 @@ data Model uid gid = Model -- | The user who owns the model, if any. Note that owning a model -- will implicitly set permissions user :: Maybe uid, + -- | The time that this model was \"deleted\", if any. For active models, + -- this will be @Nothing@ terminated :: Maybe UTCTime } deriving stock (Show, Eq, Generic) @@ -341,6 +345,8 @@ data ModelVersion uid gid c = ModelVersion -- the PSQL large object table contents :: c, version :: Version, + -- | The time that this model version was \"deleted\", if any. For active + -- models versions, this will be @Nothing@ terminated :: Maybe UTCTime } deriving stock (Show, Eq, Generic) @@ -592,13 +598,24 @@ data InferenceParam uid gid p s = InferenceParam -- (e.g. a UUID for use with @inferno-lsp@) -- -- For existing inference params, this is the foreign key for the specific - -- script in the 'InferenceScript' table + -- script in the 'InferenceScript' table (i.e. a @VCObjectHash@) script :: s, -- | This needs to be linked to a specific version of a model rather -- than the @model@ table itself model :: Id (ModelVersion uid gid Oid), - inputs :: Vector (SingleOrMany p), - outputs :: Vector (SingleOrMany p), + -- | This is called @inputs@ but is also used for script outputs as + -- well. The access (input or output) is controlled by the 'ScriptInputType'. + -- For example, if this field is set to @[("input0", Single (p, Readable))]@, + -- the script will only have a single read-only input and will not be able to + -- write anywhere (note that we should disallow this scenario, as script + -- evaluation would not work properly) + -- + -- Mapping the input\/output to the Inferno identifier helps ensure that + -- Inferno identifiers are always pointing to the correct input\/output; + -- otherwise we would need to rely on the order of the original identifiers + inputs :: Map Ident (SingleOrMany p, ScriptInputType), + -- | The time that this parameter was \"deleted\", if any. For active + -- parameters, this will be @Nothing@ terminated :: Maybe UTCTime, user :: uid } @@ -619,10 +636,6 @@ instance <$> field <*> fmap wrappedTo (field @VCObjectHashRow) <*> field - -- HACK / FIXME This is a pretty awful hack (storing as `jsonb`), - -- but Postgres sub-arrays need to be the same length and writing - -- a custom parser might be painful - <*> fmap getAeson field <*> fmap getAeson field <*> field <*> field @@ -638,13 +651,41 @@ instance [ toField Default, ip ^. the @"script" & VCObjectHashRow & toField, ip ^. the @"model" & toField, - -- HACK / FIXME See above ip ^. the @"inputs" & Aeson & toField, - ip ^. the @"outputs" & Aeson & toField, toField Default, ip ^. the @"user" & toField ] +-- | Controls input interaction within a script, i.e. ability to read from +-- and\/or write to this input. Although the term \"input\" is used, those with +-- writes enabled can also be described as \"outputs\" +data ScriptInputType + = -- | Script input can be read, but not written + Readable + | -- | Script input can be written, i.e. can be used in array of + -- write objects returned from script evaluation + Writable + | -- | Script input can be both read from and written to; this allows + -- the same script identifier to point to the same PID with both + -- types of access enabled + ReadableWritable + deriving stock (Show, Eq, Generic) + deriving anyclass (NFData) + +instance FromJSON ScriptInputType where + parseJSON = withText "ScriptInputType" $ \case + "r" -> pure Readable + "w" -> pure Writable + "rw" -> pure ReadableWritable + s -> fail $ "Invalid script input type: " <> Text.unpack s + +instance ToJSON ScriptInputType where + toJSON = + String . \case + Readable -> "r" + Writable -> "w" + ReadableWritable -> "rw" + -- | Information about execution time and resource usage. This is saved by -- @inferno-ml-server@ after script evaluation completes and can be queried -- later by using the same job identifier that was provided to the @/inference@ @@ -762,22 +803,42 @@ instance FromJSON IValue where Number n -> pure . IDouble $ toRealFloat n -- It's easier to just mark the time explicitly in an object, -- rather than try to deal with distinguishing times and doubles - Object o -> ITime <$> o .: "time" + Object o -> + asum + [ ITime <$> o .: "time", + fmap IArray $ arrayP =<< o .: "array" + ] + -- Note that this preserves a plain JSON array for tuples. But we need + -- some straightforward way of distinguishing tuples and arrays; since + -- the bridge often transmits a large number of individual tuples (times + -- and values), it's better to use arrays for the tuples and a tagged object + -- for arrays themselves; we often will only deal with one large array, and + -- adding a few bytes to this is better than adding a few bytes to thousands + -- of encoded tuples Array a | [x, y] <- Vector.toList a -> - (,) <$> parseJSON x <*> parseJSON y <&> \case - -- We don't want to confuse a two-element array of tuples with - -- a tuple itself. For example, `"[[10.0, {\"time\": 10}], [10.0, {\"time\": 10}]]"` - -- should parse as a two-element array of `(double, time)` tuples, - -- not as a `((double, time), (double, time))`. I can't think of - -- any reason to support the latter. An alternative would be to - -- change tuple encoding to an object, but then we would be transmitting - -- a much larger amount of data on most requests - (f@(ITuple _), s@(ITuple _)) -> IArray $ Vector.fromList [f, s] - t -> ITuple t - | otherwise -> IArray <$> traverse (parseJSON @IValue) a + fmap ITuple $ (,) <$> parseJSON x <*> parseJSON y + | otherwise -> fail "Only two-element tuples are supported" Null -> pure IEmpty - _ -> fail "Expected one of: string, double, time, tuple, unit (empty array), array" + _ -> fail "Expected one of: string, double, time, tuple, null, array" + where + arrayP :: Vector Value -> Parser (Vector IValue) + arrayP a = + -- This is a bit tedious, but we want to make sure that the array elements + -- are homogeneous; parsing all elements to `IValue`s first can't guarantee + -- this + asum + [ -- This alternative means that `null` will be correctly parsed to NaN + -- when inside an array of doubles + fmap IDouble <$> traverse parseJSON a, + fmap ITuple <$> traverse parseJSON a, + fmap IText <$> traverse parseJSON a, + fmap ITime <$> traverse (withObject "EpochTime" (.: "time")) a, + -- Nested array support + fmap IArray + <$> traverse (withObject "IArray" (arrayP <=< (.: "array"))) a, + fail "Expected a heterogeneous array" + ] instance ToJSON IValue where toJSON = \case @@ -786,8 +847,9 @@ instance ToJSON IValue where ITuple t -> toJSON t -- See `FromJSON` instance above ITime t -> object ["time" .= t] + -- See `FromJSON` instance above + IArray is -> object ["array" .= is] IEmpty -> toJSON Null - IArray is -> toJSON is -- | Used to represent inputs to the script. 'Many' allows for an array input data SingleOrMany a diff --git a/inferno-ml-server/CHANGELOG.md b/inferno-ml-server/CHANGELOG.md index f692cbf9..34e71fe0 100644 --- a/inferno-ml-server/CHANGELOG.md +++ b/inferno-ml-server/CHANGELOG.md @@ -1,3 +1,6 @@ +## 2023.5.29 +* Change representation of script inputs/outputs + ## 2023.5.22 * Add support for tracking evaluation info diff --git a/inferno-ml-server/exe/Dummy.hs b/inferno-ml-server/exe/Dummy.hs index e5ecd68a..68e1043b 100644 --- a/inferno-ml-server/exe/Dummy.hs +++ b/inferno-ml-server/exe/Dummy.hs @@ -87,7 +87,9 @@ valueAt _ p t = <&> maybe IEmpty IDouble . preview (at p . _Just . at t . _Just) latestValueAndTimeBefore :: Int -> PID -> DummyM IValue -latestValueAndTimeBefore _ _ = throwIO $ userError "Unsupported" +latestValueAndTimeBefore _ _ = + throwIO $ + userError "Unsupported: latestValueAndTimeBefore" valuesBetween :: Int64 -> PID -> Int -> Int -> ReaderT DummyEnv IO IValue -valuesBetween _ _ _ _ = throwIO $ userError "Unsupported" +valuesBetween _ _ _ _ = throwIO $ userError "Unsupported: valuesBetween" diff --git a/inferno-ml-server/exe/ParseAndSave.hs b/inferno-ml-server/exe/ParseAndSave.hs index 38141841..1baf63f0 100644 --- a/inferno-ml-server/exe/ParseAndSave.hs +++ b/inferno-ml-server/exe/ParseAndSave.hs @@ -13,13 +13,14 @@ module ParseAndSave (main) where import Control.Category ((>>>)) import Control.Exception (Exception (displayException)) import Control.Monad (void) +import Data.Aeson (eitherDecode) import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as Char8 +import qualified Data.ByteString.Lazy.Char8 as Lazy.Char8 +import Data.Map.Strict (Map) import Data.Text (Text) import qualified Data.Text.IO as Text.IO import Data.Time.Clock.POSIX (getPOSIXTime) -import Data.Vector (Vector) -import qualified Data.Vector as Vector import Database.PostgreSQL.Simple ( Connection, Query, @@ -37,7 +38,7 @@ import Inferno.Core import Inferno.ML.Server.Module.Prelude (mkBridgePrelude) import Inferno.ML.Server.Types import Inferno.ML.Types.Value (customTypes) -import Inferno.Types.Syntax (Expr, TCScheme) +import Inferno.Types.Syntax (Expr, Ident, TCScheme) import Inferno.Types.VersionControl ( Pinned, VCObjectHash, @@ -50,35 +51,37 @@ import Inferno.VersionControl.Types ) import System.Environment (getArgs) import System.Exit (die) -import Text.Read (readMaybe) import UnliftIO.Exception (bracket, throwString) main :: IO () main = getArgs >>= \case - scriptp : p : conns : _ -> - maybe - (throwString "Invalid PID") - (parseAndSave scriptp (Char8.pack conns) . PID) - $ readMaybe p - _ -> die "Usage ./parse " - -parseAndSave :: FilePath -> ByteString -> PID -> IO () -parseAndSave p conns pid = do + scriptp : pstr : conns : _ -> + either throwString (parseAndSave scriptp (Char8.pack conns)) + . eitherDecode + $ Lazy.Char8.pack pstr + _ -> die "Usage ./parse " + +parseAndSave :: + FilePath -> + ByteString -> + Map Ident (SingleOrMany PID, ScriptInputType) -> + IO () +parseAndSave p conns inputs = do t <- Text.IO.readFile p now <- fromIntegral @Int . round <$> getPOSIXTime ast <- either (throwString . displayException) pure . (`parse` t) =<< mkInferno @_ @BridgeMlValue (mkBridgePrelude funs) customTypes - bracket (connectPostgreSQL conns) close (saveScriptAndParam ast now pid) + bracket (connectPostgreSQL conns) close (saveScriptAndParam ast now inputs) saveScriptAndParam :: (Expr (Pinned VCObjectHash) (), TCScheme) -> CTime -> - PID -> + Map Ident (SingleOrMany PID, ScriptInputType) -> Connection -> IO () -saveScriptAndParam x now pid conn = insertScript *> insertParam +saveScriptAndParam x now inputs conn = insertScript *> insertParam where insertScript :: IO () insertScript = @@ -96,17 +99,15 @@ saveScriptAndParam x now pid conn = insertScript *> insertParam . InferenceParam Nothing hash + -- Bit of a hack. We only have one model version in the + -- tests, so we can just hard-code the ID here (Id 1) inputs - mempty Nothing $ entityIdFromInteger 0 where q :: Query - q = [sql| INSERT INTO params VALUES (?, ?, ?, ?, ?, ?, ?) |] - - inputs :: Vector (SingleOrMany PID) - inputs = Vector.singleton $ Single pid + q = [sql| INSERT INTO params VALUES (?, ?, ?, ?, ?, ?) |] vcfunc :: VCObject vcfunc = uncurry VCFunction x diff --git a/inferno-ml-server/inferno-ml-server.cabal b/inferno-ml-server/inferno-ml-server.cabal index e4e64dc5..ef746678 100644 --- a/inferno-ml-server/inferno-ml-server.cabal +++ b/inferno-ml-server/inferno-ml-server.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server -version: 2023.5.22 +version: 2023.5.29 synopsis: Server for Inferno ML description: Server for Inferno ML homepage: https://github.com/plow-technologies/inferno.git#readme @@ -120,7 +120,7 @@ executable tests executable test-client import: common main-is: Client.hs - hs-source-dirs: exe + hs-source-dirs: test ghc-options: -threaded -rtsopts -main-is Client build-depends: , aeson @@ -166,8 +166,10 @@ executable parse-and-save hs-source-dirs: exe ghc-options: -threaded -rtsopts -main-is ParseAndSave build-depends: + , aeson , base , bytestring + , containers , inferno-core , inferno-ml , inferno-ml-server diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 8b993602..ee6e50ab 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -31,7 +31,6 @@ import Data.Time (UTCTime, getCurrentTime) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.Traversable (for) import Data.UUID (UUID) -import qualified Data.Vector as Vector import Data.Word (Word64) import Database.PostgreSQL.Simple ( Only (Only), @@ -92,12 +91,14 @@ import UnliftIO.IORef (readIORef) import UnliftIO.MVar (putMVar, takeMVar, withMVar) import UnliftIO.Timeout (timeout) --- Run an inference param, locking the `MVar` held in the `Env`. This is to avoid --- running params in parallel (which may lead to problems with model caching, --- etc...) and also to indicate whether any active process is running in the --- `status` endpoint (using `tryTakeMVar`) +-- | Run an inference param, locking the @MVar@ held in the 'Env'. This is to +-- avoid running params in parallel (which may lead to problems with model +-- caching, etc...) and also to indicate whether any active process is running +-- in the @/status@ endpoint (using @tryTakeMVar@) runInferenceParam :: Id InferenceParam -> + -- | Optional resolution, defaulting to 128. This is needed in case the + -- parameter evaluates a script that calls e.g. @valueAt@ Maybe Int64 -> UUID -> RemoteM (WriteStream IO) @@ -167,9 +168,17 @@ runInferenceParam ipid (fromMaybe 128 -> res) uuid = ( mkIdent i, pids ^.. each & over mapped toSeries & VArray ) - where - ps :: [SingleOrMany PID] - ps = param ^.. #inputs . each + + -- Note that this both includes inputs (i.e. readable) + -- and outputs (i.e. writable, or readable/writable). + -- These need to be provided to the script in order + -- for the symbolic identifer (e.g. `output0`) to + -- resolve. We can discard the input type here, + -- however. The distinction is only relevant for the + -- runtime that runs as a script evaluation engine + -- and commits the output write object + ps :: [SingleOrMany PID] + ps = param ^.. #inputs . to Map.toAscList . each . _2 . _1 closure :: Map VCObjectHash VCObject closure = @@ -185,9 +194,10 @@ runInferenceParam ipid (fromMaybe 128 -> res) uuid = ) args where + -- See note above about inputs/outputs args :: [Expr (Maybe a) ()] args = - [0 .. param ^. #inputs & Vector.length & (- 1)] + [0 .. length ps - 1] <&> Var () Nothing LocalScope . Expl . ExtIdent @@ -198,18 +208,10 @@ runInferenceParam ipid (fromMaybe 128 -> res) uuid = dummy :: ImplExpl dummy = Expl . ExtIdent $ Right "dummy" - doEval expr =<< runImplEnvM mempty (mkEnvFromClosure localEnv closure) + either (throwInfernoError . Left . SomeInfernoError) yieldPairs + =<< flip (`evalExpr` implEnv) expr + =<< runImplEnvM mempty (mkEnvFromClosure localEnv closure) where - doEval :: - Expr (Maybe VCObjectHash) () -> - BridgeTermEnv RemoteM -> - RemoteM (WriteStream IO) - doEval x env = - either - (throwInfernoError . Left . SomeInfernoError) - yieldPairs - =<< evalExpr env implEnv x - yieldPairs :: Value BridgeMlValue (ImplEnvM RemoteM BridgeMlValue) -> RemoteM (WriteStream IO) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index 069cc8a9..c2d4ef30 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -43,12 +43,12 @@ import Data.Data (Typeable) import Data.Generics.Labels () import Data.Generics.Wrapped (wrappedTo) import Data.Int (Int64) +import Data.Map.Strict (Map) import Data.Text (Text) import qualified Data.Text as Text import qualified Data.Text.Read as Text.Read import Data.Time (UTCTime) import Data.UUID (UUID) -import Data.Vector (Vector) import Data.Word (Word64) import Data.Yaml (decodeFileThrow) import Database.PostgreSQL.Simple @@ -77,6 +77,7 @@ import "inferno-ml-server-types" Inferno.ML.Server.Types as M hiding ModelVersion, ) import qualified "inferno-ml-server-types" Inferno.ML.Server.Types as Types +import Inferno.Types.Syntax (Ident) import Inferno.VersionControl.Types ( VCObject, VCObjectHash, @@ -372,13 +373,12 @@ pattern InferenceParam :: Maybe (Id InferenceParam) -> VCObjectHash -> Id ModelVersion -> - Vector (SingleOrMany PID) -> - Vector (SingleOrMany PID) -> + Map Ident (SingleOrMany PID, ScriptInputType) -> Maybe UTCTime -> EntityId UId -> InferenceParam -pattern InferenceParam iid s m is os mt uid = - Types.InferenceParam iid s m is os mt uid +pattern InferenceParam iid s m ios mt uid = + Types.InferenceParam iid s m ios mt uid pattern VCMeta :: CTime -> diff --git a/inferno-ml-server/exe/Client.hs b/inferno-ml-server/test/Client.hs similarity index 77% rename from inferno-ml-server/exe/Client.hs rename to inferno-ml-server/test/Client.hs index 3485220a..88518056 100644 --- a/inferno-ml-server/exe/Client.hs +++ b/inferno-ml-server/test/Client.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeApplications #-} -- NOTE -- This executable is only intended for testing the inference endpoint with the @@ -10,8 +9,6 @@ module Client (main) where import Conduit import Control.Monad (unless, void) import Data.Coerce (coerce) -import qualified Data.Conduit.List as Conduit.List -import Data.Function (on) import Data.Int (Int64) import qualified Data.Map as Map import Inferno.ML.Server.Client (inferenceC, registerBridgeC) @@ -58,13 +55,13 @@ main = _ -> die "Usage: test-client " -- Check that the returned write stream matches the expected value -verifyWrites :: - Int64 -> - WriteStream IO -> - IO () +verifyWrites :: Int64 -> WriteStream IO -> IO () verifyWrites ipid c = do expected <- getExpected - result <- rebuildWrites + -- Note that there is only one chunk per PID in the output stream, so we + -- don't need to concatenate the results by PID. We can just sink it into + -- a list directly + result <- runConduit $ c .| sinkList unless (result == expected) . throwString . unwords $ [ "Expected: ", show expected, @@ -74,19 +71,21 @@ verifyWrites ipid c = do show ipid ] where - rebuildWrites :: IO [(Int, [(EpochTime, IValue)])] - rebuildWrites = - runConduit $ - c - .| Conduit.List.groupBy ((==) `on` fst) - .| Conduit.List.concat - .| sinkList - getExpected :: IO [(Int, [(EpochTime, IValue)])] getExpected = maybe (throwString "Missing PID") pure . Map.lookup ipid $ Map.fromList - [ (1, [(1, [(151, IDouble 2.5), (251, IDouble 3.5)])]), - (2, [(2, [(300, IDouble 25.0)])]), - (3, [(3, [(100, IDouble 7.0)])]) + [ ( 1, + [ (1, [(151, IDouble 2.5), (251, IDouble 3.5)]) + ] + ), + ( 2, + [ (2, [(300, IDouble 25.0)]) + ] + ), + ( 3, + [ (3, [(100, IDouble 7.0)]), + (4, [(100, IDouble 8.0)]) + ] + ) ] diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index 738aedbd..4d0bdbe4 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -67,8 +67,10 @@ create table if not exists params -- Script hash from `inferno-vc` , script bytea not null references scripts (id) , model integer references mversions (id) + -- Strictly speaking, this includes both inputs and outputs. The + -- corresponding Haskell type contains `(p, ScriptInputType)`, with + -- the second element determining readability and writability , inputs jsonb not null - , outputs jsonb not null -- See note above , terminated timestamptz , "user" integer references users (id) diff --git a/nix/inferno-ml/tests/scripts/mnist.inferno b/nix/inferno-ml/tests/scripts/mnist.inferno index 787ba93d..d5c16d16 100644 --- a/nix/inferno-ml/tests/scripts/mnist.inferno +++ b/nix/inferno-ml/tests/scripts/mnist.inferno @@ -1,4 +1,4 @@ -fun input0 -> +fun input0 input1 -> let input = ML.asTensor4 ML.#float [[[ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -34,7 +34,7 @@ fun input0 -> match ML.forward model [input] with { | [scores] -> let m = ML.toType ML.#double (ML.argmax 1 #false scores) in - [makeWrites input0 [(t, ML.asDouble m)]] + [makeWrites input0 [(t, ML.asDouble m)], makeWrites input1 [(t, ML.asDouble m + 1.0)]] | _ -> - [makeWrites input0 [(t, -1.0)]] + [makeWrites input0 [(t, -1.0)], makeWrites input1 [(t, -1.0)]] } diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index 955af67d..31b7962c 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -89,14 +89,22 @@ pkgs.nixosTest { # with an associated inference param for each text = let - dbstr = - "host='127.0.0.1' dbname='inferno' " - + "user='inferno' password=''"; + dbstr = "host='127.0.0.1' dbname='inferno' user='inferno' password=''"; + ios = + builtins.mapAttrs (_: builtins.toJSON) { + ones = { input0 = [ 1 "rw" ]; }; + contrived = { input0 = [ 2 "rw" ]; }; + # This test uses two outputs + mnist = { + input0 = [ 3 "rw" ]; + input1 = [ 4 "w" ]; + }; + }; in '' - parse-and-save ${./scripts/ones.inferno} 1 ${dbstr} - parse-and-save ${./scripts/contrived.inferno} 2 ${dbstr} - parse-and-save ${./scripts/mnist.inferno} 3 ${dbstr} + parse-and-save ${./scripts/ones.inferno} '${ios.ones}' ${dbstr} + parse-and-save ${./scripts/contrived.inferno} '${ios.contrived}' ${dbstr} + parse-and-save ${./scripts/mnist.inferno} '${ios.mnist}' ${dbstr} ''; } ) @@ -197,8 +205,6 @@ pkgs.nixosTest { def runtest(param): # Runs an test for an individual param using the client executable, # which confirms that the results are correct - # - # Note: The inference param DB ID and the associated PID are the same number node.succeed(f'run-inference-client-test {param}') node.wait_for_unit("multi-user.target")