More than one model for inference #1599
-
Good morning, Right now, it seem that the default Triton inference stage is creating one worker able to handle inference requests for one model only. We need support many models, defined dynamically based on input. For now on, I've used a broadcast as follows: # Author Laurent DECLERCQ, Konzeptplus ag <[email protected]>
# Version 20240408
import logging
import mrc
import mrc.core.operators as ops
from kp_std.kp_registry import KpRegistry
from morpheus.messages.multi_inference_message import MultiInferenceMessage
from morpheus.config import Config
from morpheus.pipeline.stage import Stage
from morpheus.pipeline.stage_schema import StageSchema
from mrc.core.node import Broadcast
logger = logging.getLogger(__name__)
class KpTritonDispatcherStage(Stage):
def __init__(self, c: Config, model_names: list[str]):
super().__init__(c)
self._model_names = model_names
self._count_models = len(model_names)
self._create_ports(1, self._count_models)
@property
def name(self) -> str:
return "kp-triton-dispatcher"
def supports_cpp_node(self):
return False
def compute_schema(self, schema: StageSchema):
# The output schema should have the same number of ports as the number of models.
assert len(schema.output_schemas) == self._count_models, f"Expected {self._count_models} output schemas"
for port_schema in schema.output_schemas:
# Set the type of the output port to MultiInferenceMessage.
port_schema.set_type(MultiInferenceMessage)
def get_model_output_port(self, model_name: str) -> int:
# Retrieve port number for the given model name by looking up the index in the list of model names.
return self.output_ports[self._model_names.index(model_name)]
@staticmethod
def _filter_message(x: MultiInferenceMessage, node_model: str) -> bool:
# noinspection PyUnresolvedReferences
message_model = KpRegistry.get(x.meta.payload['name']).get('model_name')
if message_model == node_model:
return True
logging.warning(f"Message is expecting node with model: {message_model}. Node model is {node_model}. Skipping...")
return False
def _build(self, builder: mrc.Builder, input_nodes: list[mrc.SegmentObject]) -> list[mrc.SegmentObject]:
assert len(input_nodes) == 1, "Only 1 input supported"
# Create a broadcast node.
broadcast = Broadcast(builder, "broadcast")
# Connect the input node to the broadcast node.
builder.make_edge(input_nodes[0], broadcast)
nodes = []
# For each model, create a node that filters messages for that model.
for model_name in self._model_names:
# Create a node to handle messages for the current model.
node = builder.make_node(
f"triton_inference_{model_name}",
# Filter out messages that are not for the current model
# (messages are broadcast to all nodes. Wo we need to filter them).
ops.filter(lambda x, node_model=model_name: KpTritonDispatcherStage._filter_message(x, node_model))
)
# Connect the broadcast node to the model-specific node.
builder.make_edge(broadcast, node)
# Add the node to the list of nodes.
nodes.append(node)
# Finally, return the list of nodes.
return nodes Basically put, for each possible model, I add a node and the underlying inference stage. Message are filtered out with reactive operator. This is working well as long as I don't use the shared memory ... Any clue ? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
I'm wondering if for such scenario, delegating inference requests to MLflow wouldn't be the solution. Basically put, Morpheus would preprocess messages received from a kafka source stage, would delegate the inference requests to MLFlow and finally, Morpheus would postprocess the output. |
Beta Was this translation helpful? Give feedback.
-
Thank for your great answer. That's much appreciated. Finally, I've revisited the pipeline a lot and the problem is gone ... This is not perfect yet but this work. The pipeline in action ====Pipeline Pre-build====
====Pre-Building Segment: main====
====Pre-Building Segment Complete!====
====Pipeline Pre-build Complete!====
====Registering Pipeline====
====Building Pipeline====
====Building Pipeline Complete!====
====Registering Pipeline Complete!====
====Starting Pipeline====
====Pipeline Started====
====Building Segment: main====
Added source: <from-kafka-ssl-0; KafkaSslSourceStage(bootstrap_servers=ai-pf-kafka-server:9092, input_topic=ai-pf-input, group_id=ai-pf, client_id=None, poll_interval=1s, disable_commit=False, auto_offset_reset=AutoOffsetReset.LATEST, stop_after=0, async_commits=True, ssl_config={'ssl.ca.location': '/opt/konzeptplus/resources/ssl/development/client/ca-cert.pem', 'ssl.certificate.location': '/opt/konzeptplus/resources/ssl/development/client/client-cert.pem', 'ssl.key.location': '/opt/konzeptplus/resources/ssl/development/client/client-key.pem', 'ssl.key.password': '', 'enable.ssl.certificate.verification': 'False'})>
└─> konzeptplus.MessageMeta
Added stage: <deserialize-1; DeserializeStage(ensure_sliceable_index=False, message_type=<class 'morpheus.messages.multi_message.MultiMessage'>, task_type=None, task_payload=None)>
└─ konzeptplus.MessageMeta -> morpheus.MultiMessage
Added stage: <num-2; NumStage()>
└─ morpheus.MultiMessage -> morpheus.MultiInferenceMessage
Added stage: <inference-5; TritonInferenceStage(model_name=socket_anomaly_model.tf, server_url=ai-pf-triton-server:8001, force_convert_inputs=False, use_shared_memory=True, needs_logits=None, inout_mapping=None)>
└─ morpheus.MultiInferenceMessage -> morpheus.MultiResponseMessage
Added stage: <inference-6; TritonInferenceStage(model_name=python_packages_anomaly_model.tf, server_url=ai-pf-triton-server:8001, force_convert_inputs=False, use_shared_memory=True, needs_logits=None, inout_mapping=None)>
└─ morpheus.MultiInferenceMessage -> morpheus.MultiResponseMessage
Added stage: <inference-7; TritonInferenceStage(model_name=startup_items_anomaly_model.tf, server_url=ai-pf-triton-server:8001, force_convert_inputs=False, use_shared_memory=True, needs_logits=None, inout_mapping=None)>
└─ morpheus.MultiInferenceMessage -> morpheus.MultiResponseMessage
Added stage: <inference-8; TritonInferenceStage(model_name=users_anomaly_model.tf, server_url=ai-pf-triton-server:8001, force_convert_inputs=False, use_shared_memory=True, needs_logits=None, inout_mapping=None)>
└─ morpheus.MultiInferenceMessage -> morpheus.MultiResponseMessage
Added stage: <serialize-4; SerializeStage()>
└─ morpheus.MultiResponseMessage -> konzeptplus.MessageMeta
Added stage: <classify-9; ClassificationStage(only_mean=False)>
└─ konzeptplus.MessageMeta -> konzeptplus.MessageMeta
Added stage: <to-kafka-ssl-10; WriteToKafkaSslStage(bootstrap_servers=ai-pf-kafka-server:9092, output_topic=ai-pf-output, client_id=None, ssl_config={'ssl.ca.location': '/opt/konzeptplus/resources/ssl/development/client/ca-cert.pem', 'ssl.certificate.location': '/opt/konzeptplus/resources/ssl/development/client/client-cert.pem', 'ssl.key.location': '/opt/konzeptplus/resources/ssl/development/client/client-key.pem', 'ssl.key.password': '', 'enable.ssl.certificate.verification': 'False'})>
└─ konzeptplus.MessageMeta -> konzeptplus.MessageMeta
====Building Segment Complete!====
2024-04-09 18:33:23.519098: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6141 MB memory: -> device: 0, name: NVIDIA GeForce RTX 4070, pci bus id: 0000:01:00.0, compute capability: 8.9
WARNING:py.warnings:/opt/conda/envs/morpheus/lib/python3.10/site-packages/cudf/io/json.py:239: UserWarning: Using CPU via Pandas to write JSON dataset
warnings.warn("Using CPU via Pandas to write JSON dataset")
WARNING:py.warnings:/opt/conda/envs/morpheus/lib/python3.10/site-packages/cudf/io/json.py:239: UserWarning: Using CPU via Pandas to write JSON dataset
warnings.warn("Using CPU via Pandas to write JSON dataset")
WARNING:py.warnings:/opt/conda/envs/morpheus/lib/python3.10/site-packages/cudf/io/json.py:239: UserWarning: Using CPU via Pandas to write JSON dataset
warnings.warn("Using CPU via Pandas to write JSON dataset")
WARNING:py.warnings:/opt/conda/envs/morpheus/lib/python3.10/site-packages/cudf/io/json.py:239: UserWarning: Using CPU via Pandas to write JSON dataset
warnings.warn("Using CPU via Pandas to write JSON dataset") Messages sent by the producer, and the results from the consumer (morpheus) nuxwin@morpheus-konzeptplus:~/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/scripts/kp_kafka$ ./producer.py --osquery normality --model packages
osquery mode: normality
model type: packages
osquery snapshot file: /home/nuxwin/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/resources/osquery/snapshots/python_packages_anomaly_model.tf/normality_1.json
Message delivered to ai-pf-input [0]
Message sent to all registered topics
(morpheus) nuxwin@morpheus-konzeptplus:~/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/scripts/kp_kafka$ ./producer.py --osquery normality --model app
osquery mode: normality
model type: app
osquery snapshot file: /home/nuxwin/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/resources/osquery/snapshots/startup_items_anomaly_model.tf/normality_1.json
Message delivered to ai-pf-input [0]
Message sent to all registered topics
(morpheus) nuxwin@morpheus-konzeptplus:~/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/scripts/kp_kafka$ ./producer.py --osquery normality --model user
osquery mode: normality
model type: user
osquery snapshot file: /home/nuxwin/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/resources/osquery/snapshots/users_anomaly_model.tf/normality_1.json
Message delivered to ai-pf-input [0]
Message sent to all registered topics
(morpheus) nuxwin@morpheus-konzeptplus:~/projects/git/konzeptplus/nvidia/ai-pf-m-ml-001/scripts/kp_kafka$ ./consumer.py
{
"mean": 17.6371974521,
"name": "pack_pf-cybersec-edr-poc_listening_ports_snapshot",
"calendar_time": "Wed Jan 31 23:47:00 2024 UTC",
"unix_time": 1706744820,
"host_identifier": "soc-deb-srv01"
}
{
"mean": 507.1062242296,
"name": "pack_pf-cybersec-edr-poc_python_packages_snapshot",
"calendar_time": "Thu Feb 1 10:05:50 2024 UTC",
"unix_time": 1706781950,
"host_identifier": "soc-deb-srv01"
}
{
"mean": 163.1230002125,
"name": "pack_pf-cybersec-edr-poc_startup_items_snapshot",
"calendar_time": "Wed Jan 31 22:23:04 2024 UTC",
"unix_time": 1706739784,
"host_identifier": "soc-deb-srv01"
}
{
"mean": 26.1308575657,
"name": "pack_pf-cybersec-edr-poc_users_snapshot",
"calendar_time": "Wed Jan 31 15:35:55 2024 UTC",
"unix_time": 1706715355,
"host_identifier": "soc-deb-srv01"
} The dispatcher (broadcast) # Author: Laurent DECLERCQ, Konzeptplus ag <[email protected]>
# Version: 20240409
import logging
import mrc
import mrc.core.operators as ops
from konzeptplus.std.registry import Registry
from morpheus.messages.multi_inference_message import MultiInferenceMessage
from morpheus.config import Config
from morpheus.pipeline.stage import Stage
from morpheus.pipeline.stage_schema import StageSchema
from mrc.core.node import Broadcast
logger = logging.getLogger(__name__)
class InferenceRequestDispatcherStage(Stage):
"""
Dispatches inference requests to the appropriate model nodes.
This stage receives MultiInferenceMessage objects and forwards them to the appropriate model nodes.
Parameters
----------
c : `morpheus.config.Config`
Pipeline configuration instance.
model_names: list[str]
List of model names to dispatch inference requests to. One output port will be created for each model name.
"""
def __init__(self, c: Config, model_names: list[str]):
super().__init__(c)
self._model_names = model_names
self._count_models = len(model_names)
self._create_ports(1, self._count_models)
@property
def name(self) -> str:
return "inference-request-dispatcher"
def supports_cpp_node(self):
return False
def compute_schema(self, schema: StageSchema):
# The output schema should have the same number of ports as the number of models.
assert len(schema.output_schemas) == self._count_models, f"Expected {self._count_models} output schemas"
for port_schema in schema.output_schemas:
# Set the type of the output port to MultiInferenceMessage.
port_schema.set_type(MultiInferenceMessage)
def get_model_output_port(self, model_name: str) -> int:
"""
Get the output port number for the given model name.
Parameters
----------
model_name: str
The name of the model.
Returns
-------
int
The output port number for the given model name.
"""
# Retrieve port number for the given model name by looking up the index in the list of model names.
return self.output_ports[self._model_names.index(model_name)]
@staticmethod
def _filter_message(x: MultiInferenceMessage, node_model: str) -> bool:
"""
Filter messages based on the model name.
Parameters
----------
x: MultiInferenceMessage
The message to filter.
node_model: str
The model name to filter on.
Returns
-------
bool
True if the message is for the given model, False otherwise.
"""
# noinspection PyUnresolvedReferences
message_model = Registry.get(x.meta.payload['name']).get('model_name')
if message_model == node_model:
return True
logging.info(f"Message is expecting node with model: {message_model}. Node model is {node_model}. Skipping...")
return False
def _build(self, builder: mrc.Builder, input_nodes: list[mrc.SegmentObject]) -> list[mrc.SegmentObject]:
assert len(input_nodes) == 1, "Only 1 input supported"
# Create a broadcast node.
broadcast = Broadcast(builder, "broadcast")
# Connect the input node to the broadcast node.
builder.make_edge(input_nodes[0], broadcast)
nodes = []
# For each model, create a node that filters messages for that model.
for model_name in self._model_names:
# Create a node to handle messages for the current model.
node = builder.make_node(
f"triton_inference_{model_name}",
# Filter out messages that are not for the current model
# (messages are broadcast to all nodes. Wo we need to filter them).
ops.filter(
lambda x, node_model=model_name: InferenceRequestDispatcherStage._filter_message(x, node_model)
)
)
# Connect the broadcast node to the model-specific node.
builder.make_edge(broadcast, node)
# Add the node to the list of nodes.
nodes.append(node)
# Finally, return the list of nodes.
return nodes The pipeline #!/opt/conda/envs/morpheus/bin/python
# Morpheus Python inference pipeline for the AI Post Finance Model project (inference).
#
# Author: Laurent DECLERCQ, Konzeptplus ag <[email protected]>
# Version 20240409
import click
import logging
import os
from konzeptplus.std.registry import Registry
from konzeptplus.stages.broadcast.inference_request_dispatcher_stage import InferenceRequestDispatcherStage
from konzeptplus.stages.input.kafka_ssl_source_stage import KafkaSslSourceStage
from konzeptplus.stages.output.write_to_kafka_ssl_stage import WriteToKafkaSslStage
from konzeptplus.stages.postprocess.classification_stage import ClassificationStage
from konzeptplus.stages.postprocess.serialize_stage import SerializeStage
from konzeptplus.stages.preprocess.num_stage import NumStage
from morpheus.config import Config, CppConfig, PipelineModes
from morpheus.pipeline.pipeline import Pipeline
from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.utils.logger import configure_logging
@click.command()
@click.option(
"--num_threads",
default=4,
type=click.IntRange(min=1, max=os.cpu_count()),
help="Number of internal pipeline threads to use.",
)
@click.option(
"--pipeline_batch_size",
default=1,
type=click.IntRange(min=1),
help=("Internal batch size for the pipeline. Can be much larger than the model batch size. "
"Also used for Kafka consumers."),
)
@click.option(
"--model_max_batch_size",
default=1,
type=click.IntRange(min=1),
help="Max batch size to use for the model."
)
@click.option(
'--bootstrap_servers',
default='ai-pf-kafka-server:9092',
help="Comma-separated list of bootstrap servers."
)
@click.option(
'--input_topic',
type=str,
default='ai-pf-input',
help="Name of the Kafka topic from which messages will be consumed."
)
@click.option(
'--output_topic',
type=str,
default='ai-pf-output',
help="kafka topic"
)
@click.option(
'--ssl_config',
default=[
['ssl.ca.location', '/opt/konzeptplus/resources/ssl/development/client/ca-cert.pem'],
['ssl.certificate.location', '/opt/konzeptplus/resources/ssl/development/client/client-cert.pem'],
['ssl.key.location', '/opt/konzeptplus/resources/ssl/development/client/client-key.pem'],
['ssl.key.password', ''],
['enable.ssl.certificate.verification', False]
],
nargs=2,
multiple=True,
help='Kafka SSL configuration parameters.'
)
@click.option(
'--group_id',
type=str,
default='ai-pf',
help="Kafka input data consumer group identifier."
)
@click.option(
"--model_fea_length",
default=90,
type=click.IntRange(min=1),
help="Features length to use for the model.",
)
@click.option(
"--only_mean",
default=False,
is_flag=True,
help="If set, this flag will only show the mean, else the full model calculations will be shown."
)
@click.option(
"--triton_server_url",
default="ai-pf-triton-server:8001",
required=True,
help="Triton server url.")
def run_pipeline(
num_threads,
pipeline_batch_size,
model_max_batch_size,
model_fea_length,
bootstrap_servers,
input_topic,
output_topic,
ssl_config,
group_id,
only_mean,
triton_server_url
):
# Enable the default logger.
configure_logging(log_level=logging.INFO)
CppConfig.set_should_use_cpp(False)
config = Config()
config.mode = PipelineModes.OTHER
config.num_threads = num_threads
config.pipeline_batch_size = pipeline_batch_size
config.model_max_batch_size = model_max_batch_size
config.feature_length = model_fea_length
# Create the pipeline.
# We don't make use of the linear pipeline here, as we have multiple inference stages (branching).
pipeline = Pipeline(config)
# Make sure ssl config is a dictionary.
ssl_config = dict(ssl_config)
# Define the list of models to use. One inference stage will be created for each model.
models = [
'socket_anomaly_model.tf',
'python_packages_anomaly_model.tf',
'startup_items_anomaly_model.tf',
'users_anomaly_model.tf'
]
# Sockets and Ports model configuration parameters.
Registry.register('pack_pf-cybersec-edr-poc_listening_ports_snapshot', {
'model_name': 'socket_anomaly_model.tf',
'model_path': '/opt/konzeptplus/morpheus/models/triton-model-repo/socket_anomaly_model.tf',
'model_type': 'anomaly_detection',
'features': ['socket', 'path', 'name', 'port']
})
# Python packages model configuration parameters.
Registry.register('pack_pf-cybersec-edr-poc_python_packages_snapshot', {
'model_name': 'python_packages_anomaly_model.tf',
'model_path': '/opt/konzeptplus/morpheus/models/triton-model-repo/python_packages_anomaly_model.tf',
'model_type': 'anomaly_detection',
'features': ['name', 'path', 'version']
})
# Startup items model configuration parameters.
Registry.register('pack_pf-cybersec-edr-poc_startup_items_snapshot', {
'model_name': 'startup_items_anomaly_model.tf',
'model_path': '/opt/konzeptplus/morpheus/models/triton-model-repo/startup_items_anomaly_model.tf',
'model_type': 'anomaly_detection',
'features': ['name', 'path', 'source']
})
# User Accounts model configuration parameters.
Registry.register('pack_pf-cybersec-edr-poc_users_snapshot', {
'model_name': 'users_anomaly_model.tf',
'model_path': '/opt/konzeptplus/morpheus/models/triton-model-repo/users_anomaly_model.tf',
'model_type': 'anomaly_detection',
'features': ['username', 'gid', 'directory']
})
# Add the source stage. We're receiving messages from a Kafka topic.
source = pipeline.add_stage(KafkaSslSourceStage(
config, bootstrap_servers=bootstrap_servers, input_topic=input_topic, group_id=group_id, ssl_config=ssl_config
))
# Add the deserialization stage.
deserialize_stage = pipeline.add_stage(DeserializeStage(config, ensure_sliceable_index=False))
# Connect the source stage to the deserialization stage.
# The deserialize stage will receive the messages from the source stage.
pipeline.add_edge(source, deserialize_stage)
# Add the num stage.
num_stage = pipeline.add_stage(NumStage(config))
# Connect the deserialization stage to the preprocessing stage.
# The preprocessing stage will receive the messages from the deserialize stage.
pipeline.add_edge(deserialize_stage, num_stage)
# Add the inference request dispatcher stage.
# This stage will dispatch the inference requests to the correct inference stage based on the model name.
inference_request_dispatcher_stage = pipeline.add_stage(InferenceRequestDispatcherStage(config, model_names=models))
# Connect the preprocessing stage to the Triton server dispatcher stage.
# The dispatcher stage will receive the messages from the preprocessing stage.
pipeline.add_edge(num_stage, inference_request_dispatcher_stage)
# Add the serialization stage.
serialize_stage = pipeline.add_stage(SerializeStage(config))
# Create the inference stages for each model and connect them to the Triton server dispatcher stage.
# The inference stages will receive the messages from the dispatcher stage.
for model_name in models:
# Add the inference stage for the model.
inf_stage = pipeline.add_stage(TritonInferenceStage(
config, model_name=model_name, server_url=triton_server_url, use_shared_memory=True
))
# Connect the Triton server dispatcher stage to the inference stage.
# The inference stage will receive the messages from the Triton server dispatcher stage.
pipeline.add_edge(inference_request_dispatcher_stage.get_model_output_port(model_name), inf_stage)
# Connect the inference stage to the serialization stage.
# The serialize stage will receive the messages from the inference stage.
pipeline.add_edge(inf_stage, serialize_stage)
# Add the classification stage.
classification_stage = pipeline.add_stage(ClassificationStage(config, only_mean=only_mean))
# Connect the serialize stage to the classification stage.
# The classification stage will receive the messages from the serialization stage.
pipeline.add_edge(serialize_stage, classification_stage)
# Add the output stage. We're sending messages to a Kafka topic.
output_stage = pipeline.add_stage(WriteToKafkaSslStage(
config,
bootstrap_servers=bootstrap_servers,
output_topic=output_topic,
ssl_config=ssl_config
))
# Connect the output_stage stage to the classification stage.
# The output_stage stage will receive the messages from the classification stage.
pipeline.add_edge(classification_stage, output_stage)
# Run the pipeline.
pipeline.run()
if __name__ == "__main__":
# Configure and run the pipeline.
run_pipeline() |
Beta Was this translation helpful? Give feedback.
-
I like your idea regarding control messages. For now on, we are receiving kafka messages (osquery snapshots), and we need preprocess them to make them compatible with the models input. We already infer the target model from the snapshot name. I'll soon add a training branch too, so I'll add a new field in the message (task_type). In the hope that your improvements will help us to get a better proccess. Thank you again for your answer. |
Beta Was this translation helpful? Give feedback.
Glad to hear its now working. I think the Broadcast solution is perfectly acceptable (it may just be a little inefficient). I would check back in after issue #1607 is completed and see if that works for your need.
Do you have any other questions?