diff --git a/cf/aggregate.py b/cf/aggregate.py index dffdca9f58..496189f3d0 100644 --- a/cf/aggregate.py +++ b/cf/aggregate.py @@ -2977,11 +2977,19 @@ def aggregate( # ------------------------------------------------------------ if axes is None: # Aggregation will be over as many axes as possible - aggregating_axes = meta[0].axis_ids + m0 = meta[0] + aggregating_axes = m0.axis_ids[:] + + # For DSG feature types, only consider aggregating the + # feature dimension(s). + if m0.featureType: + for axis in aggregating_axes[:]: + if not dsg_feature_type_axis(m0, axis): + aggregating_axes.remove(axis) + _create_hash_and_first_values( - meta, None, False, hfl_cache, rtol, atol + meta, aggregating_axes, False, hfl_cache, rtol, atol ) - else: # Specific aggregation axes have been selected aggregating_axes = [] @@ -3484,7 +3492,7 @@ def climatology_cells( def _create_hash_and_first_values( - meta, axes, donotchecknonaggregatingaxes, hfl_cache, rtol, atol + meta, aggregating_axes, donotchecknonaggregatingaxes, hfl_cache, rtol, atol ): """Updates each field's _Meta object. @@ -3492,7 +3500,8 @@ def _create_hash_and_first_values( meta: `list` of `_Meta` - axes: `None` or `list` + axes: sequence + The identities of the possible aggregating axes. donotchecknonaggregatingaxes: `bool` @@ -3509,6 +3518,9 @@ def _create_hash_and_first_values( field = m.field constructs = field.constructs.todict() + # Store the aggregating axis identities + m.aggregating_axes = aggregating_axes + m_sort_keys = m.sort_keys m_sort_indices = m.sort_indices @@ -3527,9 +3539,9 @@ def _create_hash_and_first_values( # -------------------------------------------------------- for identity in m.axis_ids: if ( - axes is not None + aggregating_axes is not None and donotchecknonaggregatingaxes - and identity not in axes + and identity not in aggregating_axes ): x = [None] * len(m.axis[identity]["keys"]) m_hash_values[identity] = x @@ -3671,12 +3683,12 @@ def _create_hash_and_first_values( coord = constructs[key] - axes = aux["axes"] + c_axes = aux["axes"] canonical_axes = aux["canonical_axes"] - if axes != canonical_axes: + if c_axes != canonical_axes: # Transpose the N-d auxiliary coordinate so that # it has the canonical axis order - iaxes = [axes.index(axis) for axis in canonical_axes] + iaxes = [c_axes.index(axis) for axis in canonical_axes] coord = coord.transpose(iaxes) sort_indices, needs_sorting = _sort_indices(m, canonical_axes) @@ -3722,14 +3734,14 @@ def _create_hash_and_first_values( else: for canonical_units, msr in m.msr.items(): hash_values = [] - for key, axes, canonical_axes in zip( + for key, c_axes, canonical_axes in zip( msr["keys"], msr["axes"], msr["canonical_axes"] ): cell_measure = constructs[key] - if axes != canonical_axes: + if c_axes != canonical_axes: # Transpose the cell measure so that it has # the canonical axis order - iaxes = [axes.index(axis) for axis in canonical_axes] + iaxes = [c_axes.index(axis) for axis in canonical_axes] cell_measure = cell_measure.transpose(iaxes) sort_indices, needs_sorting = _sort_indices( @@ -3836,12 +3848,12 @@ def _create_hash_and_first_values( field_anc = constructs[key] - axes = anc["axes"] + c_axes = anc["axes"] canonical_axes = anc["canonical_axes"] - if axes != canonical_axes: + if c_axes != canonical_axes: # Transpose the field ancillary so that it has the # canonical axis order - iaxes = [axes.index(axis) for axis in canonical_axes] + iaxes = [c_axes.index(axis) for axis in canonical_axes] field_anc = field_anc.transpose(iaxes) sort_indices, needs_sorting = _sort_indices(m, canonical_axes) @@ -3874,12 +3886,12 @@ def _create_hash_and_first_values( domain_anc = constructs[key] - axes = anc["axes"] + c_axes = anc["axes"] canonical_axes = anc["canonical_axes"] - if axes != canonical_axes: + if c_axes != canonical_axes: # Transpose the domain ancillary so that it has # the canonical axis order - iaxes = [axes.index(axis) for axis in canonical_axes] + iaxes = [c_axes.index(axis) for axis in canonical_axes] domain_anc = domain_anc.transpose(iaxes) sort_indices, needs_sorting = _sort_indices(m, canonical_axes) @@ -4131,7 +4143,7 @@ def _group_fields(meta, axis, info=False): group is represented by a `list` of `_Meta` objects. """ - axes = meta[0].axis_ids + axes = meta[0].aggregating_axes if axes: if axis in axes: