diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py new file mode 100644 index 000000000..25e2b86b2 --- /dev/null +++ b/altair/utils/_transformed_data.py @@ -0,0 +1,522 @@ +from typing import List, Optional, Tuple, Dict, Iterable, overload, Union + +from altair import ( + Chart, + FacetChart, + LayerChart, + HConcatChart, + VConcatChart, + ConcatChart, + data_transformers, +) +from altair.utils.core import _DataFrameLike +from altair.utils.schemapi import Undefined + +Scope = Tuple[int, ...] +FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] + + +@overload +def transformed_data( + chart: Union[Chart, FacetChart], + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, +) -> Optional[_DataFrameLike]: + ... + + +@overload +def transformed_data( + chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, +) -> List[_DataFrameLike]: + ... + + +def transformed_data(chart, row_limit=None, exclude=None): + """Evaluate a Chart's transforms + + Evaluate the data transforms associated with a Chart and return the + transformed data as one or more DataFrames + + Parameters + ---------- + chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart + Altair chart to evaluate transforms on + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame or list of DataFrames or None + If input chart is a Chart or Facet Chart, returns a DataFrame of the + transformed data. Otherwise, returns a list of DataFrames of the + transformed data + """ + try: + from vegafusion import runtime, get_local_tz, get_inline_datasets_for_spec # type: ignore + except ImportError as err: + raise ImportError( + "transformed_data requires the vegafusion-python-embed and vegafusion packages\n" + "These can be installed with pip using:\n" + " pip install vegafusion[embed]\n" + "Or with conda using:\n" + " conda install -c conda-forge vegafusion-python-embed vegafusion" + ) from err + + if isinstance(chart, Chart): + # Add mark if none is specified to satisfy Vega-Lite + if chart.mark == Undefined: + chart = chart.mark_point() + + # Deep copy chart so that we can rename marks without affecting caller + chart = chart.copy(deep=True) + + # Ensure that all views are named so that we can look them up in the + # resulting Vega specification + chart_names = name_views(chart, 0, exclude=exclude) + + # Compile to Vega and extract inline DataFrames + with data_transformers.enable("vegafusion-inline"): + vega_spec = chart.to_dict(format="vega") + inline_datasets = get_inline_datasets_for_spec(vega_spec) + + # Build mapping from mark names to vega datasets + facet_mapping = get_facet_mapping(vega_spec) + dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping) + + # Build a list of vega dataset names that corresponds to the order + # of the chart components + dataset_names = [] + for chart_name in chart_names: + if chart_name in dataset_mapping: + dataset_names.append(dataset_mapping[chart_name]) + else: + raise ValueError("Failed to locate all datasets") + + # Extract transformed datasets with VegaFusion + datasets, warnings = runtime.pre_transform_datasets( + vega_spec, + dataset_names, + get_local_tz(), + row_limit=row_limit, + inline_datasets=inline_datasets, + ) + + if isinstance(chart, (Chart, FacetChart)): + # Return DataFrame (or None if it was excluded) if input was a simple Chart + if not datasets: + return None + else: + return datasets[0] + else: + # Otherwise return the list of DataFrames + return datasets + + +def name_views( + chart: Union[ + Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart + ], + i: int = 0, + exclude: Optional[Iterable[str]] = None, +) -> List[str]: + """Name unnamed chart views + + Name unnamed charts views so that we can look them up later in + the compiled Vega spec. + + Note: This function mutates the input chart by applying names to + unnamed views. + + Parameters + ---------- + chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart + Altair chart to apply names to + i : int (default 0) + Starting chart index + exclude : iterable of str + Names of charts to exclude + + Returns + ------- + list of str + List of the names of the charts and subcharts + """ + exclude = set(exclude) if exclude is not None else set() + if isinstance(chart, (Chart, FacetChart)): + if chart.name not in exclude: + if chart.name in (None, Undefined): + # Add name since none is specified + chart.name = Chart._get_name() + return [chart.name] + else: + return [] + else: + if isinstance(chart, LayerChart): + subcharts = chart.layer + elif isinstance(chart, HConcatChart): + subcharts = chart.hconcat + elif isinstance(chart, VConcatChart): + subcharts = chart.vconcat + elif isinstance(chart, ConcatChart): + subcharts = chart.concat + else: + raise ValueError( + "transformed_data accepts an instance of " + "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n" + f"Received value of type: {type(chart)}" + ) + + chart_names: List[str] = [] + for subchart in subcharts: + for name in name_views(subchart, i=i + len(chart_names), exclude=exclude): + chart_names.append(name) + return chart_names + + +def get_group_mark_for_scope(vega_spec: dict, scope: Scope) -> Optional[dict]: + """Get the group mark at a particular scope + + Parameters + ---------- + vega_spec : dict + Top-level Vega specification dictionary + scope : tuple of int + Scope tuple. If empty, the original Vega specification is returned. + Otherwise, the nested group mark at the scope specified is returned. + + Returns + ------- + dict or None + Top-level Vega spec (if scope is empty) + or group mark (if scope is non-empty) + or None (if group mark at scope does not exist) + + Examples + -------- + >>> spec = { + ... "marks": [ + ... { + ... "type": "group", + ... "marks": [{"type": "symbol"}] + ... }, + ... { + ... "type": "group", + ... "marks": [{"type": "rect"}]} + ... ] + ... } + >>> get_group_mark_for_scope(spec, (1,)) + {'type': 'group', 'marks': [{'type': 'rect'}]} + """ + group = vega_spec + + # Find group at scope + for scope_value in scope: + group_index = 0 + child_group = None + for mark in group.get("marks", []): + if mark.get("type") == "group": + if group_index == scope_value: + child_group = mark + break + group_index += 1 + if child_group is None: + return None + group = child_group + + return group + + +def get_datasets_for_scope(vega_spec: dict, scope: Scope) -> List[str]: + """Get the names of the datasets that are defined at a given scope + + Parameters + ---------- + vega_spec : dict + Top-leve Vega specification + scope : tuple of int + Scope tuple. If empty, the names of top-level datasets are returned + Otherwise, the names of the datasets defined in the nested group mark + at the specified scope are returned. + + Returns + ------- + list of str + List of the names of the datasets defined at the specified scope + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data2"} + ... ], + ... "marks": [{"type": "symbol"}] + ... }, + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data3"}, + ... {"name": "data4"}, + ... ], + ... "marks": [{"type": "rect"}] + ... } + ... ] + ... } + + >>> get_datasets_for_scope(spec, ()) + ['data1'] + + >>> get_datasets_for_scope(spec, (0,)) + ['data2'] + + >>> get_datasets_for_scope(spec, (1,)) + ['data3', 'data4'] + + Returns empty when no group mark exists at scope + >>> get_datasets_for_scope(spec, (1, 3)) + [] + """ + group = get_group_mark_for_scope(vega_spec, scope) or {} + + # get datasets from group + datasets = [] + for dataset in group.get("data", []): + datasets.append(dataset["name"]) + + # Add facet dataset + facet_dataset = group.get("from", {}).get("facet", {}).get("name", None) + if facet_dataset: + datasets.append(facet_dataset) + return datasets + + +def get_definition_scope_for_data_reference( + vega_spec: dict, data_name: str, usage_scope: Scope +) -> Optional[Scope]: + """Return the scope that a dataset is defined at, for a given usage scope + + Parameters + ---------- + vega_spec: dict + Top-level Vega specification + data_name: str + The name of a dataset reference + usage_scope: tuple of int + The scope that the dataset is referenced in + + Returns + ------- + tuple of int + The scope where the referenced dataset is defined, + or None if no such dataset is found + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data2"} + ... ], + ... "marks": [{ + ... "type": "symbol", + ... "encode": { + ... "update": { + ... "x": {"field": "x", "data": "data1"}, + ... "y": {"field": "y", "data": "data2"}, + ... } + ... } + ... }] + ... } + ... ] + ... } + + data1 is referenced at scope [0] and defined at scope [] + >>> get_definition_scope_for_data_reference(spec, "data1", (0,)) + () + + data2 is referenced at scope [0] and defined at scope [0] + >>> get_definition_scope_for_data_reference(spec, "data2", (0,)) + (0,) + + If data2 is not visible at scope [] (the top level), + because it's defined in scope [0] + >>> repr(get_definition_scope_for_data_reference(spec, "data2", ())) + 'None' + """ + for i in reversed(range(len(usage_scope) + 1)): + scope = usage_scope[:i] + datasets = get_datasets_for_scope(vega_spec, scope) + if data_name in datasets: + return scope + return None + + +def get_facet_mapping(group: dict, scope: Scope = ()) -> FacetMapping: + """Create mapping from facet definitions to source datasets + + Parameters + ---------- + group : dict + Top-level Vega spec or nested group mark + scope : tuple of int + Scope of the group dictionary within a top-level Vega spec + + Returns + ------- + dict + Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope) + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "from": { + ... "facet": { + ... "name": "facet1", + ... "data": "data1", + ... "groupby": ["colA"] + ... } + ... } + ... } + ... ] + ... } + >>> get_facet_mapping(spec) + {('facet1', (0,)): ('data1', ())} + """ + facet_mapping = {} + group_index = 0 + mark_group = get_group_mark_for_scope(group, scope) or {} + for mark in mark_group.get("marks", []): + if mark.get("type", None) == "group": + # Get facet for this group + group_scope = scope + (group_index,) + facet = mark.get("from", {}).get("facet", None) + if facet is not None: + facet_name = facet.get("name", None) + facet_data = facet.get("data", None) + if facet_name is not None and facet_data is not None: + definition_scope = get_definition_scope_for_data_reference( + group, facet_data, scope + ) + if definition_scope is not None: + facet_mapping[(facet_name, group_scope)] = ( + facet_data, + definition_scope, + ) + + # Handle children recursively + child_mapping = get_facet_mapping(group, scope=group_scope) + facet_mapping.update(child_mapping) + group_index += 1 + + return facet_mapping + + +def get_from_facet_mapping( + scoped_dataset: Tuple[str, Scope], facet_mapping: FacetMapping +) -> Tuple[str, Scope]: + """Apply facet mapping to a scoped dataset + + Parameters + ---------- + scoped_dataset : (str, tuple of int) + A dataset name and scope tuple + facet_mapping : dict from (str, tuple of int) to (str, tuple of int) + The facet mapping produced by get_facet_mapping + + Returns + ------- + (str, tuple of int) + Dataset name and scope tuple that has been mapped as many times as possible + + Examples + -------- + Facet mapping as produced by get_facet_mapping + >>> facet_mapping = {("facet1", (0,)): ("data1", ()), ("facet2", (0, 1)): ("facet1", (0,))} + >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping) + ('data1', ()) + """ + while scoped_dataset in facet_mapping: + scoped_dataset = facet_mapping[scoped_dataset] + return scoped_dataset + + +def get_datasets_for_view_names( + group: dict, + vl_chart_names: List[str], + facet_mapping: FacetMapping, + scope: Scope = (), +) -> Dict[str, Tuple[str, Scope]]: + """Get the Vega datasets that correspond to the provided Altair view names + + Parameters + ---------- + group : dict + Top-level Vega spec or nested group mark + vl_chart_names : list of str + List of the Vega-Lite + facet_mapping : dict from (str, tuple of int) to (str, tuple of int) + The facet mapping produced by get_facet_mapping + scope : tuple of int + Scope of the group dictionary within a top-level Vega spec + + Returns + ------- + dict from str to (str, tuple of int) + Dict from Altair view names to scoped datasets + """ + datasets = {} + group_index = 0 + mark_group = get_group_mark_for_scope(group, scope) or {} + for mark in mark_group.get("marks", []): + for vl_chart_name in vl_chart_names: + if mark.get("name", "") == f"{vl_chart_name}_cell": + data_name = mark.get("from", {}).get("facet", None).get("data", None) + scoped_data_name = (data_name, scope) + datasets[vl_chart_name] = get_from_facet_mapping( + scoped_data_name, facet_mapping + ) + break + + name = mark.get("name", "") + if mark.get("type", "") == "group": + group_data_names = get_datasets_for_view_names( + group, vl_chart_names, facet_mapping, scope=scope + (group_index,) + ) + for k, v in group_data_names.items(): + datasets.setdefault(k, v) + group_index += 1 + else: + for vl_chart_name in vl_chart_names: + if name.startswith(vl_chart_name) and name.endswith("_marks"): + data_name = mark.get("from", {}).get("data", None) + scoped_data = get_definition_scope_for_data_reference( + group, data_name, scope + ) + if scoped_data is not None: + datasets[vl_chart_name] = get_from_facet_mapping( + (data_name, scoped_data), facet_mapping + ) + break + + return datasets diff --git a/altair/utils/core.py b/altair/utils/core.py index 41e886001..61e370b1d 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -22,6 +22,11 @@ else: from typing_extensions import ParamSpec +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + try: from pandas.api.types import infer_dtype as _infer_dtype except ImportError: @@ -32,6 +37,11 @@ _P = ParamSpec("_P") +class _DataFrameLike(Protocol): + def __dataframe__(self, *args, **kwargs): + ... + + def infer_dtype(value): """Infer the dtype of the value. diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 33844a6a2..9e97e3dff 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,7 +8,7 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any +from typing import cast, List, Optional, Any, Iterable # Have to rename it here as else it overlaps with schema.core.Type from typing import Type as TypingType @@ -21,6 +21,7 @@ from .display import renderers, VEGALITE_VERSION, VEGAEMBED_VERSION, VEGA_VERSION from .theme import themes from .compiler import vegalite_compilers +from ...utils.core import _DataFrameLike if sys.version_info >= (3, 11): from typing import Self @@ -2657,6 +2658,32 @@ def to_dict( validate=validate, format=format, ignore=ignore, context=context ) + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> Optional[_DataFrameLike]: + """Evaluate a Chart's transforms + + Evaluate the data transforms associated with a Chart and return the + transformed data a DataFrame + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame + Transformed data as a DataFrame + """ + from altair.utils._transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def add_params(self, *params) -> Self: """Add one or more parameters to the chart.""" if not params: @@ -2832,6 +2859,32 @@ def __init__( **kwds, ) + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> Optional[_DataFrameLike]: + """Evaluate a RepeatChart's transforms + + Evaluate the data transforms associated with a RepeatChart and return the + transformed data a DataFrame + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Raises + ------ + NotImplementedError + RepeatChart does not yet support transformed_data + """ + raise NotImplementedError( + "transformed_data is not yet implemented for RepeatChart" + ) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -2917,6 +2970,32 @@ def __or__(self, other): copy |= other return copy + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[_DataFrameLike]: + """Evaluate a ConcatChart's transforms + + Evaluate the data transforms associated with a ConcatChart and return the + transformed data for each subplot as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each subplot as a list of DataFrames + """ + from altair.utils._transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -2988,6 +3067,32 @@ def __or__(self, other): copy |= other return copy + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[_DataFrameLike]: + """Evaluate a HConcatChart's transforms + + Evaluate the data transforms associated with a HConcatChart and return the + transformed data for each subplot as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each subplot as a list of DataFrames + """ + from altair.utils._transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -3059,6 +3164,32 @@ def __and__(self, other): copy &= other return copy + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[_DataFrameLike]: + """Evaluate a VConcatChart's transforms + + Evaluate the data transforms associated with a VConcatChart and return the + transformed data for each subplot as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each subplot as a list of DataFrames + """ + from altair.utils._transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive @@ -3129,6 +3260,32 @@ def __init__(self, data=Undefined, layer=(), **kwargs): for prop in combined_dict: self[prop] = combined_dict[prop] + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> List[_DataFrameLike]: + """Evaluate a LayerChart's transforms + + Evaluate the data transforms associated with a LayerChart and return the + transformed data for each layer as a list of DataFrames + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + list of DataFrame + Transformed data for each layer as a list of DataFrames + """ + from altair.utils._transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def __iadd__(self, other): _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) @@ -3218,6 +3375,32 @@ def __init__( data=data, spec=spec, facet=facet, params=params, **kwargs ) + def _transformed_data( + self, + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, + ) -> Optional[_DataFrameLike]: + """Evaluate a FacetChart's transforms + + Evaluate the data transforms associated with a FacetChart and return the + transformed data a DataFrame + + Parameters + ---------- + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame + Transformed data as a DataFrame + """ + from altair.utils._transformed_data import transformed_data + + return transformed_data(self, row_limit=row_limit, exclude=exclude) + def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: """Make chart axes scales interactive diff --git a/pyproject.toml b/pyproject.toml index d17777c24..10e3768b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,8 @@ dev = [ "pandas-stubs", "types-jsonschema", "types-setuptools", - "pyarrow>=11" + "pyarrow>=11", + "vegafusion[embed]" ] doc = [ "sphinx", diff --git a/tests/test_transformed_data.py b/tests/test_transformed_data.py new file mode 100644 index 000000000..1604f2026 --- /dev/null +++ b/tests/test_transformed_data.py @@ -0,0 +1,128 @@ +from altair.utils.execeval import eval_block +import altair as alt +from tests import examples_methods_syntax +from vega_datasets import data +import pkgutil +import pytest + + +# fmt: off +@pytest.mark.parametrize("filename,rows,cols", [ + ("annual_weather_heatmap.py", 366, ["monthdate_date_end", "max_temp_max"]), + ("anscombe_plot.py", 44, ["Series", "X", "Y"]), + ("bar_chart_sorted.py", 6, ["site", "sum_yield"]), + ("bar_chart_trellis_compact.py", 27, ["p", "p_end"]), + ("beckers_barley_trellis_plot.py", 120, ["year", "site"]), + ("beckers_barley_wrapped_facet.py", 120, ["site", "median_yield"]), + ("bump_chart.py", 100, ["rank", "yearmonth_date"]), + ("comet_chart.py", 120, ["variety", "delta"]), + ("connected_scatterplot.py", 55, ["miles", "gas"]), + ("diverging_stacked_bar_chart.py", 40, ["value", "percentage_start"]), + ("donut_chart.py", 6, ["value_start", "value_end"]), + ("gapminder_bubble_plot.py", 187, ["income", "population"]), + ("grouped_bar_chart2.py", 9, ["Group", "Value_start"]), + ("hexbins.py", 84, ["xFeaturePos", "mean_temp_max"]), + ("histogram_heatmap.py", 378, ["bin_maxbins_40_Rotten_Tomatoes_Rating", "__count"]), + ("histogram_scatterplot.py", 64, ["bin_maxbins_10_Rotten_Tomatoes_Rating", "__count"]), + ("interactive_legend.py", 1708, ["sum_count_start", "series"]), + ("iowa_electricity.py", 51, ["net_generation_start", "year"]), + ("isotype.py", 37, ["animal", "x"]), + ("isotype_grid.py", 100, ["row", "col"]), + ("lasagna_plot.py", 492, ["yearmonthdate_date", "sum_price"]), + ("layered_area_chart.py", 51, ["source", "net_generation"]), + ("layered_bar_chart.py", 51, ["source", "net_generation"]), + ("layered_histogram.py", 113, ["bin_maxbins_100_Measurement"]), + ("line_chart_with_cumsum.py", 52, ["cumulative_wheat"]), + ("line_percent.py", 30, ["sex", "perc"]), + ("line_with_log_scale.py", 15, ["year", "sum_people"]), + ("multifeature_scatter_plot.py", 150, ["petalWidth", "species"]), + ("natural_disasters.py", 686, ["Deaths", "Year"]), + ("normalized_stacked_area_chart.py", 51, ["source", "net_generation_start"]), + ("normalized_stacked_bar_chart.py", 60, ["site", "sum_yield_start"]), + ("parallel_coordinates.py", 600, ["key", "value"]), + ("percentage_of_total.py", 5, ["PercentOfTotal", "TotalTime"]), + ("pie_chart.py", 6, ["category", "value_start"]), + ("pyramid.py", 3, ["category", "value_start"]), + ("stacked_bar_chart_sorted_segments.py", 60, ["variety", "site"]), + ("stem_and_leaf.py", 100, ["stem", "leaf"]), + ("streamgraph.py", 1708, ["series", "sum_count"]), + ("top_k_items.py", 10, ["rank", "IMDB_Rating_start"]), + ("top_k_letters.py", 9, ["rank", "letters"]), + ("top_k_with_others.py", 10, ["ranked_director", "mean_aggregate_gross"]), + ("trellis_area_sort_array.py", 492, ["date", "price"]), + ("trellis_histogram.py", 20, ["Origin", "__count"]), + ("us_population_over_time.py", 38, ["sex", "people_start"]), + ("us_population_over_time_facet.py", 285, ["year", "sum_people"]), + ("wilkinson-dot-plot.py", 21, ["data", "id"]), + ("window_rank.py", 12, ["team", "diff"]), +]) +# fmt: on +def test_primitive_chart_examples(filename, rows, cols): + source = pkgutil.get_data(examples_methods_syntax.__name__, filename) + chart = eval_block(source) + df = chart._transformed_data() + assert len(df) == rows + assert set(cols).issubset(set(df.columns)) + + +# fmt: off +@pytest.mark.parametrize("filename,all_rows,all_cols", [ + ("errorbars_with_std.py", [10, 10], [["upper_yield"], ["extent_yield"]]), + ("candlestick_chart.py", [44, 44], [["low"], ["close"]]), + ("co2_concentration.py", [713, 7, 7], [["first_date"], ["scaled_date"], ["end"]]), + ("falkensee.py", [2, 38, 38], [["event"], ["population"], ["population"]]), + ("heat_lane.py", [10, 10], [["bin_count_start"], ["y2"]]), + ("histogram_responsive.py", [20, 20], [["__count"], ["__count"]]), + ("histogram_with_a_global_mean_overlay.py", [9, 1], [["__count"], ["mean_IMDB_Rating"]]), + ("horizon_graph.py", [20, 20], [["x"], ["ny"]]), + ("interactive_cross_highlight.py", [64, 64, 13], [["__count"], ["__count"], ["Major_Genre"]]), + ("interval_selection.py", [123, 123], [["price_start"], ["date"]]), + ("layered_chart_with_dual_axis.py", [12, 12], [["month_date"], ["average_precipitation"]]), + ("layered_heatmap_text.py", [9, 9], [["Cylinders"], ["mean_horsepower"]]), + ("multiline_highlight.py", [560, 560], [["price"], ["date"]]), + ("multiline_tooltip.py", [300, 300, 300, 0, 300], [["x"], ["y"], ["y"], ["x"], ["x"]]), + ("pie_chart_with_labels.py", [6, 6], [["category"], ["value"]]), + ("radial_chart.py", [6, 6], [["values"], ["values_start"]]), + ("scatter_linked_table.py", [392, 14, 14, 14], [["Year"], ["Year"], ["Year"], ["Year"]]), + ("scatter_marginal_hist.py", [34, 150, 27], [["__count"], ["species"], ["__count"]]), + ("scatter_with_layered_histogram.py", [2, 19], [["gender"], ["__count"]]), + ("scatter_with_minimap.py", [1461, 1461], [["date"], ["date"]]), + ("scatter_with_rolling_mean.py", [1461, 1461], [["date"], ["rolling_mean"]]), + ("seattle_weather_interactive.py", [1461, 5], [["date"], ["__count"]]), + ("select_detail.py", [20, 1000], [["id"], ["x"]]), + ("simple_scatter_with_errorbars.py", [5, 5], [["x"], ["upper_ymin"]]), + ("stacked_bar_chart_with_text.py", [60, 60], [["site"], ["site"]]), + ("us_employment.py", [120, 1, 2], [["month"], ["president"], ["president"]]), + ("us_population_pyramid_over_time.py", [19, 38, 19], [["gender"], ["year"], ["gender"]]), +]) +# fmt: on +def test_compound_chart_examples(filename, all_rows, all_cols): + source = pkgutil.get_data(examples_methods_syntax.__name__, filename) + chart = eval_block(source) + print(chart) + + dfs = chart._transformed_data() + assert len(dfs) == len(all_rows) + for df, rows, cols in zip(dfs, all_rows, all_cols): + assert len(df) == rows + assert set(cols).issubset(set(df.columns)) + + +def test_transformed_data_exclude(): + source = data.wheat() + bar = alt.Chart(source).mark_bar().encode(x="year:O", y="wheat:Q") + rule = alt.Chart(source).mark_rule(color="red").encode(y="mean(wheat):Q") + some_annotation = ( + alt.Chart(name="some_annotation") + .mark_text(fontWeight="bold") + .encode(text=alt.value("Just some text"), y=alt.datum(85), x=alt.value(200)) + ) + + chart = (bar + rule + some_annotation).properties(width=600) + datasets = chart._transformed_data(exclude=["some_annotation"]) + + assert len(datasets) == 2 + assert len(datasets[0]) == 52 + assert "wheat_start" in datasets[0] + assert len(datasets[1]) == 1 + assert "mean_wheat" in datasets[1] diff --git a/tools/update_init_file.py b/tools/update_init_file.py index 41bb8f4d8..712d3f57e 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -6,7 +6,7 @@ import sys from pathlib import Path from os.path import abspath, dirname, join -from typing import TypeVar, Type, cast, List, Any, Optional +from typing import TypeVar, Type, cast, List, Any, Optional, Iterable import black @@ -81,6 +81,7 @@ def _is_relevant_attribute(attr_name): or attr is Any or attr is Literal or attr is Optional + or attr is Iterable or attr_name == "TypingDict" ): return False