Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _widget_reader.py #108

Merged
merged 13 commits into from
Nov 30, 2023
101 changes: 97 additions & 4 deletions src/yt_napari/_tests/test_widget_reader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import json

# note: the cache is disabled for all the tests in this file due to flakiness
# in github CI. It may be that loading from a true file, rather than the
# yt_ugrid_ds_fn fixture would fix that...
import os
from functools import partial
from unittest.mock import patch

import numpy as np

from yt_napari import _widget_reader as _wr
from yt_napari._data_model import InputModel
from yt_napari._ds_cache import dataset_cache

# import ReaderWidget, SelectionEntry, TimeSeriesReader
from yt_napari._special_loaders import _construct_ugrid_timeseries

# note: the cache is disabled for all the tests in this file due to flakiness
# in github CI. It may be that loading from a true file, rather than the
# yt_ugrid_ds_fn fixture would fix that...


def test_widget_reader_add_selections(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
Expand Down Expand Up @@ -45,6 +49,57 @@ def _rebuild_data(final_shape, data):
return np.random.random(final_shape) * data.mean()


def test_save_widget_reader(make_napari_viewer, yt_ugrid_ds_fn, tmp_path):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
r.ds_container.filename.value = yt_ugrid_ds_fn
r.ds_container.store_in_cache.value = False
r.add_new_button.click()
sel = list(r.active_selections.values())[0]
assert isinstance(sel, _wr.SelectionEntry)

mgui_region = sel.selection_container_raw
mgui_region.fields.field_type.value = "enzo"
mgui_region.fields.field_name.value = "Density"
mgui_region.resolution.value = (400, 400, 400)

rebuild = partial(_rebuild_data, mgui_region.resolution.value)
r._post_load_function = rebuild

temp_file = tmp_path / "test.json"

with patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, patch(
"PyQt5.QtWidgets.QFileDialog.selectedFiles"
) as mock_selectedFiles:
# Set the return values for the mocked functions
mock_exec.return_value = 1
mock_selectedFiles.return_value = [temp_file]

r.save_selection()

assert os.path.exists(temp_file)
with open(temp_file, "r") as json_file:
saved_data = json.load(json_file)

assert (
saved_data["datasets"][0]["selections"]["regions"][0]["fields"][0]["field_type"]
== "enzo"
)
assert (
saved_data["datasets"][0]["selections"]["regions"][0]["fields"][0]["field_name"]
== "Density"
)
assert saved_data["datasets"][0]["selections"]["regions"][0]["resolution"] == [
400,
400,
400,
]

# ensure that the saved json is a valid model
_ = InputModel.parse_obj(saved_data)
r.deleteLater()


def test_widget_reader(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
Expand Down Expand Up @@ -138,4 +193,42 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path):
tsr.load_data()
assert len(viewer.layers) == 2

temp_file = tmp_path / "test.json"

# Use patch to replace the actual QFileDialog functions with mock functions
with patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, patch(
"PyQt5.QtWidgets.QFileDialog.selectedFiles"
) as mock_selectedFiles:
# Set the return values for the mocked functions
mock_exec.return_value = 1 # Assuming QDialog::Accepted is 1
mock_selectedFiles.return_value = [temp_file]

# Call the save_selection method
tsr.save_selection()

assert os.path.exists(temp_file)
with open(temp_file, "r") as json_file:
saved_data = json.load(json_file)

assert (
saved_data["timeseries"][0]["selections"]["regions"][0]["fields"][0][
"field_type"
]
== "stream"
)
assert (
saved_data["timeseries"][0]["selections"]["regions"][0]["fields"][0][
"field_name"
]
== "density"
)
assert saved_data["timeseries"][0]["selections"]["regions"][0]["resolution"] == [
10,
10,
10,
]

# ensure that the saved json is a valid model
_ = InputModel.parse_obj(saved_data)

tsr.deleteLater()
150 changes: 106 additions & 44 deletions src/yt_napari/_widget_reader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import json
from collections import defaultdict
from typing import Callable, Optional

import napari
from magicgui import widgets
from napari.qt.threading import thread_worker
from qtpy import QtCore
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QPushButton, QVBoxLayout, QWidget
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QHBoxLayout,
QPushButton,
QVBoxLayout,
QWidget,
)

from yt_napari import _data_model, _gui_utilities, _model_ingestor
from yt_napari._ds_cache import dataset_cache
from yt_napari._schema_version import schema_name
from yt_napari.viewer import _check_for_reference_layer


Expand Down Expand Up @@ -68,7 +77,9 @@ def add_spatial_selection_widgets(self):
self.layout().addLayout(removal_group_layout)

def add_load_group_widgets(self):
pass
"""
add the widgets related to the Load button
"""

def add_a_selection(self):
selection_type = self.new_selection_type.currentText()
Expand Down Expand Up @@ -106,6 +117,26 @@ def add_load_group_widgets(self):
load_group.addWidget(cc.native)
self.layout().addLayout(load_group)

ss = widgets.PushButton(text="Save Selection")
ss.clicked.connect(self.save_selection)
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()

file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.AnyFile)
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setNameFilter("JSON Files (*.json);;All Files (*)")

if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
if file_path:
# Save the JSON data to the selected file
with open(file_path, "w") as json_file:
json.dump(py_kwargs, json_file, indent=4)

def clear_cache(self):
dataset_cache.rm_all()

Expand All @@ -114,7 +145,29 @@ def load_data(self):
# instantiate pydantic objects, which are then handed off to the
# same data ingestion function as the json loader.

# first, get the pydantic args for each selection type, embed in lists
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
ref_layer = _model_ingestor._choose_ref_layer(layer_list)
layer_list = ref_layer.align_sanitize_layers(layer_list)

for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
if self._post_load_function is not None:
im_arr = self._post_load_function(im_arr)

# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# this function save json data
selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
py_kwargs = selection.get_current_pydantic_kwargs()
Expand All @@ -129,34 +182,17 @@ def load_data(self):
py_kwargs,
ignore_attrs="selections",
)

# add selections in
py_kwargs["selections"] = selections_by_type

# now ready to instantiate the base model
py_kwargs = {
"$schema": schema_name,
"datasets": [
py_kwargs,
]
],
}
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
ref_layer = _model_ingestor._choose_ref_layer(layer_list)
layer_list = ref_layer.align_sanitize_layers(layer_list)

for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
if self._post_load_function is not None:
im_arr = self._post_load_function(im_arr)

# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)
return py_kwargs


class SelectionEntry(QWidget):
Expand Down Expand Up @@ -223,7 +259,50 @@ def add_load_group_widgets(self):
load_group.addWidget(pb.native)
self.layout().addLayout(load_group)

ss = widgets.PushButton(text="Save Selection")
ss.clicked.connect(self.save_selection)
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
# model = _data_model.InputModel.parse_obj(py_kwargs)

file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.AnyFile)
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setNameFilter("JSON Files (*.json);;All Files (*)")

if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
if file_path:
# Save the JSON data to the selected file
with open(file_path, "w") as json_file:
json.dump(py_kwargs, json_file, indent=4)

def load_data(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

if _use_threading: # pragma: no cover
worker = time_series_load(model)
worker.returned.connect(self.process_timeseries_layers)
worker.start()
else:
_, layer_list = _model_ingestor._process_validated_model(model)
self.process_timeseries_layers(layer_list)

def process_timeseries_layers(self, layer_list):
for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
# probably can remove since the _special_loaders can be used
# if self._post_load_function is not None:
# im_arr = self._post_load_function(im_arr)
# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# first, get the pydantic args for each selection type, embed in lists
selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
Expand Down Expand Up @@ -254,32 +333,15 @@ def load_data(self):

# now ready to instantiate the base model
py_kwargs = {
"$schema": schema_name,
"timeseries": [
py_kwargs,
]
],
}

model = _data_model.InputModel.parse_obj(py_kwargs)

if _use_threading:
worker = time_series_load(model)
worker.returned.connect(self.process_timeseries_layers)
worker.start()
else:
_, layer_list = _model_ingestor._process_validated_model(model)
self.process_timeseries_layers(layer_list)

def process_timeseries_layers(self, layer_list):
for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
# probably can remove since the _special_loaders can be used
# if self._post_load_function is not None:
# im_arr = self._post_load_function(im_arr)
# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)
return py_kwargs


@thread_worker(progress=True)
def time_series_load(model):
def time_series_load(model): # pragma: no cover
_, layer_list = _model_ingestor._process_validated_model(model)
return layer_list
Loading