Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve exact-match cardinality when binning data for plots #6

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions llmeter/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import warnings
from functools import partial
from numbers import Real
from statistics import StatisticsError, quantiles
from typing import Sequence

Expand All @@ -26,12 +27,24 @@
matplotlib = DeferredError("Please install matplotlib to use plotting functions")


def binning(vector, bins: int | None = None) -> list[str]:
def binning(vector, bins: int | None = None) -> tuple[list[Real], bool]:
"""Map the elements of `vector` to a discrete set of `bins` representative values for plotting

If the cardinality of `vector` already exactly matches the number of `bins`, the same values
will be returned. If the number of `bins` is not specified, a heuristic is used to select one.

Returns:
result: Elements of `vector` mapped to the mid-points of the calculated bins
binned: Truthy if the data was binned, falsy if it was left as-is

TODO: Possibly extend `binned` to return the actual bin intervals instead of just True?
"""
if len(vector) == 0:
return []
# https://stats.stackexchange.com/a/114497
return [], False
cardinality = len(set(vector))

if bins is None:
cardinality = len(set(vector))
# https://stats.stackexchange.com/a/114497
if cardinality < len(vector) / 20:
m = cardinality
else:
Expand All @@ -42,9 +55,12 @@ def binning(vector, bins: int | None = None) -> list[str]:
m = int((max(vector) - min(vector)) // h) + 1
except StatisticsError:
m = cardinality // 4 + 1
return [x.mid for x in pd.cut(vector, bins=m)]
return [x.mid for x in pd.cut(vector, bins=m)], True

return [x.mid for x in pd.cut(vector, bins=bins)]
if cardinality == bins:
return [x for x in vector], False

return [x.mid for x in pd.cut(vector, bins=bins)], True


def plot_heatmap(
Expand All @@ -54,8 +70,12 @@ def plot_heatmap(
output_path: os.PathLike | str | None = None,
):
df = pd.DataFrame([p.to_dict() for p in result.responses])
token_out_bins = binning(df["num_tokens_output"], bins=bins_output_tokens)
token_in_bins = binning(df["num_tokens_input"], bins=bins_input_tokens)
token_out_bins, is_out_binned = binning(
df["num_tokens_output"], bins=bins_output_tokens
)
token_in_bins, is_in_binned = binning(
df["num_tokens_input"], bins=bins_input_tokens
)
df["num_tokens_input"] = token_in_bins
df["num_tokens_output"] = token_out_bins

Expand Down Expand Up @@ -93,8 +113,8 @@ def plot_heatmap(
fmt=".03g",
)
fg.set_titles("{row_name}-{col_name}")
fg.set_xlabels("Input tokens")
fg.set_ylabels("Output tokens")
fg.set_xlabels(f"Input tokens ({'Binned' if is_in_binned else 'Exact'})")
fg.set_ylabels(f"Output tokens ({'Binned' if is_out_binned else 'Exact'})")
for ax in fg.axes.ravel():
ax.invert_yaxis()
break
Expand Down
31 changes: 24 additions & 7 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,37 @@ def test_binning():
vector = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# Test with default bins
result = binning(vector)
result, is_binned = binning(vector)
assert len(result) == len(vector)
assert is_binned is True
assert isinstance(result, list)

# Test with specified bins
result = binning(vector, bins=5)
result, is_binned = binning(vector, bins=5)
assert len(result) == len(vector)
assert is_binned is True
assert len(set(result)) <= 5


def test_binning_with_repeated_values():
vector = [1, 1, 1, 2, 2, 3, 3, 3, 3]
result = binning(vector)
result, is_binned = binning(vector)
assert len(result) == len(vector)
assert is_binned is True
assert len(set(result)) <= len(set(vector))


def test_binning_preserves_matching_cardinality():
# A vector with cardinality 4 but very skewed distribution:
vector = ([1] * 100) + ([2] * 50) + [3] + [10000]
result, is_binned = binning(vector, 4)
# Result should keep same cardinality and exact values:
assert len(result) == len(vector)
assert is_binned is False
assert len(set(result)) == len(set(vector))
assert result == vector


@patch("llmeter.plotting.sns")
def test_plot_heatmap(mock_sns, sample_result: Result):
fig, ax = plot_heatmap(sample_result)
Expand Down Expand Up @@ -82,17 +96,19 @@ def test_plot_sweep_results(mock_subplots, mock_dataframe):

def test_binning_edge_cases():
# Test with empty vector
assert binning([]) == []
assert binning([]) == ([], False)

# Test with single value
result = binning([1])
result, is_binned = binning([1])
assert len(result) == 1
assert is_binned is True
assert result[0] == 1

# Test with large range of values
large_vector = list(range(1000))
result = binning(large_vector)
result, is_binned = binning(large_vector)
assert len(result) == len(large_vector)
assert is_binned is True
assert len(set(result)) < len(large_vector)


Expand All @@ -101,8 +117,9 @@ def test_binning_with_different_bin_sizes(
bins: None | Literal[5] | Literal[10] | Literal[20],
):
vector = list(range(100))
result = binning(vector, bins=bins)
result, is_binned = binning(vector, bins=bins)
assert len(result) == len(vector)
assert is_binned is True
if bins is not None:
assert len(set(result)) <= bins

Expand Down