From 6b6c331bd9fc54dff681c2fe95ff91da603e3aa3 Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Thu, 24 Aug 2023 16:18:34 +0000 Subject: [PATCH 1/8] Docstrings and annotations for CmfQuery class. This commit adds doc strings and type annotation for CmfQuery class methods. --- cmflib/cmfquery.py | 840 +++++++++++++++++++++++++++------------------ 1 file changed, 501 insertions(+), 339 deletions(-) diff --git a/cmflib/cmfquery.py b/cmflib/cmfquery.py index 93d7e484..9c131ecd 100644 --- a/cmflib/cmfquery.py +++ b/cmflib/cmfquery.py @@ -14,414 +14,577 @@ # limitations under the License. ### +import json +import typing as t +import logging import pandas as pd from ml_metadata.metadata_store import metadata_store from ml_metadata.proto import metadata_store_pb2 as mlpb + from cmflib.mlmd_objects import CONTEXT_LIST -import json -import typing as t +__all__ = ["CmfQuery"] + +logger = logging.getLogger(__name__) + + +# TODO: Rename to CMF Client or just set CmfClient = CmfQuery class CmfQuery(object): - def __init__(self, filepath: str = "mlmd"): + """CMF Query (client) communicates with MLMD database and implements basic search and retrieval functionality. + + Args: + filepath: Path to the MLMD database file. + """ + def __init__(self, filepath: str = "mlmd") -> None: config = mlpb.ConnectionConfig() config.sqlite.filename_uri = filepath self.store = metadata_store.MetadataStore(config) - def _transform_to_dataframe(self, node): - #d = CmfQuery.__get_node_properties(node) - d = {"id": node.id} - d["name"] = getattr(node, "name", "") - for k, v in node.properties.items(): - if v.HasField('string_value'): - d[k] = v.string_value - elif v.HasField('int_value'): - d[k] = v.int_value - else: - d[k] = v.double_value - - for k, v in node.custom_properties.items(): - if v.HasField('string_value'): - d[k] = v.string_value - elif v.HasField('int_value'): - d[k] = v.int_value - else: - d[k] = v.double_value - - df = pd.DataFrame(d, index=[0, ]) + @staticmethod + def _copy(source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Optional[t.Dict] = None) -> t.Dict: + """Create copy of `_copy` and return use, reuse `target` if not None. + + Args: + source: Input dict-like object to create copy. + target: If not None, this will be reused and returned. If None, new dict will be created. + key_mapper: Dictionary containing how to map keys in `source`, e.g., {"key_in": "key_out"} means the + key in `source` named "key_in" should be renamed to "key_out" in output dictionary object. + Returns: + If `target` is not None, it is returned containing data from `source`. Else, new object is returned. + """ + target = target or {} + key_mapper = key_mapper or {} + + for key, value in source.items(): + if value.HasField("string_value"): + value = value.string_value + elif value.HasField("int_value"): + value = value.int_value + else: + value = value.double_value + + target[key_mapper.get(key, key)] = value + + return target + + @staticmethod + def _transform_to_dataframe(node, d: t.Optional[t.Dict] = None) -> pd.DataFrame: + """Transform MLMD entity `node` to pandas data frame. + + Args: + node: MLMD entity to transform. + d: Pre-populated dictionary of KV-pairs to associate with `node` (will become columns in output table). + Returns: + Pandas data frame with one row containing data from `node`. + """ + if d is None: + d = {"id": node.id, "name": getattr(node, "name", "")} + + _ = CmfQuery._copy(node.properties, d) + _ = CmfQuery._copy(node.custom_properties, d) + + return pd.DataFrame( + d, + index=[ + 0, + ], + ) + + @staticmethod + def _as_pandas_df(elements: t.Iterable, transform_fn: t.Callable[[t.Any], pd.DataFrame]) -> pd.DataFrame: + """Convert elements in `elements` to rows in pandas data frame using `transform_fn` function. + + Args: + elements: Collection with items to be converted to tabular representation, each item becomes one row. + transform_fn: A callable object that takes one element in `elements` and returns its tabular representation + (pandas data frame with one row). + Returns: + Pandas data frame containing representation of elements in `elements` with one row being one element. + + TODO: maybe easier to initially transform elements to list of dicts? + """ + df = pd.DataFrame() + for element in elements: + df = pd.concat([df, transform_fn(element)], sort=True, ignore_index=True) return df - def get_pipeline_id(self, pipeline_name: str) -> int: - contexts = self.store.get_contexts_by_type("Parent_Context") - for ctx in contexts: - if ctx.name == pipeline_name: - return ctx.id - return -1 + def _get_pipelines(self, name: t.Optional[str] = None) -> t.List[mlpb.Context]: + """Return list of pipelines with the given name. + + Args: + name: Piepline name or None to return all pipelines. + Returns: + List of objects associated with pipelines. + """ + pipelines: t.List[mlpb.Context] = self.store.get_contexts_by_type("Parent_Context") + if name is not None: + pipelines = [pipeline for pipeline in pipelines if pipeline.name == name] + return pipelines + + def _get_pipeline(self, name: str) -> t.Optional[mlpb.Context]: + """Return a pipeline with the given name or None if one does not exist. + Args: + name: Pipeline name. + Returns: + A pipeline object if found, else None. + """ + pipelines: t.List = self._get_pipelines(name) + if pipelines: + if len(pipelines) >= 2: + logger.debug("Found %d pipelines with '%s' name.", len(pipelines), name) + return pipelines[0] + return None + + def _get_stages(self, pipeline_id: int) -> t.List[mlpb.Context]: + """Return stages for the given pipeline. + + Args: + pipeline_id: Pipeline ID. + Returns: + List of associated pipeline stages. + """ + return self.store.get_children_contexts_by_context(pipeline_id) + + def _get_executions(self, stage_id: int, execution_id: t.Optional[int] = None) -> t.List[mlpb.Execution]: + """Return executions of the given stage. + + Args: + stage_id: Stage identifier. + execution_id: If not None, return only execution with this ID. + Returns: + List of executions matching input parameters. + """ + executions: t.List[mlpb.Execution] = self.store.get_executions_by_context(stage_id) + if execution_id is not None: + executions = [execution for execution in executions if execution.id == execution_id] + return executions + def _get_executions_by_input_artifact_id(self, artifact_id: int) -> t.List[int]: + """Return stage executions that consumed given input artifact. - def get_pipeline_names(self) -> t.List: - names = [] - contexts = self.store.get_contexts_by_type("Parent_Context") - for ctx in contexts: - names.append(ctx.name) - return names + Args: + artifact_id: Identifier of the input artifact. + Returns: + List of stage executions that consumed the given artifact. + """ + execution_ids = set( + event.execution_id + for event in self.store.get_events_by_artifact_ids([artifact_id]) + if event.type == mlpb.Event.INPUT + ) + return list(execution_ids) + + def _get_executions_by_output_artifact_id(self, artifact_id: int) -> t.List[int]: + """Return stage execution that produced given output artifact. + + Args: + artifact_id: Identifier of the output artifact. + Return: + List of stage executions, should probably be size of 1 or zero. + """ + execution_ids: t.List[int] = [ + event.execution_id + for event in self.store.get_events_by_artifact_ids([artifact_id]) + if event.type == mlpb.Event.OUTPUT + ] + if len(execution_ids) >= 2: + logger.warning("%d executions claim artifact (id=%d) as output.", len(execution_ids), artifact_id) + + return list(set(execution_ids)) + + def _get_artifact(self, name: str) -> t.Optional[mlpb.Artifact]: + """Return artifact with the given name or None. + + TODO: Different artifact types may have the same name (see `get_artifacts_by_type`, + `get_artifact_by_type_and_name`). + + Args: + name: Artifact name. + Returns: + Artifact or None (if not found). + """ + name = name.strip() + for artifact in self.store.get_artifacts(): + if artifact.name == name: + return artifact + return None + + def _get_output_artifacts(self, execution_ids: t.List[int]) -> t.List[int]: + """Return output artifacts for the given executions. + + Args: + execution_ids: List of execution identifiers to return output artifacts for. + Returns: + List of output artifact identifiers. + """ + artifact_ids: t.List[int] = [ + event.artifact_id + for event in self.store.get_events_by_execution_ids(set(execution_ids)) + if event.type == mlpb.Event.OUTPUT + ] + unique_artifact_ids = set(artifact_ids) + if len(unique_artifact_ids) != len(artifact_ids): + logger.warning("Multiple executions claim the same output artifacts") + + return list(unique_artifact_ids) + + def _get_input_artifacts(self, execution_ids: t.List[int]) -> t.List[int]: + """Return input artifacts for the given executions. + + Args: + execution_ids: List of execution identifiers to return input artifacts for. + Returns: + List of input artifact identifiers. + """ + artifact_ids = set( + event.artifact_id + for event in self.store.get_events_by_execution_ids(set(execution_ids)) + if event.type == mlpb.Event.INPUT + ) + return list(artifact_ids) - def get_pipeline_stages(self, pipeline_name: str) -> []: + def get_pipeline_names(self) -> t.List[str]: + """Return names of all pipelines in the MLMD database.""" + return [ctx.name for ctx in self._get_pipelines()] + + def get_pipeline_id(self, pipeline_name: str) -> int: + """Return pipeline identifier for the pipeline names `pipeline_name`. + Args: + pipeline_name: Name of the pipeline. + Returns: + Pipeline identifier or -1 if one does not exist. + """ + pipeline: t.Optional[mlpb.Context] = self._get_pipeline(pipeline_name) + return -1 if not pipeline else pipeline.id + + def get_pipeline_stages(self, pipeline_name: str) -> t.List[str]: + """Return list of pipeline stages for the pipeline with the given name. + + TODO: Can there be multiple pipelines with the same name? + """ stages = [] - contexts = self.store.get_contexts_by_type("Parent_Context") - for ctx in contexts: - if ctx.name == pipeline_name: - child_contexts = self.store.get_children_contexts_by_context(ctx.id) - for cc in child_contexts: - stages.append(cc.name) + for pipeline in self._get_pipelines(pipeline_name): + stages.extend(stage.name for stage in self._get_stages(pipeline.id)) return stages - def get_all_exe_in_stage(self, stage_name: str) -> []: - df = pd.DataFrame() - contexts = self.store.get_contexts_by_type("Parent_Context") - executions = None - for ctx in contexts: - child_contexts = self.store.get_children_contexts_by_context(ctx.id) - for cc in child_contexts: - if cc.name == stage_name: - executions = self.store.get_executions_by_context(cc.id) - return executions + def get_all_exe_in_stage(self, stage_name: str) -> t.List[mlpb.Execution]: + """Return list of all executions for the stage with the given name. + TODO: Can stages from different pipelines have the same name?. Currently, the first matching stage is used to + identify its executions. Also see "get_all_executions_in_stage". + """ + for pipeline in self._get_pipelines(): + for stage in self._get_stages(pipeline.id): + if stage.name == stage_name: + return self.store.get_executions_by_context(stage.id) + return [] def get_all_executions_in_stage(self, stage_name: str) -> pd.DataFrame: - df = pd.DataFrame() - contexts = self.store.get_contexts_by_type("Parent_Context") - for ctx in contexts: - child_contexts = self.store.get_children_contexts_by_context(ctx.id) - for cc in child_contexts: - if cc.name == stage_name: - executions = self.store.get_executions_by_context(cc.id) - for exe in executions: - d1 = self._transform_to_dataframe(exe) - # df = df.append(d1, sort=True, ignore_index=True) - df = pd.concat([df, d1], sort=True, ignore_index=True) - - return df + """Return executions of the given stage as pandas data frame. - def get_artifact_df(self, node): - d = {"id": node.id, "type": self.store.get_artifact_types_by_id([node.type_id])[0].name, "uri": node.uri, - "name": node.name, "create_time_since_epoch": node.create_time_since_epoch, - "last_update_time_since_epoch": node.last_update_time_since_epoch} - for k, v in node.properties.items(): - if v.HasField('string_value'): - d[k] = v.string_value - elif v.HasField('int_value'): - d[k] = v.int_value - else: - d[k] = v.double_value - for k, v in node.custom_properties.items(): - if v.HasField('string_value'): - d[k] = v.string_value - elif v.HasField('int_value'): - d[k] = v.int_value - else: - d[k] = v.double_value - df = pd.DataFrame(d, index=[0, ]) + TODO: Multiple stages with the same name? This method collects executions from all such stages. Also, see + "get_all_exe_in_stage" + """ + df = pd.DataFrame() + for pipeline in self._get_pipelines(): + for stage in self._get_stages(pipeline.id): + if stage.name == stage_name: + for execution in self._get_executions(stage.id): + df = pd.concat([df, self._transform_to_dataframe(execution)], sort=True, ignore_index=True) return df - def get_all_artifacts(self) -> t.List: - artifact_list = [] - artifacts = self.store.get_artifacts() - for art in artifacts: - name = art.name - artifact_list.append(name) - return artifact_list - - def get_artifact(self, name: str): - artifact = None - artifacts = self.store.get_artifacts() - for art in artifacts: - if art.name == name: - artifact = art - break - return self.get_artifact_df(artifact) - - def get_all_artifacts_for_execution(self, execution_id: int) -> pd.DataFrame: # change here + def get_artifact_df(self, artifact: mlpb.Artifact, d: t.Optional[t.Dict] = None) -> pd.DataFrame: + """Return artifact's data frame representation. + + Args: + artifact: MLMD entity representing artifact. + d: Optional initial content for data frame. + Returns: + A data frame with the single row containing attributes of this artifact. + """ + d = d or {} + d.update( + { + "id": artifact.id, + "type": self.store.get_artifact_types_by_id([artifact.type_id])[0].name, + "uri": artifact.uri, + "name": artifact.name, + "create_time_since_epoch": artifact.create_time_since_epoch, + "last_update_time_since_epoch": artifact.last_update_time_since_epoch, + } + ) + return self._transform_to_dataframe(artifact, d) + + def get_all_artifacts(self) -> t.List[str]: + """Return names of all artifacts. + + TODO: Can multiple artifacts have the same name? + """ + return [artifact.name for artifact in self.store.get_artifacts()] + + get_artifact_names = get_all_artifacts + + def get_artifact(self, name: str) -> t.Optional[pd.DataFrame]: + """Return artifact's data frame representation using artifact name. + + Args: + name: artifact name. + Returns: + Pandas data frame with one row. + """ + artifact: t.Optional[mlpb.Artifact] = self._get_artifact(name) + if artifact: + return self.get_artifact_df(artifact) + return None + + def get_all_artifacts_for_execution(self, execution_id: int) -> pd.DataFrame: + """Return input and output artifacts for the given execution. + + Args: + execution_id: Execution identifier. + Return: + Data frame containing input and output artifacts for the given execution, one artifact per row. + """ df = pd.DataFrame() - input_artifacts = [] - output_artifacts = [] - events = self.store.get_events_by_execution_ids([execution_id]) - for event in events: - if event.type == mlpb.Event.Type.INPUT: # 3 - INPUT #4 - Output - input_artifacts.extend(self.store.get_artifacts_by_id([event.artifact_id])) - else: - output_artifacts.extend(self.store.get_artifacts_by_id([event.artifact_id])) - for art in input_artifacts: - d1 = self.get_artifact_df(art) - d1["event"] = "INPUT" - #df = df.append(d1, sort=True, ignore_index=True) - df = pd.concat([df, d1], sort=True, ignore_index=True) - for art in output_artifacts: - d1 = self.get_artifact_df(art) - d1["event"] = "OUTPUT" - #df = df.append(d1, sort=True, ignore_index=True) - df = pd.concat([df, d1], sort=True, ignore_index=True) + for event in self.store.get_events_by_execution_ids([execution_id]): + event_type = "INPUT" if event.type == mlpb.Event.Type.INPUT else "OUTPUT" + for artifact in self.store.get_artifacts_by_id([event.artifact_id]): + df = pd.concat( + [df, self.get_artifact_df(artifact, {"event": event_type})], sort=True, ignore_index=True + ) return df def get_all_executions_for_artifact(self, artifact_name: str) -> pd.DataFrame: - selected_artifact = None - linked_execution = {} - events = [] - df = pd.DataFrame() - artifacts = self.store.get_artifacts() - for art in artifacts: - if art.name == artifact_name: - selected_artifact = art - break - if selected_artifact is not None: - events = self.store.get_events_by_artifact_ids([selected_artifact.id]) - - for evt in events: - linked_execution["Type"] = "INPUT" if evt.type == mlpb.Event.Type.INPUT else "OUTPUT" - linked_execution["execution_id"] = evt.execution_id - linked_execution["execution_name"] = self.store.get_executions_by_id([evt.execution_id])[0].name - ctx = self.store.get_contexts_by_execution(evt.execution_id)[0] - linked_execution["stage"] = self.store.get_contexts_by_execution(evt.execution_id)[0].name - - linked_execution["pipeline"] = self.store.get_parent_contexts_by_context(ctx.id)[0].name - d1 = pd.DataFrame(linked_execution, index=[0, ]) - #df = df.append(d1, sort=True, ignore_index=True) - df = pd.concat([df, d1], sort=True, ignore_index=True) + """Return executions that consumed and produced given artifact. - return df - - def get_one_hop_child_artifacts(self, artifact_name: str) -> pd.DataFrame: + Args: + artifact_name: Artifact name. + Returns: + Pandas data frame containing stage executions, one execution per row. + """ df = pd.DataFrame() - artifact = None - artifacts = self.store.get_artifacts() - for art in artifacts: - if artifact_name.strip() == art.name: - artifact = art - break - # Get a list of artifacts within a 1-hop of the artifacts of interest - artifact_ids = [artifact.id] - executions_ids = set( - event.execution_id - for event in self.store.get_events_by_artifact_ids(artifact_ids) - if event.type == mlpb.Event.INPUT) - artifacts_ids = set( - event.artifact_id - for event in self.store.get_events_by_execution_ids(executions_ids) - if event.type == mlpb.Event.OUTPUT) - artifacts = self.store.get_artifacts_by_id(artifacts_ids) - for art in artifacts: - d1 = self.get_artifact_df(art) - #df = df.append(d1, sort=True, ignore_index=True) + artifact: t.Optional = self._get_artifact(artifact_name) + if not artifact: + return df + + for event in self.store.get_events_by_artifact_ids([artifact.id]): + ctx = self.store.get_contexts_by_execution(event.execution_id)[0] + linked_execution = { + "Type": "INPUT" if event.type == mlpb.Event.Type.INPUT else "OUTPUT", + "execution_id": event.execution_id, + "execution_name": self.store.get_executions_by_id([event.execution_id])[0].name, + "stage": self.store.get_contexts_by_execution(event.execution_id)[0].name, + "pipeline": self.store.get_parent_contexts_by_context(ctx.id)[0].name + } + d1 = pd.DataFrame( + linked_execution, + index=[ + 0, + ], + ) df = pd.concat([df, d1], sort=True, ignore_index=True) return df + def get_one_hop_child_artifacts(self, artifact_name: str) -> pd.DataFrame: + """Get artifacts produced by executions that consume given artifact. + + Args: + artifact name: Name of an artifact. + Return: + Output artifacts of all executions that consumed given artifact. + """ + artifact: t.Optional = self._get_artifact(artifact_name) + if not artifact: + return pd.DataFrame() + + # Get output artifacts of executions consumed the above artifact. + artifacts_ids = self._get_output_artifacts( + self._get_executions_by_input_artifact_id(artifact.id) + ) + + return self._as_pandas_df( + self.store.get_artifacts_by_id(artifacts_ids), + lambda _artifact: self.get_artifact_df(_artifact) + ) + def get_all_child_artifacts(self, artifact_name: str) -> pd.DataFrame: + """Return all downstream artifacts starting from the given artifact. + + TODO: Only output artifacts or all? + """ df = pd.DataFrame() d1 = self.get_one_hop_child_artifacts(artifact_name) - #df = df.append(d1, sort=True, ignore_index=True) + # df = df.append(d1, sort=True, ignore_index=True) df = pd.concat([df, d1], sort=True, ignore_index=True) for row in d1.itertuples(): d1 = self.get_all_child_artifacts(row.name) - #df = df.append(d1, sort=True, ignore_index=True) + # df = df.append(d1, sort=True, ignore_index=True) df = pd.concat([df, d1], sort=True, ignore_index=True) - df = df.drop_duplicates(subset=None, keep='first', inplace=False) + df = df.drop_duplicates(subset=None, keep="first", inplace=False) return df def get_one_hop_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: - df = pd.DataFrame() + """Return input artifacts for the execution that produced the given artifact.""" + artifact: t.Optional = self._get_artifact(artifact_name) + if not artifact: + return pd.DataFrame() - artifact = None - artifacts = self.store.get_artifacts() - for art in artifacts: - if artifact_name in art.name: - artifact = art - break - # Get a list of artifacts within a 1-hop of the artifacts of interest - artifact_ids = [artifact.id] - executions_ids = set( - event.execution_id - for event in self.store.get_events_by_artifact_ids(artifact_ids) - if event.type == mlpb.Event.OUTPUT) - artifacts_ids = set( - event.artifact_id - for event in self.store.get_events_by_execution_ids(executions_ids) - if event.type == mlpb.Event.INPUT) - artifacts = self.store.get_artifacts_by_id(artifacts_ids) - for art in artifacts: - d1 = self.get_artifact_df(art) - #df = df.append(d1, sort=True, ignore_index=True) - df = pd.concat([df, d1], sort=True, ignore_index=True) - return df + artifact_ids = self._get_input_artifacts( + self._get_executions_by_output_artifact_id(artifact.id) + ) + + return self._as_pandas_df( + self.store.get_artifacts_by_id(artifact_ids), + lambda _artifact: self.get_artifact_df(_artifact) + ) def get_all_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: + """Return all upstream artifacts. + + TODO: All input and output artifacts? + """ df = pd.DataFrame() d1 = self.get_one_hop_parent_artifacts(artifact_name) - #df = df.append(d1, sort=True, ignore_index=True) + # df = df.append(d1, sort=True, ignore_index=True) df = pd.concat([df, d1], sort=True, ignore_index=True) for row in d1.itertuples(): d1 = self.get_all_parent_artifacts(row.name) - #df = df.append(d1, sort=True, ignore_index=True) + # df = df.append(d1, sort=True, ignore_index=True) df = pd.concat([df, d1], sort=True, ignore_index=True) - df = df.drop_duplicates(subset=None, keep='first', inplace=False) + df = df.drop_duplicates(subset=None, keep="first", inplace=False) return df - def get_all_parent_executions(self, artifact_name:str)-> pd.DataFrame: - df = self.get_all_parent_artifacts(artifact_name) - artifact_ids = df.id.values.tolist() + def get_all_parent_executions(self, artifact_name: str) -> pd.DataFrame: + """Return all executions that produced upstream artifacts for the given artifact.""" + parent_artifacts = self.get_all_parent_artifacts(artifact_name) - executions_ids = set( + execution_ids = set( event.execution_id - for event in self.store.get_events_by_artifact_ids(artifact_ids) - if event.type == mlpb.Event.OUTPUT) - executions = self.store.get_executions_by_id(executions_ids) - - df = pd.DataFrame() - for exe in executions: - d1 = self._transform_to_dataframe(exe) - # df = df.append(d1, sort=True, ignore_index=True) - df = pd.concat([df, d1], sort=True, ignore_index=True) - return df - - def find_producer_execution(self, artifact_name: str) -> object: - artifact = None - artifacts = self.store.get_artifacts() - for art in artifacts: - if art.name == artifact_name: - artifact = art - break + for event in self.store.get_events_by_artifact_ids(parent_artifacts.id.values.tolist()) + if event.type == mlpb.Event.OUTPUT + ) + + return self._as_pandas_df( + self.store.get_executions_by_id(execution_ids), + lambda _execution: self._transform_to_dataframe(_execution) + ) + + def find_producer_execution(self, artifact_name: str) -> t.Optional[object]: + """Return execution that produced the given artifact. + + TODO: how come one artifact can have multiple producer executions? + """ + artifact: t.Optional[mlpb.Artifact] = self._get_artifact(artifact_name) + if not artifact: + logger.debug("Artifact does not exist (name=%s).", artifact_name) + return None executions_ids = set( event.execution_id for event in self.store.get_events_by_artifact_ids([artifact.id]) - if event.type == mlpb.Event.OUTPUT) - return self.store.get_executions_by_id(executions_ids)[0] - - def get_metrics(self, metrics_name:str) ->pd.DataFrame: - metric = None - metrics = self.store.get_artifacts_by_type("Step_Metrics") - for m in metrics: - if m.name == metrics_name: - metric = m - break - if metric is None: - print("Error : The given metrics does not exist") + if event.type == mlpb.Event.OUTPUT + ) + if not executions_ids: + logger.debug( + "No producer execution exists for artifact (name=%s, id=%s).", artifact.name, artifact.id + ) return None - name = "" - for k, v in metric.custom_properties.items(): - if k == "Name": - name = v - break - df = pd.read_parquet(name) - return df - - def read_dataslice(self, name: str) -> pd.DataFrame: - """Reads the dataslice""" - # To do checkout if not there - df = pd.read_parquet(name) - return df - + executions: t.List[mlpb.Execution] = self.store.get_executions_by_id(executions_ids) + if not executions: + logger.debug("No executions exist for given IDs (ids=%s)", str(executions_ids)) + return None + if len(executions) >= 2: + logger.debug( + "Multiple executions (ids=%s) claim artifact (name=%s) as output.", + [e.id for e in executions], artifact.name + ) - @staticmethod - def __get_node_properties(node) -> dict: - # print(node) - node_dict = {} - for attr in dir(node): - if attr in CONTEXT_LIST: - if attr == "properties": - node_dict["properties"] = CmfQuery.__get_properties(node) - elif attr == "custom_properties": - node_dict["custom_properties"] = CmfQuery.__get_customproperties( - node - ) - else: - node_dict[attr] = getattr(node, attr, "") + return executions[0] - return node_dict + get_producer_execution = find_producer_execution - @staticmethod - def __get_properties(node) -> dict: - prop_dict = {} - for k, v in node.properties.items(): - if v.HasField("string_value"): - prop_dict[k] = v.string_value - elif v.HasField("int_value"): - prop_dict[k] = v.int_value - else: - prop_dict[k] = v.double_value - return prop_dict + def get_metrics(self, metrics_name: str) -> t.Optional[pd.DataFrame]: + """Return metric data frame. + + TODO: need better description. + """ + for metric in self.store.get_artifacts_by_type("Step_Metrics"): + if metric.name == metrics_name: + name: t.Optional[str] = metric.custom_properties.get("Name", None) + if name: + return pd.read_parquet(name) + break + return None @staticmethod - def __get_customproperties(node) -> dict: - prop_dict = {} - for k, v in node.custom_properties.items(): - if k == "type": - k = "user_type" - if v.HasField("string_value"): - prop_dict[k] = v.string_value - elif v.HasField("int_value"): - prop_dict[k] = v.int_value - else: - prop_dict[k] = v.double_value - return prop_dict - - def dumptojson(self, pipeline_name: str, exec_id): - mlmd_json = {} - mlmd_json["Pipeline"] = [] - contexts = self.store.get_contexts_by_type("Parent_Context") - for ctx in contexts: - if ctx.name == pipeline_name: - ctx_dict = CmfQuery.__get_node_properties(ctx) - ctx_dict["stages"] = [] - stages = self.store.get_children_contexts_by_context(ctx.id) - for stage in stages: - stage_dict = CmfQuery.__get_node_properties(stage) - # ctx["stages"].append(stage_dict) - stage_dict["executions"] = [] - executions = self.store.get_executions_by_context(stage.id) - if exec_id is None: - list_executions = [exe for exe in executions] - elif exec_id is not None: - list_executions = [ - exe for exe in executions if exe.id == int(exec_id) - ] - else: - return "Invalid execution id given." - for exe in list_executions: - exe_dict = CmfQuery.__get_node_properties(exe) - exe_type = self.store.get_execution_types_by_id([exe.type_id]) - exe_dict["type"] = exe_type[0].name - exe_dict["events"] = [] - # name will be an empty string for executions that are created with - # create new execution as true(default) - # In other words name property will there only for execution - # that are created with create new execution flag set to false(special case) - exe_dict["name"] = exe.name if exe.name != "" else "" - events = self.store.get_events_by_execution_ids([exe.id]) - for evt in events: - evt_dict = CmfQuery.__get_node_properties(evt) - artifact = self.store.get_artifacts_by_id([evt.artifact_id]) - if artifact is not None: - artifact_type = self.store.get_artifact_types_by_id( - [artifact[0].type_id] - ) - artifact_dict = CmfQuery.__get_node_properties( - artifact[0] - ) - artifact_dict["type"] = artifact_type[0].name - evt_dict["artifact"] = artifact_dict - exe_dict["events"].append(evt_dict) - stage_dict["executions"].append(exe_dict) - ctx_dict["stages"].append(stage_dict) - mlmd_json["Pipeline"].append(ctx_dict) - json_str = json.dumps(mlmd_json) - # json_str = jsonpickle.encode(ctx_dict) - return json_str - - '''def materialize(self, artifact_name:str): + def read_dataslice(name: str) -> pd.DataFrame: + """Reads the data slice. + + TODO: Why is it here? + """ + # To do checkout if not there + df = pd.read_parquet(name) + return df + + def dumptojson(self, pipeline_name: str, exec_id: t.Optional[int] = None) -> t.Optional[str]: + """Return JSON-parsable string containing details about the given pipeline. + + TODO: Think if this method should return dict. + """ + if exec_id is not None: + exec_id = int(exec_id) + + def _get_node_attributes(_node: t.Union[mlpb.Context, mlpb.Execution, mlpb.Event], _attrs: t.Dict) -> t.Dict: + for attr in CONTEXT_LIST: + if getattr(_node, attr, None) is not None: + _attrs[attr] = getattr(_node, attr) + + if "properties" in _attrs: + _attrs["properties"] = CmfQuery._copy(_attrs["properties"]) + if "custom_properties" in _attrs: + _attrs["custom_properties"] = CmfQuery._copy(_attrs["custom_properties"], + key_mapper={"type": "user_type"}) + return _attrs + + pipelines: t.List[t.Dict] = [] + for pipeline in self._get_pipelines(pipeline_name): + pipeline_attrs = _get_node_attributes(pipeline, {"stages": []}) + for stage in self._get_stages(pipeline.id): + stage_attrs = _get_node_attributes(stage, {"executions": []}) + for execution in self._get_executions(stage.id, execution_id=exec_id): + # name will be an empty string for executions that are created with + # create new execution as true(default) + # In other words name property will there only for execution + # that are created with create new execution flag set to false(special case) + exec_attrs = _get_node_attributes( + execution, + { + "type": self.store.get_execution_types_by_id([execution.type_id])[0].name, + "name": execution.name if execution.name != "" else "", + "events": [] + } + ) + for event in self.store.get_events_by_execution_ids([execution.id]): + event_attrs = _get_node_attributes(event, {"artifacts": []}) + for artifact in self.store.get_artifacts_by_id([event.artifact_id]): + artifact_attrs = _get_node_attributes( + artifact, + {"type": self.store.get_artifact_types_by_id([artifact.type_id])[0].name} + ) + event_attrs["artifacts"].append(artifact_attrs) + exec_attrs["events"].append(event_attrs) + stage_attrs["executions"].append(exec_attrs) + pipeline_attrs["stages"].append(stage_attrs) + pipelines.append(pipeline_attrs) + + return json.dumps({"Pipeline": pipelines}) + + """def materialize(self, artifact_name:str): artifacts = self.store.get_artifacts() for art in artifacts: if art.name == artifact_name: @@ -437,5 +600,4 @@ def dumptojson(self, pipeline_name: str, exec_id): elif (remote == "Remote"): remote = v - Cmf.materialize(path, git_repo, rev, remote)''' - + Cmf.materialize(path, git_repo, rev, remote)""" From f5fc1a641895af75500f3adda227b6936813298b Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Sun, 10 Sep 2023 00:27:08 +0000 Subject: [PATCH 2/8] WIP updates. - References to `client` in the source code (e.g., CmfClient instead of CmfQuery) are removed. - Several bugs related to checking input dict parameters are fixed (e.g., `d = d or {}` where it should be `if d is None: d= {}`). Now, the corrent object is returned when input dict is just empty. - Missing doc strings and type annotations are added. - Additional checks and log messages are added. --- cmflib/cmfquery.py | 185 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 139 insertions(+), 46 deletions(-) diff --git a/cmflib/cmfquery.py b/cmflib/cmfquery.py index 9c131ecd..3efe28dd 100644 --- a/cmflib/cmfquery.py +++ b/cmflib/cmfquery.py @@ -15,8 +15,9 @@ ### import json -import typing as t import logging +import typing as t + import pandas as pd from ml_metadata.metadata_store import metadata_store from ml_metadata.proto import metadata_store_pb2 as mlpb @@ -28,15 +29,13 @@ logger = logging.getLogger(__name__) -# TODO: Rename to CMF Client or just set CmfClient = CmfQuery - - class CmfQuery(object): - """CMF Query (client) communicates with MLMD database and implements basic search and retrieval functionality. + """CMF Query communicates with the MLMD database and implements basic search and retrieval functionality. Args: filepath: Path to the MLMD database file. """ + def __init__(self, filepath: str = "mlmd") -> None: config = mlpb.ConnectionConfig() config.sqlite.filename_uri = filepath @@ -44,7 +43,7 @@ def __init__(self, filepath: str = "mlmd") -> None: @staticmethod def _copy(source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Optional[t.Dict] = None) -> t.Dict: - """Create copy of `_copy` and return use, reuse `target` if not None. + """Create copy of `source` and return it, reuse `target` if not None. Args: source: Input dict-like object to create copy. @@ -54,8 +53,10 @@ def _copy(source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Op Returns: If `target` is not None, it is returned containing data from `source`. Else, new object is returned. """ - target = target or {} - key_mapper = key_mapper or {} + if target is None: + target = {} + if key_mapper is None: + key_mapper = {} for key, value in source.items(): if value.HasField("string_value"): @@ -70,7 +71,9 @@ def _copy(source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Op return target @staticmethod - def _transform_to_dataframe(node, d: t.Optional[t.Dict] = None) -> pd.DataFrame: + def _transform_to_dataframe( + node: t.Union[mlpb.Execution, mlpb.Artifact], d: t.Optional[t.Dict] = None + ) -> pd.DataFrame: """Transform MLMD entity `node` to pandas data frame. Args: @@ -78,12 +81,52 @@ def _transform_to_dataframe(node, d: t.Optional[t.Dict] = None) -> pd.DataFrame: d: Pre-populated dictionary of KV-pairs to associate with `node` (will become columns in output table). Returns: Pandas data frame with one row containing data from `node`. + + TODO: (sergey) Overwriting `d` with key/values from `properties` and `custom_properties` is not safe, some + tests fail (`test_get_all_parent_executions`) - it happens to be the case that `id` gets overwritten. For + instance, this happens when artifact is of type Dataset, and custom_properties contain id equal to + `multi_nli`. Later, this `id` can be used to invoke other MLMD APIs and that fails. + + TODO: (sergey) Maybe add prefix for properties and custom_properties keys? """ if d is None: - d = {"id": node.id, "name": getattr(node, "name", "")} - - _ = CmfQuery._copy(node.properties, d) - _ = CmfQuery._copy(node.custom_properties, d) + d = {} + + keys_to_be_updated = set(d.keys()).intersection(node.properties.keys()) + if keys_to_be_updated: + logger.warning( + "Unsafe OP detected for node (type=%s, id=%i): existing node keys (%s) will be updated from " + "node's `properties`.", + type(node), + node.id, + keys_to_be_updated, + ) + if "id" in node.properties: + logger.warning( + "Unsafe OP detected for node (type=%s, id=%i): will update `id` from properties (value=%s)", + type(node), + node.id, + node.properties["id"], + ) + d = CmfQuery._copy(node.properties, d) + + keys_to_be_updated = set(d.keys()).intersection(node.properties.keys()) + if keys_to_be_updated: + logger.warning( + "Unsafe OP detected for node (type=%s, id=%i): existing node keys (%s) will be updated from " + "node's `custom_properties`.", + type(node), + node.id, + keys_to_be_updated, + ) + if "id" in node.custom_properties: + logger.warning( + "Unsafe OP detected for node (type=%s, id=%i): will update `id` from custom_properties (value=%s)", + type(node), + node.id, + node.properties["id"], + ) + d = CmfQuery._copy(node.custom_properties, d) return pd.DataFrame( d, @@ -117,6 +160,10 @@ def _get_pipelines(self, name: t.Optional[str] = None) -> t.List[mlpb.Context]: name: Piepline name or None to return all pipelines. Returns: List of objects associated with pipelines. + + TODO (sergey): Is `Parent_Context` value always used for pipelines? + TODO (sergey): Why `name` parameter when there's another method `_get_pipeline`? + TODO (sergey): Use `self.store.get_context_by_type_and_name` when name presents? """ pipelines: t.List[mlpb.Context] = self.store.get_contexts_by_type("Parent_Context") if name is not None: @@ -129,6 +176,8 @@ def _get_pipeline(self, name: str) -> t.Optional[mlpb.Context]: name: Pipeline name. Returns: A pipeline object if found, else None. + + TODO (sergey): Use `self.store.get_context_by_type_and_name` instead calling self._get_pipelines? """ pipelines: t.List = self._get_pipelines(name) if pipelines: @@ -197,9 +246,9 @@ def _get_executions_by_output_artifact_id(self, artifact_id: int) -> t.List[int] def _get_artifact(self, name: str) -> t.Optional[mlpb.Artifact]: """Return artifact with the given name or None. - TODO: Different artifact types may have the same name (see `get_artifacts_by_type`, + TODO: Different artifact types can have the same name (see `get_artifacts_by_type`, `get_artifact_by_type_and_name`). - + TODO: (sergey) Use `self.store.get_artifacts` with list_options (filter_query)? Args: name: Artifact name. Returns: @@ -218,6 +267,9 @@ def _get_output_artifacts(self, execution_ids: t.List[int]) -> t.List[int]: execution_ids: List of execution identifiers to return output artifacts for. Returns: List of output artifact identifiers. + + TODO: (sergey) The `test_get_one_hop_child_artifacts` prints the warning in this method (Multiple executions + claim the same output artifacts) """ artifact_ids: t.List[int] = [ event.artifact_id @@ -246,7 +298,11 @@ def _get_input_artifacts(self, execution_ids: t.List[int]) -> t.List[int]: return list(artifact_ids) def get_pipeline_names(self) -> t.List[str]: - """Return names of all pipelines in the MLMD database.""" + """Return names of all pipelines. + + Returns: + List of all pipeline names. + """ return [ctx.name for ctx in self._get_pipelines()] def get_pipeline_id(self, pipeline_name: str) -> int: @@ -262,7 +318,13 @@ def get_pipeline_id(self, pipeline_name: str) -> int: def get_pipeline_stages(self, pipeline_name: str) -> t.List[str]: """Return list of pipeline stages for the pipeline with the given name. + Args: + pipeline_name: Name of the pipeline for which stages need to be returned. + Returns: + List of stage names associated with the given pipeline. + TODO: Can there be multiple pipelines with the same name? + TODO: (sergey) not clear from method name that this method returns stage names """ stages = [] for pipeline in self._get_pipelines(pipeline_name): @@ -272,6 +334,11 @@ def get_pipeline_stages(self, pipeline_name: str) -> t.List[str]: def get_all_exe_in_stage(self, stage_name: str) -> t.List[mlpb.Execution]: """Return list of all executions for the stage with the given name. + Args: + stage_name: Name of the stage. + Returns: + List of executions for the given stage. + TODO: Can stages from different pipelines have the same name?. Currently, the first matching stage is used to identify its executions. Also see "get_all_executions_in_stage". """ @@ -284,6 +351,11 @@ def get_all_exe_in_stage(self, stage_name: str) -> t.List[mlpb.Execution]: def get_all_executions_in_stage(self, stage_name: str) -> pd.DataFrame: """Return executions of the given stage as pandas data frame. + Args: + stage_name: Stage name. + Returns: + Data frame with all executions associated with the given stage. + TODO: Multiple stages with the same name? This method collects executions from all such stages. Also, see "get_all_exe_in_stage" """ @@ -292,7 +364,10 @@ def get_all_executions_in_stage(self, stage_name: str) -> pd.DataFrame: for stage in self._get_stages(pipeline.id): if stage.name == stage_name: for execution in self._get_executions(stage.id): - df = pd.concat([df, self._transform_to_dataframe(execution)], sort=True, ignore_index=True) + ex_as_df: pd.DataFrame = self._transform_to_dataframe( + execution, {"id": execution.id, "name": execution.name} + ) + df = pd.concat([df, ex_as_df], sort=True, ignore_index=True) return df def get_artifact_df(self, artifact: mlpb.Artifact, d: t.Optional[t.Dict] = None) -> pd.DataFrame: @@ -303,8 +378,12 @@ def get_artifact_df(self, artifact: mlpb.Artifact, d: t.Optional[t.Dict] = None) d: Optional initial content for data frame. Returns: A data frame with the single row containing attributes of this artifact. + + TODO: (sergey) there are no "public" methods that return `mlpb.Artifact`. + TODO: (sergey) what's the difference between this method and `get_artifact`? """ - d = d or {} + if d is None: + d = {} d.update( { "id": artifact.id, @@ -320,7 +399,11 @@ def get_artifact_df(self, artifact: mlpb.Artifact, d: t.Optional[t.Dict] = None) def get_all_artifacts(self) -> t.List[str]: """Return names of all artifacts. - TODO: Can multiple artifacts have the same name? + Returns: + List of all artifact names. + + TODO: (sergey) Can multiple artifacts have the same name? + TODO: (sergey) Maybe rename to get_artifact_names (to be consistent with `get_pipeline_names`)? """ return [artifact.name for artifact in self.store.get_artifacts()] @@ -330,9 +413,11 @@ def get_artifact(self, name: str) -> t.Optional[pd.DataFrame]: """Return artifact's data frame representation using artifact name. Args: - name: artifact name. + name: Artifact name. Returns: - Pandas data frame with one row. + Pandas data frame with one row containing attributes of this artifact. + + TODO: (sergey) what's the difference between this method and `get_artifact_df`? """ artifact: t.Optional[mlpb.Artifact] = self._get_artifact(name) if artifact: @@ -346,6 +431,8 @@ def get_all_artifacts_for_execution(self, execution_id: int) -> pd.DataFrame: execution_id: Execution identifier. Return: Data frame containing input and output artifacts for the given execution, one artifact per row. + + TODO: (sergey) briefly describe in what cases an execution may not have any artifacts. """ df = pd.DataFrame() for event in self.store.get_events_by_execution_ids([execution_id]): @@ -363,6 +450,9 @@ def get_all_executions_for_artifact(self, artifact_name: str) -> pd.DataFrame: artifact_name: Artifact name. Returns: Pandas data frame containing stage executions, one execution per row. + + TODO: (sergey) build list of dicts and then convert to data frame - will be quicker. + TODO: (sergey) can multiple contexts (pipeline and stage) be associated with one execution? """ df = pd.DataFrame() @@ -371,13 +461,14 @@ def get_all_executions_for_artifact(self, artifact_name: str) -> pd.DataFrame: return df for event in self.store.get_events_by_artifact_ids([artifact.id]): + # TODO: (sergey) seems to be the same as stage below. What's this context for (stage or pipeline)? ctx = self.store.get_contexts_by_execution(event.execution_id)[0] linked_execution = { "Type": "INPUT" if event.type == mlpb.Event.Type.INPUT else "OUTPUT", "execution_id": event.execution_id, "execution_name": self.store.get_executions_by_id([event.execution_id])[0].name, "stage": self.store.get_contexts_by_execution(event.execution_id)[0].name, - "pipeline": self.store.get_parent_contexts_by_context(ctx.id)[0].name + "pipeline": self.store.get_parent_contexts_by_context(ctx.id)[0].name, } d1 = pd.DataFrame( linked_execution, @@ -401,18 +492,20 @@ def get_one_hop_child_artifacts(self, artifact_name: str) -> pd.DataFrame: return pd.DataFrame() # Get output artifacts of executions consumed the above artifact. - artifacts_ids = self._get_output_artifacts( - self._get_executions_by_input_artifact_id(artifact.id) - ) + artifacts_ids = self._get_output_artifacts(self._get_executions_by_input_artifact_id(artifact.id)) return self._as_pandas_df( - self.store.get_artifacts_by_id(artifacts_ids), - lambda _artifact: self.get_artifact_df(_artifact) + self.store.get_artifacts_by_id(artifacts_ids), lambda _artifact: self.get_artifact_df(_artifact) ) def get_all_child_artifacts(self, artifact_name: str) -> pd.DataFrame: """Return all downstream artifacts starting from the given artifact. + Args: + artifact_name: Artifact name. + Returns: + Data frame containing all child artifacts. + TODO: Only output artifacts or all? """ df = pd.DataFrame() @@ -427,18 +520,19 @@ def get_all_child_artifacts(self, artifact_name: str) -> pd.DataFrame: return df def get_one_hop_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: - """Return input artifacts for the execution that produced the given artifact.""" + """Return input artifacts for the execution that produced the given artifact. + + Args: + artifact_name + """ artifact: t.Optional = self._get_artifact(artifact_name) if not artifact: return pd.DataFrame() - artifact_ids = self._get_input_artifacts( - self._get_executions_by_output_artifact_id(artifact.id) - ) + artifact_ids: t.List[int] = self._get_input_artifacts(self._get_executions_by_output_artifact_id(artifact.id)) return self._as_pandas_df( - self.store.get_artifacts_by_id(artifact_ids), - lambda _artifact: self.get_artifact_df(_artifact) + self.store.get_artifacts_by_id(artifact_ids), lambda _artifact: self.get_artifact_df(_artifact) ) def get_all_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: @@ -459,7 +553,7 @@ def get_all_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: def get_all_parent_executions(self, artifact_name: str) -> pd.DataFrame: """Return all executions that produced upstream artifacts for the given artifact.""" - parent_artifacts = self.get_all_parent_artifacts(artifact_name) + parent_artifacts: pd.DataFrame = self.get_all_parent_artifacts(artifact_name) execution_ids = set( event.execution_id @@ -469,10 +563,10 @@ def get_all_parent_executions(self, artifact_name: str) -> pd.DataFrame: return self._as_pandas_df( self.store.get_executions_by_id(execution_ids), - lambda _execution: self._transform_to_dataframe(_execution) + lambda _exec: self._transform_to_dataframe(_exec, {"id": _exec.id, "name": _exec.name}), ) - def find_producer_execution(self, artifact_name: str) -> t.Optional[object]: + def find_producer_execution(self, artifact_name: str) -> t.Optional[mlpb.Execution]: """Return execution that produced the given artifact. TODO: how come one artifact can have multiple producer executions? @@ -488,9 +582,7 @@ def find_producer_execution(self, artifact_name: str) -> t.Optional[object]: if event.type == mlpb.Event.OUTPUT ) if not executions_ids: - logger.debug( - "No producer execution exists for artifact (name=%s, id=%s).", artifact.name, artifact.id - ) + logger.debug("No producer execution exists for artifact (name=%s, id=%s).", artifact.name, artifact.id) return None executions: t.List[mlpb.Execution] = self.store.get_executions_by_id(executions_ids) @@ -501,7 +593,8 @@ def find_producer_execution(self, artifact_name: str) -> t.Optional[object]: if len(executions) >= 2: logger.debug( "Multiple executions (ids=%s) claim artifact (name=%s) as output.", - [e.id for e in executions], artifact.name + [e.id for e in executions], + artifact.name, ) return executions[0] @@ -547,8 +640,9 @@ def _get_node_attributes(_node: t.Union[mlpb.Context, mlpb.Execution, mlpb.Event if "properties" in _attrs: _attrs["properties"] = CmfQuery._copy(_attrs["properties"]) if "custom_properties" in _attrs: - _attrs["custom_properties"] = CmfQuery._copy(_attrs["custom_properties"], - key_mapper={"type": "user_type"}) + _attrs["custom_properties"] = CmfQuery._copy( + _attrs["custom_properties"], key_mapper={"type": "user_type"} + ) return _attrs pipelines: t.List[t.Dict] = [] @@ -566,15 +660,14 @@ def _get_node_attributes(_node: t.Union[mlpb.Context, mlpb.Execution, mlpb.Event { "type": self.store.get_execution_types_by_id([execution.type_id])[0].name, "name": execution.name if execution.name != "" else "", - "events": [] - } + "events": [], + }, ) for event in self.store.get_events_by_execution_ids([execution.id]): event_attrs = _get_node_attributes(event, {"artifacts": []}) for artifact in self.store.get_artifacts_by_id([event.artifact_id]): artifact_attrs = _get_node_attributes( - artifact, - {"type": self.store.get_artifact_types_by_id([artifact.type_id])[0].name} + artifact, {"type": self.store.get_artifact_types_by_id([artifact.type_id])[0].name} ) event_attrs["artifacts"].append(artifact_attrs) exec_attrs["events"].append(event_attrs) From df0e1eaa9feda41d69839dea8d9463e52d6aa92c Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Wed, 20 Sep 2023 22:12:47 +0000 Subject: [PATCH 3/8] Removing questions that have been answered. Fixing one possible bug related to accessing a column in a data frame when this data frame is empty. Adding key mapper classes to help map source to target keys when copying dictionaries. --- cmflib/cmfquery.py | 299 +++++++++++++++++++++++++++------------------ 1 file changed, 182 insertions(+), 117 deletions(-) diff --git a/cmflib/cmfquery.py b/cmflib/cmfquery.py index 3efe28dd..c422903a 100644 --- a/cmflib/cmfquery.py +++ b/cmflib/cmfquery.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. ### - +import abc import json import logging import typing as t +from enum import Enum import pandas as pd from ml_metadata.metadata_store import metadata_store @@ -29,9 +30,88 @@ logger = logging.getLogger(__name__) +class _KeyMapper(abc.ABC): + """Map one key (string) to another key (string) using a predefined strategy. + Args: + on_collision: What to do if the mapped key already exists in the target dictionary. + """ + + class OnCollision(Enum): + """What to do when a mapped key exists in the target dictionary.""" + + DO_NOTHING = 0 + """Ignore the collision and overwrite the value associated with this key.""" + RESOLVE = 1 + """Resolve it by appending `_INDEX` where INDEX as the smallest positive integer that avoids this collision.""" + RAISE_ERROR = 2 + """Raise an exception.""" + + def __init__(self, on_collision: OnCollision = OnCollision.DO_NOTHING) -> None: + self.on_collision = on_collision + + def get(self, d: t.Mapping, key: t.Any) -> t.Any: + """Return new (mapped) key. + Args: + d: Dictionary to update with the mapped key. + key: Source key name. + Returns: + A mapped (target) key to be used with the `d` dictionary. + """ + new_key = self._get(key) + if new_key in d: + if self.on_collision == _KeyMapper.OnCollision.RAISE_ERROR: + raise ValueError(f"Mapped key ({key} -> {new_key}) already exists.") + elif self.on_collision == _KeyMapper.OnCollision.RESOLVE: + _base_key, index = new_key, 0 + while new_key in d: + index += 1 + new_key = f"{_base_key}_{index}" + return new_key + + @abc.abstractmethod + def _get(self, key: t.Any) -> t.Any: + """Mapp a source key to a target key. + Args: + key: Source key. + Returns: + Target key. + """ + raise NotImplementedError() + + +class _DictMapper(_KeyMapper): + """Use dictionaries to specify key mappings (source -> target).""" + + def __init__(self, mappings: t.Mapping, **kwargs) -> None: + super().__init__(**kwargs) + self.mappings = mappings + + def _get(self, key: t.Any) -> t.Any: + return self.mappings.get(key, key) + + +class _PrefixMapper(_KeyMapper): + """Prepend a constant prefix to produce a mapped key.""" + + def __init__(self, prefix: str, **kwargs) -> None: + super().__init__(**kwargs) + self.prefix = prefix + + def _get(self, key: t.Any) -> t.Any: + return self.prefix + key + + class CmfQuery(object): """CMF Query communicates with the MLMD database and implements basic search and retrieval functionality. + This class has been designed to work with the CMF framework. CMF alters names of pipelines, stages and artifacts + in various ways. This means that actual names in the MLMD database will be different from those originally provided + by users via CMF API. When methods in this class accept `name` parameters, it is expected that values of these + parameters are fully-qualified names of respective entities. + + TODO: (sergey) need to provide concrete examples and detailed description on how to actually use methods of this + class correctly, e.g., how to determine these fully-qualified names. + Args: filepath: Path to the MLMD database file. """ @@ -42,21 +122,27 @@ def __init__(self, filepath: str = "mlmd") -> None: self.store = metadata_store.MetadataStore(config) @staticmethod - def _copy(source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Optional[t.Dict] = None) -> t.Dict: + def _copy( + source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Optional[t.Union[t.Dict, _KeyMapper]] = None + ) -> t.Dict: """Create copy of `source` and return it, reuse `target` if not None. Args: source: Input dict-like object to create copy. target: If not None, this will be reused and returned. If None, new dict will be created. key_mapper: Dictionary containing how to map keys in `source`, e.g., {"key_in": "key_out"} means the - key in `source` named "key_in" should be renamed to "key_out" in output dictionary object. + key in `source` named "key_in" should be renamed to "key_out" in output dictionary object, or + instance of _KeyMapper. Returns: If `target` is not None, it is returned containing data from `source`. Else, new object is returned. """ if target is None: target = {} if key_mapper is None: - key_mapper = {} + key_mapper = _DictMapper({}) + elif isinstance(key_mapper, dict): + key_mapper = _DictMapper(key_mapper) + assert isinstance(key_mapper, _KeyMapper), f"Invalid key_mapper type (type={type(key_mapper)})." for key, value in source.items(): if value.HasField("string_value"): @@ -66,7 +152,7 @@ def _copy(source: t.Mapping, target: t.Optional[t.Dict] = None, key_mapper: t.Op else: value = value.double_value - target[key_mapper.get(key, key)] = value + target[key_mapper.get(target, key)] = value return target @@ -81,52 +167,20 @@ def _transform_to_dataframe( d: Pre-populated dictionary of KV-pairs to associate with `node` (will become columns in output table). Returns: Pandas data frame with one row containing data from `node`. - - TODO: (sergey) Overwriting `d` with key/values from `properties` and `custom_properties` is not safe, some - tests fail (`test_get_all_parent_executions`) - it happens to be the case that `id` gets overwritten. For - instance, this happens when artifact is of type Dataset, and custom_properties contain id equal to - `multi_nli`. Later, this `id` can be used to invoke other MLMD APIs and that fails. - - TODO: (sergey) Maybe add prefix for properties and custom_properties keys? """ if d is None: d = {} - keys_to_be_updated = set(d.keys()).intersection(node.properties.keys()) - if keys_to_be_updated: - logger.warning( - "Unsafe OP detected for node (type=%s, id=%i): existing node keys (%s) will be updated from " - "node's `properties`.", - type(node), - node.id, - keys_to_be_updated, - ) - if "id" in node.properties: - logger.warning( - "Unsafe OP detected for node (type=%s, id=%i): will update `id` from properties (value=%s)", - type(node), - node.id, - node.properties["id"], - ) - d = CmfQuery._copy(node.properties, d) - - keys_to_be_updated = set(d.keys()).intersection(node.properties.keys()) - if keys_to_be_updated: - logger.warning( - "Unsafe OP detected for node (type=%s, id=%i): existing node keys (%s) will be updated from " - "node's `custom_properties`.", - type(node), - node.id, - keys_to_be_updated, - ) - if "id" in node.custom_properties: - logger.warning( - "Unsafe OP detected for node (type=%s, id=%i): will update `id` from custom_properties (value=%s)", - type(node), - node.id, - node.properties["id"], - ) - d = CmfQuery._copy(node.custom_properties, d) + d = CmfQuery._copy( + source=node.properties, + target=d, + key_mapper=_PrefixMapper("properties_", on_collision=_KeyMapper.OnCollision.RESOLVE), + ) + d = CmfQuery._copy( + source=node.custom_properties, + target=d, + key_mapper=_PrefixMapper("custom_properties_", on_collision=_KeyMapper.OnCollision.RESOLVE), + ) return pd.DataFrame( d, @@ -145,8 +199,6 @@ def _as_pandas_df(elements: t.Iterable, transform_fn: t.Callable[[t.Any], pd.Dat (pandas data frame with one row). Returns: Pandas data frame containing representation of elements in `elements` with one row being one element. - - TODO: maybe easier to initially transform elements to list of dicts? """ df = pd.DataFrame() for element in elements: @@ -154,18 +206,14 @@ def _as_pandas_df(elements: t.Iterable, transform_fn: t.Callable[[t.Any], pd.Dat return df def _get_pipelines(self, name: t.Optional[str] = None) -> t.List[mlpb.Context]: + pipelines: t.List[mlpb.Context] = self.store.get_contexts_by_type("Parent_Context") """Return list of pipelines with the given name. Args: name: Piepline name or None to return all pipelines. Returns: List of objects associated with pipelines. - - TODO (sergey): Is `Parent_Context` value always used for pipelines? - TODO (sergey): Why `name` parameter when there's another method `_get_pipeline`? - TODO (sergey): Use `self.store.get_context_by_type_and_name` when name presents? """ - pipelines: t.List[mlpb.Context] = self.store.get_contexts_by_type("Parent_Context") if name is not None: pipelines = [pipeline for pipeline in pipelines if pipeline.name == name] return pipelines @@ -176,8 +224,6 @@ def _get_pipeline(self, name: str) -> t.Optional[mlpb.Context]: name: Pipeline name. Returns: A pipeline object if found, else None. - - TODO (sergey): Use `self.store.get_context_by_type_and_name` instead calling self._get_pipelines? """ pipelines: t.List = self._get_pipelines(name) if pipelines: @@ -238,19 +284,17 @@ def _get_executions_by_output_artifact_id(self, artifact_id: int) -> t.List[int] for event in self.store.get_events_by_artifact_ids([artifact_id]) if event.type == mlpb.Event.OUTPUT ] - if len(execution_ids) >= 2: - logger.warning("%d executions claim artifact (id=%d) as output.", len(execution_ids), artifact_id) + # According to CMF, it's OK to have multiple executions that produce the same exact artifact. + # if len(execution_ids) >= 2: + # logger.warning("%d executions claim artifact (id=%d) as output.", len(execution_ids), artifact_id) return list(set(execution_ids)) def _get_artifact(self, name: str) -> t.Optional[mlpb.Artifact]: """Return artifact with the given name or None. - - TODO: Different artifact types can have the same name (see `get_artifacts_by_type`, - `get_artifact_by_type_and_name`). - TODO: (sergey) Use `self.store.get_artifacts` with list_options (filter_query)? Args: - name: Artifact name. + name: Fully-qualified name (e.g., artifact hash is added to the name), so name collisions across different + artifact types are not issues here. Returns: Artifact or None (if not found). """ @@ -263,13 +307,14 @@ def _get_artifact(self, name: str) -> t.Optional[mlpb.Artifact]: def _get_output_artifacts(self, execution_ids: t.List[int]) -> t.List[int]: """Return output artifacts for the given executions. + Artifacts are uniquely identified by their hashes in CMF, and so, when executions produce the same exact file, + they will claim this artifact as an output artifact, and so same artifact can have multiple producer + executions. + Args: execution_ids: List of execution identifiers to return output artifacts for. Returns: List of output artifact identifiers. - - TODO: (sergey) The `test_get_one_hop_child_artifacts` prints the warning in this method (Multiple executions - claim the same output artifacts) """ artifact_ids: t.List[int] = [ event.artifact_id @@ -319,12 +364,10 @@ def get_pipeline_stages(self, pipeline_name: str) -> t.List[str]: """Return list of pipeline stages for the pipeline with the given name. Args: - pipeline_name: Name of the pipeline for which stages need to be returned. + pipeline_name: Name of the pipeline for which stages need to be returned. In CMF, there are no different + pipelines with the same name. Returns: List of stage names associated with the given pipeline. - - TODO: Can there be multiple pipelines with the same name? - TODO: (sergey) not clear from method name that this method returns stage names """ stages = [] for pipeline in self._get_pipelines(pipeline_name): @@ -335,12 +378,10 @@ def get_all_exe_in_stage(self, stage_name: str) -> t.List[mlpb.Execution]: """Return list of all executions for the stage with the given name. Args: - stage_name: Name of the stage. + stage_name: Name of the stage. Before stages are recorded in MLMD, they are modified (e.g., pipeline name + will become part of the stage name). So stage names from different pipelines will not collide. Returns: List of executions for the given stage. - - TODO: Can stages from different pipelines have the same name?. Currently, the first matching stage is used to - identify its executions. Also see "get_all_executions_in_stage". """ for pipeline in self._get_pipelines(): for stage in self._get_stages(pipeline.id): @@ -352,12 +393,9 @@ def get_all_executions_in_stage(self, stage_name: str) -> pd.DataFrame: """Return executions of the given stage as pandas data frame. Args: - stage_name: Stage name. + stage_name: Stage name. See doc strings for the prev method. Returns: Data frame with all executions associated with the given stage. - - TODO: Multiple stages with the same name? This method collects executions from all such stages. Also, see - "get_all_exe_in_stage" """ df = pd.DataFrame() for pipeline in self._get_pipelines(): @@ -378,9 +416,6 @@ def get_artifact_df(self, artifact: mlpb.Artifact, d: t.Optional[t.Dict] = None) d: Optional initial content for data frame. Returns: A data frame with the single row containing attributes of this artifact. - - TODO: (sergey) there are no "public" methods that return `mlpb.Artifact`. - TODO: (sergey) what's the difference between this method and `get_artifact`? """ if d is None: d = {} @@ -401,9 +436,6 @@ def get_all_artifacts(self) -> t.List[str]: Returns: List of all artifact names. - - TODO: (sergey) Can multiple artifacts have the same name? - TODO: (sergey) Maybe rename to get_artifact_names (to be consistent with `get_pipeline_names`)? """ return [artifact.name for artifact in self.store.get_artifacts()] @@ -416,8 +448,6 @@ def get_artifact(self, name: str) -> t.Optional[pd.DataFrame]: name: Artifact name. Returns: Pandas data frame with one row containing attributes of this artifact. - - TODO: (sergey) what's the difference between this method and `get_artifact_df`? """ artifact: t.Optional[mlpb.Artifact] = self._get_artifact(name) if artifact: @@ -431,8 +461,6 @@ def get_all_artifacts_for_execution(self, execution_id: int) -> pd.DataFrame: execution_id: Execution identifier. Return: Data frame containing input and output artifacts for the given execution, one artifact per row. - - TODO: (sergey) briefly describe in what cases an execution may not have any artifacts. """ df = pd.DataFrame() for event in self.store.get_events_by_execution_ids([execution_id]): @@ -450,9 +478,6 @@ def get_all_executions_for_artifact(self, artifact_name: str) -> pd.DataFrame: artifact_name: Artifact name. Returns: Pandas data frame containing stage executions, one execution per row. - - TODO: (sergey) build list of dicts and then convert to data frame - will be quicker. - TODO: (sergey) can multiple contexts (pipeline and stage) be associated with one execution? """ df = pd.DataFrame() @@ -461,14 +486,13 @@ def get_all_executions_for_artifact(self, artifact_name: str) -> pd.DataFrame: return df for event in self.store.get_events_by_artifact_ids([artifact.id]): - # TODO: (sergey) seems to be the same as stage below. What's this context for (stage or pipeline)? - ctx = self.store.get_contexts_by_execution(event.execution_id)[0] + stage_ctx = self.store.get_contexts_by_execution(event.execution_id)[0] linked_execution = { "Type": "INPUT" if event.type == mlpb.Event.Type.INPUT else "OUTPUT", "execution_id": event.execution_id, "execution_name": self.store.get_executions_by_id([event.execution_id])[0].name, - "stage": self.store.get_contexts_by_execution(event.execution_id)[0].name, - "pipeline": self.store.get_parent_contexts_by_context(ctx.id)[0].name, + "stage": stage_ctx.name, + "pipeline": self.store.get_parent_contexts_by_context(stage_ctx.id)[0].name, } d1 = pd.DataFrame( linked_execution, @@ -505,8 +529,6 @@ def get_all_child_artifacts(self, artifact_name: str) -> pd.DataFrame: artifact_name: Artifact name. Returns: Data frame containing all child artifacts. - - TODO: Only output artifacts or all? """ df = pd.DataFrame() d1 = self.get_one_hop_child_artifacts(artifact_name) @@ -520,11 +542,7 @@ def get_all_child_artifacts(self, artifact_name: str) -> pd.DataFrame: return df def get_one_hop_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: - """Return input artifacts for the execution that produced the given artifact. - - Args: - artifact_name - """ + """Return input artifacts for the execution that produced the given artifact.""" artifact: t.Optional = self._get_artifact(artifact_name) if not artifact: return pd.DataFrame() @@ -536,10 +554,7 @@ def get_one_hop_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: ) def get_all_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: - """Return all upstream artifacts. - - TODO: All input and output artifacts? - """ + """Return all upstream artifacts.""" df = pd.DataFrame() d1 = self.get_one_hop_parent_artifacts(artifact_name) # df = df.append(d1, sort=True, ignore_index=True) @@ -554,6 +569,9 @@ def get_all_parent_artifacts(self, artifact_name: str) -> pd.DataFrame: def get_all_parent_executions(self, artifact_name: str) -> pd.DataFrame: """Return all executions that produced upstream artifacts for the given artifact.""" parent_artifacts: pd.DataFrame = self.get_all_parent_artifacts(artifact_name) + if parent_artifacts.shape[0] == 0: + # If it's empty, there's no `id` column and the code below raises an exception. + return pd.DataFrame() execution_ids = set( event.execution_id @@ -569,7 +587,8 @@ def get_all_parent_executions(self, artifact_name: str) -> pd.DataFrame: def find_producer_execution(self, artifact_name: str) -> t.Optional[mlpb.Execution]: """Return execution that produced the given artifact. - TODO: how come one artifact can have multiple producer executions? + One artifact can have multiple producer executions (names of artifacts are fully-qualified with hashes). So, + if two executions produced the same exact artifact, this one artifact will have multiple parent executions. """ artifact: t.Optional[mlpb.Artifact] = self._get_artifact(artifact_name) if not artifact: @@ -602,10 +621,7 @@ def find_producer_execution(self, artifact_name: str) -> t.Optional[mlpb.Executi get_producer_execution = find_producer_execution def get_metrics(self, metrics_name: str) -> t.Optional[pd.DataFrame]: - """Return metric data frame. - - TODO: need better description. - """ + """Return metric data frame.""" for metric in self.store.get_artifacts_by_type("Step_Metrics"): if metric.name == metrics_name: name: t.Optional[str] = metric.custom_properties.get("Name", None) @@ -616,18 +632,16 @@ def get_metrics(self, metrics_name: str) -> t.Optional[pd.DataFrame]: @staticmethod def read_dataslice(name: str) -> pd.DataFrame: - """Reads the data slice. - - TODO: Why is it here? - """ + """Reads the data slice.""" # To do checkout if not there df = pd.read_parquet(name) return df def dumptojson(self, pipeline_name: str, exec_id: t.Optional[int] = None) -> t.Optional[str]: """Return JSON-parsable string containing details about the given pipeline. - - TODO: Think if this method should return dict. + Args: + pipeline_name: Name of an AI pipelines. + exec_id: Optional stage execution ID - filter stages by this execution ID. """ if exec_id is not None: exec_id = int(exec_id) @@ -640,6 +654,7 @@ def _get_node_attributes(_node: t.Union[mlpb.Context, mlpb.Execution, mlpb.Event if "properties" in _attrs: _attrs["properties"] = CmfQuery._copy(_attrs["properties"]) if "custom_properties" in _attrs: + # TODO: (sergey) why do we need to rename "type" to "user_type" if we just copy into a new dictionary? _attrs["custom_properties"] = CmfQuery._copy( _attrs["custom_properties"], key_mapper={"type": "user_type"} ) @@ -694,3 +709,53 @@ def _get_node_attributes(_node: t.Union[mlpb.Context, mlpb.Execution, mlpb.Event remote = v Cmf.materialize(path, git_repo, rev, remote)""" + + +def test_on_collision() -> None: + from unittest import TestCase + + tc = TestCase() + + tc.assertEqual(3, len(_KeyMapper.OnCollision)) + tc.assertEqual(0, _KeyMapper.OnCollision.DO_NOTHING.value) + tc.assertEqual(1, _KeyMapper.OnCollision.RESOLVE.value) + tc.assertEqual(2, _KeyMapper.OnCollision.RAISE_ERROR.value) + + +def test_dict_mapper() -> None: + from unittest import TestCase + + tc = TestCase() + + dm = _DictMapper({"src_key": "tgt_key"}, on_collision=_KeyMapper.OnCollision.RESOLVE) + tc.assertEqual("tgt_key", dm.get({}, "src_key")) + tc.assertEqual("other_key", dm.get({}, "other_key")) + tc.assertEqual("existing_key_1", dm.get({"existing_key": "value"}, "existing_key")) + tc.assertEqual("existing_key_2", dm.get({"existing_key": "value", "existing_key_1": "value_1"}, "existing_key")) + + dm = _DictMapper({"src_key": "tgt_key"}, on_collision=_KeyMapper.OnCollision.DO_NOTHING) + tc.assertEqual("existing_key", dm.get({"existing_key": "value"}, "existing_key")) + + +def test_prefix_mapper() -> None: + from unittest import TestCase + + tc = TestCase() + + pm = _PrefixMapper("nested_", on_collision=_KeyMapper.OnCollision.RESOLVE) + tc.assertEqual("nested_src_key", pm.get({}, "src_key")) + + tc.assertEqual("nested_existing_key_1", pm.get({"nested_existing_key": "value"}, "existing_key")) + tc.assertEqual( + "nested_existing_key_2", + pm.get({"nested_existing_key": "value", "nested_existing_key_1": "value_1"}, "existing_key"), + ) + + dm = _PrefixMapper("nested_", on_collision=_KeyMapper.OnCollision.DO_NOTHING) + tc.assertEqual("nested_existing_key", dm.get({"nested_existing_key": "value"}, "existing_key")) + + +if __name__ == "__main__": + test_on_collision() + test_dict_mapper() + test_prefix_mapper() From 31ad9e99bba4bf6c730bfe259f45c586329a3ca7 Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Tue, 3 Oct 2023 06:01:42 +0000 Subject: [PATCH 4/8] Graph-like API to traverse CMF metadata. The API implements graph-like API to traverse CMF metadata in a graph-like manner. The entry point is the `MetadataStore` class that retrieves from metadata store pipelines, stages, executions and artifacts. Users can specify search query to specify what they want to return (the query is basically the value for the `filter_query` parameter of the `ListOptions` class ML Metadata (MLMD) library). Each node mentioned above (pipeline, stages, executions and artifacts) havs its own Python wrapper class that provides developer-friendly API to access node's parameters and travers graph of machine learning concepts (e.g., get all stages of a pipeline or get all executions of a stage). The graph API also provides the `Properties` wrapper for MLMD's properties and custom_propertied nodes' fields. This wrapper implements the `Mapping` API and automatically converts MLMD's values to Python values on the fly. --- cmflib/contrib/graph_api.py | 402 ++++++++++++++++++++++++++++++++++++ 1 file changed, 402 insertions(+) create mode 100644 cmflib/contrib/graph_api.py diff --git a/cmflib/contrib/graph_api.py b/cmflib/contrib/graph_api.py new file mode 100644 index 00000000..871331d5 --- /dev/null +++ b/cmflib/contrib/graph_api.py @@ -0,0 +1,402 @@ +import typing as t +from typing import KT, Iterator, T_co, VT_co + +from cmfquery import CmfQuery +from ml_metadata import ListOptions +from ml_metadata.proto import metadata_store_pb2 as mlpb + +MlmdNode = t.Union[mlpb.Context, mlpb.Execution, mlpb.Artifact] +Node = t.TypeVar("Node", bound="Base") + + +_PIPELINE_CONTEXT_NAME = "Parent_Context" +"""Name of a context type for pipelines.""" + +_STAGE_CONTEXT_NAME = "Pipeline_Stage" +"""Name of a context type for pipeline stages.""" + + +class Properties(t.Mapping): + """Read-only wrapper around MessageMapContainer that converts values to python types on the fly. + This is used to represent `properties` and `custom_properties` of all MLMD nodes (pipelines, stages, executions + and artifacts). + """ + + def __init__(self) -> None: + self._properties: t.Optional[t.Mapping] = None + """This is really google.protobuf.pyext._message.MessageMapContainer that inherits from MutableMapping""" + + def __str__(self) -> str: + return str({k: v for k, v in self.items()}) + + def __iter__(self) -> Iterator[T_co]: + for k in self._properties.keys(): + yield k + + def __len__(self) -> int: + return len(self._properties) + + def __getitem__(self, __key: KT) -> VT_co: + return get_python_value(self._properties[__key]) + + +class Base: + """Base class for wrappers that provide user-friendly API for MLMD's pipelines, stages, executions and artifacts. + + Instance of child classes are not supposed to be directly created by users, so class members are "protected". + """ + + def __init__(self) -> None: + self._db: t.Optional[CmfQuery] = None + """Data access layer for MLMD.""" + + self._node: t.Optional[MlmdNode] = None + """Reference to an entity in MLMD that this class wraps.""" + + def __str__(self) -> str: + return ( + f"{self.__class__.__name__}(id={self.id}, name={self.name}, properties={self.properties}, " + f"custom_properties={self.custom_properties})" + ) + + @property + def id(self) -> int: + return self._node.id + + @property + def name(self) -> str: + return self._node.name + + @property + def properties(self) -> Properties: + _properties = Properties() + _properties._properties = self._node.properties + return _properties + + @property + def custom_properties(self) -> Properties: + _properties = Properties() + _properties._properties = self._node.custom_properties + return _properties + + # @property + # def type_id(self) -> int: + # return self._node.type_id + + # @property + # def type(self) -> str: + # return self._node.type + + # @property + # def external_id(self) -> id: + # return self._node.external_id + + @classmethod + def _create(cls, db: CmfQuery, node: MlmdNode, attrs: t.Optional[t.Dict] = None) -> Node: + """Create class instance (users are not supposed to call this method by themselves). + Args: + db: Data access layer. + node: MLMD's node. + attrs: Optional attributes to set on newly created class instance (with `setattr` function). + Returns: + Instance of one of child classes. + """ + obj = cls() + obj._db = db + obj._node = node + if attrs: + for name, value in attrs.items(): + setattr(obj, name, value) + return obj + + @classmethod + def _unique(cls, nodes: t.List[MlmdNode]) -> t.List[MlmdNode]: + """Return unique input elements in the input list using the `id` attribute as a unique key. + Args: + nodes: List of input elements. + Returns: + New list containing unique elements in `nodes`. Duplicates are identified using the `id` attribute. + """ + ids = set(node.id for node in nodes) + return [node for node in nodes if node.id in ids] + + @classmethod + def _one(cls, nodes: t.List[t.Any]) -> t.Any: + """Ensure input list contains exactly one element and return it. + Args: + nodes: List of input elements. + Returns: + First element in the list. + Raises: + ValueError error if length of `nodes` is not 1. + """ + if len(nodes) != 1: + raise ValueError(f"List (len={len(nodes)}) expected to contain one element.") + return nodes[0] + + +class Pipeline(Base): + """Class that represents AI pipelines.""" + + def __init__(self) -> None: + super().__init__() + + def stages(self, query: t.Optional[str] = None) -> t.List["Stage"]: + """Return list of all stages in this pipeline.""" + _mandatory_query = f"parent_contexts_a.name = '{self.name}'" + stage_contexts: t.List[mlpb.Context] = self._db._get_stages(pipeline_id=self.id) + return [Stage._create(self._db, ctx, {"_pipeline": self}) for ctx in stage_contexts] + + def executions(self) -> t.List["Execution"]: + """Return list of all executions in this pipeline""" + executions: t.List[Execution] = [] + for stage in self.stages(): + executions.extend(stage.executions()) + return executions + + def artifacts(self) -> t.List["Artifact"]: + """Return list of all unique artifacts consumed and produced by this pipeline.""" + artifacts: t.List[Artifact] = [] + for execution in self.executions(): + artifacts.extend(execution.inputs) + artifacts.extend(execution.outputs) + return self._unique(artifacts) + + +class Stage(Base): + """Class that represents pipeline stages.""" + + def __init__(self) -> None: + super().__init__() + + self._pipeline: t.Optional[Pipeline] = None + """Parent pipeline (lazily initialized).""" + + @property + def pipeline(self) -> Pipeline: + """Return parent pipeline.""" + if self._pipeline is None: + pipeline_context: mlpb.Context = self._one( + self._db.store.get_parent_contexts_by_context(context_id=self.id) + ) + self._pipeline = Pipeline._create(self._db, pipeline_context) + return self._pipeline + + def executions(self) -> t.List["Execution"]: + """Return list of all executions for this stage.""" + executions: t.List[mlpb.Execution] = self._db._get_executions(stage_id=self.id) + return [Execution._create(self._db, execution, {"_stage": self}) for execution in executions] + + def artifacts(self) -> t.List["Artifact"]: + """Return list of unique artifacts consumed and produced by this pipeline.""" + artifacts: t.List[Artifact] = [] + for execution in self.executions(): + artifacts.extend(execution.inputs) + artifacts.extend(execution.outputs) + return self._unique(artifacts) + + +class Execution(Base): + """Class that represents stage executions.""" + + def __init__(self) -> None: + super().__init__() + + self._stage: t.Optional[Stage] = None + """Stage for this execution (lazily initialized).""" + + @property + def stage(self) -> Stage: + if self._stage is None: + stage_context: mlpb.Context = self._one(self._db.store.get_contexts_by_execution(execution_id=self.id)) + self._stage = Stage._create(self._db, stage_context) + return self._stage + + @property + def inputs(self) -> t.List["Artifact"]: + """Return list of unique input artifacts for this execution.""" + artifacts: t.List[mlpb.Artifact] = self._unique( + self._db.store.self.store.get_artifacts_by_id(self._db._get_input_artifacts([self.id])) + ) + return [Artifact._create(self._db, artifact) for artifact in artifacts] + + @property + def outputs(self) -> t.List["Artifact"]: + """Return list of unique output artifacts for this execution.""" + artifacts: t.List[mlpb.Artifact] = self._unique( + self._db.store.self.store.get_artifacts_by_id(self._db._get_output_artifacts([self.id])) + ) + return [Artifact._create(self._db, artifact) for artifact in artifacts] + + +class Artifact(Base): + """Class that represents artifacts.""" + + def __init__(self) -> None: + super().__init__() + + self._consumed_by: t.Optional[t.List[Execution]] = None + """List of unique executions that consumed this artifact (lazily initialized).""" + + self._produced_by: t.Optional[t.List[Execution]] = None + """List of unique executions that produced this artifact (lazily initialized).""" + + @property + def uri(self) -> str: + return self._node.uri + + @property + def consumed_by(self) -> t.List[Execution]: + """Return all executions that have consumed this artifact.""" + if self._consumed_by is None: + executions: t.List[mlpb.Execution] = self._unique( + self._db.store.get_executions_by_id(self._db._get_executions_by_input_artifact_id(artifact_id=self.id)) + ) + self._consumed_by = [Execution._create(self._db, execution) for execution in executions] + return self._consumed_by + + @property + def produced_by(self) -> t.List[Execution]: + """Return all executions that have produced this artifact""" + if self._produced_by is None: + executions: t.List[mlpb.Execution] = self._unique( + self._db.store.get_executions_by_id(self._db._get_executions_by_output_artifact_id(artifact_id=self.id)) + ) + self._produced_by = [Execution._create(self._db, execution) for execution in executions] + return self._produced_by + + +class MetadataStore: + """`Entry point` for traversing the MLMD database using graph-like API. + + Many methods in this class support the `query` string argument. This is the same as the `filter_query` field in + the `ListOptions`, which is an input argument of several methods in MLMD that retrieve nodes of various types. The + not-so-detailed description of what this string can look like can be found here: + https://github.com/google/ml-metadata/blob/master/ml_metadata/proto/metadata_store.proto + Many MLMD node types have same attributes, such as `id`, `name`, `properties`, `custom_properties` and others. The + following functionality has been tested. + Common node attributes: + "id = 1113", "id != 1113", "id IN (1113, 1)", "name = 'text-generation'", "name != 'text-generation'", + "name LIKE '%-generation'" + + Args: + file_path: Path to an MLMD database. + + """ + + def __init__(self, file_path: str) -> None: + self._db = CmfQuery(filepath=file_path) + """Data access layer for MLMD.""" + + def _check_context_types( + self, + contexts: t.List[mlpb.Context], + type_name: str, + ): + ctx_type: mlpb.ContextType = self._db.store.get_context_type(type_name) + for context in contexts: + if context.type_id != ctx_type.id: + raise ValueError( + f"MLMD query returned contexts of the wrong type (actual_type_id={context.type_id}, " + f"expected_type_id={ctx_type.id}). " + f"Did you forget to specify the type as part of the query (type = '{type_name}')?" + ) + + def _search_contexts( + self, ctx_type_name: str, ctx_wrapper_cls: t.Type[t.Union["Pipeline", "Stage"]], query: t.Optional[str] = None + ) -> t.Union[t.List["Pipeline"], t.List["Stage"]]: + if query is None: + query = f"type = '{ctx_type_name}'" + contexts: t.List[mlpb.Context] = self._db.store.get_contexts(self.list_options(query)) + self._check_context_types(contexts, ctx_type_name) + return [ctx_wrapper_cls._create(self._db, ctx) for ctx in contexts] + + def pipelines(self, query: t.Optional[str] = None) -> t.List["Pipeline"]: + """Retrieve pipelines. + Pipelines are represented as contexts in MLMD with type attributed equal to `Parent_Context`. + Args: + query: The `filter_query` field for the `ListOptions` instance. See class doc strings for examples. + Raises: + ValueError when query is present, and results in contexts that are not pipelines. + Known limitations: + When query is not None, it must define type (type = 'Parent_Context'), when ID is present this may not be + required though. + No filtering is supported by stage ID (no big deal). + Query examples: + Filtering by basic node attributes. + See class foc strings. + Filter by stage attributes (child_contexts_a is the stage context). There is a bug that prevents using the + ID for filtering by stage ID. It's fixed in MLMD version 1.14.0 (CMF uses the earlier version). + "child_contexts_a.name LIKE 'text-generation/%'" + """ + return self._search_contexts(_PIPELINE_CONTEXT_NAME, Pipeline, query) + + def stages(self, query: t.Optional[str] = None) -> t.List["Stage"]: + """Retrieve stages. + Stages are represented as contexts in MLMD with type attributed equal to `Pipeline_Stage`. + Args: + query: The `filter_query` field for the `ListOptions` instance. See class doc strings for examples. + Raises: + ValueError when query is present, and results in contexts that are not stages. + Known limitations: + When query is not None, it must define type (type = 'Stage_Context'), when ID is present this may not be + required though. + No filtering is supported by pipeline ID (pretty significant feature). + Query examples: + Filtering by basic node attributes + See class foc strings. + Filter by pipeline attributes (parent_contexts_a is the pipeline context). There is a bug that prevents + using the ID for filtering by pipeline ID. It's fixed in MLMD version 1.14.0 (CMF uses the earlier version). + "parent_contexts_a.name = 'text-classification'", "parent_contexts_a.name LIKE '%-generation'", + "parent_contexts_a.type = 'Parent_Context'" + """ + return self._search_contexts(_STAGE_CONTEXT_NAME, Stage, query) + + def executions(self) -> t.List["Execution"]: + """Retrieve stage executions. + See `pipelines` method for more details. + """ + executions: t.List[mlpb.Execution] = self._db.store.get_executions() + return [Execution._create(self._db, execution) for execution in executions] + + def artifacts(self, query: t.Optional[str] = None) -> t.List["Artifact"]: + """ + # Find artifact with this ID + id = 1315 + # Find artifacts with a particular ArtifactType + type = 'Model' + type = 'Dataset' + # Find artifacts using pattern matching + name LIKE 'models/%' + name LIKE '%falcon%' + name LIKE 'datasets/%' + # Search using properties and custom_properties: + properties.url.string_value LIKE '%a655dead548f56fe3409321b3569a3%' + properties.pipeline_tag.string_value = 'text-classification' + custom_properties.pipeline_tag.string_value = 'text-classification' + """ + artifacts: t.List[mlpb.Artifact] = self._db.store.get_artifacts(self.list_options(query)) + return [Artifact._create(self._db, artifact) for artifact in artifacts] + + @staticmethod + def list_options(query: t.Optional[str] = None) -> t.Optional[ListOptions]: + list_options: t.Optional[ListOptions] = None + if query: + list_options = ListOptions(filter_query=query) + return list_options + + +def get_python_value(value: mlpb.Value) -> t.Union[str, int, float]: + """Convert MLMD value to a python value. + Args: + value: MLMD value. + Returns: + Python value. + """ + if value.HasField("string_value"): + return value.string_value + elif value.HasField("int_value"): + return value.int_value + elif value.HasField("double_value"): + return value.double_value + raise NotImplementedError("Only string, int and double fields are supported.") From 30119a4e466b3f8e371dabf971869c7aa06868ef Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Mon, 16 Oct 2023 01:34:53 +0000 Subject: [PATCH 5/8] Refactoring implementation and unit tests. - Adding unit tests (97% coverage). - Renaming certain classes, redesigning implementation of several methods. - Adding `Type` class that represents concepts such as ContextType, ExecutionType and ArtifactType in MLMD library. --- cmflib/contrib/graph_api.py | 601 +++++++++++++++++++++------------ test/contrib/test_graph_api.py | 317 +++++++++++++++++ 2 files changed, 708 insertions(+), 210 deletions(-) create mode 100644 test/contrib/test_graph_api.py diff --git a/cmflib/contrib/graph_api.py b/cmflib/contrib/graph_api.py index 871331d5..76bd861e 100644 --- a/cmflib/contrib/graph_api.py +++ b/cmflib/contrib/graph_api.py @@ -1,25 +1,86 @@ +### +# Copyright (2023) Hewlett Packard Enterprise Development LP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +### + +""" +Introduction: + This module implements a read-only graph-like API to ML metadata database. It provides high level wrappers such as + `Pipeline`, `Stage`, `Execution` and `Artifact` that provide user-friendly interface to MLMD concepts (`Context`, + `Execution` and `Artifact`). + +Warning: + The information provided by classes in this module will become outdated if underlying MLMD database is modified + at the same time. + + The mechanism to model relations implemented in this module is not directly following MLMD database scheme. For + instance, users can go directly from artifacts to executions using methods implemented in the `Artifact` class + not using the concept of events from MLMD library. + + For convenience purposes, this class provides two wrappers for `Context` concept in MLMD - `Pipeline` and `Stage`. + + Users should not create node wrappers (`Pipeline`, `Stage`, `Execution` and `Artifact`) by themselves. Instead, + users should use `MlmdGraph` class to query nodes, and then traverse graph using nodes' APIs. + +How to get started: + Users should always start exploring data in MLMD database with the `MlmdGraph` class implemented in this module. +""" + import typing as t -from typing import KT, Iterator, T_co, VT_co +from typing import Iterator, TypeVar from cmfquery import CmfQuery -from ml_metadata import ListOptions -from ml_metadata.proto import metadata_store_pb2 as mlpb +from ml_metadata import MetadataStore +from ml_metadata.proto import metadata_store_pb2 + +__all__ = [ + "MlmdNode", + "MlmdType", + "Node", + "Properties", + "Type", + "Base", + "Pipeline", + "Stage", + "Execution", + "Artifact", + "MlmdGraph", + "unique", + "one", +] + +MlmdNode = t.Union[metadata_store_pb2.Context, metadata_store_pb2.Execution, metadata_store_pb2.Artifact] +"""Type for all nodes in MLMD library.""" + +MlmdType = t.Union[metadata_store_pb2.ContextType, metadata_store_pb2.ExecutionType, metadata_store_pb2.ArtifactType] +"""Type for all node types in MLMD library.""" -MlmdNode = t.Union[mlpb.Context, mlpb.Execution, mlpb.Artifact] Node = t.TypeVar("Node", bound="Base") +"""Type for all nodes implemented in this module.""" -_PIPELINE_CONTEXT_NAME = "Parent_Context" -"""Name of a context type for pipelines.""" - -_STAGE_CONTEXT_NAME = "Pipeline_Stage" -"""Name of a context type for pipeline stages.""" +# These are not for exports in `typing` module, so defining them here (only needed for `Properties` class). +KT = TypeVar("KT") +T_co = TypeVar("T_co", covariant=True) +VT_co = TypeVar("VT_co", covariant=True) class Properties(t.Mapping): - """Read-only wrapper around MessageMapContainer that converts values to python types on the fly. + """Read-only wrapper around MessageMapContainer from MLMD that converts values to python types on the fly. + This is used to represent `properties` and `custom_properties` of all MLMD nodes (pipelines, stages, executions - and artifacts). + and artifacts). Users should not create instances of this class directly. """ def __init__(self) -> None: @@ -37,21 +98,151 @@ def __len__(self) -> int: return len(self._properties) def __getitem__(self, __key: KT) -> VT_co: - return get_python_value(self._properties[__key]) + return _python_value(self._properties[__key]) + + +class Type: + """Semantic type description for all graph nodes (pipelines, stages, executions and artifacts). + + Information source: + https://github.com/google/ml-metadata/blob/master/ml_metadata/proto/metadata_store.proto + The following types and associated methods should be used in implementations: + | Nodes | MlmdNodes | Metadata store function | + |-----------------|---------------|---------------------------| + | Pipeline, Stage | ContextType | get_context_types_by_id | + | Execution | ExecutionType | get_execution_types_by_id | + | Artifact | ArtifactType | get_artifact_types_by_id | + All types seem to share many common attributes: + id:int64 name:str version:str description:str external_id:str properties:map + ContextType + + + + + + + ExecutionType + + + + + + + ArtifactType + + + + + + + + Users should not create instances of this class directly - use the `Type.get` method instead. This class maintains + internal mapping from a type ID to instances of this class. This means there is one instance exists for each MLMD + type. + """ + + # The values below come from the CMF core library. They are standard type names that CMF uses to differentiate + # between contexts (pipelines and stages), and natively supported artifacts (datasets, models and metrics) - + # concepts exposed via CMF public API. + + # No fields are defined for execution and generic artifact types because CMF does not enforce standard + # names at this level. One aspect to keep in mind that MLMD uses the `type_kind` attribute of a type to + # differentiate between context, execution and artifact types (that may be stored in one table in a relational + # backend database). + + # PIPELINE and STAGE are context types. + + PIPELINE = "Parent_Context" + """Name of a context type (ContextType) for pipelines.""" + + STAGE = "Pipeline_Stage" + """Name of a context type (ContextType) for pipeline stages.""" + + # DATASET, MODEL and METRICS are artifact types + + DATASET = "Dataset" + """Type name (ArtifactType) for `dataset` artifacts.""" + + MODEL = "Model" + """Type name (ArtifactType) for `model` artifacts.""" + + METRICS = "Metrics" + """Type name ArtifactType for execution metrics artifact.""" + + def __init__(self) -> None: + self._type: t.Optional[MlmdType] = None + + @property + def id(self) -> int: + return self._type.id + + @property + def name(self) -> str: + return self._type.name + + @property + def version(self) -> str: + return self._type.version + + @property + def description(self) -> str: + return self._type.description + + @property + def external_id(self) -> str: + return self._type.external_id + + @property + def properties(self) -> Properties: + _properties = Properties() + _properties._properties = self._type.properties + return _properties + + _types: t.Dict[int, "Type"] = {} + """A mapping from type ID to type description.""" + + @staticmethod + def get(store: MetadataStore, node: MlmdNode) -> "Type": + """Factory method to return the type description. + + Every type (identified by type ID) has only one instance, so getting types should be relatively lightweight + operation (database is accessed only once for each type ID). + + Args: + store: Metadata store object from MLMD library. + node: A node to return type for. + Returns: + Instance of this class that wraps one of MLMD classes (ContextType, ExecutionType or ArtifactType). + """ + if node.type_id not in Type._types: + if isinstance(node, metadata_store_pb2.Context): + node_type, get_type_fn = metadata_store_pb2.ContextType, store.get_context_types_by_id + elif isinstance(node, metadata_store_pb2.Execution): + node_type, get_type_fn = metadata_store_pb2.ExecutionType, store.get_execution_types_by_id + elif isinstance(node, metadata_store_pb2.Artifact): + node_type, get_type_fn = metadata_store_pb2.ArtifactType, store.get_artifact_types_by_id + else: + raise NotImplementedError(f"No type description for MLMD node (node={node}).") + + type_ = Type() + type_._type = one( + get_type_fn([node.type_id]), + error=ValueError(f"Broken MLMD database. Expecting exactly one type for type_id = {node.type_id}."), + ) + Type._types[node.type_id] = type_ + return Type._types[node.type_id] class Base: - """Base class for wrappers that provide user-friendly API for MLMD's pipelines, stages, executions and artifacts. + """Base class for node wrappers providing user-friendly API for pipelines, stages, executions and artifacts in + MLMD database. - Instance of child classes are not supposed to be directly created by users, so class members are "protected". + Instance of child classes are not supposed to be directly created by users. This wrapper does not expose `type` + and `type_id` fields as defined in MLMD. Instead, users should use the `type` property in this class to access + type information (calling this method should be a lightweight operation). """ def __init__(self) -> None: self._db: t.Optional[CmfQuery] = None - """Data access layer for MLMD.""" + """Data access layer for MLMD database.""" self._node: t.Optional[MlmdNode] = None - """Reference to an entity in MLMD that this class wraps.""" + """Reference to an entity in MLMD database that this class wraps.""" + + def __hash__(self): + """Compute hash. + TODO: Is type_id enough to differentiate between pipeline and stage contexts? + """ + assert self._node is not None, "Internal error: self._node is None in Base.__hash__." + # I can't use `self._node.__class__.__name__` since Pipelines and Stages have the same MLMD class `Context`. + return hash((self.__class__.__name__, self._node.type_id, self.id)) + + def __eq__(self, other: Node) -> bool: + assert self._node is not None, "Internal error: self._node is None in Base.__eq__." + assert other._node is not None, "Internal error: other._node is None in Base.__eq__." + return isinstance(self, type(other)) and self._node.type_id == other._node.type_id and self.id == other.id def __str__(self) -> str: return ( @@ -79,73 +270,31 @@ def custom_properties(self) -> Properties: _properties._properties = self._node.custom_properties return _properties - # @property - # def type_id(self) -> int: - # return self._node.type_id - - # @property - # def type(self) -> str: - # return self._node.type - - # @property - # def external_id(self) -> id: - # return self._node.external_id + @property + def external_id(self) -> str: + return self._node.external_id - @classmethod - def _create(cls, db: CmfQuery, node: MlmdNode, attrs: t.Optional[t.Dict] = None) -> Node: - """Create class instance (users are not supposed to call this method by themselves). - Args: - db: Data access layer. - node: MLMD's node. - attrs: Optional attributes to set on newly created class instance (with `setattr` function). - Returns: - Instance of one of child classes. - """ - obj = cls() - obj._db = db - obj._node = node - if attrs: - for name, value in attrs.items(): - setattr(obj, name, value) - return obj - - @classmethod - def _unique(cls, nodes: t.List[MlmdNode]) -> t.List[MlmdNode]: - """Return unique input elements in the input list using the `id` attribute as a unique key. - Args: - nodes: List of input elements. - Returns: - New list containing unique elements in `nodes`. Duplicates are identified using the `id` attribute. - """ - ids = set(node.id for node in nodes) - return [node for node in nodes if node.id in ids] + # The type_id:int and type:str are not provided, use `type` instead to get full type description. - @classmethod - def _one(cls, nodes: t.List[t.Any]) -> t.Any: - """Ensure input list contains exactly one element and return it. - Args: - nodes: List of input elements. - Returns: - First element in the list. - Raises: - ValueError error if length of `nodes` is not 1. - """ - if len(nodes) != 1: - raise ValueError(f"List (len={len(nodes)}) expected to contain one element.") - return nodes[0] + @property + def type(self) -> Type: + return Type.get(self._db.store, self._node) class Pipeline(Base): - """Class that represents AI pipelines.""" + """Class that represents AI pipelines by wrapping the `Context` concept in MLMD. + + Users should not create instances of this class - use `MlmdGraph` class instead or other node wrappers. + """ def __init__(self) -> None: super().__init__() - def stages(self, query: t.Optional[str] = None) -> t.List["Stage"]: + def stages(self) -> t.List["Stage"]: """Return list of all stages in this pipeline.""" - _mandatory_query = f"parent_contexts_a.name = '{self.name}'" - stage_contexts: t.List[mlpb.Context] = self._db._get_stages(pipeline_id=self.id) - return [Stage._create(self._db, ctx, {"_pipeline": self}) for ctx in stage_contexts] + # noinspection PyProtectedMember + stage_contexts: t.List[metadata_store_pb2.Context] = self._db._get_stages(pipeline_id=self.id) + return [_graph_node(Stage, self._db, ctx, {"_pipeline": self}) for ctx in stage_contexts] def executions(self) -> t.List["Execution"]: """Return list of all executions in this pipeline""" @@ -155,16 +304,19 @@ def executions(self) -> t.List["Execution"]: return executions def artifacts(self) -> t.List["Artifact"]: - """Return list of all unique artifacts consumed and produced by this pipeline.""" + """Return list of all unique artifacts consumed and produced by executions of this pipeline.""" artifacts: t.List[Artifact] = [] for execution in self.executions(): - artifacts.extend(execution.inputs) - artifacts.extend(execution.outputs) - return self._unique(artifacts) + artifacts.extend(execution.inputs()) + artifacts.extend(execution.outputs()) + return unique(artifacts, "id") class Stage(Base): - """Class that represents pipeline stages.""" + """Class that represents pipeline stages by wrapping the `Context` concept in MLMD. + + Users should not create instances of this class - use `MlmdGraph` class instead or other node wrappers. + """ def __init__(self) -> None: super().__init__() @@ -176,28 +328,31 @@ def __init__(self) -> None: def pipeline(self) -> Pipeline: """Return parent pipeline.""" if self._pipeline is None: - pipeline_context: mlpb.Context = self._one( - self._db.store.get_parent_contexts_by_context(context_id=self.id) + self._pipeline = _graph_node( + Pipeline, self._db, one(self._db.store.get_parent_contexts_by_context(context_id=self.id)) ) - self._pipeline = Pipeline._create(self._db, pipeline_context) return self._pipeline def executions(self) -> t.List["Execution"]: - """Return list of all executions for this stage.""" - executions: t.List[mlpb.Execution] = self._db._get_executions(stage_id=self.id) - return [Execution._create(self._db, execution, {"_stage": self}) for execution in executions] + """Return list of all executions of this stage.""" + # noinspection PyProtectedMember + executions: t.List[metadata_store_pb2.Execution] = self._db._get_executions(stage_id=self.id) + return [_graph_node(Execution, self._db, execution, {"_stage": self}) for execution in executions] def artifacts(self) -> t.List["Artifact"]: - """Return list of unique artifacts consumed and produced by this pipeline.""" + """Return list of unique artifacts consumed and produced by executions of this stage.""" artifacts: t.List[Artifact] = [] for execution in self.executions(): - artifacts.extend(execution.inputs) - artifacts.extend(execution.outputs) - return self._unique(artifacts) + artifacts.extend(execution.inputs()) + artifacts.extend(execution.outputs()) + return unique(artifacts, "id") class Execution(Base): - """Class that represents stage executions.""" + """Class that represents stage executions wrapping the `Execution` concept in MLMD. + + Users should not create instances of this class - use `MlmdGraph` class instead or other node wrappers. + """ def __init__(self) -> None: super().__init__() @@ -207,30 +362,37 @@ def __init__(self) -> None: @property def stage(self) -> Stage: + """Return stage of this execution.""" if self._stage is None: - stage_context: mlpb.Context = self._one(self._db.store.get_contexts_by_execution(execution_id=self.id)) - self._stage = Stage._create(self._db, stage_context) + self._stage = _graph_node( + Stage, self._db, one(self._db.store.get_contexts_by_execution(execution_id=self.id)) + ) return self._stage - @property def inputs(self) -> t.List["Artifact"]: """Return list of unique input artifacts for this execution.""" - artifacts: t.List[mlpb.Artifact] = self._unique( - self._db.store.self.store.get_artifacts_by_id(self._db._get_input_artifacts([self.id])) + # noinspection PyProtectedMember + artifacts: t.List[metadata_store_pb2.Artifact] = unique( + self._db.store.get_artifacts_by_id(self._db._get_input_artifacts([self.id])), key="id" ) - return [Artifact._create(self._db, artifact) for artifact in artifacts] + return [_graph_node(Artifact, self._db, artifact) for artifact in artifacts] - @property def outputs(self) -> t.List["Artifact"]: """Return list of unique output artifacts for this execution.""" - artifacts: t.List[mlpb.Artifact] = self._unique( - self._db.store.self.store.get_artifacts_by_id(self._db._get_output_artifacts([self.id])) + # noinspection PyProtectedMember + artifacts: t.List[metadata_store_pb2.Artifact] = unique( + self._db.store.get_artifacts_by_id(self._db._get_output_artifacts([self.id])), key="id" ) - return [Artifact._create(self._db, artifact) for artifact in artifacts] + return [_graph_node(Artifact, self._db, artifact) for artifact in artifacts] class Artifact(Base): - """Class that represents artifacts.""" + """Class that represents artifacts in MLMD by wrapping the `Artifact` concept in MLMD. + + Users should not create instances of this class - use `MlmdGraph` class instead or other node wrappers. + + TODO (sergey) Need to brainstorm the idea of providing artifact-specific classes derived from this class. + """ def __init__(self) -> None: super().__init__() @@ -245,30 +407,50 @@ def __init__(self) -> None: def uri(self) -> str: return self._node.uri - @property def consumed_by(self) -> t.List[Execution]: - """Return all executions that have consumed this artifact.""" + """Return all executions that have consumed this artifact. + + Users must not modify the returned list. + """ if self._consumed_by is None: - executions: t.List[mlpb.Execution] = self._unique( - self._db.store.get_executions_by_id(self._db._get_executions_by_input_artifact_id(artifact_id=self.id)) + # noinspection PyProtectedMember + executions: t.List[metadata_store_pb2.Execution] = unique( + self._db.store.get_executions_by_id(self._db._get_executions_by_input_artifact_id(artifact_id=self.id)), + key="id", ) - self._consumed_by = [Execution._create(self._db, execution) for execution in executions] + self._consumed_by = [_graph_node(Execution, self._db, execution) for execution in executions] return self._consumed_by - @property def produced_by(self) -> t.List[Execution]: - """Return all executions that have produced this artifact""" + """Return all executions that have produced this artifact. + + Users must not modify the returned list. How come one artifact is produced by multiple executions? The CMF + uses hashes of artifacts to determine the artifacts' uniqueness. If multiple executions have happened to produce + artifacts with the same content (so, hashes are the same), then there will be one record of this in MLMD + database. + """ if self._produced_by is None: - executions: t.List[mlpb.Execution] = self._unique( - self._db.store.get_executions_by_id(self._db._get_executions_by_output_artifact_id(artifact_id=self.id)) + # noinspection PyProtectedMember + executions: t.List[metadata_store_pb2.Execution] = unique( + self._db.store.get_executions_by_id( + self._db._get_executions_by_output_artifact_id(artifact_id=self.id) + ), + key="id", ) - self._produced_by = [Execution._create(self._db, execution) for execution in executions] + self._produced_by = [_graph_node(Execution, self._db, execution) for execution in executions] return self._produced_by -class MetadataStore: +class MlmdGraph: """`Entry point` for traversing the MLMD database using graph-like API. + [Opinionated]. Graph libraries typically implement a `Graph` class that provides at least two methods to iterate + over nodes and edges. In addition, other methods are also implemented that compute various graph characteristics. + At the moment of implementing this initial version, it is a bit confusing to differentiate between various MLMD + node kinds (contexts, executions and artifacts) and CMF (semantic) types (pipelines, stages, executions, models, + artifacts, datasets, models and metrics). This is the reason for not implementing the `nodes` method. Instead, + multiple methods are implemented to return CMF nodes. + Many methods in this class support the `query` string argument. This is the same as the `filter_query` field in the `ListOptions`, which is an input argument of several methods in MLMD that retrieve nodes of various types. The not-so-detailed description of what this string can look like can be found here: @@ -288,115 +470,114 @@ def __init__(self, file_path: str) -> None: self._db = CmfQuery(filepath=file_path) """Data access layer for MLMD.""" - def _check_context_types( - self, - contexts: t.List[mlpb.Context], - type_name: str, - ): - ctx_type: mlpb.ContextType = self._db.store.get_context_type(type_name) - for context in contexts: - if context.type_id != ctx_type.id: - raise ValueError( - f"MLMD query returned contexts of the wrong type (actual_type_id={context.type_id}, " - f"expected_type_id={ctx_type.id}). " - f"Did you forget to specify the type as part of the query (type = '{type_name}')?" - ) - - def _search_contexts( - self, ctx_type_name: str, ctx_wrapper_cls: t.Type[t.Union["Pipeline", "Stage"]], query: t.Optional[str] = None - ) -> t.Union[t.List["Pipeline"], t.List["Stage"]]: - if query is None: - query = f"type = '{ctx_type_name}'" - contexts: t.List[mlpb.Context] = self._db.store.get_contexts(self.list_options(query)) - self._check_context_types(contexts, ctx_type_name) - return [ctx_wrapper_cls._create(self._db, ctx) for ctx in contexts] - - def pipelines(self, query: t.Optional[str] = None) -> t.List["Pipeline"]: - """Retrieve pipelines. - Pipelines are represented as contexts in MLMD with type attributed equal to `Parent_Context`. - Args: - query: The `filter_query` field for the `ListOptions` instance. See class doc strings for examples. - Raises: - ValueError when query is present, and results in contexts that are not pipelines. - Known limitations: - When query is not None, it must define type (type = 'Parent_Context'), when ID is present this may not be - required though. - No filtering is supported by stage ID (no big deal). - Query examples: - Filtering by basic node attributes. - See class foc strings. - Filter by stage attributes (child_contexts_a is the stage context). There is a bug that prevents using the - ID for filtering by stage ID. It's fixed in MLMD version 1.14.0 (CMF uses the earlier version). - "child_contexts_a.name LIKE 'text-generation/%'" - """ - return self._search_contexts(_PIPELINE_CONTEXT_NAME, Pipeline, query) + def pipelines(self) -> t.List["Pipeline"]: + """Return all pipelines.""" + pipelines: t.List[metadata_store_pb2.Context] = self._db.store.get_contexts_by_type(Type.PIPELINE) + return [_graph_node(Pipeline, self._db, pipeline) for pipeline in pipelines] - def stages(self, query: t.Optional[str] = None) -> t.List["Stage"]: - """Retrieve stages. - Stages are represented as contexts in MLMD with type attributed equal to `Pipeline_Stage`. - Args: - query: The `filter_query` field for the `ListOptions` instance. See class doc strings for examples. - Raises: - ValueError when query is present, and results in contexts that are not stages. - Known limitations: - When query is not None, it must define type (type = 'Stage_Context'), when ID is present this may not be - required though. - No filtering is supported by pipeline ID (pretty significant feature). - Query examples: - Filtering by basic node attributes - See class foc strings. - Filter by pipeline attributes (parent_contexts_a is the pipeline context). There is a bug that prevents - using the ID for filtering by pipeline ID. It's fixed in MLMD version 1.14.0 (CMF uses the earlier version). - "parent_contexts_a.name = 'text-classification'", "parent_contexts_a.name LIKE '%-generation'", - "parent_contexts_a.type = 'Parent_Context'" - """ - return self._search_contexts(_STAGE_CONTEXT_NAME, Stage, query) + def stages(self) -> t.List["Stage"]: + """Return all stages.""" + stages: t.List[metadata_store_pb2.Context] = self._db.store.get_contexts_by_type(Type.STAGE) + return [_graph_node(Stage, self._db, stage) for stage in stages] def executions(self) -> t.List["Execution"]: - """Retrieve stage executions. - See `pipelines` method for more details. - """ - executions: t.List[mlpb.Execution] = self._db.store.get_executions() - return [Execution._create(self._db, execution) for execution in executions] + """Return all stage executions.""" + executions: t.List[metadata_store_pb2.Execution] = self._db.store.get_executions() + return [_graph_node(Execution, self._db, execution) for execution in executions] - def artifacts(self, query: t.Optional[str] = None) -> t.List["Artifact"]: - """ - # Find artifact with this ID - id = 1315 - # Find artifacts with a particular ArtifactType - type = 'Model' - type = 'Dataset' - # Find artifacts using pattern matching - name LIKE 'models/%' - name LIKE '%falcon%' - name LIKE 'datasets/%' - # Search using properties and custom_properties: - properties.url.string_value LIKE '%a655dead548f56fe3409321b3569a3%' - properties.pipeline_tag.string_value = 'text-classification' - custom_properties.pipeline_tag.string_value = 'text-classification' - """ - artifacts: t.List[mlpb.Artifact] = self._db.store.get_artifacts(self.list_options(query)) - return [Artifact._create(self._db, artifact) for artifact in artifacts] + def artifacts(self) -> t.List["Artifact"]: + """Return all artifacts.""" + artifacts: t.List[metadata_store_pb2.Artifact] = self._db.store.get_artifacts() + return [_graph_node(Artifact, self._db, artifact) for artifact in artifacts] - @staticmethod - def list_options(query: t.Optional[str] = None) -> t.Optional[ListOptions]: - list_options: t.Optional[ListOptions] = None - if query: - list_options = ListOptions(filter_query=query) - return list_options +def unique(items: t.List, key: t.Optional[t.Union[str, t.Callable]] = None) -> t.List: + """Return unique input items in the input list, possible using `key` to determine uniqueness of items. + + Always maintains order of items when `key` is not none. When key is none, order is maintained starting python 3.6. + + Args: + items: List of input items. + key: Attribute name or a function that computes `item` key. The `operator.attrgetter` is used + when type is string (so nested structured are supported, e.g., "type.id"). If it is None, items are + considered to be keys. + Returns: + New list containing unique items according to specified key. + """ + if key is None: + return list(dict.fromkeys(items)) + if isinstance(key, str): + from operator import attrgetter -def get_python_value(value: mlpb.Value) -> t.Union[str, int, float]: + key = attrgetter(key) + + keys, unique_items = set(), [] + for item in items: + item_key = key(item) + if item_key not in keys: + keys.add(item_key) + unique_items.append(item) + return unique_items + + +def one(items: t.List[t.Any], return_none_if_empty: bool = False, error: t.Any = None) -> t.Any: + """Return the only element in the input list. + + Args: + items: List of input items. + return_none_if_empty: If true, return None when `items` is empty, if false - raise `error`. + error: This is thrown when `items` contains wrong number of items. + Returns: + The only item in the list or None if list is empty and return_none_if_empty is true. + Raises: + ValueError error if length of `nodes` is not 1 when `error` is None or `error`. + """ + if not items and return_none_if_empty: + return None + if len(items) != 1: + if error is None: + error = ValueError(f"List (len={len(items)}) expected to contain one element.") + raise error + return items[0] + + +def _python_value(value: metadata_store_pb2.Value) -> t.Union[str, int, float]: """Convert MLMD value to a python value. Args: value: MLMD value. Returns: Python value. + + TODO: debug under what circumstances this function receives a non-`metadata_store_pb2.Value` values (this + happens when testing `Type.properties`). + """ + if isinstance(value, (str, int, float)): + return value + elif isinstance(value, metadata_store_pb2.Value): + if value.HasField("string_value"): + return value.string_value + elif value.HasField("int_value"): + return value.int_value + elif value.HasField("double_value"): + return value.double_value + raise NotImplementedError(f"Unsupported `metadata_store_pb2.Value` value (value={value}).") + raise NotImplementedError(f"Unsupported value type (type={type(value)})") + + +def _graph_node(node_type: t.Type[Node], db: CmfQuery, mlmd_node: MlmdNode, attrs: t.Optional[t.Dict] = None) -> Node: + """Create class instance (users are not supposed to call this method by themselves). + Args: + node_type: Graph node type to create (Pipeline, Stage, Execution or Stage) derived from Base. + db: Data access layer. + mlmd_node: Node in MLMD database. + attrs: Optional attributes to set on newly created class instance (with `setattr` function). + Returns: + Instance of one of child classes. """ - if value.HasField("string_value"): - return value.string_value - elif value.HasField("int_value"): - return value.int_value - elif value.HasField("double_value"): - return value.double_value - raise NotImplementedError("Only string, int and double fields are supported.") + node = node_type() + node._db = db + node._node = mlmd_node + if attrs: + for name, value in attrs.items(): + setattr(node, name, value) + return node diff --git a/test/contrib/test_graph_api.py b/test/contrib/test_graph_api.py new file mode 100644 index 00000000..58ca56be --- /dev/null +++ b/test/contrib/test_graph_api.py @@ -0,0 +1,317 @@ +### +# Copyright (2023) Hewlett Packard Enterprise Development LP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +### + +import typing as t +from pathlib import Path +from tempfile import gettempdir +from unittest import TestCase, skipUnless + +from cmflib.contrib.graph_api import ( + Artifact, + Execution, + MlmdGraph, + Node, + Pipeline, + Properties, + Stage, + Type, + _graph_node, + one, + unique, +) + +_TEST_MLMD_FILE = Path(gettempdir(), "mlmd.sqlite") +"""This file must exist for test in this module to run, else they will be skipped.""" + + +class AttrTest: + """Helper class to run test for an object's attributed. + + Args: + attr: Name of an object attribute. + expected_type: Expected type of attribute value. + test_fn: callable object that takes an attribute value and returns true/false (additional checks). + """ + + def __init__(self, attr: str, expected_type: t.Type, test_fn: t.Optional[t.Callable] = None) -> None: + self.attr = attr + self.expected_type = expected_type + self.test_fn = test_fn or (lambda v: True) + + def __call__(self, test_case: TestCase, obj: t.Any) -> None: + """Run a test against this attribute. + + Args: + test_case: Test case instance that runs this test. + obj: Object to run attribute test for. + """ + # Test attribute exists and check attribute's value type. + test_case.assertTrue(hasattr(obj, self.attr), f"Attribute does not exist (attr={self.attr}, node={obj}).") + value = getattr(obj, self.attr) + test_case.assertIsInstance( + value, + self.expected_type, + f"Wrong attr type (attr={self.attr}, actual_type={type(value)}, " + f"expected_type={self.expected_type}, object={obj}).", + ) + # Run additional tests. + test_case.assertTrue( + self.test_fn(value), + f"Test is false for object attribute (attr={self.attr}, value={value}, " + f"type={self.expected_type}, object={obj}).", + ) + + +class ObjTest: + """Run tests against some object. + + Args: + expected_type: Expected type of this object. + attr_tests: Tests for object's attributes. + """ + + def __init__(self, expected_type: t.Type, attr_tests: t.Optional[t.List[AttrTest]] = None) -> None: + self.expected_type = expected_type + self.attr_tests = attr_tests or [] + + def __call__(self, test_case: TestCase, obj: t.Any) -> None: + """Run a test against this object. + + Args: + test_case: Test case instance that runs this test. + obj: Object to run this test for. + """ + test_case.assertIsInstance(obj, self.expected_type) + for attr_test in self.attr_tests: + attr_test(test_case, obj) + + +@skipUnless(_TEST_MLMD_FILE.exists(), f"The `mlmd.sqlite` does not exist in {gettempdir()}.") +class TestNodes(TestCase): + def setUp(self): + self.md = MlmdGraph(_TEST_MLMD_FILE.as_posix()) + """The `md` stands for metadata.""" + + def test_types(self) -> None: + # Test type names for accidental change. + self.assertEqual(Type.PIPELINE, "Parent_Context") + self.assertEqual(Type.STAGE, "Pipeline_Stage") + self.assertEqual(Type.DATASET, "Dataset") + self.assertEqual(Type.MODEL, "Model") + self.assertEqual(Type.METRICS, "Metrics") + + # Tester instance for `Type` instances. + attr_test = ObjTest( + Type, + [ + AttrTest("id", int, lambda v: v >= 0), + AttrTest("name", str, lambda v: len(v) > 0), + AttrTest("version", str, lambda v: len(v) >= 0), + AttrTest("description", str, lambda v: len(v) >= 0), + AttrTest("external_id", str, lambda v: len(v) >= 0), + AttrTest("properties", Properties), + ], + ) + + # Test types for pipelines, stages, executions and artifacts. + def _test_types(_nodes: t.List[Node], _expected_type_name: t.Optional[str] = None) -> None: + self.assertTrue(len(_nodes) > 0) + for _node in _nodes: + if _expected_type_name is not None: + self.assertEqual(_node.type.name, _expected_type_name) + attr_test(self, _node.type) + + _test_types(self.md.pipelines(), Type.PIPELINE) + _test_types(self.md.stages(), Type.STAGE) + _test_types(self.md.executions()) + _test_types(self.md.artifacts()) + + def test_base(self) -> None: + pipelines: t.List[Pipeline] = self.md.pipelines() + self.assertTrue(len(pipelines) > 0) + + pipeline: Pipeline = pipelines[0] + # Test these properties are read-only. + self.assertRaises(AttributeError, setattr, pipeline, "id", 1) + self.assertRaises(AttributeError, setattr, pipeline, "name", "new name") + + # Test properties and custom properties are read only + with self.assertRaises(TypeError): + pipeline.properties["key"] = "value" + with self.assertRaises(TypeError): + pipeline.custom_properties["key"] = "value" + + # Test users can iterate over properties and custom properties + for k, v in pipeline.properties.items(): + print("properties (k, v):", k, v) + for k, v in pipeline.custom_properties.items(): + print("custom_properties (k, v):", k, v) + + def test_pipelines(self) -> None: + pipelines: t.List[Pipeline] = self.md.pipelines() + self.assertTrue(len(pipelines) > 0) + for pipeline in pipelines: + self.assertIsInstance(pipeline, Pipeline) + self.assertEqual(pipeline.type.name, Type.PIPELINE) + + self._test_list_of_nodes(pipeline.stages(), Stage, Type.STAGE) + self._test_list_of_nodes(pipeline.executions(), Execution, None) + self._test_list_of_nodes(pipeline.artifacts(), Artifact, None) + + def test_stages(self) -> None: + stages: t.List[Stage] = self.md.stages() + self.assertTrue(len(stages) > 0) + for stage in stages: + self.assertIsInstance(stage.pipeline, Pipeline) + self.assertEqual(stage.pipeline.type.name, Type.PIPELINE) + self._test_list_of_nodes(stage.executions(), Execution, None) + self._test_list_of_nodes(stage.artifacts(), Artifact, None, empty_ok=True) + + def test_executions(self) -> None: + executions: t.List[Execution] = self.md.executions() + self.assertTrue(len(executions) > 0) + for execution in executions: + self.assertIsInstance(execution.stage, Stage) + self._test_list_of_nodes(execution.inputs(), Artifact, None, empty_ok=True) + self._test_list_of_nodes(execution.outputs(), Artifact, None, empty_ok=True) + + def test_artifacts(self) -> None: + artifacts: t.List[Artifact] = self.md.artifacts() + self.assertTrue(len(artifacts) > 0) + for artifact in artifacts: + # TODO (sergey) in the test MLMD dataset produced_by() returns empty lists sometimes + self._test_list_of_nodes(artifact.produced_by(), Execution, None, empty_ok=True) + self._test_list_of_nodes(artifact.consumed_by(), Execution, None, empty_ok=True) + + def _test_list_of_nodes( + self, nodes: t.List[Node], node_type_cls: t.Type, type_name: t.Optional[str] = None, empty_ok: bool = False + ) -> None: + if empty_ok: + self.assertTrue(len(nodes) >= 0) + else: + self.assertTrue(len(nodes) > 0) + for node in nodes: + self.assertIsInstance(node, node_type_cls) + if type_name is not None: + self.assertEqual(node.type.name, type_name) + + +@skipUnless(_TEST_MLMD_FILE.exists(), f"The `mlmd.sqlite` does not exist in {gettempdir()}.") +class TestGraphAPI(TestCase): + def setUp(self): + self.md = MlmdGraph(_TEST_MLMD_FILE.as_posix()) + """The `md` stands for metadata.""" + + def _compare_nodes(self, node1: Node, node2: Node, must_equal: bool = True) -> None: + """Check that two graph nodes are the same or different by invoking `hash` and `__eq__` methods.""" + assert_fn = self.assertEqual if must_equal else self.assertNotEqual + assert_fn(hash(node1), hash(node2)) + assert_fn(node1, node2) + + def _test_nodes(self, nodes: t.List[Node], obj_test: ObjTest, other_types: t.List[t.Type[Node]]) -> None: + """Run common tests for graph nodes.""" + # Check node type and common attributes + for node in nodes: + obj_test(self, node) + + # Check __hash__ and __eq__ methods + _graph_node(obj_test.expected_type, nodes[0]._db, nodes[0]._node) + self._compare_nodes( + nodes[0], _graph_node(obj_test.expected_type, nodes[0]._db, nodes[0]._node), must_equal=True + ) + self._compare_nodes(nodes[0], nodes[1], must_equal=False) + for other_type in other_types: + self._compare_nodes(nodes[0], _graph_node(other_type, nodes[0]._db, nodes[0]._node), must_equal=False) + + def _test_psa(self, nodes: t.List[Node], expected_node_type: t.Type, other_types: t.List[t.Type[Node]]) -> None: + """Run tests against metadata nodes for Pipelines, Stages and Attributes.""" + self.assertTrue(len(nodes) > 0) + self._test_nodes( + nodes, + ObjTest( + expected_node_type, + [ + AttrTest("id", int, lambda v: v >= 0), + AttrTest("name", str, lambda v: len(v) > 0), + AttrTest("type", Type), + ], + ), + other_types, + ) + + def test_pipelines(self) -> None: + self._test_psa(self.md.pipelines(), Pipeline, [Stage, Execution, Artifact]) + + def test_stages(self) -> None: + self._test_psa(self.md.stages(), Stage, [Pipeline, Execution, Artifact]) + + def test_artifacts(self) -> None: + self._test_psa(self.md.artifacts(), Artifact, [Pipeline, Execution, Stage]) + + def test_executions(self) -> None: + """Run tests for executions. + + TODO: (sergey) value of the `name` attribute is empty for some or all executions. Is this expected? + """ + executions: t.List[Execution] = self.md.executions() + self.assertTrue(len(executions) > 0) + self._test_nodes( + executions, + ObjTest( + Execution, + [ + AttrTest("id", int, lambda v: v >= 0), + AttrTest("name", str, lambda v: len(v) >= 0), + AttrTest("type", Type), + ], + ), + [Pipeline, Stage, Artifact], + ) + + +@skipUnless(_TEST_MLMD_FILE.exists(), f"The `mlmd.sqlite` does not exist in {gettempdir()}.") +class TestListOperators(TestCase): + def setUp(self): + self.md = MlmdGraph(_TEST_MLMD_FILE.as_posix()) + """The `md` stands for metadata.""" + + def test_one(self) -> None: + # Check default parameters + self.assertEqual(one([1]), 1) + self.assertRaises(ValueError, one, []) + self.assertRaises(RuntimeError, one, [], error=RuntimeError()) + + # Check input is empty + self.assertIsNone(one([], return_none_if_empty=True)) + self.assertIsNone(one([], return_none_if_empty=True, error=RuntimeError())) + + # Check input contains multiple elements + self.assertRaises(ValueError, one, [1, 2]) + self.assertRaises(ValueError, one, [1, 2], return_none_if_empty=True) + self.assertRaises(RuntimeError, one, [1, 2], return_none_if_empty=True, error=RuntimeError()) + + def test_unique(self) -> None: + # + self.assertListEqual(unique([1, 2, 3, 4]), [1, 2, 3, 4]) + self.assertListEqual(unique([12, 12, 45, 66, 67, 888]), [12, 45, 66, 67, 888]) + self.assertListEqual(unique([12, 45, 66, 67, 66, 888]), [12, 45, 66, 67, 888]) + self.assertListEqual(unique([45, 12, 45, 66, 67, 66, 888]), [45, 12, 66, 67, 888]) + self.assertListEqual(unique([12, 45, 66, 67, 888, 888]), [12, 45, 66, 67, 888]) + + # + ps: t.List[Pipeline] = self.md.pipelines() + self.assertListEqual(unique([ps[0], ps[1]], "id"), [ps[0], ps[1]]) + self.assertListEqual(unique([ps[0], ps[0]], "id"), [ps[0]]) From c79e3f2a1524103a144e7d5fc1f44bd28e97dda1 Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Tue, 17 Oct 2023 04:38:02 +0000 Subject: [PATCH 6/8] Work-in-progress commit. - Implementation of multiple analytic functions. - Unified mechanism to traverse graph of artifacts along their dependency paths. --- cmflib/contrib/query_engine.py | 280 +++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 cmflib/contrib/query_engine.py diff --git a/cmflib/contrib/query_engine.py b/cmflib/contrib/query_engine.py new file mode 100644 index 00000000..3f203803 --- /dev/null +++ b/cmflib/contrib/query_engine.py @@ -0,0 +1,280 @@ +import typing as t + +from contrib.graph_api import Artifact, Execution, Type, one, unique + +__all__ = ["Visitor", "Accept", "Stop", "Traverse", "QueryEngine"] + + +class Visitor: + """Class that `visits` MLMD artifact nodes. + + Args: + acceptor: Callable that takes one artifact and returns True if this node should be accepted (stored in + `Visitor.artifacts` list). + stopper: Callable that takes a list of accepted artifacts and returns True of graph traversal should stop. + """ + + def __init__(self, acceptor: t.Optional[t.Callable] = None, stopper: t.Optional[t.Callable] = None) -> None: + self._acceptor = acceptor or Accept.all + self._stopper = stopper or Stop.never + + self.artifacts: t.List[Artifact] = [] + """List of accepted artifacts.""" + + self.stopped: bool = False + """True if this visitor has accepted all requested artifacts.""" + + def visit(self, artifact: Artifact) -> bool: + """Process artifact and return True if traversal should stop. + Args: + artifact: New artifact to process. + Returns: + True if graph traversal should stop. + """ + if self._acceptor(artifact): + self.artifacts.append(artifact) + self.stopped = self._stopper(self.artifacts) + return self.stopped + + +class Accept: + """Class that implements various acceptor functions.""" + + @staticmethod + def all(_artifact: Artifact) -> bool: + """Accept all artifacts.""" + return True + + @staticmethod + def by_type(type_: str) -> t.Callable: + """Accept artifacts of this particular type.""" + + def _accept(artifact: Artifact) -> bool: + return artifact.type.name == type_ + + return _accept + + @staticmethod + def by_id(id_: t.Union[int, t.Set[int]]) -> t.Callable: + """Accept artifacts with this ID or IDs.""" + + def _accept_one(artifact: Artifact) -> bool: + return artifact.id == id_ + + def _accept_many(artifact: Artifact) -> bool: + return artifact.id in id_ + + return _accept_one if isinstance(id_, int) else _accept_many + + +class Stop: + """Class that implements various graph traversal stoppers.""" + + @staticmethod + def never(_artifacts: t.List[Artifact]) -> bool: + """Never stop, visit all nodes.""" + return False + + @staticmethod + def by_accepted_count(count: int) -> t.Callable: + """Accept this number of artifacts and stop.""" + + def _stop(artifacts: t.List[Artifact]) -> bool: + return len(artifacts) == count + + return _stop + + +class Traverse: + """Upstream and downstream traversal algorithms for artifacts.""" + + @staticmethod + def _traverse(artifact: Artifact, visitor: Visitor, direction: str) -> Visitor: + """Traverse artifacts in upstream or downstream direction. + + Artifact traversal follows `dependency` path. When traversing downstream, all output artifacts of some execution + depend on any input artifact, while when traversing upstream, only those output artifacts of some execution + are considered that are inputs to previously visited executions. + + Args: + artifact: Anchor artifact to start with. + visitor: An instance of `Visitor` class that decides what artifacts need to be accepted and when traversal + should stop. + direction: One of `upstream` or `downstream`. + """ + + if direction not in ("upstream", "downstream"): + raise ValueError(f"Internal Error: unsupported traversal direction (`{direction}`).") + + def _next_executions(_artifact: Artifact) -> t.List[Execution]: + return _artifact.consumed_by() if direction == "downstream" else _artifact.produced_by() + + def _next_artifacts(_execution: Execution) -> t.List[Artifact]: + return _execution.outputs() if direction == "downstream" else _execution.inputs() + + visited: t.Set[int] = set() + pending: t.List[Execution] = _next_executions(artifact).copy() + + while pending: + execution: Execution = pending.pop() + if execution.id in visited: + continue + visited.add(execution.id) + for artifact in _next_artifacts(execution): + if visitor.visit(artifact): + pending.clear() + break + pending.extend((e for e in _next_executions(artifact) if e.id not in visited)) + + return visitor + + @staticmethod + def downstream(artifact: Artifact, visitor: Visitor) -> Visitor: + return Traverse._traverse(artifact, visitor, "downstream") + + @staticmethod + def upstream(artifact: Artifact, visitor: Visitor) -> Visitor: + return Traverse._traverse(artifact, visitor, "upstream") + + +class QueryEngine: + """Query and search engine for ML and pipeline metadata. + + This implementation uses the graph API to access metadata. Basic metadata search is supported by graph API + (`MlmdGraph`) by providing the following features: + - Iterating over graph nodes: pipelines, stages, executions and artifacts. + - Traversing the metadata graph using the following relations: + pipeline -> stages, executions, artifacts + stage -> pipeline, executions, artifacts + execution -> stage, artifacts (inputs and outputs) + artifact -> executions (consumed_by and produced_by) + + This class is based on `MlmdGraph` to provide high-level query and search features for multiple common use cases. + """ + + def __init__(self) -> None: + ... + + def is_model_trained_on_dataset(self, model: Artifact, dataset: Artifact) -> bool: + """Return true if this model was trained on this dataset. + + Args: + model: Machine learning model + dataset: Training dataset. + Returns: + True if this model was trained on this dataset. + + TODO: (sergey) How do I know if this dataset was used as a train and not test or validation dataset? + """ + _check_artifact_type(model, Type.MODEL) + _check_artifact_type(dataset, Type.DATASET) + visitor: Visitor = Traverse.downstream(dataset, Visitor(Accept.by_id(model.id), Stop.by_accepted_count(1))) + return visitor.stopped + + def is_on_the_same_lineage_path(self, artifacts: t.List[Artifact]) -> bool: + """Determine if all artifacts belong to one lineage path. + + Args: + artifacts: List of artifacts. + Returns: + True when all artifacts are connected via dependency chain. + """ + if not artifacts: + return False + + artifacts = unique(artifacts.copy(), "id") + if len(artifacts) == 1: + return True + + anchor_artifact = artifacts.pop(0) + ids = set((artifact.id for artifact in artifacts)) + # TODO (sergey) can this accept the same node multiple times so that visited nodes are counted wrong? + visitor = Visitor(Accept.by_id(ids), Stop.by_accepted_count(len(ids))) + + visitor = Traverse.upstream(anchor_artifact, visitor) + if not visitor.stopped: + visitor = Traverse.downstream(anchor_artifact, visitor) + + return visitor.stopped + + def get_datasets_by_dataset(self, dataset: Artifact) -> t.List[Artifact]: + """Return all datasets produced by executions that directly or indirectly depend on this dataset. + + Args: + dataset Training dataset. + Returns: + Datasets that directly or indirectly depend on input dataset. + """ + _check_artifact_type(dataset, Type.DATASET) + visitor: Visitor = Traverse.downstream(dataset, Visitor(Accept.by_type(Type.DATASET))) + return visitor.artifacts + + def get_models_by_dataset(self, dataset: Artifact) -> t.List[Artifact]: + """Get all models that depend on this dataset. + Args: + dataset: Dataset + Returns: + List of unique models trained on the given dataset. + """ + _check_artifact_type(dataset, Type.DATASET) + visitor = Traverse.downstream(dataset, Visitor(Accept.by_type(Type.MODEL))) + return visitor.artifacts + + def get_metrics_by_models(self, models: t.List[Artifact]) -> t.List[Artifact]: + """Return metrics for each model. + + A model and its metrics are `siblings`, i.e., they must be in the list of output artifacts of some execution. + + Args: + models: List of models. + Returns: + List of artifacts that have `Type.METRICS` type. There's one to one correspondence of models in an input + list and metrics in an output list. + """ + metrics: t.List[t.Optional[Artifact]] = [None] * len(models) + for idx, model in enumerate(models): + if model.type.name != Type.MODEL: + raise ValueError(f"Input artifact is not a model (idx={idx}, artifact={model}).") + execution: Execution = one( + model.produced_by, + error=NotImplementedError( + f"Multiple producer executions ({len(model.produced_by)}) are not supported yet." + ), + ) + metrics[idx] = one( + [a for a in execution.outputs() if a.type.name == Type.METRICS], + return_none_if_empty=True, + error=NotImplementedError("Multiple metrics in one execution are not supported yet."), + ) + return metrics + + def get_metrics_by_executions(self, executions: t.List[Execution]) -> t.List[Artifact]: + """Return metrics for each execution. + + Args: + executions: List of executions. + Returns: + List of metrics. There's one to one correspondence of executions in an input list and metrics in + an output list. + """ + metrics: t.List[t.Optional[Artifact]] = [None] * len(executions) + for idx, execution in enumerate(executions): + metrics[idx] = one( + [a for a in execution.outputs if a.type.name == Type.METRICS], + return_none_if_empty=True, + error=NotImplementedError("Multiple metrics in one execution are not supported yet."), + ) + return metrics + + +def _check_artifact_type(artifact: Artifact, type_name: str) -> None: + """Helper function to check input artifact has required type. + + Args: + artifact: Input artifact. + type_name: Name of a type that this artifact is expected to be. + Raises: + ValueError when types mismatch. + """ + if artifact.type.name != type_name: + raise ValueError(f"Invalid artifact type (type={artifact.type}). Expected type is '{type_name}'.") From 0cd835de39b05ec9bde912a57e2b40c7cf74798d Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Wed, 18 Oct 2023 23:36:51 +0000 Subject: [PATCH 7/8] Work in progress updates - New traverse API. - Base methods. - Allow users to specify selection criteria for artifacts in some methods. --- cmflib/contrib/query_engine.py | 193 +++++++++++++++++++++++++++------ 1 file changed, 159 insertions(+), 34 deletions(-) diff --git a/cmflib/contrib/query_engine.py b/cmflib/contrib/query_engine.py index 3f203803..41d46a8c 100644 --- a/cmflib/contrib/query_engine.py +++ b/cmflib/contrib/query_engine.py @@ -1,6 +1,6 @@ import typing as t -from contrib.graph_api import Artifact, Execution, Type, one, unique +from contrib.graph_api import Artifact, Execution, Node, Type, one, unique __all__ = ["Visitor", "Accept", "Stop", "Traverse", "QueryEngine"] @@ -41,7 +41,7 @@ class Accept: """Class that implements various acceptor functions.""" @staticmethod - def all(_artifact: Artifact) -> bool: + def all(_node: Node) -> bool: """Accept all artifacts.""" return True @@ -49,8 +49,8 @@ def all(_artifact: Artifact) -> bool: def by_type(type_: str) -> t.Callable: """Accept artifacts of this particular type.""" - def _accept(artifact: Artifact) -> bool: - return artifact.type.name == type_ + def _accept(node: Node) -> bool: + return node.type.name == type_ return _accept @@ -58,11 +58,11 @@ def _accept(artifact: Artifact) -> bool: def by_id(id_: t.Union[int, t.Set[int]]) -> t.Callable: """Accept artifacts with this ID or IDs.""" - def _accept_one(artifact: Artifact) -> bool: - return artifact.id == id_ + def _accept_one(node: Node) -> bool: + return node.id == id_ - def _accept_many(artifact: Artifact) -> bool: - return artifact.id in id_ + def _accept_many(node: Node) -> bool: + return node.id in id_ return _accept_one if isinstance(id_, int) else _accept_many @@ -137,6 +137,7 @@ def upstream(artifact: Artifact, visitor: Visitor) -> Visitor: return Traverse._traverse(artifact, visitor, "upstream") +# noinspection PyMethodMayBeStatic class QueryEngine: """Query and search engine for ML and pipeline metadata. @@ -150,25 +151,42 @@ class QueryEngine: artifact -> executions (consumed_by and produced_by) This class is based on `MlmdGraph` to provide high-level query and search features for multiple common use cases. + + TODO (sergey) Some methods assume (for simplicity) that any given artifact has only one producer execution. """ def __init__(self) -> None: ... - def is_model_trained_on_dataset(self, model: Artifact, dataset: Artifact) -> bool: - """Return true if this model was trained on this dataset. + def is_direct_descendant(self, parent: Artifact, child: Artifact) -> bool: + """Return true if the given child is the direct descendant of the given parent. + + This implementation assumes that the iven child was produced by exactly one execution. The parent and + child must then be related as: [parent] --input-> execution --output--> [child] Args: - model: Machine learning model - dataset: Training dataset. + parent: Candidate parent. + child: Candidate child. Returns: - True if this model was trained on this dataset. + True if given child is the direct descendant of the given parent. + """ + # FIXME (sergey) assumption is there's only one producer execution. + execution: Execution = one(child.produced_by()) + input_ids: t.Set[int] = set((artifact.id for artifact in execution.inputs())) + return parent.id in input_ids - TODO: (sergey) How do I know if this dataset was used as a train and not test or validation dataset? + def is_descendant(self, ancestor: Artifact, descendant: Artifact) -> bool: + """Identify if `descendant` artifact depends on `ancestor` artifact implicitly or explicitly. + + Args: + ancestor: Artifact that is supposedly resulted in descendant. + descendant: Artifact that is supposedly depend on ancestor. + Returns: + True if ancestor and descendant are on the same lineage path, and descendant is the downstream for ancestor. """ - _check_artifact_type(model, Type.MODEL) - _check_artifact_type(dataset, Type.DATASET) - visitor: Visitor = Traverse.downstream(dataset, Visitor(Accept.by_id(model.id), Stop.by_accepted_count(1))) + visitor: Visitor = Traverse.downstream( + ancestor, Visitor(Accept.by_id(descendant.id), Stop.by_accepted_count(1)) + ) return visitor.stopped def is_on_the_same_lineage_path(self, artifacts: t.List[Artifact]) -> bool: @@ -177,7 +195,7 @@ def is_on_the_same_lineage_path(self, artifacts: t.List[Artifact]) -> bool: Args: artifacts: List of artifacts. Returns: - True when all artifacts are connected via dependency chain. + True when all artifacts are connected via a dependency chain. """ if not artifacts: return False @@ -197,28 +215,107 @@ def is_on_the_same_lineage_path(self, artifacts: t.List[Artifact]) -> bool: return visitor.stopped - def get_datasets_by_dataset(self, dataset: Artifact) -> t.List[Artifact]: + def get_siblings( + self, artifact: Artifact, select: t.Optional[t.Callable[[Artifact], bool]] = None + ) -> t.List[Artifact]: + """Get all siblings of the given node. + + Siblings in this method are defined as other artifacts produced by the same execution that produced this + artifact. It is assumed there's one producer execution. The artifact itself is not sibling to itself. + + Args: + artifact: Artifact. + select: Callable object that accepts an artifact and returns True if this artifact should be added to list + of siblings. The artifact itself is excluded from siblings automatically. + """ + # FIXME (sergey) assumption is there's only one producer execution. + execution: Execution = one(artifact.produced_by()) + siblings = [sibling for sibling in execution.outputs() if sibling.id != artifact.id] + return _select(siblings, select) + + def get_direct_descendants( + self, + parent: Artifact, + select_executions_fn: t.Optional[t.Callable[[Execution], bool]] = None, + select_artifact_fn: t.Optional[t.Callable[[Artifact], bool]] = None, + ) -> t.List[t.Tuple[Execution, t.List[Artifact]]]: + """Return all direct descendants (immediate children) of the given parent artifact. + + Immediate descendants are those for which `is_direct_descendant` method returns True - see doc strings for more + details. + + """ + descendants: t.List[t.Tuple[Execution, t.List[Artifact]]] = [] + executions = _select(parent.consumed_by(), select_executions_fn) + for execution in executions: + execution_outputs = _select(execution.outputs(), select_artifact_fn) + if execution_outputs: + descendants.append((execution, execution_outputs)) + return descendants + + def get_descendants( + self, artifact: Artifact, select_fn: t.Optional[t.Callable[[Artifact], bool]] = None + ) -> t.List[Artifact]: + """Get all descendants of the given artifact. + + A descendant is defined as an artifact reachable starting the anchor node and traversing downstream. + + Args: + artifact: Anchor artifact. + select_fn: Only those artifacts for which this callable returns True are returned. + Returns: + List of all descendants artifacts for which select_fn(descendant) == True. + """ + visitor: Visitor = Traverse.downstream(artifact, Visitor(acceptor=select_fn)) + return visitor.artifacts + + def is_model_trained_on_dataset(self, model: Artifact, dataset: Artifact) -> bool: + """Return true if this model was trained on this dataset. + + This method verifies there is a path between the `dataset` and the `model`, in other words, the model should + be reachable when traversing MLMD graph starting from the `dataset` node in the downstream direction. + Depending on adopted best practices, there maybe an easier (better to some extent) approach to test this. In + some cases, a model and a dataset must be connected to the same execution in order to be related via + `trained_on` relation. + + Args: + model: Machine learning model + dataset: Training dataset. + Returns: + True if this model was trained on this dataset. + """ + _check_artifact_type(model, Type.MODEL) + _check_artifact_type(dataset, Type.DATASET) + return self.is_descendant(ancestor=dataset, descendant=model) + + def get_datasets_by_dataset( + self, dataset: Artifact, select_fn: t.Optional[t.Callable[[Artifact], bool]] + ) -> t.List[Artifact]: """Return all datasets produced by executions that directly or indirectly depend on this dataset. Args: - dataset Training dataset. + dataset: Training dataset. + select_fn: Function to select datasets matching certain criteria. This function does not need to check + artifact type - datasets are selected automatically. Returns: Datasets that directly or indirectly depend on input dataset. """ _check_artifact_type(dataset, Type.DATASET) - visitor: Visitor = Traverse.downstream(dataset, Visitor(Accept.by_type(Type.DATASET))) - return visitor.artifacts + return self.get_descendants(dataset, _combine_select_fn(Accept.by_type(Type.DATASET), select_fn)) - def get_models_by_dataset(self, dataset: Artifact) -> t.List[Artifact]: + def get_models_by_dataset( + self, dataset: Artifact, select_fn: t.Optional[t.Callable[[Artifact], bool]] + ) -> t.List[Artifact]: """Get all models that depend on this dataset. Args: dataset: Dataset + select_fn: Function to select models matching certain criteria. This function does not need to check + artifact type - models are selected automatically. Returns: List of unique models trained on the given dataset. """ _check_artifact_type(dataset, Type.DATASET) - visitor = Traverse.downstream(dataset, Visitor(Accept.by_type(Type.MODEL))) - return visitor.artifacts + return self.get_descendants(dataset, _combine_select_fn(Accept.by_type(Type.MODEL), select_fn)) def get_metrics_by_models(self, models: t.List[Artifact]) -> t.List[Artifact]: """Return metrics for each model. @@ -233,16 +330,9 @@ def get_metrics_by_models(self, models: t.List[Artifact]) -> t.List[Artifact]: """ metrics: t.List[t.Optional[Artifact]] = [None] * len(models) for idx, model in enumerate(models): - if model.type.name != Type.MODEL: - raise ValueError(f"Input artifact is not a model (idx={idx}, artifact={model}).") - execution: Execution = one( - model.produced_by, - error=NotImplementedError( - f"Multiple producer executions ({len(model.produced_by)}) are not supported yet." - ), - ) + _check_artifact_type(model, Type.MODEL) metrics[idx] = one( - [a for a in execution.outputs() if a.type.name == Type.METRICS], + self.get_siblings(model, Accept.by_type(Type.METRICS)), return_none_if_empty=True, error=NotImplementedError("Multiple metrics in one execution are not supported yet."), ) @@ -259,6 +349,7 @@ def get_metrics_by_executions(self, executions: t.List[Execution]) -> t.List[Art """ metrics: t.List[t.Optional[Artifact]] = [None] * len(executions) for idx, execution in enumerate(executions): + # FIXME (sergey): It is assumed there's one or zero metric artifacts for each execution. metrics[idx] = one( [a for a in execution.outputs if a.type.name == Type.METRICS], return_none_if_empty=True, @@ -278,3 +369,37 @@ def _check_artifact_type(artifact: Artifact, type_name: str) -> None: """ if artifact.type.name != type_name: raise ValueError(f"Invalid artifact type (type={artifact.type}). Expected type is '{type_name}'.") + + +def _select(items: t.List, select_fn: t.Optional[t.Callable[[t.Any], bool]] = None) -> t.List: + """Select items from `items` using select_fn condition. + + Args: + items: Input list of items. + select_fn: Callable object that accepts one item from items and returns True if it should be selected. + Returns: + Input list if `select_fn` is None, else those items for which select_fn(item) returns True. + """ + if not select_fn: + return items + return [item for item in items if select_fn(item)] + + +def _combine_select_fn(func_a: t.Callable, func_b: t.Optional[t.Callable] = None) -> t.Callable: + """Combine two selection functions using AND operator. + + Selection function is a function that takes one parameter and returns True or False. + + Args: + func_a: First selection function is always present. + func_b: Second selection function is optional. + Returns: + func_a if func_b is None else new function that implements func_a(input) and func_b(input) + """ + if func_b is None: + return func_a + + def _combined_fn(node: t.Any) -> bool: + return func_a(node) and func_b(node) + + return _combined_fn From f9be98c30516620bbfbb760b4fbc9e7656c643d3 Mon Sep 17 00:00:00 2001 From: Sergey Serebryakov Date: Mon, 23 Oct 2023 17:42:33 +0000 Subject: [PATCH 8/8] Work in progress updates. - Adding HPE header. - Improving doc strings for classes that are responsible for traversing MLMD graph of artifacts. --- cmflib/contrib/query_engine.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/cmflib/contrib/query_engine.py b/cmflib/contrib/query_engine.py index 41d46a8c..76279f8f 100644 --- a/cmflib/contrib/query_engine.py +++ b/cmflib/contrib/query_engine.py @@ -1,3 +1,19 @@ +### +# Copyright (2023) Hewlett Packard Enterprise Development LP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +### + import typing as t from contrib.graph_api import Artifact, Execution, Node, Type, one, unique @@ -8,6 +24,9 @@ class Visitor: """Class that `visits` MLMD artifact nodes. + It is used to traverse MLMD graph and to collect artifacts of interest. An artifact is collected if it is accepted + by an acceptor function. + Args: acceptor: Callable that takes one artifact and returns True if this node should be accepted (stored in `Visitor.artifacts` list). @@ -38,7 +57,12 @@ def visit(self, artifact: Artifact) -> bool: class Accept: - """Class that implements various acceptor functions.""" + """Class that implements various acceptor functions. + + They are used to inform graph visitors if a particular artifact (or node in general) should be collected. All + acceptor functions must accept one parameter of type `Node` and return True if this artifact is "accepted", e.g., + matchers user criteria. + """ @staticmethod def all(_node: Node) -> bool: @@ -68,7 +92,12 @@ def _accept_many(node: Node) -> bool: class Stop: - """Class that implements various graph traversal stoppers.""" + """Class that implements various graph traversal stoppers. + + Teh visitor instance can instruct the traversal algorithm to stop after visiting each artifact in a graph. Stopper + functions are used to determine if traversal should be stopped. A stopper function accepts one argument - list of + collected (accepted) artifacts so far, and returns True if traversal should be stopped. + """ @staticmethod def never(_artifacts: t.List[Artifact]) -> bool: