Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inferno-ml] Save evaluation info #121

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inferno-ml-server-types/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Revision History for inferno-ml-server-types
*Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH)

## 0.3.0
* Add support for tracking evaluation info

## 0.2.0
* Add `terminated` columns for DB types

Expand Down
3 changes: 2 additions & 1 deletion inferno-ml-server-types/inferno-ml-server-types.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 2.4
name: inferno-ml-server-types
version: 0.2.0
version: 0.3.0
synopsis: Types for Inferno ML server
description: Types for Inferno ML server
homepage: https://github.com/plow-technologies/inferno.git#readme
Expand Down Expand Up @@ -64,4 +64,5 @@ library
, unix
, uri-bytestring
, uri-bytestring-aeson
, uuid
, vector
12 changes: 12 additions & 0 deletions inferno-ml-server-types/src/Inferno/ML/Server/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ where

import Data.Int (Int64)
import Data.Proxy (Proxy (Proxy))
import Data.UUID (UUID)
import Inferno.ML.Server.Types
import Servant ((:<|>) ((:<|>)))
import Servant.Client.Streaming (ClientM, client)
Expand All @@ -23,8 +24,19 @@ statusC :: ClientM (Maybe ())

-- | Run an inference parameter
inferenceC ::
-- | SQL identifier of the inference parameter to be run
Id (InferenceParam uid gid p s) ->
-- | Optional resolution for scripts that use e.g. @valueAt@; defaults to
-- 128 if not specified
Maybe Int64 ->
-- | Job identifer. This is used to save execution statistics for each
-- inference evaluation
UUID ->
-- | Note that every item in the output stream (first element of each
-- outer tuple) should be declared as writable by the corresponding
-- inference parameter. It is the responsibility of the runtime system
-- (not defined in this repository) to verify this before directing
-- the writes to their final destination
ClientM (WriteStream IO)

-- | Cancel the existing inference job, if it exists
Expand Down
48 changes: 48 additions & 0 deletions inferno-ml-server-types/src/Inferno/ML/Server/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text.Encoding
import Data.Time (UTCTime)
import Data.UUID (UUID)
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import Data.Word (Word32, Word64)
Expand Down Expand Up @@ -108,6 +109,7 @@ type InfernoMlServerAPI uid gid p s t =
:<|> "inference"
:> Capture "id" (Id (InferenceParam uid gid p s))
:> QueryParam "res" Int64
:> QueryParam' '[Required] "uuid" UUID
:> StreamPost NewlineFraming JSON (WriteStream IO)
:<|> "inference" :> "cancel" :> Put '[JSON] ()
-- Register the bridge. This is an `inferno-ml-server` endpoint, not a
Expand Down Expand Up @@ -643,6 +645,52 @@ instance
ip ^. the @"user" & toField
]

-- | Information about execution time and resource usage. This is saved by
-- @inferno-ml-server@ after script evaluation completes and can be queried
-- later by using the same job identifier that was provided to the @/inference@
-- route
data EvaluationInfo uid gid p = EvaluationInfo
{ -- | Note that this is the job identifier provided to the inference
-- evaluation route, and is also the primary key of the database table
id :: UUID,
param :: Id (InferenceParam uid gid p VCObjectHash),
-- | When inference evaluation started
start :: UTCTime,
-- | When inference evaluation ended
end :: UTCTime,
-- | The number of bytes allocated between the @start@ and @end@. Note
-- that this is /total/ allocation over the course of evaluation, which
-- can be many times greater than peak memory usage. Nevertheless, this
-- can be useful to track memory usage over time and across different
-- script revisions
allocated :: Word64,
-- | Additional CPU time used between the @start@ and @end@. This is
-- converted from picoseconds to milliseconds
cpu :: Word64
}
deriving stock (Show, Eq, Generic)
deriving anyclass (FromJSON, ToJSON)

instance FromRow (EvaluationInfo uid gid p) where
fromRow =
EvaluationInfo
<$> field
<*> field
<*> field
<*> field
<*> fmap (fromIntegral @Int64) field
<*> fmap (fromIntegral @Int64) field

instance ToRow (EvaluationInfo uid gid p) where
toRow ei =
[ ei ^. the @"id" & toField,
ei ^. the @"param" & toField,
ei ^. the @"start" & toField,
ei ^. the @"end" & toField,
ei ^. the @"allocated" & toField,
ei ^. the @"cpu" & toField
]

-- | A user, parameterized by the user and group types
data User uid gid = User
{ id :: uid,
Expand Down
3 changes: 3 additions & 0 deletions inferno-ml-server/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## 2023.5.22
* Add support for tracking evaluation info

## 2023.4.3
* Add `terminated` column to DB types

Expand Down
4 changes: 3 additions & 1 deletion inferno-ml-server/exe/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Servant.Client.Streaming
)
import System.Exit (die)
import System.Posix.Types (EpochTime)
import System.Random (randomIO)
import Text.Read (readMaybe)
import UnliftIO (throwString)
import UnliftIO.Environment (getArgs)
Expand All @@ -41,6 +42,7 @@ main =
getArgs >>= \case
i : _ -> do
ipid <- maybe (throwString "Invalid ID") (pure . Id) $ readMaybe i
uuid <- randomIO
env <-
mkClientEnv
<$> newManager defaultManagerSettings
Expand All @@ -51,7 +53,7 @@ main =
. registerBridgeC
. flip BridgeInfo 9999
$ toIPv4 (127, 0, 0, 1)
withClientM (inferenceC ipid Nothing) env . either throwIO $
withClientM (inferenceC ipid Nothing uuid) env . either throwIO $
verifyWrites (coerce ipid)
_ -> die "Usage: test-client <inference-parameter-id>"

Expand Down
5 changes: 4 additions & 1 deletion inferno-ml-server/inferno-ml-server.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 2.4
name: inferno-ml-server
version: 2023.4.3
version: 2023.5.22
synopsis: Server for Inferno ML
description: Server for Inferno ML
homepage: https://github.com/plow-technologies/inferno.git#readme
Expand Down Expand Up @@ -85,6 +85,7 @@ library
, text
, time
, unliftio
, uuid
, vector
, wai
, wai-logger
Expand Down Expand Up @@ -132,8 +133,10 @@ executable test-client
, http-client
, inferno-ml-server-types
, iproute
, random
, servant-client
, unliftio
, uuid

executable dummy-bridge
import: common
Expand Down
89 changes: 83 additions & 6 deletions inferno-ml-server/src/Inferno/ML/Server/Inference.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LexicalNegation #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
Expand All @@ -26,12 +27,16 @@ import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import qualified Data.Text as Text
import Data.Time (UTCTime, getCurrentTime)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Traversable (for)
import Data.UUID (UUID)
import qualified Data.Vector as Vector
import Data.Word (Word64)
import Database.PostgreSQL.Simple
( Only (Only),
Query,
SqlError,
)
import Database.PostgreSQL.Simple.SqlQQ (sql)
import Foreign.C (CTime)
Expand Down Expand Up @@ -59,8 +64,11 @@ import Inferno.VersionControl.Types
( VCObject (VCFunction),
)
import Lens.Micro.Platform
import System.CPUTime (getCPUTime)
import System.FilePath (dropExtensions, (<.>))
import System.Mem (getAllocationCounter, setAllocationCounter)
import System.Posix.Types (EpochTime)
import UnliftIO (withRunInIO)
import UnliftIO.Async (wait, withAsync)
import UnliftIO.Directory
( createFileLink,
Expand All @@ -73,7 +81,12 @@ import UnliftIO.Directory
removePathForcibly,
withCurrentDirectory,
)
import UnliftIO.Exception (bracket_, catchIO, displayException)
import UnliftIO.Exception
( bracket_,
catch,
catchIO,
displayException,
)
import UnliftIO.IO.File (writeBinaryFileDurableAtomic)
import UnliftIO.IORef (readIORef)
import UnliftIO.MVar (putMVar, takeMVar, withMVar)
Expand All @@ -86,10 +99,9 @@ import UnliftIO.Timeout (timeout)
runInferenceParam ::
Id InferenceParam ->
Maybe Int64 ->
UUID ->
RemoteM (WriteStream IO)
-- FIXME / TODO Deal with default resolution, probably shouldn't need to be
-- passed on all requests
runInferenceParam ipid (fromMaybe 128 -> res) =
runInferenceParam ipid (fromMaybe 128 -> res) uuid =
withTimeoutMillis $ \t -> do
logTrace $ RunningInference ipid t
maybe (throwM (ScriptTimeout t)) pure
Expand All @@ -107,8 +119,13 @@ runInferenceParam ipid (fromMaybe 128 -> res) =
$ wait a

-- Actually runs the inference evaluation, within the configured timeout
--
-- NOTE: Do not fork anything else inside here; this is already running
-- in an `Async` and we want to be able to get execution statistics from
-- the runtime. Specifically, we are using `getAllocationCounter`, but
-- this only captures the allocations _in this thread only_
runInference :: Int -> RemoteM (Maybe (WriteStream IO))
runInference tmo = timeout tmo $ do
runInference tmo = timeout tmo . withEvaluationInfo $ do
view #interpreter >>= readIORef >>= \case
Nothing -> throwM BridgeNotRegistered
Just interpreter -> do
Expand Down Expand Up @@ -235,7 +252,67 @@ runInferenceParam ipid (fromMaybe 128 -> res) =
withTimeoutMillis :: (Int -> RemoteM b) -> RemoteM b
withTimeoutMillis =
(view (#config . #timeout) >>=)
. (. (* 1000000) . fromIntegral)
. (. (* 1_000_000) . fromIntegral)

withEvaluationInfo :: RemoteM a -> RemoteM a
withEvaluationInfo f = withRunInIO $ \r -> do
-- So allocation counter doesn't go below the lower limit, which is
-- unlikely but should be accounted for at any rate
setAllocationCounter maxBound
start <- getCurrentTime
bytes0 <- getAllocationCounter
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative to this would be to get one of the memory-related fields from GHC.Stats.RTSStats and compare it. But I noticed this was very unreliable unless a GC was triggered before and after -- sometimes the difference between beginning/ending max memory in use was negative. Of course we could trigger a GC before and after eval, but that can cause performance to degrade. So I think getting allocated bytes is sufficient

cpu0 <- getCPUTime
ws <- r f
end <- getCurrentTime
bytes1 <- getAllocationCounter
cpu1 <- getCPUTime

ws <$ r (saveEvaluationInfo (end, start) (bytes1, bytes0) (cpu1, cpu0))
where
saveEvaluationInfo ::
-- End and start times
(UTCTime, UTCTime) ->
-- Ending and beginning byte allocation
(Int64, Int64) ->
-- Ending and beginning CPU time
(Integer, Integer) ->
RemoteM ()
saveEvaluationInfo (end, start) (bytes1, bytes0) (cpu1, cpu0) =
insert `catch` logAndIgnore
where
insert :: RemoteM ()
insert =
executeStore q $
EvaluationInfo uuid ipid start end allocated cpuMillis
where
q :: Query
q = [sql| INSERT INTO evalinfo VALUES (?, ?, ?, ?, ?, ?) |]

-- Note that the allocation counter counts *down*, so we need to
-- subtract the second value from the first value
allocated :: Word64
allocated =
fromIntegral
-- In the unlikely event that more memory was freed in
-- this thread between the beginning of evaluation and
-- the end, so we don't end up with `maxBound @Word64`
. max 0
$ bytes0 - bytes1

-- Convert the picoseconds of CPU time to milliseconds
cpuMillis :: Word64
cpuMillis = fromIntegral $ (cpu1 - cpu0) `div` 1_000_000_000

-- We don't want a DB error to completely break inference
-- evaluation. Inability to store the eval info is more of
-- an inconvenience than a fatal error
logAndIgnore :: SqlError -> RemoteM ()
logAndIgnore =
logTrace
. OtherWarn
. ("Failed to save eval info: " <>)
. Text.pack
. displayException

getVcObject :: VCObjectHash -> RemoteM (VCMeta VCObject)
getVcObject vch =
Expand Down
16 changes: 15 additions & 1 deletion inferno-ml-server/src/Inferno/ML/Server/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Read as Text.Read
import Data.Time (UTCTime)
import Data.UUID (UUID)
import Data.Vector (Vector)
import Data.Word (Word64)
import Data.Yaml (decodeFileThrow)
Expand All @@ -68,7 +69,8 @@ import GHC.Generics (Generic)
import Inferno.Core (Interpreter)
import Inferno.ML.Server.Module.Types as M
import "inferno-ml-server-types" Inferno.ML.Server.Types as M hiding
( InferenceParam,
( EvaluationInfo,
InferenceParam,
InferenceScript,
InfernoMlServerAPI,
Model,
Expand Down Expand Up @@ -353,6 +355,8 @@ f ?? x = ($ x) <$> f
type InferenceParam =
Types.InferenceParam (EntityId UId) (EntityId GId) PID VCObjectHash

type EvaluationInfo = Types.EvaluationInfo (EntityId UId) (EntityId GId) PID

type Model = Types.Model (EntityId UId) (EntityId GId)

type ModelVersion = Types.ModelVersion (EntityId UId) (EntityId GId) Oid
Expand Down Expand Up @@ -389,6 +393,16 @@ pattern VCMeta ::
pattern VCMeta t a g n d p v o =
Inferno.VersionControl.Types.VCMeta t a g n d p v o

pattern EvaluationInfo ::
UUID ->
Id InferenceParam ->
UTCTime ->
UTCTime ->
Word64 ->
Word64 ->
EvaluationInfo
pattern EvaluationInfo u i s e m c = Types.EvaluationInfo u i s e m c

type InfernoMlServerAPI =
Types.InfernoMlServerAPI (EntityId UId) (EntityId GId) PID VCObjectHash EpochTime

Expand Down
Loading
Loading