From 704588d7d9a577bad92fbdbfe5fecc59aa7abea2 Mon Sep 17 00:00:00 2001 From: chrishavlin Date: Fri, 9 Feb 2024 15:54:45 -0600 Subject: [PATCH 1/6] in progress --- setup.cfg | 4 +- src/yt_napari/_data_model.py | 85 +++++++++++--------- src/yt_napari/_gui_utilities.py | 90 +++++++++++++--------- src/yt_napari/_tests/test_gui_utilities.py | 6 +- 4 files changed, 105 insertions(+), 80 deletions(-) diff --git a/setup.cfg b/setup.cfg index 69505fa..8b45b9e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,10 +29,10 @@ project_urls = packages = find: install_requires = magicgui>=0.6.1 - napari>=0.4.13 + napari>=0.4.18 numpy packaging - pydantic + pydantic>2.0 qtpy unyt yt>=4.0.1 diff --git a/src/yt_napari/_data_model.py b/src/yt_napari/_data_model.py index 4330a23..3775986 100644 --- a/src/yt_napari/_data_model.py +++ b/src/yt_napari/_data_model.py @@ -8,110 +8,108 @@ from yt_napari.schemas import _manager -class ytField(BaseModel): +class _ytBaseModel(BaseModel): + pass + + +class ytField(_ytBaseModel): field_type: str = Field(None, description="a field type in the yt dataset") field_name: str = Field(None, description="a field in the yt dataset") - take_log: Optional[bool] = Field( + take_log: bool = Field( True, description="if true, will apply log10 to the selected data" ) -class Length_Value(BaseModel): +class Length_Value(_ytBaseModel): value: float = Field(None, description="Single unitful value.") unit: str = Field("code_length", description="the unit length string.") -class Left_Edge(BaseModel): +class Left_Edge(_ytBaseModel): value: Tuple[float, float, float] = Field( (0.0, 0.0, 0.0), description="3-element unitful tuple." ) unit: str = Field("code_length", description="the unit length string.") -class Right_Edge(BaseModel): +class Right_Edge(_ytBaseModel): value: Tuple[float, float, float] = Field( (1.0, 1.0, 1.0), description="3-element unitful tuple." ) unit: str = Field("code_length", description="the unit length string.") -class Length_Tuple(BaseModel): +class Length_Tuple(_ytBaseModel): value: Tuple[float, float, float] = Field( None, description="3-element unitful tuple." ) unit: str = Field("code_length", description="the unit length string.") -class Region(BaseModel): +class Region(_ytBaseModel): fields: List[ytField] = Field( None, description="list of fields to load for this selection" ) - left_edge: Optional[Left_Edge] = Field( + left_edge: Left_Edge = Field( None, description="the left edge (min x, min y, min z)", ) - right_edge: Optional[Right_Edge] = Field( + right_edge: Right_Edge = Field( None, description="the right edge (max x, max y, max z)", ) - resolution: Optional[Tuple[int, int, int]] = Field( + resolution: Tuple[int, int, int] = Field( (400, 400, 400), description="the resolution at which to sample between the edges.", ) - rescale: Optional[bool] = Field( - False, description="rescale the final image between 0,1" - ) + rescale: bool = Field(False, description="rescale the final image between 0,1") -class Slice(BaseModel): +class Slice(_ytBaseModel): fields: List[ytField] = Field( None, description="list of fields to load for this selection" ) normal: str = Field(None, description="the normal axis of the slice") - center: Optional[Length_Tuple] = Field( + center: Length_Tuple = Field( None, description="The center point of the slice, default domain center" ) - slice_width: Optional[Length_Value] = Field( + slice_width: Length_Value = Field( None, description="The slice width, defaults to full domain" ) - slice_height: Optional[Length_Value] = Field( + slice_height: Length_Value = Field( None, description="The slice width, defaults to full domain" ) - resolution: Optional[Tuple[int, int]] = Field( + resolution: Tuple[int, int] = Field( (400, 400), description="the resolution at which to sample the slice", ) - periodic: Optional[bool] = Field( + periodic: bool = Field( False, description="should the slice be periodic? default False." ) - rescale: Optional[bool] = Field( - False, description="rescale the final image between 0,1" - ) + rescale: bool = Field(False, description="rescale the final image between 0,1") -class SelectionObject(BaseModel): - regions: Optional[List[Region]] = Field( - None, description="a list of regions to load" - ) - slices: Optional[List[Slice]] = Field(None, description="a list of slices to load") +class SelectionObject(_ytBaseModel): + regions: List[Region] = Field(None, description="a list of regions to load") + slices: List[Slice] = Field(None, description="a list of slices to load") -class DataContainer(BaseModel): +class DataContainer(_ytBaseModel): filename: str = Field(None, description="the filename for the dataset") selections: SelectionObject = Field( None, description="selections to load in this dataset" ) - store_in_cache: Optional[bool] = Field( + store_in_cache: bool = Field( ytcfg.get("yt_napari", "in_memory_cache"), description="if enabled, will store references to yt datasets.", ) -class TimeSeriesFileSelection(BaseModel): +class TimeSeriesFileSelection(_ytBaseModel): directory: str = Field(None, description="The directory of the timseries") - file_pattern: Optional[str] = Field(None, description="The file pattern to match") - file_list: Optional[List[str]] = Field(None, description="List of files to load.") - file_range: Optional[Tuple[int, int, int]] = Field( + file_pattern: str = Field(None, description="The file pattern to match") + file_list: List[str] = Field(None, description="List of files to load.") + file_range: Tuple[int, int, int] = Field( None, description="Given files matched by file_pattern, " "this option will select a range. Argument order" @@ -119,12 +117,12 @@ class TimeSeriesFileSelection(BaseModel): ) -class Timeseries(BaseModel): +class Timeseries(_ytBaseModel): file_selection: TimeSeriesFileSelection selections: SelectionObject = Field( None, description="selections to load in this dataset" ) - load_as_stack: Optional[bool] = Field( + load_as_stack: bool = Field( False, description="If True, will stack images along a new dimension." ) # process_in_parallel: Optional[bool] = Field( @@ -132,7 +130,7 @@ class Timeseries(BaseModel): # ) -class InputModel(BaseModel): +class InputModel(_ytBaseModel): datasets: List[DataContainer] = Field( None, description="list of dataset containers to load" ) @@ -155,7 +153,7 @@ def _store_schema(schema_db: Optional[Union[PosixPath, str]] = None, **kwargs): m.write_new_schema(schema_contents, schema_prefix=prefix, **kwargs) -class MetadataModel(BaseModel): +class MetadataModel(_ytBaseModel): filename: str = Field(None, description="the filename for the dataset") include_field_list: bool = Field(True, description="whether to list the fields") _ds_attrs: Tuple[str] = ( @@ -164,3 +162,16 @@ class MetadataModel(BaseModel): "current_time", "domain_dimensions", ) + + +def _get_dm_listing(locals_dict): + _data_model_list = [] + for ky, val in locals_dict.items(): + if inspect.isclass(val) and issubclass(val, _ytBaseModel): + _data_model_list.append(ky) + _data_model_list.append(val) + _data_model_list.append(val.__module__ + "." + ky) + return tuple(_data_model_list) + + +_data_model_list = _get_dm_listing(locals()) diff --git a/src/yt_napari/_gui_utilities.py b/src/yt_napari/_gui_utilities.py index d52cc76..b976789 100644 --- a/src/yt_napari/_gui_utilities.py +++ b/src/yt_napari/_gui_utilities.py @@ -1,8 +1,9 @@ from collections import defaultdict -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Union, get_args, get_origin import pydantic from magicgui import type_map, widgets +from pydantic_core import PydanticUndefinedType from yt_napari import _data_model from yt_napari.logging import ytnapari_log @@ -24,7 +25,7 @@ def __init__(self): def register( self, - pydantic_model: Union[pydantic.BaseModel, pydantic.main.ModelMetaclass], + pydantic_model: pydantic.BaseModel, field: str, magicgui_factory: Callable = None, magicgui_args: Optional[tuple] = None, @@ -104,7 +105,7 @@ def get_pydantic_attr(self, pydantic_model, field: str, widget_instance): def add_pydantic_to_container( self, - py_model: Union[pydantic.BaseModel, pydantic.main.ModelMetaclass], + py_model: pydantic.BaseModel, container: widgets.Container, ignore_attrs: Optional[Union[str, List[str]]] = None, ): @@ -116,21 +117,24 @@ def add_pydantic_to_container( ignore_attrs, ] - for field, field_def in py_model.__fields__.items(): + for field, field_def in py_model.model_fields.items(): if field in ignore_attrs: continue - ftype = field_def.type_ - if isinstance(ftype, pydantic.BaseModel) or isinstance( - ftype, pydantic.main.ModelMetaclass - ): + ftype = field_def.annotation + if _is_base_model_or_yt_obj(field_def): # the field is a pydantic class, add a container for it and fill it new_widget_cls = widgets.Container - new_widget = new_widget_cls(name=field_def.name) + new_widget = new_widget_cls(name=field) self.add_pydantic_to_container(ftype, new_widget) + elif get_origin(field_def.annotation) is list: + new_widget_cls = widgets.Container + new_widget = new_widget_cls(name=field) + ftype_inner = get_args(field_def.annotation)[0] + self.add_pydantic_to_container(ftype_inner, new_widget) elif self.is_registered(py_model, field): new_widget = self.get_widget_instance(py_model, field) else: - new_widget = get_magicguidefault(field_def) + new_widget = get_magicguidefault(field, field_def) if isinstance(new_widget, widgets.EmptyWidget): msg = "magicgui could not identify a widget for " msg += f" {py_model}.{field}, which has type {ftype}" @@ -140,7 +144,7 @@ def add_pydantic_to_container( def get_pydantic_kwargs( self, container: widgets.Container, - py_model, + py_model: pydantic.BaseModel, pydantic_kwargs: dict, ignore_attrs: Optional[Union[str, List[str]]] = None, ): @@ -152,17 +156,16 @@ def get_pydantic_kwargs( ] # traverse model fields, pull out values from container - for field, field_def in py_model.__fields__.items(): + for field, field_def in py_model.model_fields.items(): if field in ignore_attrs: continue - ftype = field_def.type_ - if isinstance(ftype, pydantic.BaseModel) or isinstance( - ftype, pydantic.main.ModelMetaclass - ): + ftype = field_def.annotation + + if _is_base_model_or_yt_obj(field_def): new_kwargs = {} # new dictionary for the new nest level # any pydantic class will be a container, so pull that out to pass # to the recursive call - sub_container = getattr(container, field_def.name) + sub_container = getattr(container, field) self.get_pydantic_kwargs(sub_container, ftype, new_kwargs) if "typing.List" in str(field_def.outer_type_): new_kwargs = [ @@ -171,16 +174,14 @@ def get_pydantic_kwargs( pydantic_kwargs[field] = new_kwargs elif self.is_registered(py_model, field): - widget_instance = getattr( - container, field_def.name - ) # pull from container + widget_instance = getattr(container, field) # pull from container pydantic_kwargs[field] = self.get_pydantic_attr( py_model, field, widget_instance ) else: # not a pydantic class, just pull the field value from the container - if hasattr(container, field_def.name): - value = getattr(container, field_def.name).value + if hasattr(container, field): + value = getattr(container, field).value pydantic_kwargs[field] = value @@ -196,19 +197,21 @@ def get_filename(file_widget: widgets.FileEdit): return str(file_widget.value) -def get_magicguidefault(field_def: pydantic.fields.ModelField): +def get_magicguidefault(field_name: str, field_def: pydantic.fields.Field): # returns an instance of the default widget selected by magicgui - ftype = field_def.type_ + # returns an instance of the default widget selected by magicgui + ftype = field_def.annotation + opts_dict = dict(name=field_name, annotation=ftype) + if ( + not type(field_def.default) is PydanticUndefinedType + and field_def.default is not None + ): + opts_dict["value"] = field_def.default new_widget_cls, ops = type_map.get_widget_class( - None, - ftype, - dict(name=field_def.name, value=field_def.default, annotation=ftype), + annotation=ftype, + options=opts_dict, raise_on_unknown=False, ) - if field_def.default is None: - # for some widgets, explicitly passing None as a default will error - _ = ops.pop("value", None) - return new_widget_cls(**ops) @@ -225,8 +228,10 @@ def split_comma_sep_string(widget_instance) -> List[str]: return files.split(",") -def _get_pydantic_model_field(py_model, field: str) -> pydantic.fields.ModelField: - return py_model.__fields__[field] +def _get_pydantic_model_field( + py_model: pydantic.BaseModel, field: str +) -> pydantic.fields.Field: + return py_model.model_fields[field] # the following model-field tuples will be embedded in containers @@ -254,7 +259,7 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): py_model, field, magicgui_factory=get_magicguidefault, - magicgui_args=(py_model.__fields__[field]), + magicgui_args=(field, py_model.model_fields[field]), pydantic_attr_factory=embed_in_list, ) translator.register( @@ -269,7 +274,10 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): _data_model.TimeSeriesFileSelection, "file_list", magicgui_factory=get_magicguidefault, - magicgui_args=(_data_model.TimeSeriesFileSelection.__fields__["file_list"],), + magicgui_args=( + "file_list", + _data_model.TimeSeriesFileSelection.model_fields["file_list"], + ), pydantic_attr_factory=split_comma_sep_string, ) @@ -280,9 +288,7 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): def get_yt_data_container( ignore_attrs: Optional[Union[str, List[str]]] = None, - pydantic_model_class: Optional[ - Union[pydantic.BaseModel, pydantic.main.ModelMetaclass] - ] = None, + pydantic_model_class: Optional[pydantic.BaseModel] = None, ) -> widgets.Container: if pydantic_model_class is None: pydantic_model_class = _data_model.DataContainer @@ -319,3 +325,11 @@ def get_yt_metadata_container(): data_container = widgets.Container() translator.add_pydantic_to_container(_data_model.MetadataModel, data_container) return data_container + + +def _is_base_model_or_yt_obj(field_info: pydantic.fields.FieldInfo): + ftype = field_info.annotation + ispydy = isinstance(ftype, pydantic.BaseModel) + if ispydy: + return ispydy + return ftype in _data_model._data_model_list diff --git a/src/yt_napari/_tests/test_gui_utilities.py b/src/yt_napari/_tests/test_gui_utilities.py index 1ccbca8..5dd0f6c 100644 --- a/src/yt_napari/_tests/test_gui_utilities.py +++ b/src/yt_napari/_tests/test_gui_utilities.py @@ -33,7 +33,7 @@ class LowerModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel): field_1: int = 1 - vec_field1: Tuple[float, float] = (1.0, 2.0) + vec_field1: Tuple[float, float] = pydantic.Field((1.0, 2.0)) vec_field2: Tuple[float, float, float] bad_field: TypeVar("BadType") low_model: LowerModel @@ -99,12 +99,12 @@ def test_pydantic_magicgui_default(Model, backend, caplog): app = use_app(backend) # noqa: F841 model_field = Model.__fields__["field_1"] - c = gu.get_magicguidefault(model_field) + c = gu.get_magicguidefault("field_1", model_field) assert c.value == model_field.default c.close() model_field = Model.__fields__["bad_field"] - empty = gu.get_magicguidefault(model_field) + empty = gu.get_magicguidefault("bad_field", model_field) assert isinstance(empty, widgets.EmptyWidget) empty.close() From a8b76e28560c2557b42e4d1f9a4712f11eedb6da Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 12 Feb 2024 10:16:44 -0600 Subject: [PATCH 2/6] in progress, passing test_gui_utilities --- src/yt_napari/_gui_utilities.py | 18 ++++++++++-------- src/yt_napari/_tests/test_gui_utilities.py | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/yt_napari/_gui_utilities.py b/src/yt_napari/_gui_utilities.py index b976789..5d9ccff 100644 --- a/src/yt_napari/_gui_utilities.py +++ b/src/yt_napari/_gui_utilities.py @@ -126,13 +126,14 @@ def add_pydantic_to_container( new_widget_cls = widgets.Container new_widget = new_widget_cls(name=field) self.add_pydantic_to_container(ftype, new_widget) - elif get_origin(field_def.annotation) is list: - new_widget_cls = widgets.Container - new_widget = new_widget_cls(name=field) - ftype_inner = get_args(field_def.annotation)[0] - self.add_pydantic_to_container(ftype_inner, new_widget) elif self.is_registered(py_model, field): - new_widget = self.get_widget_instance(py_model, field) + if get_origin(field_def.annotation) is list: + new_widget_cls = widgets.Container + new_widget = new_widget_cls(name=field) + ftype_inner = get_args(field_def.annotation)[0] + self.add_pydantic_to_container(ftype_inner, new_widget) + else: + new_widget = self.get_widget_instance(py_model, field) else: new_widget = get_magicguidefault(field, field_def) if isinstance(new_widget, widgets.EmptyWidget): @@ -255,11 +256,12 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): ) for py_model, field in _models_to_embed_in_list: + # lists are automatically embedded in pydantic + # containers if registered, only need to provide + # the function for building the pydantic args. translator.register( py_model, field, - magicgui_factory=get_magicguidefault, - magicgui_args=(field, py_model.model_fields[field]), pydantic_attr_factory=embed_in_list, ) translator.register( diff --git a/src/yt_napari/_tests/test_gui_utilities.py b/src/yt_napari/_tests/test_gui_utilities.py index 5dd0f6c..18674f9 100644 --- a/src/yt_napari/_tests/test_gui_utilities.py +++ b/src/yt_napari/_tests/test_gui_utilities.py @@ -98,12 +98,12 @@ def test_yt_widget(backend): def test_pydantic_magicgui_default(Model, backend, caplog): app = use_app(backend) # noqa: F841 - model_field = Model.__fields__["field_1"] + model_field = Model.model_fields["field_1"] c = gu.get_magicguidefault("field_1", model_field) assert c.value == model_field.default c.close() - model_field = Model.__fields__["bad_field"] + model_field = Model.model_fields["bad_field"] empty = gu.get_magicguidefault("bad_field", model_field) assert isinstance(empty, widgets.EmptyWidget) empty.close() From 11fce0e75ee40ccbd834184237490774f732963a Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 12 Feb 2024 10:34:06 -0600 Subject: [PATCH 3/6] reader validation working --- src/yt_napari/_data_model.py | 2 +- src/yt_napari/_model_ingestor.py | 3 ++- src/yt_napari/_tests/test_schema_manager.py | 2 +- src/yt_napari/schemas/_version_comparison.py | 5 +++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/yt_napari/_data_model.py b/src/yt_napari/_data_model.py index 3775986..7d87822 100644 --- a/src/yt_napari/_data_model.py +++ b/src/yt_napari/_data_model.py @@ -139,7 +139,7 @@ class InputModel(_ytBaseModel): def _get_standard_schema_contents() -> Tuple[str, str]: - prefix = InputModel._schema_prefix + prefix = InputModel._schema_prefix.default schema_contents = InputModel.schema_json(indent=2) return prefix, schema_contents diff --git a/src/yt_napari/_model_ingestor.py b/src/yt_napari/_model_ingestor.py index 6779482..a0ab1ca 100644 --- a/src/yt_napari/_model_ingestor.py +++ b/src/yt_napari/_model_ingestor.py @@ -733,7 +733,8 @@ def load_from_json(json_paths: List[str]) -> List[Layer]: timeseries_layers = [] # timeseries layers handled separately for json_path in json_paths: # InputModel is a pydantic class, the following will validate the json - model = InputModel.parse_file(json_path) + with open(json_path, "r") as open_file: + model = InputModel.model_validate_json(open_file.read()) # now that we have a validated model, we can use the model attributes # to execute the code that will return our array for the image diff --git a/src/yt_napari/_tests/test_schema_manager.py b/src/yt_napari/_tests/test_schema_manager.py index 574b02a..f178a50 100644 --- a/src/yt_napari/_tests/test_schema_manager.py +++ b/src/yt_napari/_tests/test_schema_manager.py @@ -60,7 +60,7 @@ def get_expected(prefix, vstring): def test_schema_generation(tmp_path): _store_schema(schema_db=tmp_path) m = Manager(schema_db=tmp_path) - pfx = InputModel._schema_prefix + pfx = InputModel._schema_prefix.default expected_file = tmp_path.joinpath(m._filename(pfx, "0.0.1")) file_exists = expected_file.is_file() assert file_exists diff --git a/src/yt_napari/schemas/_version_comparison.py b/src/yt_napari/schemas/_version_comparison.py index a5561af..261b650 100644 --- a/src/yt_napari/schemas/_version_comparison.py +++ b/src/yt_napari/schemas/_version_comparison.py @@ -16,7 +16,7 @@ def _get_version_tuple(): def schema_version_is_valid( schema_version: str, dev_version_check: bool = True ) -> bool: - pfx = InputModel._schema_prefix + pfx = InputModel._schema_prefix.default if schema_version is None or pfx not in schema_version: # the schema does not match a known schema for this plugin return False @@ -67,6 +67,7 @@ def _schema_version_tuple_from_str(schema_version_raw: str) -> Tuple[int, int, i return _get_version_tuple() schema_end = schema_version_raw.split("/")[-1] - v_schema = schema_end.replace(InputModel._schema_prefix, "") + sc_prefix = InputModel._schema_prefix.default + v_schema = schema_end.replace(sc_prefix, "") v_schema = v_schema.replace("_", "").replace(".json", "") return tuple([int(v) for v in v_schema.split(".")]) From eeba550e17988714c668cc36ec483aff2cc4e6ea Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 12 Feb 2024 12:19:07 -0600 Subject: [PATCH 4/6] passing all tests locally on 3.10, some warnings remain --- src/yt_napari/_gui_utilities.py | 48 +++++++++++++------- src/yt_napari/_model_ingestor.py | 2 +- src/yt_napari/_tests/test_model_ingestor.py | 2 +- src/yt_napari/_tests/test_regions_json.py | 2 +- src/yt_napari/_tests/test_slices_json.py | 6 +-- src/yt_napari/_tests/test_timeseries_json.py | 8 ++-- src/yt_napari/_tests/test_widget_reader.py | 20 ++++---- src/yt_napari/_widget_matadata.py | 2 +- src/yt_napari/_widget_reader.py | 20 ++++---- src/yt_napari/timeseries.py | 7 ++- 10 files changed, 67 insertions(+), 50 deletions(-) diff --git a/src/yt_napari/_gui_utilities.py b/src/yt_napari/_gui_utilities.py index 5d9ccff..49a52cd 100644 --- a/src/yt_napari/_gui_utilities.py +++ b/src/yt_napari/_gui_utilities.py @@ -33,6 +33,7 @@ def register( pydantic_attr_factory: Callable = None, pydantic_attr_args: Optional[tuple] = None, pydantic_attr_kwargs: Optional[dict] = None, + auto_list_handling: Optional[bool] = False, ): """ @@ -55,6 +56,9 @@ def register( a tuple containing arguments to pydantic_attr_factory pydantic_attr_kwargs : a dict containing keyword arguments to pydantic_attr_factory + auto_list_handling: bool + True if the field is an embedded list that should be + handled as a single entry in a container """ magicgui_args = set_default(magicgui_args, ()) magicgui_kwargs = set_default(magicgui_kwargs, {}) @@ -69,6 +73,7 @@ def register( pydantic_attr_args, pydantic_attr_kwargs, ), + "auto_list_handling": auto_list_handling, } self.registry[pydantic_model][field] = new_entry @@ -127,7 +132,8 @@ def add_pydantic_to_container( new_widget = new_widget_cls(name=field) self.add_pydantic_to_container(ftype, new_widget) elif self.is_registered(py_model, field): - if get_origin(field_def.annotation) is list: + auto_list = self.registry[py_model][field]["auto_list_handling"] + if get_origin(field_def.annotation) is list and auto_list: new_widget_cls = widgets.Container new_widget = new_widget_cls(name=field) ftype_inner = get_args(field_def.annotation)[0] @@ -161,24 +167,34 @@ def get_pydantic_kwargs( if field in ignore_attrs: continue ftype = field_def.annotation - if _is_base_model_or_yt_obj(field_def): new_kwargs = {} # new dictionary for the new nest level # any pydantic class will be a container, so pull that out to pass # to the recursive call sub_container = getattr(container, field) self.get_pydantic_kwargs(sub_container, ftype, new_kwargs) - if "typing.List" in str(field_def.outer_type_): + if get_origin(ftype) is list: new_kwargs = [ new_kwargs, ] pydantic_kwargs[field] = new_kwargs elif self.is_registered(py_model, field): - widget_instance = getattr(container, field) # pull from container - pydantic_kwargs[field] = self.get_pydantic_attr( - py_model, field, widget_instance - ) + auto_list = self.registry[py_model][field]["auto_list_handling"] + if get_origin(ftype) is list and auto_list: + inner = get_args(ftype)[0] + sub_container = getattr(container, field) # pull from container + new_kwargs = {} + self.get_pydantic_kwargs(sub_container, inner, new_kwargs) + new_kwargs = [ + new_kwargs, + ] + pydantic_kwargs[field] = new_kwargs + else: + widget_instance = getattr(container, field) # pull from container + pydantic_kwargs[field] = self.get_pydantic_attr( + py_model, field, widget_instance + ) else: # not a pydantic class, just pull the field value from the container if hasattr(container, field): @@ -222,11 +238,11 @@ def embed_in_list(widget_instance) -> list: return returnval -def split_comma_sep_string(widget_instance) -> List[str]: - files = widget_instance.value - for ch in " []": - files = files.replace(ch, "") - return files.split(",") +def handle_str_list_edit(widget_instance) -> List[str]: + # recent versions of magicgui will return a ListEdit + # where the value is a list of strings here, just + # return it. + return widget_instance.value def _get_pydantic_model_field( @@ -256,13 +272,11 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): ) for py_model, field in _models_to_embed_in_list: - # lists are automatically embedded in pydantic - # containers if registered, only need to provide - # the function for building the pydantic args. + # identified lists are automatically handled a bit differently. translator.register( py_model, field, - pydantic_attr_factory=embed_in_list, + auto_list_handling=True, ) translator.register( _data_model.MetadataModel, @@ -280,7 +294,7 @@ def _register_yt_data_model(translator: MagicPydanticRegistry): "file_list", _data_model.TimeSeriesFileSelection.model_fields["file_list"], ), - pydantic_attr_factory=split_comma_sep_string, + pydantic_attr_factory=handle_str_list_edit, ) diff --git a/src/yt_napari/_model_ingestor.py b/src/yt_napari/_model_ingestor.py index a0ab1ca..f6c12e4 100644 --- a/src/yt_napari/_model_ingestor.py +++ b/src/yt_napari/_model_ingestor.py @@ -190,7 +190,7 @@ def selections_match(sel_1: Union[Slice, Region], sel_2: Union[Slice, Region]) - if not type(sel_2) is type(sel_1): return False - for attr in sel_1.__fields__.keys(): + for attr in sel_1.model_fields.keys(): if attr != "fields": val_1 = getattr(sel_1, attr) val_2 = getattr(sel_2, attr) diff --git a/src/yt_napari/_tests/test_model_ingestor.py b/src/yt_napari/_tests/test_model_ingestor.py index 60b55a3..403b4f2 100644 --- a/src/yt_napari/_tests/test_model_ingestor.py +++ b/src/yt_napari/_tests/test_model_ingestor.py @@ -461,7 +461,7 @@ def test_find_timeseries_file_selection(tmp_path, file_sel_dict): fdir = str(fdir) file_sel_dict["directory"] = fdir - tsfs = _mi.TimeSeriesFileSelection.parse_obj(file_sel_dict) + tsfs = _mi.TimeSeriesFileSelection.model_validate(file_sel_dict) files = _mi._find_timeseries_files(tsfs) if "file_list" not in file_sel_dict: diff --git a/src/yt_napari/_tests/test_regions_json.py b/src/yt_napari/_tests/test_regions_json.py index 39332c6..db6f21b 100644 --- a/src/yt_napari/_tests/test_regions_json.py +++ b/src/yt_napari/_tests/test_regions_json.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("jdict", jdicts) def test_load_region(jdict): jdict["datasets"][0]["selections"]["regions"][0]["rescale"] = True - m = InputModel.parse_obj(jdict) + m = InputModel.model_validate(jdict) layers, _ = _process_validated_model(m) im_data = layers[0][0] assert im_data.min() == 0.0 diff --git a/src/yt_napari/_tests/test_slices_json.py b/src/yt_napari/_tests/test_slices_json.py index d772856..4d6fa7a 100644 --- a/src/yt_napari/_tests/test_slices_json.py +++ b/src/yt_napari/_tests/test_slices_json.py @@ -49,18 +49,18 @@ @pytest.mark.parametrize("jdict", jdicts) def test_basic_slice_validation(jdict): - _ = InputModel.parse_obj(jdict) + _ = InputModel.model_validate(jdict) @pytest.mark.parametrize("jdict", jdicts) def test_slice_load(yt_ugrid_ds_fn, jdict): - im = InputModel.parse_obj(jdict) + im = InputModel.model_validate(jdict) layer_lists, _ = _process_validated_model(im) ref_layer = _choose_ref_layer(layer_lists) _ = ref_layer.align_sanitize_layers(layer_lists) jdict["datasets"][0]["selections"]["slices"][0]["rescale"] = True - im = InputModel.parse_obj(jdict) + im = InputModel.model_validate(jdict) layer_lists, _ = _process_validated_model(im) im_data = layer_lists[0][0] assert im_data.min() == 0 diff --git a/src/yt_napari/_tests/test_timeseries_json.py b/src/yt_napari/_tests/test_timeseries_json.py index 2caef64..47c3928 100644 --- a/src/yt_napari/_tests/test_timeseries_json.py +++ b/src/yt_napari/_tests/test_timeseries_json.py @@ -75,7 +75,7 @@ @pytest.mark.parametrize("jdict", jdicts) def test_basic_validation(jdict): - _ = InputModel.parse_obj(jdict) + _ = InputModel.model_validate(jdict) @pytest.mark.parametrize("jdict,expected_res", zip(jdicts, [(10, 10), (10, 10, 10)])) @@ -88,7 +88,7 @@ def test_full_load(tmp_path, jdict, expected_res): jdict_new = jdict.copy() jdict_new["timeseries"][0]["file_selection"] = f_dict - im = InputModel.parse_obj(jdict_new) + im = InputModel.model_validate(jdict_new) files = mi._find_timeseries_files(im.timeseries[0].file_selection) assert all([file in files for file in flist]) @@ -109,7 +109,7 @@ def test_unstacked_load(tmp_path, jdict): jdict_new["timeseries"][0]["file_selection"] = f_dict jdict_new["timeseries"][0]["load_as_stack"] = False - im = InputModel.parse_obj(jdict_new) + im = InputModel.model_validate(jdict_new) _, ts_layers = mi._process_validated_model(im) assert len(ts_layers) == 2 * nfiles # two fields per file @@ -150,7 +150,7 @@ def test_aspect_rat(tmp_path): ], } - im = InputModel.parse_obj(jdict_ar) + im = InputModel.model_validate(jdict_ar) _, ts_layers = mi._process_validated_model(im) for _, im_kwargs, _ in ts_layers: print(im_kwargs) diff --git a/src/yt_napari/_tests/test_widget_reader.py b/src/yt_napari/_tests/test_widget_reader.py index 7d9280a..794d18d 100644 --- a/src/yt_napari/_tests/test_widget_reader.py +++ b/src/yt_napari/_tests/test_widget_reader.py @@ -67,10 +67,9 @@ def test_save_widget_reader(make_napari_viewer, yt_ugrid_ds_fn, tmp_path): r._post_load_function = rebuild temp_file = tmp_path / "test.json" - with ( - patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, - patch("PyQt5.QtWidgets.QFileDialog.selectedFiles") as mock_selectedFiles, + patch("qtpy.QtWidgets.QFileDialog.exec_") as mock_exec, + patch("qtpy.QtWidgets.QFileDialog.selectedFiles") as mock_selectedFiles, ): # Set the return values for the mocked functions mock_exec.return_value = 1 @@ -97,7 +96,7 @@ def test_save_widget_reader(make_napari_viewer, yt_ugrid_ds_fn, tmp_path): ] # ensure that the saved json is a valid model - _ = InputModel.parse_obj(saved_data) + _ = InputModel.model_validate(saved_data) r.deleteLater() @@ -188,8 +187,11 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path): assert len(viewer.layers) == nfiles viewer.layers.clear() - filestr_list = "_ytnapari_load_grid-0001, _ytnapari_load_grid-0002" - tsr.ds_container.file_selection.file_list.value = filestr_list + file_list = [ + "_ytnapari_load_grid-0001", + "_ytnapari_load_grid-0002", + ] + tsr.ds_container.file_selection.file_list.value = file_list tsr.ds_container.file_selection.file_pattern.value = "" tsr.load_data() assert len(viewer.layers) == 2 @@ -198,8 +200,8 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path): # Use patch to replace the actual QFileDialog functions with mock functions with ( - patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, - patch("PyQt5.QtWidgets.QFileDialog.selectedFiles") as mock_selectedFiles, + patch("qtpy.QtWidgets.QFileDialog.exec_") as mock_exec, + patch("qtpy.QtWidgets.QFileDialog.selectedFiles") as mock_selectedFiles, ): # Set the return values for the mocked functions mock_exec.return_value = 1 # Assuming QDialog::Accepted is 1 @@ -231,6 +233,6 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path): ] # ensure that the saved json is a valid model - _ = InputModel.parse_obj(saved_data) + _ = InputModel.model_validate(saved_data) tsr.deleteLater() diff --git a/src/yt_napari/_widget_matadata.py b/src/yt_napari/_widget_matadata.py index 622a88e..49cdec1 100644 --- a/src/yt_napari/_widget_matadata.py +++ b/src/yt_napari/_widget_matadata.py @@ -55,7 +55,7 @@ def inspect_file(self): ) # instantiate the base model - model = _data_model.MetadataModel.parse_obj(py_kwargs) + model = _data_model.MetadataModel.model_validate(py_kwargs) # process it! meta_data_dict, fields_by_type = _model_ingestor._process_metadata_model(model) diff --git a/src/yt_napari/_widget_reader.py b/src/yt_napari/_widget_reader.py index 5d77a3f..d66c1f8 100644 --- a/src/yt_napari/_widget_reader.py +++ b/src/yt_napari/_widget_reader.py @@ -145,9 +145,8 @@ def load_data(self): # instantiate pydantic objects, which are then handed off to the # same data ingestion function as the json loader. - py_kwargs = {} py_kwargs = self._validate_data_model() - model = _data_model.InputModel.parse_obj(py_kwargs) + model = _data_model.InputModel.model_validate(py_kwargs) # process each layer layer_list, _ = _model_ingestor._process_validated_model(model) @@ -264,9 +263,7 @@ def add_load_group_widgets(self): load_group.addWidget(ss.native) def save_selection(self): - py_kwargs = {} py_kwargs = self._validate_data_model() - # model = _data_model.InputModel.parse_obj(py_kwargs) file_dialog = QFileDialog() file_dialog.setFileMode(QFileDialog.AnyFile) @@ -281,17 +278,15 @@ def save_selection(self): json.dump(py_kwargs, json_file, indent=4) def load_data(self): - py_kwargs = {} py_kwargs = self._validate_data_model() - model = _data_model.InputModel.parse_obj(py_kwargs) - + model = _data_model.InputModel.model_validate(py_kwargs) if _use_threading: # pragma: no cover worker = time_series_load(model) worker.returned.connect(self.process_timeseries_layers) worker.start() else: _, layer_list = _model_ingestor._process_validated_model(model) - self.process_timeseries_layers(layer_list) + self.process_timeseries_layers(layer_list) def process_timeseries_layers(self, layer_list): for new_layer in layer_list: @@ -320,13 +315,14 @@ def _validate_data_model(self): ) if py_kwargs["file_selection"]["file_pattern"] == "": - py_kwargs["file_selection"]["file_pattern"] = None + _ = py_kwargs["file_selection"].pop("file_pattern") - if py_kwargs["file_selection"]["file_list"] == [""]: - py_kwargs["file_selection"]["file_list"] = None + flist = py_kwargs["file_selection"]["file_list"] + if flist == [""] or len(flist) == 0: + _ = py_kwargs["file_selection"].pop("file_list") if py_kwargs["file_selection"]["file_range"] == (0, 0, 0): - py_kwargs["file_selection"]["file_range"] = None + _ = py_kwargs["file_selection"].pop("file_range") # add selections in py_kwargs["selections"] = selections_by_type diff --git a/src/yt_napari/timeseries.py b/src/yt_napari/timeseries.py index fdfe958..f4e7f80 100644 --- a/src/yt_napari/timeseries.py +++ b/src/yt_napari/timeseries.py @@ -312,12 +312,17 @@ def _get_im_data( stack_scaling: Optional[float] = 1.0, **kwargs, ): - tfs = _dm.TimeSeriesFileSelection( + ts_kwargs = dict( file_pattern=file_pattern, directory=file_dir, file_list=file_list, file_range=file_range, ) + for ky in ["file_pattern", "directory", "file_list", "file_range"]: + if ts_kwargs[ky] is None: + _ = ts_kwargs.pop(ky) + + tfs = _dm.TimeSeriesFileSelection(**ts_kwargs) files = _mi._find_timeseries_files(tfs) im_data = [] From c3e73fb54724acf9fc0b1aaa1d2f3ea0ffeabd11 Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 12 Feb 2024 12:24:25 -0600 Subject: [PATCH 5/6] final warnings fixed --- src/yt_napari/_data_model.py | 4 +++- src/yt_napari/_tests/test_schema_manager.py | 2 +- src/yt_napari/schemas/_manager.py | 10 +++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/yt_napari/_data_model.py b/src/yt_napari/_data_model.py index 7d87822..2c4c383 100644 --- a/src/yt_napari/_data_model.py +++ b/src/yt_napari/_data_model.py @@ -1,4 +1,5 @@ import inspect +import json from pathlib import PosixPath from typing import List, Optional, Tuple, Union @@ -140,7 +141,8 @@ class InputModel(_ytBaseModel): def _get_standard_schema_contents() -> Tuple[str, str]: prefix = InputModel._schema_prefix.default - schema_contents = InputModel.schema_json(indent=2) + schema_contents = InputModel.model_json_schema() + schema_contents = json.dumps(schema_contents, indent=2) return prefix, schema_contents diff --git a/src/yt_napari/_tests/test_schema_manager.py b/src/yt_napari/_tests/test_schema_manager.py index f178a50..cdf4e30 100644 --- a/src/yt_napari/_tests/test_schema_manager.py +++ b/src/yt_napari/_tests/test_schema_manager.py @@ -65,7 +65,7 @@ def test_schema_generation(tmp_path): file_exists = expected_file.is_file() assert file_exists - schema_contents = InputModel.schema_json(indent=2) + schema_contents = InputModel.model_json_schema() with pytest.raises(ValueError): m.write_new_schema(schema_contents, schema_prefix="bad_prefix") diff --git a/src/yt_napari/schemas/_manager.py b/src/yt_napari/schemas/_manager.py index 3f2c423..c6418f1 100644 --- a/src/yt_napari/schemas/_manager.py +++ b/src/yt_napari/schemas/_manager.py @@ -1,3 +1,4 @@ +import json import shutil from collections import defaultdict from os import PathLike @@ -55,7 +56,7 @@ def _filename(self, schema_prefix: str, schema_version: str) -> PosixPath: def write_new_schema( self, - schema_json: str, + schema_json: Union[str, dict], schema_prefix: Optional[str] = None, inc_micro: Optional[bool] = True, inc_minor: Optional[bool] = False, @@ -69,8 +70,9 @@ def write_new_schema( Parameters: ----------- - schema_json: str - the json string to write, assumes that it is already validated + schema_json: str or dict + the json string or dict to write, assumes that it is already validated. + If dict, will call json_dumps with indent=2. schema_prefix: Optional[str] file prefix for the schema. Version incrementing will only check schemas with matching prefix for determining the current version. @@ -119,6 +121,8 @@ def write_new_schema( # write out json to filename ytnapari_log.info(f"writing new schema {filename}") + if isinstance(schema_json, dict): + schema_json = json.dumps(schema_json, indent=2) with open(filename, "w") as f: f.write(schema_json) From eeb2bbb1c5c30e8e7413f8d06ca07a578ae8990f Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 12 Feb 2024 12:44:25 -0600 Subject: [PATCH 6/6] coverage fixes --- src/yt_napari/_gui_utilities.py | 3 --- src/yt_napari/_tests/test_schema_manager.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/yt_napari/_gui_utilities.py b/src/yt_napari/_gui_utilities.py index 49a52cd..3bda5d8 100644 --- a/src/yt_napari/_gui_utilities.py +++ b/src/yt_napari/_gui_utilities.py @@ -345,7 +345,4 @@ def get_yt_metadata_container(): def _is_base_model_or_yt_obj(field_info: pydantic.fields.FieldInfo): ftype = field_info.annotation - ispydy = isinstance(ftype, pydantic.BaseModel) - if ispydy: - return ispydy return ftype in _data_model._data_model_list diff --git a/src/yt_napari/_tests/test_schema_manager.py b/src/yt_napari/_tests/test_schema_manager.py index cdf4e30..5d23615 100644 --- a/src/yt_napari/_tests/test_schema_manager.py +++ b/src/yt_napari/_tests/test_schema_manager.py @@ -23,7 +23,7 @@ def get_expected(prefix, vstring): # run again with defaults, should increment expected_file = get_expected(pfx, "0.0.2") - m.write_new_schema("any old string") + m.write_new_schema({"or a": "dictionary"}) assert expected_file.is_file() # test other increments