diff --git a/llmeter/plotting.py b/llmeter/plotting.py index 7633513..9d52dfe 100644 --- a/llmeter/plotting.py +++ b/llmeter/plotting.py @@ -4,6 +4,7 @@ import os import warnings from functools import partial +from numbers import Real from statistics import StatisticsError, quantiles from typing import Sequence @@ -26,12 +27,24 @@ matplotlib = DeferredError("Please install matplotlib to use plotting functions") -def binning(vector, bins: int | None = None) -> list[str]: +def binning(vector, bins: int | None = None) -> tuple[list[Real], bool]: + """Map the elements of `vector` to a discrete set of `bins` representative values for plotting + + If the cardinality of `vector` already exactly matches the number of `bins`, the same values + will be returned. If the number of `bins` is not specified, a heuristic is used to select one. + + Returns: + result: Elements of `vector` mapped to the mid-points of the calculated bins + binned: Truthy if the data was binned, falsy if it was left as-is + + TODO: Possibly extend `binned` to return the actual bin intervals instead of just True? + """ if len(vector) == 0: - return [] - # https://stats.stackexchange.com/a/114497 + return [], False + cardinality = len(set(vector)) + if bins is None: - cardinality = len(set(vector)) + # https://stats.stackexchange.com/a/114497 if cardinality < len(vector) / 20: m = cardinality else: @@ -42,9 +55,12 @@ def binning(vector, bins: int | None = None) -> list[str]: m = int((max(vector) - min(vector)) // h) + 1 except StatisticsError: m = cardinality // 4 + 1 - return [x.mid for x in pd.cut(vector, bins=m)] + return [x.mid for x in pd.cut(vector, bins=m)], True - return [x.mid for x in pd.cut(vector, bins=bins)] + if cardinality == bins: + return [x for x in vector], False + + return [x.mid for x in pd.cut(vector, bins=bins)], True def plot_heatmap( @@ -54,8 +70,12 @@ def plot_heatmap( output_path: os.PathLike | str | None = None, ): df = pd.DataFrame([p.to_dict() for p in result.responses]) - token_out_bins = binning(df["num_tokens_output"], bins=bins_output_tokens) - token_in_bins = binning(df["num_tokens_input"], bins=bins_input_tokens) + token_out_bins, is_out_binned = binning( + df["num_tokens_output"], bins=bins_output_tokens + ) + token_in_bins, is_in_binned = binning( + df["num_tokens_input"], bins=bins_input_tokens + ) df["num_tokens_input"] = token_in_bins df["num_tokens_output"] = token_out_bins @@ -93,8 +113,8 @@ def plot_heatmap( fmt=".03g", ) fg.set_titles("{row_name}-{col_name}") - fg.set_xlabels("Input tokens") - fg.set_ylabels("Output tokens") + fg.set_xlabels(f"Input tokens ({'Binned' if is_in_binned else 'Exact'})") + fg.set_ylabels(f"Output tokens ({'Binned' if is_out_binned else 'Exact'})") for ax in fg.axes.ravel(): ax.invert_yaxis() break diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 4ae2f55..fbcde3e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -28,23 +28,37 @@ def test_binning(): vector = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Test with default bins - result = binning(vector) + result, is_binned = binning(vector) assert len(result) == len(vector) + assert is_binned is True assert isinstance(result, list) # Test with specified bins - result = binning(vector, bins=5) + result, is_binned = binning(vector, bins=5) assert len(result) == len(vector) + assert is_binned is True assert len(set(result)) <= 5 def test_binning_with_repeated_values(): vector = [1, 1, 1, 2, 2, 3, 3, 3, 3] - result = binning(vector) + result, is_binned = binning(vector) assert len(result) == len(vector) + assert is_binned is True assert len(set(result)) <= len(set(vector)) +def test_binning_preserves_matching_cardinality(): + # A vector with cardinality 4 but very skewed distribution: + vector = ([1] * 100) + ([2] * 50) + [3] + [10000] + result, is_binned = binning(vector, 4) + # Result should keep same cardinality and exact values: + assert len(result) == len(vector) + assert is_binned is False + assert len(set(result)) == len(set(vector)) + assert result == vector + + @patch("llmeter.plotting.sns") def test_plot_heatmap(mock_sns, sample_result: Result): fig, ax = plot_heatmap(sample_result) @@ -82,17 +96,19 @@ def test_plot_sweep_results(mock_subplots, mock_dataframe): def test_binning_edge_cases(): # Test with empty vector - assert binning([]) == [] + assert binning([]) == ([], False) # Test with single value - result = binning([1]) + result, is_binned = binning([1]) assert len(result) == 1 + assert is_binned is True assert result[0] == 1 # Test with large range of values large_vector = list(range(1000)) - result = binning(large_vector) + result, is_binned = binning(large_vector) assert len(result) == len(large_vector) + assert is_binned is True assert len(set(result)) < len(large_vector) @@ -101,8 +117,9 @@ def test_binning_with_different_bin_sizes( bins: None | Literal[5] | Literal[10] | Literal[20], ): vector = list(range(100)) - result = binning(vector, bins=bins) + result, is_binned = binning(vector, bins=bins) assert len(result) == len(vector) + assert is_binned is True if bins is not None: assert len(set(result)) <= bins