Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Feb 9, 2024
1 parent 43b2ff5 commit 704588d
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 80 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ project_urls =
packages = find:
install_requires =
magicgui>=0.6.1
napari>=0.4.13
napari>=0.4.18
numpy
packaging
pydantic
pydantic>2.0
qtpy
unyt
yt>=4.0.1
Expand Down
85 changes: 48 additions & 37 deletions src/yt_napari/_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,131 +8,129 @@
from yt_napari.schemas import _manager


class ytField(BaseModel):
class _ytBaseModel(BaseModel):
pass


class ytField(_ytBaseModel):
field_type: str = Field(None, description="a field type in the yt dataset")
field_name: str = Field(None, description="a field in the yt dataset")
take_log: Optional[bool] = Field(
take_log: bool = Field(
True, description="if true, will apply log10 to the selected data"
)


class Length_Value(BaseModel):
class Length_Value(_ytBaseModel):
value: float = Field(None, description="Single unitful value.")
unit: str = Field("code_length", description="the unit length string.")


class Left_Edge(BaseModel):
class Left_Edge(_ytBaseModel):
value: Tuple[float, float, float] = Field(
(0.0, 0.0, 0.0), description="3-element unitful tuple."
)
unit: str = Field("code_length", description="the unit length string.")


class Right_Edge(BaseModel):
class Right_Edge(_ytBaseModel):
value: Tuple[float, float, float] = Field(
(1.0, 1.0, 1.0), description="3-element unitful tuple."
)
unit: str = Field("code_length", description="the unit length string.")


class Length_Tuple(BaseModel):
class Length_Tuple(_ytBaseModel):
value: Tuple[float, float, float] = Field(
None, description="3-element unitful tuple."
)
unit: str = Field("code_length", description="the unit length string.")


class Region(BaseModel):
class Region(_ytBaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
)
left_edge: Optional[Left_Edge] = Field(
left_edge: Left_Edge = Field(
None,
description="the left edge (min x, min y, min z)",
)
right_edge: Optional[Right_Edge] = Field(
right_edge: Right_Edge = Field(
None,
description="the right edge (max x, max y, max z)",
)
resolution: Optional[Tuple[int, int, int]] = Field(
resolution: Tuple[int, int, int] = Field(
(400, 400, 400),
description="the resolution at which to sample between the edges.",
)
rescale: Optional[bool] = Field(
False, description="rescale the final image between 0,1"
)
rescale: bool = Field(False, description="rescale the final image between 0,1")


class Slice(BaseModel):
class Slice(_ytBaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
)
normal: str = Field(None, description="the normal axis of the slice")
center: Optional[Length_Tuple] = Field(
center: Length_Tuple = Field(
None, description="The center point of the slice, default domain center"
)
slice_width: Optional[Length_Value] = Field(
slice_width: Length_Value = Field(
None, description="The slice width, defaults to full domain"
)
slice_height: Optional[Length_Value] = Field(
slice_height: Length_Value = Field(
None, description="The slice width, defaults to full domain"
)
resolution: Optional[Tuple[int, int]] = Field(
resolution: Tuple[int, int] = Field(
(400, 400),
description="the resolution at which to sample the slice",
)
periodic: Optional[bool] = Field(
periodic: bool = Field(
False, description="should the slice be periodic? default False."
)
rescale: Optional[bool] = Field(
False, description="rescale the final image between 0,1"
)
rescale: bool = Field(False, description="rescale the final image between 0,1")


class SelectionObject(BaseModel):
regions: Optional[List[Region]] = Field(
None, description="a list of regions to load"
)
slices: Optional[List[Slice]] = Field(None, description="a list of slices to load")
class SelectionObject(_ytBaseModel):
regions: List[Region] = Field(None, description="a list of regions to load")
slices: List[Slice] = Field(None, description="a list of slices to load")


class DataContainer(BaseModel):
class DataContainer(_ytBaseModel):
filename: str = Field(None, description="the filename for the dataset")
selections: SelectionObject = Field(
None, description="selections to load in this dataset"
)
store_in_cache: Optional[bool] = Field(
store_in_cache: bool = Field(
ytcfg.get("yt_napari", "in_memory_cache"),
description="if enabled, will store references to yt datasets.",
)


class TimeSeriesFileSelection(BaseModel):
class TimeSeriesFileSelection(_ytBaseModel):
directory: str = Field(None, description="The directory of the timseries")
file_pattern: Optional[str] = Field(None, description="The file pattern to match")
file_list: Optional[List[str]] = Field(None, description="List of files to load.")
file_range: Optional[Tuple[int, int, int]] = Field(
file_pattern: str = Field(None, description="The file pattern to match")
file_list: List[str] = Field(None, description="List of files to load.")
file_range: Tuple[int, int, int] = Field(
None,
description="Given files matched by file_pattern, "
"this option will select a range. Argument order"
"is taken as start:stop:step.",
)


class Timeseries(BaseModel):
class Timeseries(_ytBaseModel):
file_selection: TimeSeriesFileSelection
selections: SelectionObject = Field(
None, description="selections to load in this dataset"
)
load_as_stack: Optional[bool] = Field(
load_as_stack: bool = Field(
False, description="If True, will stack images along a new dimension."
)
# process_in_parallel: Optional[bool] = Field(
# False, description="If True, will attempt to load selections in parallel."
# )


class InputModel(BaseModel):
class InputModel(_ytBaseModel):
datasets: List[DataContainer] = Field(
None, description="list of dataset containers to load"
)
Expand All @@ -155,7 +153,7 @@ def _store_schema(schema_db: Optional[Union[PosixPath, str]] = None, **kwargs):
m.write_new_schema(schema_contents, schema_prefix=prefix, **kwargs)


class MetadataModel(BaseModel):
class MetadataModel(_ytBaseModel):
filename: str = Field(None, description="the filename for the dataset")
include_field_list: bool = Field(True, description="whether to list the fields")
_ds_attrs: Tuple[str] = (
Expand All @@ -164,3 +162,16 @@ class MetadataModel(BaseModel):
"current_time",
"domain_dimensions",
)


def _get_dm_listing(locals_dict):
_data_model_list = []
for ky, val in locals_dict.items():
if inspect.isclass(val) and issubclass(val, _ytBaseModel):
_data_model_list.append(ky)
_data_model_list.append(val)
_data_model_list.append(val.__module__ + "." + ky)
return tuple(_data_model_list)


_data_model_list = _get_dm_listing(locals())
Loading

0 comments on commit 704588d

Please sign in to comment.