Skip to content

Commit

Permalink
Merge #6 Exact-match cardinality in plot binning
Browse files Browse the repository at this point in the history
In (heatmap) plotting, preserve original values if the dataset
exactly matches the requested binning cardinality. Partially
addresses #5
  • Loading branch information
athewsey authored Oct 9, 2024
2 parents f485458 + 0e6d026 commit eb98131
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 17 deletions.
40 changes: 30 additions & 10 deletions llmeter/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import warnings
from functools import partial
from numbers import Real
from statistics import StatisticsError, quantiles
from typing import Sequence

Expand All @@ -26,12 +27,24 @@
matplotlib = DeferredError("Please install matplotlib to use plotting functions")


def binning(vector, bins: int | None = None) -> list[str]:
def binning(vector, bins: int | None = None) -> tuple[list[Real], bool]:
"""Map the elements of `vector` to a discrete set of `bins` representative values for plotting
If the cardinality of `vector` already exactly matches the number of `bins`, the same values
will be returned. If the number of `bins` is not specified, a heuristic is used to select one.
Returns:
result: Elements of `vector` mapped to the mid-points of the calculated bins
binned: Truthy if the data was binned, falsy if it was left as-is
TODO: Possibly extend `binned` to return the actual bin intervals instead of just True?
"""
if len(vector) == 0:
return []
# https://stats.stackexchange.com/a/114497
return [], False
cardinality = len(set(vector))

if bins is None:
cardinality = len(set(vector))
# https://stats.stackexchange.com/a/114497
if cardinality < len(vector) / 20:
m = cardinality
else:
Expand All @@ -42,9 +55,12 @@ def binning(vector, bins: int | None = None) -> list[str]:
m = int((max(vector) - min(vector)) // h) + 1
except StatisticsError:
m = cardinality // 4 + 1
return [x.mid for x in pd.cut(vector, bins=m)]
return [x.mid for x in pd.cut(vector, bins=m)], True

return [x.mid for x in pd.cut(vector, bins=bins)]
if cardinality == bins:
return [x for x in vector], False

return [x.mid for x in pd.cut(vector, bins=bins)], True


def plot_heatmap(
Expand All @@ -54,8 +70,12 @@ def plot_heatmap(
output_path: os.PathLike | str | None = None,
):
df = pd.DataFrame([p.to_dict() for p in result.responses])
token_out_bins = binning(df["num_tokens_output"], bins=bins_output_tokens)
token_in_bins = binning(df["num_tokens_input"], bins=bins_input_tokens)
token_out_bins, is_out_binned = binning(
df["num_tokens_output"], bins=bins_output_tokens
)
token_in_bins, is_in_binned = binning(
df["num_tokens_input"], bins=bins_input_tokens
)
df["num_tokens_input"] = token_in_bins
df["num_tokens_output"] = token_out_bins

Expand Down Expand Up @@ -93,8 +113,8 @@ def plot_heatmap(
fmt=".03g",
)
fg.set_titles("{row_name}-{col_name}")
fg.set_xlabels("Input tokens")
fg.set_ylabels("Output tokens")
fg.set_xlabels(f"Input tokens ({'Binned' if is_in_binned else 'Exact'})")
fg.set_ylabels(f"Output tokens ({'Binned' if is_out_binned else 'Exact'})")
for ax in fg.axes.ravel():
ax.invert_yaxis()
break
Expand Down
31 changes: 24 additions & 7 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,37 @@ def test_binning():
vector = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# Test with default bins
result = binning(vector)
result, is_binned = binning(vector)
assert len(result) == len(vector)
assert is_binned is True
assert isinstance(result, list)

# Test with specified bins
result = binning(vector, bins=5)
result, is_binned = binning(vector, bins=5)
assert len(result) == len(vector)
assert is_binned is True
assert len(set(result)) <= 5


def test_binning_with_repeated_values():
vector = [1, 1, 1, 2, 2, 3, 3, 3, 3]
result = binning(vector)
result, is_binned = binning(vector)
assert len(result) == len(vector)
assert is_binned is True
assert len(set(result)) <= len(set(vector))


def test_binning_preserves_matching_cardinality():
# A vector with cardinality 4 but very skewed distribution:
vector = ([1] * 100) + ([2] * 50) + [3] + [10000]
result, is_binned = binning(vector, 4)
# Result should keep same cardinality and exact values:
assert len(result) == len(vector)
assert is_binned is False
assert len(set(result)) == len(set(vector))
assert result == vector


@patch("llmeter.plotting.sns")
def test_plot_heatmap(mock_sns, sample_result: Result):
fig, ax = plot_heatmap(sample_result)
Expand Down Expand Up @@ -82,17 +96,19 @@ def test_plot_sweep_results(mock_subplots, mock_dataframe):

def test_binning_edge_cases():
# Test with empty vector
assert binning([]) == []
assert binning([]) == ([], False)

# Test with single value
result = binning([1])
result, is_binned = binning([1])
assert len(result) == 1
assert is_binned is True
assert result[0] == 1

# Test with large range of values
large_vector = list(range(1000))
result = binning(large_vector)
result, is_binned = binning(large_vector)
assert len(result) == len(large_vector)
assert is_binned is True
assert len(set(result)) < len(large_vector)


Expand All @@ -101,8 +117,9 @@ def test_binning_with_different_bin_sizes(
bins: None | Literal[5] | Literal[10] | Literal[20],
):
vector = list(range(100))
result = binning(vector, bins=bins)
result, is_binned = binning(vector, bins=bins)
assert len(result) == len(vector)
assert is_binned is True
if bins is not None:
assert len(set(result)) <= bins

Expand Down

0 comments on commit eb98131

Please sign in to comment.