Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inferno-ml] Link multiple models to params #124

Merged
merged 8 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inferno-ml-server-types/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Revision History for inferno-ml-server-types
*Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH)

## 0.6.0
* Support linking multiple models to inference parameters

## 0.5.0
* Add `resolution` to `InferenceParam`

Expand Down
2 changes: 1 addition & 1 deletion inferno-ml-server-types/inferno-ml-server-types.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 2.4
name: inferno-ml-server-types
version: 0.5.0
version: 0.6.0
synopsis: Types for Inferno ML server
description: Types for Inferno ML server
homepage: https://github.com/plow-technologies/inferno.git#readme
Expand Down
21 changes: 15 additions & 6 deletions inferno-ml-server-types/src/Inferno/ML/Server/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,16 @@ data InferenceParam uid gid p s = InferenceParam
-- For existing inference params, this is the foreign key for the specific
-- script in the 'InferenceScript' table (i.e. a @VCObjectHash@)
script :: s,
-- | This needs to be linked to a specific version of a model rather
-- than the @model@ table itself
model :: Id (ModelVersion uid gid Oid),
-- | All of the (specific versions of) models that can be used with this
-- parameters. @inferno-ml-server@ will copy the contents of each of the
-- model versions when evaluating inference scripts. Inference scripts can
-- reference any of the linked model versions by referring to the parent
-- model\'s name, e.g. @loadModel "name.ts.pt"@
--
-- Each element represents the ID of a specific model version. However, due
-- to limitations in PostgreSQL, there is no referential integrity; i.e.
-- the elements are treated as plain integers
Comment on lines +609 to +611
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What PostgreSQL limitations prevent referential integrity? AFAICT this looks like a many-to-one relation between modelVersion and inferenceParam which can be modeled as a inferenceParam foreign key on the modelVersion table

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I phrased it like that is that arrays of foreign keys are in a recent SQL standard but Postgres doesn't support them

models :: Vector (Id (ModelVersion uid gid Oid)),
-- | This is called @inputs@ but is also used for script outputs as
-- well. The access (input or output) is controlled by the 'ScriptInputType'.
-- For example, if this field is set to @[("input0", Single (p, Readable))]@,
Expand All @@ -629,7 +636,9 @@ data InferenceParam uid gid p s = InferenceParam
instance
( FromJSON p,
Typeable p,
FromField uid
FromField uid,
Typeable gid,
Typeable uid
) =>
FromRow (InferenceParam uid gid p VCObjectHash)
where
Expand All @@ -653,7 +662,7 @@ instance
toRow ip =
[ toField Default,
ip ^. the @"script" & VCObjectHashRow & toField,
ip ^. the @"model" & toField,
ip ^. the @"models" & toField,
ip ^. the @"inputs" & Aeson & toField,
ip ^. the @"resolution" & Aeson & toField,
toField Default,
Expand Down Expand Up @@ -781,7 +790,7 @@ fromIPv4 :: IPv4 -> (Int, Int, Int, Int)
fromIPv4 =
wrappedTo >>> Data.IP.fromIPv4 >>> \case
[a, b, c, d] -> (a, b, c, d)
-- Should not happen
-- Should not happen, `fromIPv4` always produces a 4-element list
_ -> error "Invalid IP address"

-- Bridge-related types
Expand Down
5 changes: 5 additions & 0 deletions inferno-ml-server/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Revision History for `inferno-ml-server`

## 2023.6.5
* Support linking multiple models to inference parameters

## 2023.6.1
* Add `resolution` to `InferenceParam`

Expand Down
3 changes: 2 additions & 1 deletion inferno-ml-server/exe/ParseAndSave.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Data.Map.Strict (Map)
import Data.Text (Text)
import qualified Data.Text.IO as Text.IO
import Data.Time.Clock.POSIX (getPOSIXTime)
import qualified Data.Vector as Vector
import Database.PostgreSQL.Simple
( Connection,
Query,
Expand Down Expand Up @@ -101,7 +102,7 @@ saveScriptAndParam x now inputs conn = insertScript *> insertParam
hash
-- Bit of a hack. We only have one model version in the
-- tests, so we can just hard-code the ID here
(Id 1)
(Vector.singleton (Id 1))
inputs
128
Nothing
Expand Down
4 changes: 3 additions & 1 deletion inferno-ml-server/inferno-ml-server.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 2.4
name: inferno-ml-server
version: 2023.6.1
version: 2023.6.5
synopsis: Server for Inferno ML
description: Server for Inferno ML
homepage: https://github.com/plow-technologies/inferno.git#readme
Expand Down Expand Up @@ -114,8 +114,10 @@ executable tests
, inferno-ml-server
, microlens-platform
, mtl
, plow-log
, text
, unliftio
, vector

executable test-client
import: common
Expand Down
4 changes: 2 additions & 2 deletions inferno-ml-server/src/Inferno/ML/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ main = runServer =<< mkOptions

runInEnv :: Config -> (Env -> IO ()) -> IO ()
runInEnv cfg f = withRemoteTracer $ \tracer -> do
traceWith tracer StartingServer
traceWith tracer $ InfoTrace StartingServer
withConnect (view #store cfg) $ \conn ->
f
=<< Env cfg conn tracer
Expand Down Expand Up @@ -159,4 +159,4 @@ server =
=<< view #job
where
logAndCancel :: (Id InferenceParam, Async (Maybe (WriteStream IO))) -> RemoteM ()
logAndCancel (i, j) = logTrace (CancelingInference i) *> cancel j
logAndCancel (i, j) = logWarn (CancelingInference i) *> cancel j
2 changes: 1 addition & 1 deletion inferno-ml-server/src/Inferno/ML/Server/Bridge.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import UnliftIO.IORef (atomicWriteIORef, readIORef)
-- data from\/to the data source)
registerBridgeInfo :: BridgeInfo -> RemoteM ()
registerBridgeInfo bi = do
logTrace $ RegisteringBridge bi
logInfo $ RegisteringBridge bi
liftIO $ encodeFile bridgeCache bi
(`atomicWriteIORef` Just bi) =<< view (#bridge . #info)
interpreter <- mkInferno @_ @BridgeMlValue (mkBridgePrelude funs) customTypes
Expand Down
45 changes: 23 additions & 22 deletions inferno-ml-server/src/Inferno/ML/Server/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

module Inferno.ML.Server.Inference
( runInferenceParam,
getAndCacheModel,
getAndCacheModels,
linkVersionedModel,
)
where

import Conduit (ConduitT, awaitForever, mapC, yieldMany, (.|))
import Control.Monad (void, when, (<=<))
import Control.Monad.Catch (throwM)
import Control.Monad.Extra (loopM, unlessM, whenM)
import Control.Monad.Extra (loopM, unlessM, whenJust, whenM)
import Control.Monad.IO.Class (MonadIO (liftIO))
import Control.Monad.ListM (sortByM)
import Data.Bifoldable (bitraverse_)
import Data.Conduit.List (chunksOf, sourceList)
import Data.Foldable (foldl')
import Data.Foldable (foldl', traverse_)
import Data.Generics.Wrapped (wrappedTo)
import Data.Int (Int64)
import Data.Map (Map)
Expand All @@ -30,6 +30,7 @@ import Data.Time (UTCTime, getCurrentTime)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Traversable (for)
import Data.UUID (UUID)
import Data.Vector (Vector)
import Data.Word (Word64)
import Database.PostgreSQL.Simple
( Only (Only),
Expand Down Expand Up @@ -104,7 +105,7 @@ runInferenceParam ::
RemoteM (WriteStream IO)
runInferenceParam ipid mres uuid =
withTimeoutMillis $ \t -> do
logTrace $ RunningInference ipid t
logInfo $ RunningInference ipid t
maybe (throwM (ScriptTimeout t)) pure
=<< (`withMVar` const (run t))
=<< view #lock
Expand Down Expand Up @@ -139,8 +140,9 @@ runInferenceParam ipid mres uuid =
-- need to be updated to use an absolute path to a versioned model,
-- e.g. `loadModel "~/inferno/.cache/..."`)
withCurrentDirectory (view #path cache) $ do
logTrace $ EvaluatingScript ipid
linkVersionedModel =<< getAndCacheModel cache (view #model param)
logInfo $ EvaluatingScript ipid
traverse_ linkVersionedModel
=<< getAndCacheModels cache (view #models param)
runEval interpreter param t obj
where
runEval ::
Expand Down Expand Up @@ -319,7 +321,7 @@ runInferenceParam ipid mres uuid =
-- an inconvenience than a fatal error
logAndIgnore :: SqlError -> RemoteM ()
logAndIgnore =
logTrace
logWarn
. OtherWarn
. ("Failed to save eval info: " <>)
. Text.pack
Expand Down Expand Up @@ -361,29 +363,28 @@ getParameter iid =
LIMIT 1
|]

-- | First retrieves the specified model version from the database, then fetches
-- the associated parent model. The contents of the model version are retrieved
-- (the Postgres large object), then copied to the model cache if it has not
-- yet been cached. Older previously saved model versions(s) are evicted if the
-- cache 'maxSize' is exceeded by adding the model version contents
-- | For all of the model version IDs declared in the param, fetch the model
-- version and the parent model, and then cache them
--
-- The contents of the model version are retrieved (the Postgres large object),
-- then copied to the model cache if it has not yet been cached. Previously
-- saved model versions(s) are evicted if the cache 'maxSize' is exceeded by
-- adding the model version contents; this is based on access time
--
-- NOTE: This action assumes that the current working directory is the model
-- cache! It can be run using e.g. 'withCurrentDirectory'
getAndCacheModel :: ModelCache -> Id ModelVersion -> RemoteM FilePath
getAndCacheModel cache mid = do
-- Both the individual version is required (in order to fetch the contents)
-- as well as the parent model row (for the model name)
mversion <- getModelVersion mid
model <- getModel $ view #model mversion
copyAndCache model mversion
getAndCacheModels ::
ModelCache -> Vector (Id ModelVersion) -> RemoteM (Vector FilePath)
getAndCacheModels cache =
traverse (uncurry copyAndCache) <=< getModelsAndVersions
where
copyAndCache :: Model -> ModelVersion -> RemoteM FilePath
copyAndCache model mversion =
versioned <$ do
unlessM (doesPathExist versioned) $ do
logTrace $ CopyingModel mid
mversion ^. #id & (`whenJust` logInfo . CopyingModel)
bitraverse_ checkCacheSize (writeBinaryFileDurableAtomic versioned)
=<< getModelSizeAndContents (view #contents mversion)
=<< getModelVersionSizeAndContents (view #contents mversion)
where
-- Cache the model with its specific version, i.e.
-- `<name>.ts.pt.<version>`, which will later be
Expand Down Expand Up @@ -425,7 +426,7 @@ getAndCacheModel cache mid = do
tryRemoveFile :: RemoteM ()
tryRemoveFile =
catchIO (removeFile m) $
logTrace
logWarn
. OtherWarn
. Text.pack
. displayException
Expand Down
65 changes: 28 additions & 37 deletions inferno-ml-server/src/Inferno/ML/Server/Inference/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
{-# LANGUAGE ScopedTypeVariables #-}

module Inferno.ML.Server.Inference.Model
( getModel,
getModelVersion,
getModelSizeAndContents,
( getModelsAndVersions,
getModelVersionSizeAndContents,
)
where

import Control.Monad.IO.Class (MonadIO (liftIO))
import Data.ByteString (ByteString)
import Data.Generics.Wrapped (wrappedTo)
import Data.Foldable (toList)
import Data.Vector (Vector)
import Database.PostgreSQL.Simple
( Only (Only, fromOnly),
( In (In),
Only (Only, fromOnly),
Query,
withTransaction,
)
Expand All @@ -33,47 +34,37 @@ import Lens.Micro.Platform
import UnliftIO (MonadUnliftIO (withRunInIO))
import UnliftIO.Exception (bracket)

-- | Get the model row itself. This is to access things like the name,
-- permissions, etc... that are not contained in the model version table
-- (see 'getModelVersion')
getModel :: Id Model -> RemoteM Model
getModel mid =
firstOrThrow (NoSuchModel (wrappedTo mid))
=<< queryStore q (Only mid)
-- | Get an array of model versions along with the parent model of each; note
-- that this does not retrieve the model version contents -- each version only
-- contains the 'Oid' of the large object
getModelsAndVersions ::
Vector (Id ModelVersion) -> RemoteM (Vector (Model, ModelVersion))
getModelsAndVersions =
fmap (fmap joinToTuple)
. queryStore q
. Only
. In
. toList
where
q :: Query
q =
[sql|
SELECT * FROM models WHERE id = ?
AND terminated IS NULL
|]

-- Get a row from the model versions table, which contains the actual contents,
-- description, etc... The foreign key of the version row can be used to get
-- the invariant model metadata (see 'getModel')
--
-- This does not include the actual contents of the model, which need to be
-- fetched separately using 'loImport'
getModelVersion :: Id ModelVersion -> RemoteM ModelVersion
getModelVersion mid =
firstOrThrow (NoSuchModel (wrappedTo mid))
=<< queryStore q (Only mid)
where
q :: Query
q =
[sql|
SELECT * FROM mversions WHERE id = ?
AND terminated IS NULL
SELECT M.*, V.*
FROM mversions V
INNER JOIN models M ON V.model = M.id
WHERE V.id IN ?
AND V.terminated IS NULL
AND M.terminated IS NULL
|]

-- | Get the actual serialized bytes of the model, which is stored in the Postgres
-- large object table (and must be explicitly imported using 'loImport'), along
-- with the number of bytes
getModelSizeAndContents :: Oid -> RemoteM (Integer, ByteString)
getModelSizeAndContents m =
getModelVersionSizeAndContents :: Oid -> RemoteM (Integer, ByteString)
getModelVersionSizeAndContents m =
view #store >>= \conn -> withRunInIO $ \r ->
withTransaction conn . r $ do
size <- getModelSize m
size <- getModelVersionSize m
bs <-
liftIO
. bracket (loOpen conn m ReadMode) (loClose conn)
Expand All @@ -84,8 +75,8 @@ getModelSizeAndContents m =
-- | Get the size of the model contents themselves (byte count of large object).
-- It is better to do this via Postgres rather than using @ByteString.length@
-- on the returned bytes
getModelSize :: Oid -> RemoteM Integer
getModelSize oid =
getModelVersionSize :: Oid -> RemoteM Integer
getModelVersionSize oid =
fmap fromOnly $
firstOrThrow (OtherRemoteError "Could not get model size")
=<< queryStore q (Only oid)
Expand Down
Loading
Loading