Skip to content

Commit

Permalink
Ray model data logging (#216)
Browse files Browse the repository at this point in the history
* model and dataset logging for raytune using CMF

* Added top_n results

* Updated Docs
  • Loading branch information
rishabhsharma22 authored Dec 11, 2024
1 parent 1608a38 commit 462e9e4
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 17 deletions.
103 changes: 88 additions & 15 deletions cmflib/cmf_ray_logger.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,58 @@
from ray import tune
from ray.tune import Callback
from cmflib import cmf
import heapq

class CmfRayLogger(Callback):
#id_count = 1

def __init__(self, pipeline_name, file_path, pipeline_stage, data_dir = None):
def __init__(self, pipeline_name, file_path, pipeline_stage, data_dir = None, metric = 'accuracy', order = 'max', top_n=5 ):
"""
pipeline_name: The name of the CMF Pipelibe
pipeline_name: The name of the CMF Pipeline
file_path: The path to metadata file
pipeline_stage: The name for the stage of cmf_pipeline
data_dir = Directory/File for data that is logged
metric: The metric to track (e.g., 'accuracy', 'loss')
order: 'max' for maximum, 'min' for minimum
top_n: Number of top results to keep
"""
self.pipeline_name = pipeline_name
self.file_path = file_path
self.pipeline_stage = pipeline_stage
self.cmf_obj = {}
self.cmf_run = {}
self.data_dir = data_dir
self.metric = metric
self.order = order
self.top_n = top_n

# Initialize heap based on user-defined order
self.heap = []
self.heap_comparator = -1 if self.order == 'max' else 1

# Dictionary to track best metric and model for each trial
self.best_metric_values = {}
self.best_models = {}
self.execution_ids = {}

def on_trial_start(self, iteration, trials, trial, **info):
trial_id = trial.trial_id
trial_config = trial.config
print(f"CMF Logging Started for Trial {trial_id}")
self.cmf_obj[trial_id] = cmf.Cmf(filepath = self.file_path, pipeline_name = self.pipeline_name)
_ = self.cmf_obj[trial_id].create_context(pipeline_stage = self.pipeline_stage)
_ = self.cmf_obj[trial_id].create_execution(execution_type=f"Trial_{trial_id}",
execution_id = self.cmf_obj[trial_id].create_execution(execution_type=f"Trial_{trial_id}",
create_new_execution = False,
custom_properties = {'Configuration': trial_config})
#self.execution_id[trial_id] = CmfRayLogger.id_count
#CmfRayLogger.id_count+=1

# Store the execution_type which will be used to update the execution later
self.execution_ids[trial_id] = execution_id.id

if self.data_dir:
_ = self.cmf_obj[trial_id].log_dataset(url = str(self.data_dir), event = "input")
_ = self.cmf_obj[trial_id].log_dataset(url = str(self.data_dir), event = 'input')

self.best_metric_values[trial_id] = None
self.best_models[trial_id] = None

def on_trial_result(self, iteration, trials, trial, result, **info):
trial_id = trial.trial_id
Expand All @@ -47,24 +69,55 @@ def on_trial_result(self, iteration, trials, trial, result, **info):
custom_properties = {'Output': curr_res})
self.cmf_run[trial_id] = True

# Track the current metric value and model path (if available)
metric_value = curr_res.get(self.metric, None)
model_path = curr_res.get('model_path', None) # Track the model path if available


# Update best metric and model for the trial if necessary
if metric_value is not None:
if self.best_metric_values[trial_id] is None:
self.best_metric_values[trial_id] = metric_value
self.best_models[trial_id] = model_path
else:
# Update best metric based on order (max/min)
if ((self.order == 'max' and metric_value > self.best_metric_values[trial_id]) or
(self.order == 'min' and metric_value < self.best_metric_values[trial_id])):
self.best_metric_values[trial_id] = metric_value
self.best_models[trial_id] = model_path

def on_trial_complete(self, iteration, trials, trial, **info):
trial_id = trial.trial_id
trial_config = trial.config
trial_result = trial.last_result

best_metric_value = self.best_metric_values.get(trial_id, None)
best_model_path = self.best_models.get(trial_id, None)
execution_id = self.execution_ids[trial_id]

if best_metric_value is not None:
# Push the best value of the trial and its corresponding model into the heap
heapq.heappush(self.heap, (self.heap_comparator * best_metric_value, trial_id, best_metric_value, best_model_path,
execution_id))
if len(self.heap) > self.top_n:
heapq.heappop(self.heap) # Maintain top_n elements in the heap


_ = self.cmf_obj[trial_id].log_execution_metrics(metrics_name=f"Best_{self.metric}_Trial_{trial_id}",
custom_properties={f'Best_{self.metric}': best_metric_value,
'execution_id': execution_id}
)

# Commit the metrics for the trial and log its final state
print(f"Trial {trial_id} completed, Commiting to CMF: with name Trial_{trial_id}_metrics")
print()
#if trial_id in self.execution_id:
# _ = self.metawriter.update_execution(int(self.execution_id[trial_id]))

_ = self.cmf_obj[trial_id].commit_metrics(f"Trial_{trial_id}_metrics")
_ = self.cmf_obj[trial_id].log_execution_metrics(metrics_name = f"Trial_{trial_id}_Result",
custom_properties = {'Result': trial_result})

if 'model_path' in trial_result:
_ = self.cmf_obj[trial_id].log_model(path = trial_result['model_path'],
event = 'input',
model_name = f"{trial_id}_model")
if best_model_path:
_ = self.cmf_obj[trial_id].log_model(path=best_model_path,
event='input',
model_name=f"{trial_id}_model")

def on_trial_error(self, iteration, trials, trial, **info):
trial_id = trial.trial_id
Expand All @@ -75,4 +128,24 @@ def on_trial_error(self, iteration, trials, trial, **info):
if self.cmf_run[trial_id]:
_ = self.cmf_obj[trial_id].commit_metrics(f"Trial_{trial_id}_metrics")
_ = self.cmf_obj[trial_id].log_execution_metrics(metrics_name = f"Trial_{trial_id}_Result",
custom_properties = {'Result': '-inf'})
custom_properties = {'Result': '-inf'})

def on_experiment_end(self, trials, **info):
"""
This function is called at the end of the experiment to log the top 'n' results and update execution
with {'in_top_n': True} for each of the top trials.
"""
print(f"Marking top {self.top_n} trials with 'in_top_n = True' at experiment end.")

# Log the top 'n' trials from the heap
while self.heap:
_, top_trial_id, top_metric_value, top_model_path, top_execution_id = heapq.heappop(self.heap)
print(f"Top trial: {top_trial_id} with {self.metric}: {top_metric_value}")

# Update the execution for this trial with 'in_top_n = True'
_ = self.cmf_obj[top_trial_id].update_execution(
execution_id=top_execution_id,
custom_properties={'in_top_n': True}
)

print(f"Execution {top_execution_id} for Trial {top_trial_id} updated with 'in_top_n = True'.")
4 changes: 2 additions & 2 deletions docs/api/public/cmf_ray_logger.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The `CmfRayLogger` class is designed to log Ray Tune metrics for the CMF (Common
To use `CmfRayLogger`, import it in your Python script:

```python
from cmf import cmf_ray_logger
from cmflib import cmf_ray_logger
```

## Usage
Expand Down Expand Up @@ -73,7 +73,7 @@ During each trial, `CmfRayLogger` will automatically create a CMF object with at
Here is a complete example of how to use `CmfRayLogger` with Ray Tune:

```Python
from cmf import cmf_ray_logger
from cmflib import cmf_ray_logger
from ray import tune

# Initialize the logger
Expand Down

0 comments on commit 462e9e4

Please sign in to comment.