diff --git a/setup.cfg b/setup.cfg index 69505fa..8b45b9e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,10 +29,10 @@ project_urls = packages = find: install_requires = magicgui>=0.6.1 - napari>=0.4.13 + napari>=0.4.18 numpy packaging - pydantic + pydantic>2.0 qtpy unyt yt>=4.0.1 diff --git a/src/yt_napari/_data_model.py b/src/yt_napari/_data_model.py index 4330a23..3775986 100644 --- a/src/yt_napari/_data_model.py +++ b/src/yt_napari/_data_model.py @@ -8,110 +8,108 @@ from yt_napari.schemas import _manager -class ytField(BaseModel): +class _ytBaseModel(BaseModel): + pass + + +class ytField(_ytBaseModel): field_type: str = Field(None, description="a field type in the yt dataset") field_name: str = Field(None, description="a field in the yt dataset") - take_log: Optional[bool] = Field( + take_log: bool = Field( True, description="if true, will apply log10 to the selected data" ) -class Length_Value(BaseModel): +class Length_Value(_ytBaseModel): value: float = Field(None, description="Single unitful value.") unit: str = Field("code_length", description="the unit length string.") -class Left_Edge(BaseModel): +class Left_Edge(_ytBaseModel): value: Tuple[float, float, float] = Field( (0.0, 0.0, 0.0), description="3-element unitful tuple." ) unit: str = Field("code_length", description="the unit length string.") -class Right_Edge(BaseModel): +class Right_Edge(_ytBaseModel): value: Tuple[float, float, float] = Field( (1.0, 1.0, 1.0), description="3-element unitful tuple." ) unit: str = Field("code_length", description="the unit length string.") -class Length_Tuple(BaseModel): +class Length_Tuple(_ytBaseModel): value: Tuple[float, float, float] = Field( None, description="3-element unitful tuple." ) unit: str = Field("code_length", description="the unit length string.") -class Region(BaseModel): +class Region(_ytBaseModel): fields: List[ytField] = Field( None, description="list of fields to load for this selection" ) - left_edge: Optional[Left_Edge] = Field( + left_edge: Left_Edge = Field( None, description="the left edge (min x, min y, min z)", ) - right_edge: Optional[Right_Edge] = Field( + right_edge: Right_Edge = Field( None, description="the right edge (max x, max y, max z)", ) - resolution: Optional[Tuple[int, int, int]] = Field( + resolution: Tuple[int, int, int] = Field( (400, 400, 400), description="the resolution at which to sample between the edges.", ) - rescale: Optional[bool] = Field( - False, description="rescale the final image between 0,1" - ) + rescale: bool = Field(False, description="rescale the final image between 0,1") -class Slice(BaseModel): +class Slice(_ytBaseModel): fields: List[ytField] = Field( None, description="list of fields to load for this selection" ) normal: str = Field(None, description="the normal axis of the slice") - center: Optional[Length_Tuple] = Field( + center: Length_Tuple = Field( None, description="The center point of the slice, default domain center" ) - slice_width: Optional[Length_Value] = Field( + slice_width: Length_Value = Field( None, description="The slice width, defaults to full domain" ) - slice_height: Optional[Length_Value] = Field( + slice_height: Length_Value = Field( None, description="The slice width, defaults to full domain" ) - resolution: Optional[Tuple[int, int]] = Field( + resolution: Tuple[int, int] = Field( (400, 400), description="the resolution at which to sample the slice", ) - periodic: Optional[bool] = Field( + periodic: bool = Field( False, description="should the slice be periodic? default False." ) - rescale: Optional[bool] = Field( - False, description="rescale the final image between 0,1" - ) + rescale: bool = Field(False, description="rescale the final image between 0,1") -class SelectionObject(BaseModel): - regions: Optional[List[Region]] = Field( - None, description="a list of regions to load" - ) - slices: Optional[List[Slice]] = Field(None, description="a list of slices to load") +class SelectionObject(_ytBaseModel): + regions: List[Region] = Field(None, description="a list of regions to load") + slices: List[Slice] = Field(None, description="a list of slices to load") -class DataContainer(BaseModel): +class DataContainer(_ytBaseModel): filename: str = Field(None, description="the filename for the dataset") selections: SelectionObject = Field( None, description="selections to load in this dataset" ) - store_in_cache: Optional[bool] = Field( + store_in_cache: bool = Field( ytcfg.get("yt_napari", "in_memory_cache"), description="if enabled, will store references to yt datasets.", ) -class TimeSeriesFileSelection(BaseModel): +class TimeSeriesFileSelection(_ytBaseModel): directory: str = Field(None, description="The directory of the timseries") - file_pattern: Optional[str] = Field(None, description="The file pattern to match") - file_list: Optional[List[str]] = Field(None, description="List of files to load.") - file_range: Optional[Tuple[int, int, int]] = Field( + file_pattern: str = Field(None, description="The file pattern to match") + file_list: List[str] = Field(None, description="List of files to load.") + file_range: Tuple[int, int, int] = Field( None, description="Given files matched by file_pattern, " "this option will select a range. Argument order" @@ -119,12 +117,12 @@ class TimeSeriesFileSelection(BaseModel): ) -class Timeseries(BaseModel): +class Timeseries(_ytBaseModel): file_selection: TimeSeriesFileSelection selections: SelectionObject = Field( None, description="selections to load in this dataset" ) - load_as_stack: Optional[bool] = Field( + load_as_stack: bool = Field( False, description="If True, will stack images along a new dimension." ) # process_in_parallel: Optional[bool] = Field( @@ -132,7 +130,7 @@ class Timeseries(BaseModel): # ) -class InputModel(BaseModel): +class InputModel(_ytBaseModel): datasets: List[DataContainer] = Field( None, description="list of dataset containers to load" ) @@ -155,7 +153,7 @@ def _store_schema(schema_db: Optional[Union[PosixPath, str]] = None, **kwargs): m.write_new_schema(schema_contents, schema_prefix=prefix, **kwargs) -class MetadataModel(BaseModel): +class MetadataModel(_ytBaseModel): filename: str = Field(None, description="the filename for the dataset") include_field_list: bool = Field(True, description="whether to list the fields") _ds_attrs: Tuple[str] = ( @@ -164,3 +162,16 @@ class MetadataModel(BaseModel): "current_time", "domain_dimensions", ) + + +def _get_dm_listing(locals_dict): + _data_model_list = [] + for ky, val in locals_dict.items(): + if inspect.isclass(val) and issubclass(val, _ytBaseModel): + _data_model_list.append(ky) + _data_model_list.append(val) + _data_model_list.append(val.__module__ + "." + ky) + return tuple(_data_model_list) + + +_data_model_list = _get_dm_listing(locals()) diff --git a/src/yt_napari/_gui_utilities.py b/src/yt_napari/_gui_utilities.py index d52cc76..b976789 100644 --- a/src/yt_napari/_gui_utilities.py +++ b/src/yt_napari/_gui_utilities.py @@ -1,8 +1,9 @@ from collections import defaultdict -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Union, get_args, get_origin import pydantic from magicgui import type_map, widgets +from pydantic_core import PydanticUndefinedType from yt_napari import _data_model from yt_napari.logging import ytnapari_log @@ -24,7 +25,7 @@ def __init__(self): def register( self, - pydantic_model: Union[pydantic.BaseModel, pydantic.main.ModelMetaclass], + pydantic_model: pydantic.BaseModel, field: str, magicgui_factory: Callable = None, magicgui_args: Optional[tuple] = None, @@ -104,7 +105,7 @@ def get_pydantic_attr(self, pydantic_model, field: str, widget_instance): def add_pydantic_to_container( self, - py_model: Union[pydantic.BaseModel, pydantic.main.ModelMetaclass], + py_model: pydantic.BaseModel, container: widgets.Container, ignore_attrs: Optional[Union[str, List[str]]] = None, ): @@ -116,21 +117,24 @@ def add_pydantic_to_container( ignore_attrs, ] - for field, field_def in py_model.__fields__.items(): + for field, field_def in py_model.model_fields.items(): if field in ignore_attrs: continue - ftype = field_def.type_ - if isinstance(ftype, pydantic.BaseModel) or isinstance( - ftype, pydantic.main.ModelMetaclass - ): + ftype = field_def.annotation + if _is_base_model_or_yt_obj(field_def): # the field is a pydantic class, add a container for it and fill it new_widget_cls = widgets.Container - new_widget = new_widget_cls(name=field_def.name) + new_widget = new_widget_cls(name=field) self.add_pydantic_to_container(ftype, new_widget) + elif get_origin(field_def.annotation) is list: + new_widget_cls = widgets.Container + new_widget = new_widget_cls(name=field) + ftype_inner = get_args(field_def.annotation)[0] + self.add_pydantic_to_container(ftype_inner, new_widget) elif self.is_registered(py_model, field): new_widget = self.get_widget_instance(py_model, field) else: - new_widget = get_magicguidefault(field_def) + new_widget = get_magicguidefault(field, field_def) if isinstance(new_widget, widgets.EmptyWidget): msg = "magicgui could not identify a widget for " msg += f" {py_model}.{field}, which has type {ftype}" @@ -140,7 +144,7 @@ def add_pydantic_to_container( def get_pydantic_kwargs( self, container: widgets.Container, - py_model, + py_model: pydantic.BaseModel, pydantic_kwargs: dict, ignore_attrs: Optional[Union[str, List[str]]] = None, ): @@ -152,17 +156,16 @@ def get_pydantic_kwargs( ] # traverse model fields, pull out values from container - for field, field_def in py_model.__fields__.items(): + for field, field_def in py_model.model_fields.items(): if field in ignore_attrs: continue - ftype = field_def.type_ - if isinstance(ftype, pydantic.BaseModel) or isinstance( - ftype, pydantic.main.ModelMetaclass - ): + ftype = field_def.annotation + + if _is_base_model_or_yt_obj(field_def): new_kwargs = {} # new dictionary for the new nest level # any pydantic class will be a container, so pull that out to pass # to the recursive call - sub_container = getattr(container, field_def.name) + sub_container = getattr(container, field) self.get_pydantic_kwargs(sub_container, ftype, new_kwargs) if "typing.List" in str(field_def.outer_type_): new_kwargs = [ @@ -171,16 +174,14 @@ def get_pydantic_kwargs( pydantic_kwargs[field] = new_kwargs elif self.is_registered(py_model, field): - widget_instance = getattr( - container, field_def.name - ) # pull from container + widget_instance = getattr(container, field) # pull from container pydantic_kwargs[field] = self.get_pydantic_attr( py_model, field, widget_instance ) else: # not a pydantic class, just pull the field value from the container - if hasattr(container, field_def.name): - value = getattr(container, field_def.name).value + if hasattr(container, field): + value = getattr(container, field).value pydantic_kwargs[field] = value @@ -196,19 +197,21 @@ def get_filename(file_widget: widgets.FileEdit): return str(file_widget.value) -def get_magicguidefault(field_def: pydantic.fields.ModelField): +def get_magicguidefault(field_name: str, field_def: pydantic.fields.Field): # returns an instance of the default widget selected by magicgui - ftype = field_def.type_ + # returns an instance of the default widget selected by magicgui + ftype = field_def.annotation + opts_dict = dict(name=field_name, annotation=ftype) + if ( + not type(field_def.default) is PydanticUndefinedType + and field_def.default is not None + ): + opts_dict["value"] = field_def.default new_widget_cls, ops = type_map.get_widget_class( - None, - ftype, - dict(name=field_def.name, value=field_def.default, annotation=ftype), + annotation=ftype, + options=opts_dict, raise_on_unknown=False, ) - if field_def.default is None: - # for some widgets, explicitly passing None as a default will error - _ = ops.pop("value", None) - return new_widget_cls(**ops) @@ -225,8 +228,10 @@ def split_comma_sep_string(widget_instance) -> List[str]: return files.split(",") -def _get_pydantic_model_field(py_model, field: str) -> pydantic.fields.ModelField: - return py_model.__fields__[field] +def _get_pydantic_model_field( + py_model: pydantic.BaseModel, field: str +) -> pydantic.fields.Field: + return py_model.model_fields[field] # the following model-field tuples will be embedded in containers @@ -254,7 +259,7 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): py_model, field, magicgui_factory=get_magicguidefault, - magicgui_args=(py_model.__fields__[field]), + magicgui_args=(field, py_model.model_fields[field]), pydantic_attr_factory=embed_in_list, ) translator.register( @@ -269,7 +274,10 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): _data_model.TimeSeriesFileSelection, "file_list", magicgui_factory=get_magicguidefault, - magicgui_args=(_data_model.TimeSeriesFileSelection.__fields__["file_list"],), + magicgui_args=( + "file_list", + _data_model.TimeSeriesFileSelection.model_fields["file_list"], + ), pydantic_attr_factory=split_comma_sep_string, ) @@ -280,9 +288,7 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): def get_yt_data_container( ignore_attrs: Optional[Union[str, List[str]]] = None, - pydantic_model_class: Optional[ - Union[pydantic.BaseModel, pydantic.main.ModelMetaclass] - ] = None, + pydantic_model_class: Optional[pydantic.BaseModel] = None, ) -> widgets.Container: if pydantic_model_class is None: pydantic_model_class = _data_model.DataContainer @@ -319,3 +325,11 @@ def get_yt_metadata_container(): data_container = widgets.Container() translator.add_pydantic_to_container(_data_model.MetadataModel, data_container) return data_container + + +def _is_base_model_or_yt_obj(field_info: pydantic.fields.FieldInfo): + ftype = field_info.annotation + ispydy = isinstance(ftype, pydantic.BaseModel) + if ispydy: + return ispydy + return ftype in _data_model._data_model_list diff --git a/src/yt_napari/_tests/test_gui_utilities.py b/src/yt_napari/_tests/test_gui_utilities.py index 1ccbca8..5dd0f6c 100644 --- a/src/yt_napari/_tests/test_gui_utilities.py +++ b/src/yt_napari/_tests/test_gui_utilities.py @@ -33,7 +33,7 @@ class LowerModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel): field_1: int = 1 - vec_field1: Tuple[float, float] = (1.0, 2.0) + vec_field1: Tuple[float, float] = pydantic.Field((1.0, 2.0)) vec_field2: Tuple[float, float, float] bad_field: TypeVar("BadType") low_model: LowerModel @@ -99,12 +99,12 @@ def test_pydantic_magicgui_default(Model, backend, caplog): app = use_app(backend) # noqa: F841 model_field = Model.__fields__["field_1"] - c = gu.get_magicguidefault(model_field) + c = gu.get_magicguidefault("field_1", model_field) assert c.value == model_field.default c.close() model_field = Model.__fields__["bad_field"] - empty = gu.get_magicguidefault(model_field) + empty = gu.get_magicguidefault("bad_field", model_field) assert isinstance(empty, widgets.EmptyWidget) empty.close()