Skip to content

Commit

Permalink
Merge pull request #120 from chrishavlin/covering_grid_support
Browse files Browse the repository at this point in the history
adding covering grid support
  • Loading branch information
chrishavlin authored Apr 3, 2024
2 parents 5038bad + 5b7b239 commit 3d390d3
Show file tree
Hide file tree
Showing 14 changed files with 482 additions and 85 deletions.
132 changes: 108 additions & 24 deletions docs/examples/ytnapari_scene_01_intro.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/examples/ytnapari_scene_04_timeseries.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
"\n",
"One difference between `yt-napari` and `yt` proper is that when sampling a time series, you first specify a selection object **independently** from a dataset object to define the extents and field of selection. That selection is then applied across all specified timesteps.\n",
"\n",
"The currently available selection objects are a `Slice` or 3D gridded `Region`. The arguments follow the same convention as a usual `yt` dataset selection object (i.e., `ds.slice`, `ds.region`) for specifying the geometric bounds of the selection with the additional constraint that you must specify a single field and the resolution you want to sample at:"
"The currently available selection objects are a 2D `Slice` or 3D gridded region, either a `Region` of a `CoveringGrid`. The arguments follow the same convention as a usual `yt` dataset selection object (i.e., `ds.slice`, `ds.region`, `ds.covering_grid`) for specifying the geometric bounds of the selection with the additional constraint that you must specify a single field and the resolution you want to sample at:"
]
},
{
Expand Down Expand Up @@ -238,7 +238,7 @@
"id": "edd2babf-5aae-4d2f-8079-96a68b594b22",
"metadata": {},
"source": [
"Once you create a `Slice` or `Region`, you can pass that to `add_to_viewer` and it will be used to sample each timestep specified. \n",
"Once you create a `Slice`, `Region` or `CoveringGrid`, you can pass that to `add_to_viewer` and it will be used to sample each timestep specified. \n",
"\n",
"## Slices through a timeseries\n",
"\n",
Expand Down Expand Up @@ -1131,7 +1131,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
23 changes: 23 additions & 0 deletions src/yt_napari/_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ class Region(_ytBaseModel):
rescale: bool = Field(False, description="rescale the final image between 0,1")


class CoveringGrid(_ytBaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
)
left_edge: Left_Edge = Field(
None,
description="the left edge (min x, min y, min z)",
)
right_edge: Right_Edge = Field(
None,
description="the right edge (max x, max y, max z)",
)
level: int = Field(0, description="Grid level to sample at")
num_ghost_zones: int = Field(
0,
description="Number of ghost zones to include",
)
rescale: bool = Field(False, description="rescale the final image between 0,1")


class Slice(_ytBaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
Expand Down Expand Up @@ -93,6 +113,9 @@ class Slice(_ytBaseModel):
class SelectionObject(_ytBaseModel):
regions: List[Region] = Field(None, description="a list of regions to load")
slices: List[Slice] = Field(None, description="a list of slices to load")
covering_grids: List[CoveringGrid] = Field(
None, description="a list of covering grids to load"
)


class DataContainer(_ytBaseModel):
Expand Down
40 changes: 37 additions & 3 deletions src/yt_napari/_gui_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,16 @@ def get_widget_instance(self, pydantic_model, field: str):
func, args, kwargs = self.registry[pydantic_model][field]["magicgui"]
return func(*args, **kwargs)

def get_pydantic_attr(self, pydantic_model, field: str, widget_instance):
def get_pydantic_attr(
self, pydantic_model, field: str, widget_instance, required: bool = True
):
# given a widget instance, return an object that can be used to set a
# pydantic field
if self.is_registered(pydantic_model, field, required=True):
if self.is_registered(pydantic_model, field, required=required):
func, args, kwargs = self.registry[pydantic_model][field]["pydantic"]
return func(widget_instance, *args, **kwargs)
else:
raise RuntimeError("Could not retrieve pydantic attribute.")

def add_pydantic_to_container(
self,
Expand Down Expand Up @@ -214,6 +218,15 @@ def get_filename(file_widget: widgets.FileEdit):
return str(file_widget.value)


def get_int_box_widget(*args, **kwargs):
# could remove the need for this if the model uses pathlib.Path for typing
return widgets.IntText(*args, **kwargs)


def get_int_val(int_box: widgets.IntText):
return int(int_box.value)


def get_magicguidefault(field_name: str, field_def: pydantic.fields.Field):
# returns an instance of the default widget selected by magicgui
# returns an instance of the default widget selected by magicgui
Expand All @@ -229,6 +242,10 @@ def get_magicguidefault(field_name: str, field_def: pydantic.fields.Field):
options=opts_dict,
raise_on_unknown=False,
)

if new_widget_cls == widgets.TupleEdit:
ops["options"] = {"min": -1e12, "max": 1e12}

return new_widget_cls(**ops)


Expand All @@ -255,8 +272,10 @@ def _get_pydantic_model_field(
_models_to_embed_in_list = (
(_data_model.Slice, "fields"),
(_data_model.Region, "fields"),
(_data_model.CoveringGrid, "fields"),
(_data_model.DataContainer, "selections"),
(_data_model.SelectionObject, "regions"),
(_data_model.SelectionObject, "covering_grids"),
(_data_model.SelectionObject, "slices"),
)

Expand Down Expand Up @@ -297,6 +316,21 @@ def _register_yt_data_model(translator: MagicPydanticRegistry):
pydantic_attr_factory=handle_str_list_edit,
)

translator.register(
_data_model.CoveringGrid,
"level",
magicgui_factory=get_int_box_widget,
magicgui_kwargs={"name": "level"},
pydantic_attr_factory=get_int_val,
)
translator.register(
_data_model.CoveringGrid,
"num_ghost_zones",
magicgui_factory=get_int_box_widget,
magicgui_kwargs={"name": "num_ghost_zones"},
pydantic_attr_factory=get_int_val,
)


translator = MagicPydanticRegistry()
_register_yt_data_model(translator)
Expand All @@ -318,7 +352,7 @@ def get_yt_data_container(
return data_container


_valid_selections = ("Region", "Slice")
_valid_selections = ("Region", "Slice", "CoveringGrid")


def get_yt_selection_container(selection_type: str, return_native: bool = False):
Expand Down
50 changes: 40 additions & 10 deletions src/yt_napari/_model_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from yt_napari import _special_loaders
from yt_napari._data_model import (
CoveringGrid,
DataContainer,
InputModel,
MetadataModel,
Expand All @@ -29,6 +30,30 @@ def _le_re_to_cen_wid(
return center, width


def _get_covering_grid(
ds, left_edge, right_edge, level, num_ghost_zones, test_dims=None
):
# returns a covering grid instance and the resolution of the covering grid
if test_dims is None:
test_dims = (4, 4, 4)
nghostzones = num_ghost_zones
temp_cg = ds.covering_grid(level, left_edge, test_dims, num_ghost_zones=nghostzones)
effective_dds = temp_cg.dds
dims = (right_edge - left_edge) / effective_dds
# get the actual covering grid
frb = ds.covering_grid(level, left_edge, dims, num_ghost_zones=nghostzones)
return frb, dims


def _get_region_frb(ds, LE, RE, res):
frb = ds.r[
LE[0] : RE[0] : complex(0, res[0]), # noqa: E203
LE[1] : RE[1] : complex(0, res[1]), # noqa: E203
LE[2] : RE[2] : complex(0, res[2]), # noqa: E203
]
return frb


class LayerDomain:
# container for domain info for a single layer
# left_edge, right_edge, resolution, n_d are all self explanatory.
Expand Down Expand Up @@ -434,7 +459,13 @@ def _load_3D_regions(
layer_list: list,
timeseries_container: Optional[TimeseriesContainer] = None,
) -> list:
for sel in selections.regions:

sels = []
for seltype in ("regions", "covering_grids"):
if getattr(selections, seltype) is not None:
sels += [sel for sel in getattr(selections, seltype)]

for sel in sels:
# get the left, right edge as a unitful array, initialize the layer
# domain tracking for this layer and update the global domain extent
if sel.left_edge is None:
Expand All @@ -446,16 +477,15 @@ def _load_3D_regions(
RE = ds.domain_right_edge
else:
RE = ds.arr(sel.right_edge.value, sel.right_edge.unit)
res = sel.resolution
layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res)

# create the fixed resolution buffer
frb = ds.r[
LE[0] : RE[0] : complex(0, res[0]), # noqa: E203
LE[1] : RE[1] : complex(0, res[1]), # noqa: E203
LE[2] : RE[2] : complex(0, res[2]), # noqa: E203
]
if isinstance(sel, Region):
res = sel.resolution
frb = _get_region_frb(ds, LE, RE, res)
elif isinstance(sel, CoveringGrid):
frb, dims = _get_covering_grid(ds, LE, RE, sel.level, sel.num_ghost_zones)
res = dims

layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res)
for field_container in sel.fields:
field = (field_container.field_type, field_container.field_name)

Expand Down Expand Up @@ -600,7 +630,7 @@ def _load_selections_from_ds(
layer_list: List[SpatialLayer],
timeseries_container: Optional[TimeseriesContainer] = None,
) -> List[SpatialLayer]:
if selections.regions is not None:
if selections.regions is not None or selections.covering_grids is not None:
layer_list = _load_3D_regions(
ds, selections, layer_list, timeseries_container=timeseries_container
)
Expand Down
45 changes: 45 additions & 0 deletions src/yt_napari/_tests/test_covering_grid_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from yt_napari._data_model import InputModel
from yt_napari._model_ingestor import _choose_ref_layer, _process_validated_model
from yt_napari._schema_version import schema_name

jdicts = []
jdicts.append(
{
"$schema": schema_name,
"datasets": [
{
"filename": "_ytnapari_load_grid",
"selections": {
"covering_grids": [
{
"fields": [{"field_name": "density", "field_type": "gas"}],
"left_edge": {"value": (0.4, 0.4, 0.4)},
"right_edge": {"value": (0.5, 0.5, 0.5)},
"level": 0,
"rescale": 1,
}
]
},
}
],
}
)


@pytest.mark.parametrize("jdict", jdicts)
def test_covering_grid_validation(jdict):
_ = InputModel.model_validate(jdict)


@pytest.mark.parametrize("jdict", jdicts)
def test_slice_load(yt_ugrid_ds_fn, jdict):
im = InputModel.model_validate(jdict)
layer_lists, _ = _process_validated_model(im)
ref_layer = _choose_ref_layer(layer_lists)
_ = ref_layer.align_sanitize_layers(layer_lists)

im_data = layer_lists[0][0]
assert im_data.min() == 0
assert im_data.max() == 1
2 changes: 0 additions & 2 deletions src/yt_napari/_tests/test_ds_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_ds_cache(caplog):

dataset_cache.rm_ds(ds_name)
assert dataset_cache.exists(ds_name) is False
assert len(dataset_cache.available) == 0

ds_none = dataset_cache.get_ds("doesnotexist")
assert ds_none is None
Expand All @@ -35,7 +34,6 @@ def test_ds_cache(caplog):
dataset_cache.add_ds(ds, ds_name)
assert dataset_cache.exists(ds_name)
dataset_cache.rm_all()
assert len(dataset_cache.available) == 0
assert dataset_cache.most_recent is None


Expand Down
4 changes: 4 additions & 0 deletions src/yt_napari/_tests/test_gui_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def get_value_from_nested(container_widget, extra_string):
pyvalue = reg.get_pydantic_attr(Model, "field_1", widget_instance)
assert pyvalue == "2_testxyz"

with pytest.raises(RuntimeError, match="Could not retrieve pydantic attribute."):
reg.get_pydantic_attr(
Model, "field_does_not_exist", widget_instance, required=False
)
widget_instance.close()


Expand Down
7 changes: 7 additions & 0 deletions src/yt_napari/_tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def test_region(yt_ds_0):
assert np.all(np.log10(data4) == data)


def test_covering_grid(yt_ds_0):
cg = ts.CoveringGrid(_field)
data = cg.sample_ds(yt_ds_0)
# sampled at level 0 for full domain, so should get out the base dimensions
assert data.shape == tuple(yt_ds_0.domain_dimensions)


def test_slice(yt_ds_0):
sample_res = (20, 20)
slc = ts.Slice(_field, "x", resolution=sample_res)
Expand Down
7 changes: 7 additions & 0 deletions src/yt_napari/_tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def test_viewer(make_napari_viewer, yt_ds, caplog):
expected_layers += 1
assert len(viewer.layers) == expected_layers

LE = yt_ds.domain_left_edge
dds = yt_ds.domain_width / yt_ds.domain_dimensions
RE = yt_ds.arr(LE + dds * 10)
sc.add_covering_grid(viewer, yt_ds, ("gas", "density"), left_edge=LE, right_edge=RE)
expected_layers += 1
assert len(viewer.layers) == expected_layers

# build a new scene so it builds from prior
sc = Scene()
sc.add_region(viewer, yt_ds, ("gas", "density"))
Expand Down
27 changes: 26 additions & 1 deletion src/yt_napari/_tests/test_widget_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _rebuild_data(final_shape, data):
# the yt file thats being loaded from the pytest fixture is a saved
# dataset created from an in-memory uniform grid, and the re-loaded
# dataset will not have the full functionality of a ds. so here, we
# inject a correctly shaped random array here. If we start using full
# inject a correctly shaped random array. If we start using full
# test datasets from yt in testing, this should be changed.
rng = np.random.default_rng()
return rng.random(final_shape) * data.mean()
Expand Down Expand Up @@ -237,3 +237,28 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path):
_ = InputModel.model_validate(saved_data)

tsr.deleteLater()


def test_covering_grid_selection(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
r.ds_container.filename.value = yt_ugrid_ds_fn
r.ds_container.store_in_cache.value = False
r.new_selection_type.setCurrentIndex(2)
r.add_new_button.click()
assert len(r.active_selections) == 1
sel = list(r.active_selections.values())[0]
assert isinstance(sel, _wr.SelectionEntry)
assert sel.selection_type == "CoveringGrid"

mgui_region = sel.selection_container_raw
mgui_region.fields.field_type.value = "gas"
mgui_region.fields.field_name.value = "density"
mgui_region.level.value = 0

mgui_region.left_edge.value.value = (-1.5,) * 3
mgui_region.right_edge.value.value = (1.5,) * 3
rebuild = partial(_rebuild_data, (64, 64, 64))
r._post_load_function = rebuild
r.load_data()
r.deleteLater()
Loading

0 comments on commit 3d390d3

Please sign in to comment.