diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0d702f2d..a877a153 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,6 +36,7 @@ jobs: with: install_url: https://releases.nixos.org/nix/nix-2.13.3/install extra_nix_config: | + fallback = true substituters = https://cache.nixos.org https://cache.iog.io trusted-public-keys = cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY= hydra.iohk.io:f/Ea+s+dFdN+3Y/G+FDgSq+a5NEWhJGzdjvKNGv0/EQ= narinfo-cache-negative-ttl = 60 diff --git a/inferno-ml-server-types/CHANGELOG.md b/inferno-ml-server-types/CHANGELOG.md index d7993081..6450fca8 100644 --- a/inferno-ml-server-types/CHANGELOG.md +++ b/inferno-ml-server-types/CHANGELOG.md @@ -1,6 +1,9 @@ # Revision History for inferno-ml-server-types *Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH) +## 0.4.0 +* Change representation of script inputs/outputs + ## 0.3.0 * Add support for tracking evaluation info diff --git a/inferno-ml-server-types/inferno-ml-server-types.cabal b/inferno-ml-server-types/inferno-ml-server-types.cabal index 6817b0dc..76888089 100644 --- a/inferno-ml-server-types/inferno-ml-server-types.cabal +++ b/inferno-ml-server-types/inferno-ml-server-types.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server-types -version: 0.3.0 +version: 0.4.0 synopsis: Types for Inferno ML server description: Types for Inferno ML server homepage: https://github.com/plow-technologies/inferno.git#readme diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index e098c65a..447057ad 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -17,7 +17,7 @@ import Conduit (ConduitT) import Control.Applicative (asum, optional) import Control.Category ((>>>)) import Control.DeepSeq (NFData (rnf), rwhnf) -import Control.Monad (void) +import Control.Monad (void, (<=<)) import Data.Aeson import Data.Aeson.Types (Parser) import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec @@ -64,6 +64,7 @@ import Database.PostgreSQL.Simple.Types ) import Foreign.C (CUInt (CUInt)) import GHC.Generics (Generic) +import Inferno.Types.Syntax (Ident) import Inferno.Types.VersionControl ( VCObjectHash, byteStringToVCObjectHash, @@ -172,8 +173,9 @@ newtype Id a = Id Int64 -- | Row for the table containing inference script closures data InferenceScript uid gid = InferenceScript - { -- NOTE: This is the ID for each row, stored as a `bytea` (bytes of the hash) + { -- | This is the ID for each row, stored as a @bytea@ (bytes of the hash) hash :: VCObjectHash, + -- | Script closure obj :: VCMeta uid gid VCObject } deriving stock (Show, Eq, Generic) @@ -237,6 +239,8 @@ data Model uid gid = Model -- | The user who owns the model, if any. Note that owning a model -- will implicitly set permissions user :: Maybe uid, + -- | The time that this model was \"deleted\", if any. For active models, + -- this will be @Nothing@ terminated :: Maybe UTCTime } deriving stock (Show, Eq, Generic) @@ -341,6 +345,8 @@ data ModelVersion uid gid c = ModelVersion -- the PSQL large object table contents :: c, version :: Version, + -- | The time that this model version was \"deleted\", if any. For active + -- models versions, this will be @Nothing@ terminated :: Maybe UTCTime } deriving stock (Show, Eq, Generic) @@ -592,13 +598,24 @@ data InferenceParam uid gid p s = InferenceParam -- (e.g. a UUID for use with @inferno-lsp@) -- -- For existing inference params, this is the foreign key for the specific - -- script in the 'InferenceScript' table + -- script in the 'InferenceScript' table (i.e. a @VCObjectHash@) script :: s, -- | This needs to be linked to a specific version of a model rather -- than the @model@ table itself model :: Id (ModelVersion uid gid Oid), - inputs :: Vector (SingleOrMany p), - outputs :: Vector (SingleOrMany p), + -- | This is called @inputs@ but is also used for script outputs as + -- well. The access (input or output) is controlled by the 'ScriptInputType'. + -- For example, if this field is set to @[("input0", Single (p, Readable))]@, + -- the script will only have a single read-only input and will not be able to + -- write anywhere (note that we should disallow this scenario, as script + -- evaluation would not work properly) + -- + -- Mapping the input\/output to the Inferno identifier helps ensure that + -- Inferno identifiers are always pointing to the correct input\/output; + -- otherwise we would need to rely on the order of the original identifiers + inputs :: Map Ident (SingleOrMany p, ScriptInputType), + -- | The time that this parameter was \"deleted\", if any. For active + -- parameters, this will be @Nothing@ terminated :: Maybe UTCTime, user :: uid } @@ -619,10 +636,6 @@ instance <$> field <*> fmap wrappedTo (field @VCObjectHashRow) <*> field - -- HACK / FIXME This is a pretty awful hack (storing as `jsonb`), - -- but Postgres sub-arrays need to be the same length and writing - -- a custom parser might be painful - <*> fmap getAeson field <*> fmap getAeson field <*> field <*> field @@ -638,13 +651,41 @@ instance [ toField Default, ip ^. the @"script" & VCObjectHashRow & toField, ip ^. the @"model" & toField, - -- HACK / FIXME See above ip ^. the @"inputs" & Aeson & toField, - ip ^. the @"outputs" & Aeson & toField, toField Default, ip ^. the @"user" & toField ] +-- | Controls input interaction within a script, i.e. ability to read from +-- and\/or write to this input. Although the term \"input\" is used, those with +-- writes enabled can also be described as \"outputs\" +data ScriptInputType + = -- | Script input can be read, but not written + Readable + | -- | Script input can be written, i.e. can be used in array of + -- write objects returned from script evaluation + Writable + | -- | Script input can be both read from and written to; this allows + -- the same script identifier to point to the same PID with both + -- types of access enabled + ReadableWritable + deriving stock (Show, Eq, Generic) + deriving anyclass (NFData) + +instance FromJSON ScriptInputType where + parseJSON = withText "ScriptInputType" $ \case + "r" -> pure Readable + "w" -> pure Writable + "rw" -> pure ReadableWritable + s -> fail $ "Invalid script input type: " <> Text.unpack s + +instance ToJSON ScriptInputType where + toJSON = + String . \case + Readable -> "r" + Writable -> "w" + ReadableWritable -> "rw" + -- | Information about execution time and resource usage. This is saved by -- @inferno-ml-server@ after script evaluation completes and can be queried -- later by using the same job identifier that was provided to the @/inference@ @@ -762,22 +803,42 @@ instance FromJSON IValue where Number n -> pure . IDouble $ toRealFloat n -- It's easier to just mark the time explicitly in an object, -- rather than try to deal with distinguishing times and doubles - Object o -> ITime <$> o .: "time" + Object o -> + asum + [ ITime <$> o .: "time", + fmap IArray $ arrayP =<< o .: "array" + ] + -- Note that this preserves a plain JSON array for tuples. But we need + -- some straightforward way of distinguishing tuples and arrays; since + -- the bridge often transmits a large number of individual tuples (times + -- and values), it's better to use arrays for the tuples and a tagged object + -- for arrays themselves; we often will only deal with one large array, and + -- adding a few bytes to this is better than adding a few bytes to thousands + -- of encoded tuples Array a | [x, y] <- Vector.toList a -> - (,) <$> parseJSON x <*> parseJSON y <&> \case - -- We don't want to confuse a two-element array of tuples with - -- a tuple itself. For example, `"[[10.0, {\"time\": 10}], [10.0, {\"time\": 10}]]"` - -- should parse as a two-element array of `(double, time)` tuples, - -- not as a `((double, time), (double, time))`. I can't think of - -- any reason to support the latter. An alternative would be to - -- change tuple encoding to an object, but then we would be transmitting - -- a much larger amount of data on most requests - (f@(ITuple _), s@(ITuple _)) -> IArray $ Vector.fromList [f, s] - t -> ITuple t - | otherwise -> IArray <$> traverse (parseJSON @IValue) a + fmap ITuple $ (,) <$> parseJSON x <*> parseJSON y + | otherwise -> fail "Only two-element tuples are supported" Null -> pure IEmpty - _ -> fail "Expected one of: string, double, time, tuple, unit (empty array), array" + _ -> fail "Expected one of: string, double, time, tuple, null, array" + where + arrayP :: Vector Value -> Parser (Vector IValue) + arrayP a = + -- This is a bit tedious, but we want to make sure that the array elements + -- are homogeneous; parsing all elements to `IValue`s first can't guarantee + -- this + asum + [ -- This alternative means that `null` will be correctly parsed to NaN + -- when inside an array of doubles + fmap IDouble <$> traverse parseJSON a, + fmap ITuple <$> traverse parseJSON a, + fmap IText <$> traverse parseJSON a, + fmap ITime <$> traverse (withObject "EpochTime" (.: "time")) a, + -- Nested array support + fmap IArray + <$> traverse (withObject "IArray" (arrayP <=< (.: "array"))) a, + fail "Expected a heterogeneous array" + ] instance ToJSON IValue where toJSON = \case @@ -786,8 +847,9 @@ instance ToJSON IValue where ITuple t -> toJSON t -- See `FromJSON` instance above ITime t -> object ["time" .= t] + -- See `FromJSON` instance above + IArray is -> object ["array" .= is] IEmpty -> toJSON Null - IArray is -> toJSON is -- | Used to represent inputs to the script. 'Many' allows for an array input data SingleOrMany a diff --git a/inferno-ml-server/CHANGELOG.md b/inferno-ml-server/CHANGELOG.md index f692cbf9..34e71fe0 100644 --- a/inferno-ml-server/CHANGELOG.md +++ b/inferno-ml-server/CHANGELOG.md @@ -1,3 +1,6 @@ +## 2023.5.29 +* Change representation of script inputs/outputs + ## 2023.5.22 * Add support for tracking evaluation info diff --git a/inferno-ml-server/exe/Dummy.hs b/inferno-ml-server/exe/Dummy.hs index e5ecd68a..68e1043b 100644 --- a/inferno-ml-server/exe/Dummy.hs +++ b/inferno-ml-server/exe/Dummy.hs @@ -87,7 +87,9 @@ valueAt _ p t = <&> maybe IEmpty IDouble . preview (at p . _Just . at t . _Just) latestValueAndTimeBefore :: Int -> PID -> DummyM IValue -latestValueAndTimeBefore _ _ = throwIO $ userError "Unsupported" +latestValueAndTimeBefore _ _ = + throwIO $ + userError "Unsupported: latestValueAndTimeBefore" valuesBetween :: Int64 -> PID -> Int -> Int -> ReaderT DummyEnv IO IValue -valuesBetween _ _ _ _ = throwIO $ userError "Unsupported" +valuesBetween _ _ _ _ = throwIO $ userError "Unsupported: valuesBetween" diff --git a/inferno-ml-server/exe/ParseAndSave.hs b/inferno-ml-server/exe/ParseAndSave.hs index 38141841..1baf63f0 100644 --- a/inferno-ml-server/exe/ParseAndSave.hs +++ b/inferno-ml-server/exe/ParseAndSave.hs @@ -13,13 +13,14 @@ module ParseAndSave (main) where import Control.Category ((>>>)) import Control.Exception (Exception (displayException)) import Control.Monad (void) +import Data.Aeson (eitherDecode) import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as Char8 +import qualified Data.ByteString.Lazy.Char8 as Lazy.Char8 +import Data.Map.Strict (Map) import Data.Text (Text) import qualified Data.Text.IO as Text.IO import Data.Time.Clock.POSIX (getPOSIXTime) -import Data.Vector (Vector) -import qualified Data.Vector as Vector import Database.PostgreSQL.Simple ( Connection, Query, @@ -37,7 +38,7 @@ import Inferno.Core import Inferno.ML.Server.Module.Prelude (mkBridgePrelude) import Inferno.ML.Server.Types import Inferno.ML.Types.Value (customTypes) -import Inferno.Types.Syntax (Expr, TCScheme) +import Inferno.Types.Syntax (Expr, Ident, TCScheme) import Inferno.Types.VersionControl ( Pinned, VCObjectHash, @@ -50,35 +51,37 @@ import Inferno.VersionControl.Types ) import System.Environment (getArgs) import System.Exit (die) -import Text.Read (readMaybe) import UnliftIO.Exception (bracket, throwString) main :: IO () main = getArgs >>= \case - scriptp : p : conns : _ -> - maybe - (throwString "Invalid PID") - (parseAndSave scriptp (Char8.pack conns) . PID) - $ readMaybe p - _ -> die "Usage ./parse " - -parseAndSave :: FilePath -> ByteString -> PID -> IO () -parseAndSave p conns pid = do + scriptp : pstr : conns : _ -> + either throwString (parseAndSave scriptp (Char8.pack conns)) + . eitherDecode + $ Lazy.Char8.pack pstr + _ -> die "Usage ./parse " + +parseAndSave :: + FilePath -> + ByteString -> + Map Ident (SingleOrMany PID, ScriptInputType) -> + IO () +parseAndSave p conns inputs = do t <- Text.IO.readFile p now <- fromIntegral @Int . round <$> getPOSIXTime ast <- either (throwString . displayException) pure . (`parse` t) =<< mkInferno @_ @BridgeMlValue (mkBridgePrelude funs) customTypes - bracket (connectPostgreSQL conns) close (saveScriptAndParam ast now pid) + bracket (connectPostgreSQL conns) close (saveScriptAndParam ast now inputs) saveScriptAndParam :: (Expr (Pinned VCObjectHash) (), TCScheme) -> CTime -> - PID -> + Map Ident (SingleOrMany PID, ScriptInputType) -> Connection -> IO () -saveScriptAndParam x now pid conn = insertScript *> insertParam +saveScriptAndParam x now inputs conn = insertScript *> insertParam where insertScript :: IO () insertScript = @@ -96,17 +99,15 @@ saveScriptAndParam x now pid conn = insertScript *> insertParam . InferenceParam Nothing hash + -- Bit of a hack. We only have one model version in the + -- tests, so we can just hard-code the ID here (Id 1) inputs - mempty Nothing $ entityIdFromInteger 0 where q :: Query - q = [sql| INSERT INTO params VALUES (?, ?, ?, ?, ?, ?, ?) |] - - inputs :: Vector (SingleOrMany PID) - inputs = Vector.singleton $ Single pid + q = [sql| INSERT INTO params VALUES (?, ?, ?, ?, ?, ?) |] vcfunc :: VCObject vcfunc = uncurry VCFunction x diff --git a/inferno-ml-server/inferno-ml-server.cabal b/inferno-ml-server/inferno-ml-server.cabal index e4e64dc5..ef746678 100644 --- a/inferno-ml-server/inferno-ml-server.cabal +++ b/inferno-ml-server/inferno-ml-server.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server -version: 2023.5.22 +version: 2023.5.29 synopsis: Server for Inferno ML description: Server for Inferno ML homepage: https://github.com/plow-technologies/inferno.git#readme @@ -120,7 +120,7 @@ executable tests executable test-client import: common main-is: Client.hs - hs-source-dirs: exe + hs-source-dirs: test ghc-options: -threaded -rtsopts -main-is Client build-depends: , aeson @@ -166,8 +166,10 @@ executable parse-and-save hs-source-dirs: exe ghc-options: -threaded -rtsopts -main-is ParseAndSave build-depends: + , aeson , base , bytestring + , containers , inferno-core , inferno-ml , inferno-ml-server diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 8b993602..ee6e50ab 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -31,7 +31,6 @@ import Data.Time (UTCTime, getCurrentTime) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.Traversable (for) import Data.UUID (UUID) -import qualified Data.Vector as Vector import Data.Word (Word64) import Database.PostgreSQL.Simple ( Only (Only), @@ -92,12 +91,14 @@ import UnliftIO.IORef (readIORef) import UnliftIO.MVar (putMVar, takeMVar, withMVar) import UnliftIO.Timeout (timeout) --- Run an inference param, locking the `MVar` held in the `Env`. This is to avoid --- running params in parallel (which may lead to problems with model caching, --- etc...) and also to indicate whether any active process is running in the --- `status` endpoint (using `tryTakeMVar`) +-- | Run an inference param, locking the @MVar@ held in the 'Env'. This is to +-- avoid running params in parallel (which may lead to problems with model +-- caching, etc...) and also to indicate whether any active process is running +-- in the @/status@ endpoint (using @tryTakeMVar@) runInferenceParam :: Id InferenceParam -> + -- | Optional resolution, defaulting to 128. This is needed in case the + -- parameter evaluates a script that calls e.g. @valueAt@ Maybe Int64 -> UUID -> RemoteM (WriteStream IO) @@ -167,9 +168,17 @@ runInferenceParam ipid (fromMaybe 128 -> res) uuid = ( mkIdent i, pids ^.. each & over mapped toSeries & VArray ) - where - ps :: [SingleOrMany PID] - ps = param ^.. #inputs . each + + -- Note that this both includes inputs (i.e. readable) + -- and outputs (i.e. writable, or readable/writable). + -- These need to be provided to the script in order + -- for the symbolic identifer (e.g. `output0`) to + -- resolve. We can discard the input type here, + -- however. The distinction is only relevant for the + -- runtime that runs as a script evaluation engine + -- and commits the output write object + ps :: [SingleOrMany PID] + ps = param ^.. #inputs . to Map.toAscList . each . _2 . _1 closure :: Map VCObjectHash VCObject closure = @@ -185,9 +194,10 @@ runInferenceParam ipid (fromMaybe 128 -> res) uuid = ) args where + -- See note above about inputs/outputs args :: [Expr (Maybe a) ()] args = - [0 .. param ^. #inputs & Vector.length & (- 1)] + [0 .. length ps - 1] <&> Var () Nothing LocalScope . Expl . ExtIdent @@ -198,18 +208,10 @@ runInferenceParam ipid (fromMaybe 128 -> res) uuid = dummy :: ImplExpl dummy = Expl . ExtIdent $ Right "dummy" - doEval expr =<< runImplEnvM mempty (mkEnvFromClosure localEnv closure) + either (throwInfernoError . Left . SomeInfernoError) yieldPairs + =<< flip (`evalExpr` implEnv) expr + =<< runImplEnvM mempty (mkEnvFromClosure localEnv closure) where - doEval :: - Expr (Maybe VCObjectHash) () -> - BridgeTermEnv RemoteM -> - RemoteM (WriteStream IO) - doEval x env = - either - (throwInfernoError . Left . SomeInfernoError) - yieldPairs - =<< evalExpr env implEnv x - yieldPairs :: Value BridgeMlValue (ImplEnvM RemoteM BridgeMlValue) -> RemoteM (WriteStream IO) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index 069cc8a9..c2d4ef30 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -43,12 +43,12 @@ import Data.Data (Typeable) import Data.Generics.Labels () import Data.Generics.Wrapped (wrappedTo) import Data.Int (Int64) +import Data.Map.Strict (Map) import Data.Text (Text) import qualified Data.Text as Text import qualified Data.Text.Read as Text.Read import Data.Time (UTCTime) import Data.UUID (UUID) -import Data.Vector (Vector) import Data.Word (Word64) import Data.Yaml (decodeFileThrow) import Database.PostgreSQL.Simple @@ -77,6 +77,7 @@ import "inferno-ml-server-types" Inferno.ML.Server.Types as M hiding ModelVersion, ) import qualified "inferno-ml-server-types" Inferno.ML.Server.Types as Types +import Inferno.Types.Syntax (Ident) import Inferno.VersionControl.Types ( VCObject, VCObjectHash, @@ -372,13 +373,12 @@ pattern InferenceParam :: Maybe (Id InferenceParam) -> VCObjectHash -> Id ModelVersion -> - Vector (SingleOrMany PID) -> - Vector (SingleOrMany PID) -> + Map Ident (SingleOrMany PID, ScriptInputType) -> Maybe UTCTime -> EntityId UId -> InferenceParam -pattern InferenceParam iid s m is os mt uid = - Types.InferenceParam iid s m is os mt uid +pattern InferenceParam iid s m ios mt uid = + Types.InferenceParam iid s m ios mt uid pattern VCMeta :: CTime -> diff --git a/inferno-ml-server/exe/Client.hs b/inferno-ml-server/test/Client.hs similarity index 77% rename from inferno-ml-server/exe/Client.hs rename to inferno-ml-server/test/Client.hs index 3485220a..88518056 100644 --- a/inferno-ml-server/exe/Client.hs +++ b/inferno-ml-server/test/Client.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeApplications #-} -- NOTE -- This executable is only intended for testing the inference endpoint with the @@ -10,8 +9,6 @@ module Client (main) where import Conduit import Control.Monad (unless, void) import Data.Coerce (coerce) -import qualified Data.Conduit.List as Conduit.List -import Data.Function (on) import Data.Int (Int64) import qualified Data.Map as Map import Inferno.ML.Server.Client (inferenceC, registerBridgeC) @@ -58,13 +55,13 @@ main = _ -> die "Usage: test-client " -- Check that the returned write stream matches the expected value -verifyWrites :: - Int64 -> - WriteStream IO -> - IO () +verifyWrites :: Int64 -> WriteStream IO -> IO () verifyWrites ipid c = do expected <- getExpected - result <- rebuildWrites + -- Note that there is only one chunk per PID in the output stream, so we + -- don't need to concatenate the results by PID. We can just sink it into + -- a list directly + result <- runConduit $ c .| sinkList unless (result == expected) . throwString . unwords $ [ "Expected: ", show expected, @@ -74,19 +71,21 @@ verifyWrites ipid c = do show ipid ] where - rebuildWrites :: IO [(Int, [(EpochTime, IValue)])] - rebuildWrites = - runConduit $ - c - .| Conduit.List.groupBy ((==) `on` fst) - .| Conduit.List.concat - .| sinkList - getExpected :: IO [(Int, [(EpochTime, IValue)])] getExpected = maybe (throwString "Missing PID") pure . Map.lookup ipid $ Map.fromList - [ (1, [(1, [(151, IDouble 2.5), (251, IDouble 3.5)])]), - (2, [(2, [(300, IDouble 25.0)])]), - (3, [(3, [(100, IDouble 7.0)])]) + [ ( 1, + [ (1, [(151, IDouble 2.5), (251, IDouble 3.5)]) + ] + ), + ( 2, + [ (2, [(300, IDouble 25.0)]) + ] + ), + ( 3, + [ (3, [(100, IDouble 7.0)]), + (4, [(100, IDouble 8.0)]) + ] + ) ] diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index 738aedbd..4d0bdbe4 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -67,8 +67,10 @@ create table if not exists params -- Script hash from `inferno-vc` , script bytea not null references scripts (id) , model integer references mversions (id) + -- Strictly speaking, this includes both inputs and outputs. The + -- corresponding Haskell type contains `(p, ScriptInputType)`, with + -- the second element determining readability and writability , inputs jsonb not null - , outputs jsonb not null -- See note above , terminated timestamptz , "user" integer references users (id) diff --git a/nix/inferno-ml/tests/scripts/mnist.inferno b/nix/inferno-ml/tests/scripts/mnist.inferno index 787ba93d..d5c16d16 100644 --- a/nix/inferno-ml/tests/scripts/mnist.inferno +++ b/nix/inferno-ml/tests/scripts/mnist.inferno @@ -1,4 +1,4 @@ -fun input0 -> +fun input0 input1 -> let input = ML.asTensor4 ML.#float [[[ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -34,7 +34,7 @@ fun input0 -> match ML.forward model [input] with { | [scores] -> let m = ML.toType ML.#double (ML.argmax 1 #false scores) in - [makeWrites input0 [(t, ML.asDouble m)]] + [makeWrites input0 [(t, ML.asDouble m)], makeWrites input1 [(t, ML.asDouble m + 1.0)]] | _ -> - [makeWrites input0 [(t, -1.0)]] + [makeWrites input0 [(t, -1.0)], makeWrites input1 [(t, -1.0)]] } diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index 955af67d..31b7962c 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -89,14 +89,22 @@ pkgs.nixosTest { # with an associated inference param for each text = let - dbstr = - "host='127.0.0.1' dbname='inferno' " - + "user='inferno' password=''"; + dbstr = "host='127.0.0.1' dbname='inferno' user='inferno' password=''"; + ios = + builtins.mapAttrs (_: builtins.toJSON) { + ones = { input0 = [ 1 "rw" ]; }; + contrived = { input0 = [ 2 "rw" ]; }; + # This test uses two outputs + mnist = { + input0 = [ 3 "rw" ]; + input1 = [ 4 "w" ]; + }; + }; in '' - parse-and-save ${./scripts/ones.inferno} 1 ${dbstr} - parse-and-save ${./scripts/contrived.inferno} 2 ${dbstr} - parse-and-save ${./scripts/mnist.inferno} 3 ${dbstr} + parse-and-save ${./scripts/ones.inferno} '${ios.ones}' ${dbstr} + parse-and-save ${./scripts/contrived.inferno} '${ios.contrived}' ${dbstr} + parse-and-save ${./scripts/mnist.inferno} '${ios.mnist}' ${dbstr} ''; } ) @@ -197,8 +205,6 @@ pkgs.nixosTest { def runtest(param): # Runs an test for an individual param using the client executable, # which confirms that the results are correct - # - # Note: The inference param DB ID and the associated PID are the same number node.succeed(f'run-inference-client-test {param}') node.wait_for_unit("multi-user.target")