diff --git a/src/yt_napari/_model_ingestor.py b/src/yt_napari/_model_ingestor.py index b0870bf..14cf795 100644 --- a/src/yt_napari/_model_ingestor.py +++ b/src/yt_napari/_model_ingestor.py @@ -18,12 +18,22 @@ def _le_re_to_cen_wid( class LayerDomain: # container for domain info for a single layer + # left_edge, right_edge, resolution, n_d are all self explanatory. + # other parameters: + # + # new_dim_value: optional unyt_quantity. + # If n_d == 2, and upgrade_to_3D is subsequently called, then this value + # will be used for the new + # new_dim_axis: optional int. + # the index position to add the new_dim_position, default is last def __init__( self, left_edge: unyt_array, right_edge: unyt_array, resolution: tuple, - n_d: int = 3, + n_d: Optional[int] = 3, + new_dim_value: Optional[unyt_quantity] = None, + new_dim_axis: Optional[int] = 2, ): if len(left_edge) != len(right_edge): @@ -33,7 +43,10 @@ def __init__( if len(resolution) == 1: resolution = resolution * n_d # assume same in every dim else: - raise ValueError("length of resolution does not match edge arrays") + msg = f"{len(resolution)}:{len(left_edge)}" + raise ValueError( + f"length of resolution does not match edge arrays {msg}" + ) self.left_edge = left_edge self.right_edge = right_edge @@ -43,6 +56,36 @@ def __init__( self.aspect_ratio = self.width / self.width[0] self.requires_scale = np.any(self.aspect_ratio != unyt_array(1.0, "")) self.n_d = n_d + if new_dim_value is None: + new_dim_value = unyt_quantity(0.0, left_edge.units) + self.new_dim_value = new_dim_value + self.new_dim_axis = new_dim_axis + + def upgrade_to_3D(self): + if self.n_d == 3: + # already 3D, nothing to do + return + + if self.n_d == 2: + new_l_r = getattr(self, "new_dim_value") + axid = self.new_dim_axis + self.left_edge = _insert_to_unyt_array(self.left_edge, new_l_r, axid) + self.right_edge = _insert_to_unyt_array(self.right_edge, new_l_r, axid) + self.resolution = _insert_to_unyt_array(self.right_edge, 1, axid) + self.grid_width = _insert_to_unyt_array(self.grid_width, 0, axid) + self.aspect_ratio = _insert_to_unyt_array(self.aspect_ratio, 1.0, axid) + self.n_d = 3 + + +def _insert_to_unyt_array( + x: unyt_array, new_value: Union[float, unyt_array], position: int +) -> unyt_array: + # just for scalars + if isinstance(new_value, unyt_array): + # reminder: unyt_quantity is instance of unyt_array + new_value = new_value.to(x.units).d + + return unyt_array(np.insert(x.d, position, new_value), x.units) # define types for the napari layer tuples @@ -65,6 +108,9 @@ def __init__(self, ref_layer_domain: LayerDomain): self.grid_width = ref_layer_domain.grid_width self.aspect_ratio = ref_layer_domain.aspect_ratio + # and store the full domain + self.layer_domain = ref_layer_domain + def calculate_scale(self, other_layer: LayerDomain) -> unyt_array: # calculate the pixel scale for a layer relative to the reference @@ -96,6 +142,9 @@ def align_sanitize_layer(self, layer: SpatialLayer) -> Layer: # pull out the elements of the SpatialLayer tuple im_arr, im_kwargs, layer_type, domain = layer + # upgrade to 3D if necessary + domain = self.handle_dimensionality(domain) + # calculate scale and translation scale = self.calculate_scale(domain) translate = self.calculate_translation(domain) @@ -120,6 +169,22 @@ def align_sanitize_layers(self, layer_list: List[SpatialLayer]) -> List[Layer]: # layer_list return [self.align_sanitize_layer(layer) for layer in layer_list] + def handle_dimensionality(self, domain: LayerDomain) -> LayerDomain: + # upgrade from 2d to 3d if required: correct orientation is NOT + # guaranteed. + + if domain.n_d > self.layer_domain.n_d: + raise RuntimeError( + f"cannot add a {domain.n_d}D layer to a lower" + f"dimensionality scene. Layers must be added from" + f"high to low dimensionality." + ) + + if domain.n_d == 2 and self.layer_domain.n_d == 3: + domain.upgrade_to_3d(new_left_right=self.center[-1]) + + return domain + def create_metadata_dict( data: np.ndarray, diff --git a/src/yt_napari/_tests/test_model_ingestor.py b/src/yt_napari/_tests/test_model_ingestor.py index 8d918e7..55aa553 100644 --- a/src/yt_napari/_tests/test_model_ingestor.py +++ b/src/yt_napari/_tests/test_model_ingestor.py @@ -83,16 +83,69 @@ def test_layer_domain(domains_to_test): assert np.all(layer_domain.width == d.width) # check some instantiation things - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="length of edge arrays must match"): _ = _mi.LayerDomain(d.left_edge, unyt.unyt_array([1, 2], "m"), d.resolution) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="length of resolution does not"): _ = _mi.LayerDomain(d.left_edge, d.right_edge, (10, 12)) ld = _mi.LayerDomain(d.left_edge, d.right_edge, (10,)) assert len(ld.resolution) == 3 +def test_layer_domain_dimensionality(): + # sets of left_edge, right_edge, center, width, res + le = unyt.unyt_array([1.0, 1.0], "km") + re = unyt.unyt_array([2000.0, 2000.0], "m") + res = (10, 20) + ld = _mi.LayerDomain(le, re, res, n_d=2) + assert ld.n_d == 2 + + ld.upgrade_to_3D() + assert ld.n_d == 3 + assert len(ld.left_edge) == 3 + assert ld.left_edge[-1] == 0.0 + + ld = _mi.LayerDomain(le, re, res, n_d=2, new_dim_value=0.5) + ld.upgrade_to_3D() + assert ld.left_edge[2] == unyt.unyt_quantity(0.5, le.units) + + new_val = unyt.unyt_quantity(0.5, "km") + ld = _mi.LayerDomain(le, re, res, n_d=2, new_dim_value=new_val) + ld.upgrade_to_3D() + assert ld.left_edge[2].to("km") == new_val + + ld = _mi.LayerDomain(le, re, res, n_d=2, new_dim_value=new_val, new_dim_axis=0) + ld.upgrade_to_3D() + assert ld.left_edge[0].to("km") == new_val + + +_test_cases_insert = [ + ( + unyt.unyt_array([1.0, 1.0], "km"), + unyt.unyt_array( + [ + 1000.0, + ], + "m", + ), + unyt.unyt_array([1.0, 1.0, 1.0], "km"), + ), + ( + unyt.unyt_array([1.0, 1.0], "km"), + unyt.unyt_quantity(1000.0, "m"), + unyt.unyt_array([1.0, 1.0, 1.0], "km"), + ), + (unyt.unyt_array([1.0, 1.0], "km"), 0.5, unyt.unyt_array([1.0, 1.0, 0.5], "km")), +] + + +@pytest.mark.parametrize("x,x2,expected", _test_cases_insert) +def test_insert_to_unyt_array(x, x2, expected): + result = _mi._insert_to_unyt_array(x, x2, 2) + assert np.all(result == expected) + + def test_domain_tracking(domains_to_test): full_domain = _mi.PhysicalDomainTracker()