Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor compute_stats #521

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions lerobot/common/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,95 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
)
)
return stats


# TODO(aliberts): refactor stats in save_episodes
# import numpy as np
# from lerobot.common.datasets.utils import load_image_as_numpy
# def aggregate_stats_v2(stats_list: list) -> dict:
# """Aggregate stats from multiple compute_stats outputs into a single set of stats.

# The final stats will have the union of all data keys from each of the stats dicts.

# For instance:
# - new_min = min(min_dataset_0, min_dataset_1, ...)
# - new_max = max(max_dataset_0, max_dataset_1, ...)
# - new_mean = (mean of all data, weighted by counts)
# - new_std = (std of all data)
# """
# data_keys = set(key for stats in stats_list for key in stats.keys())
# aggregated_stats = {key: {} for key in data_keys}

# for key in data_keys:
# # Collect stats for the current key from all datasets where it exists
# stats_with_key = [stats[key] for stats in stats_list if key in stats]

# # Aggregate 'min' and 'max' using np.minimum and np.maximum
# aggregated_stats[key]['min'] = np.minimum.reduce([s['min'] for s in stats_with_key])
# aggregated_stats[key]['max'] = np.maximum.reduce([s['max'] for s in stats_with_key])

# # Extract means, variances (std^2), and counts
# means = np.array([s['mean'] for s in stats_with_key])
# variances = np.array([s['std']**2 for s in stats_with_key])
# counts = np.array([s['count'] for s in stats_with_key])

# # Ensure counts can broadcast with means/variances if they have additional dimensions
# counts = counts.reshape(-1, *[1]*(means.ndim - 1))

# # Compute total counts
# total_count = counts.sum(axis=0)

# # Compute the weighted mean
# weighted_means = means * counts
# total_mean = weighted_means.sum(axis=0) / total_count

# # Compute the variance using the parallel algorithm
# delta_means = means - total_mean
# weighted_variances = (variances + delta_means**2) * counts
# total_variance = weighted_variances.sum(axis=0) / total_count

# # Store the aggregated stats
# aggregated_stats[key]['mean'] = total_mean
# aggregated_stats[key]['std'] = np.sqrt(total_variance)
# aggregated_stats[key]['count'] = total_count

# return aggregated_stats


# def compute_episode_stats(episode_buffer: dict, features: dict, episode_length: int, image_sampling: int = 10) -> dict:
# stats = {}
# for key, data in episode_buffer.items():
# if features[key]["dtype"] in ["image", "video"]:
# stats[key] = compute_image_stats(data, sampling=image_sampling)
# else:
# axes_to_reduce = 0 # Compute stats over the first axis
# stats[key] = {
# "min": np.min(data, axis=axes_to_reduce),
# "max": np.max(data, axis=axes_to_reduce),
# "mean": np.mean(data, axis=axes_to_reduce),
# "std": np.std(data, axis=axes_to_reduce),
# "count": episode_length,
# }
# return stats


# def compute_image_stats(image_paths: list[str], sampling: int = 10) -> dict:
# images = []
# samples = range(0, len(image_paths), sampling)
# for idx in samples:
# path = image_paths[idx]
# img = load_image_as_numpy(path, channel_first=True)
# images.append(img)

# images = np.stack(images)
# axes_to_reduce = (0, 2, 3) # keep channel dim
# image_stats = {
# "min": np.min(images, axis=axes_to_reduce, keepdims=True),
# "max": np.max(images, axis=axes_to_reduce, keepdims=True),
# "mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
# "std": np.std(images, axis=axes_to_reduce, keepdims=True)
# }
# for key in image_stats: # squeeze batch dim
# image_stats[key] = np.squeeze(image_stats[key], axis=0)

# return image_stats
Loading