Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _widget_reader.py #108

Merged
merged 13 commits into from
Nov 30, 2023
103 changes: 99 additions & 4 deletions src/yt_napari/_tests/test_widget_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import json

# note: the cache is disabled for all the tests in this file due to flakiness
# in github CI. It may be that loading from a true file, rather than the
# yt_ugrid_ds_fn fixture would fix that...
import os
from functools import partial
from unittest.mock import patch

import numpy as np

Expand All @@ -8,10 +15,6 @@
# import ReaderWidget, SelectionEntry, TimeSeriesReader
from yt_napari._special_loaders import _construct_ugrid_timeseries

# note: the cache is disabled for all the tests in this file due to flakiness
# in github CI. It may be that loading from a true file, rather than the
# yt_ugrid_ds_fn fixture would fix that...


def test_widget_reader_add_selections(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
Expand Down Expand Up @@ -45,6 +48,61 @@
return np.random.random(final_shape) * data.mean()


def test_save_widget_reader(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
r.ds_container.filename.value = yt_ugrid_ds_fn
r.ds_container.store_in_cache.value = False
r.add_new_button.click()
sel = list(r.active_selections.values())[0]
assert isinstance(sel, _wr.SelectionEntry)

mgui_region = sel.selection_container_raw
mgui_region.fields.field_type.value = "enzo"
mgui_region.fields.field_name.value = "Density"
mgui_region.resolution.value = (400, 400, 400)

rebuild = partial(_rebuild_data, mgui_region.resolution.value)
r._post_load_function = rebuild

temp_file = "test.json"

with patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, patch(
"PyQt5.QtWidgets.QFileDialog.selectedFiles"
) as mock_selectedFiles:
# Set the return values for the mocked functions
mock_exec.return_value = 1
mock_selectedFiles.return_value = [temp_file]

r.save_selection()

assert os.path.exists(temp_file)
with open(temp_file, "r") as json_file:
saved_data = json.load(json_file)

assert (
saved_data["datasets"][0]["selections"]["regions"][0]["fields"][0]["field_type"]
== "enzo"
)
assert (
saved_data["datasets"][0]["selections"]["regions"][0]["fields"][0]["field_name"]
== "Density"
)
assert saved_data["datasets"][0]["selections"]["regions"][0]["resolution"] == [
400,
400,
400,
]

os.remove(temp_file)
r.deleteLater()


def simulate_file_selection(args, **kwargs):
# Simulate filling in a file name, selecting a file type, and closing the dialog
return ("selected_file.json", "JSON Files (.json)")

Check warning on line 103 in src/yt_napari/_tests/test_widget_reader.py

View check run for this annotation

Codecov / codecov/patch

src/yt_napari/_tests/test_widget_reader.py#L103

Added line #L103 was not covered by tests


def test_widget_reader(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
Expand Down Expand Up @@ -138,4 +196,41 @@
tsr.load_data()
assert len(viewer.layers) == 2

temp_file = "test.json"

# Use patch to replace the actual QFileDialog functions with mock functions
with patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, patch(
"PyQt5.QtWidgets.QFileDialog.selectedFiles"
) as mock_selectedFiles:
# Set the return values for the mocked functions
mock_exec.return_value = 1 # Assuming QDialog::Accepted is 1
mock_selectedFiles.return_value = [temp_file]

# Call the save_selection method
tsr.save_selection()

assert os.path.exists(temp_file)
with open(temp_file, "r") as json_file:
saved_data = json.load(json_file)

assert (
saved_data["timeseries"][0]["selections"]["regions"][0]["fields"][0][
"field_type"
]
== "stream"
)
assert (
saved_data["timeseries"][0]["selections"]["regions"][0]["fields"][0][
"field_name"
]
== "density"
)
assert saved_data["timeseries"][0]["selections"]["regions"][0]["resolution"] == [
10,
10,
10,
]

os.remove(temp_file)

tsr.deleteLater()
144 changes: 102 additions & 42 deletions src/yt_napari/_widget_reader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import json
from collections import defaultdict
from typing import Callable, Optional

import napari
from magicgui import widgets
from napari.qt.threading import thread_worker
from qtpy import QtCore
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QPushButton, QVBoxLayout, QWidget
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QHBoxLayout,
QPushButton,
QVBoxLayout,
QWidget,
)

from yt_napari import _data_model, _gui_utilities, _model_ingestor
from yt_napari._ds_cache import dataset_cache
from yt_napari._schema_version import schema_name
from yt_napari.viewer import _check_for_reference_layer


Expand Down Expand Up @@ -106,6 +115,26 @@
load_group.addWidget(cc.native)
self.layout().addLayout(load_group)

ss = widgets.PushButton(text="Save Selection")
ss.clicked.connect(self.save_selection)
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()

file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.AnyFile)
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setNameFilter("JSON Files (*.json);;All Files (*)")

if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
if file_path:
# Save the JSON data to the selected file
with open(file_path, "w") as json_file:
json.dump(py_kwargs, json_file, indent=4)

def clear_cache(self):
dataset_cache.rm_all()

Expand All @@ -114,7 +143,29 @@
# instantiate pydantic objects, which are then handed off to the
# same data ingestion function as the json loader.

# first, get the pydantic args for each selection type, embed in lists
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
ref_layer = _model_ingestor._choose_ref_layer(layer_list)
layer_list = ref_layer.align_sanitize_layers(layer_list)

for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
if self._post_load_function is not None:
im_arr = self._post_load_function(im_arr)

# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# this function save json data
selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
py_kwargs = selection.get_current_pydantic_kwargs()
Expand All @@ -129,34 +180,17 @@
py_kwargs,
ignore_attrs="selections",
)

# add selections in
py_kwargs["selections"] = selections_by_type

# now ready to instantiate the base model
py_kwargs = {
"$schema": schema_name,
"datasets": [
py_kwargs,
]
],
}
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
ref_layer = _model_ingestor._choose_ref_layer(layer_list)
layer_list = ref_layer.align_sanitize_layers(layer_list)

for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
if self._post_load_function is not None:
im_arr = self._post_load_function(im_arr)

# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)
return py_kwargs


class SelectionEntry(QWidget):
Expand Down Expand Up @@ -223,7 +257,50 @@
load_group.addWidget(pb.native)
self.layout().addLayout(load_group)

ss = widgets.PushButton(text="Save Selection")
ss.clicked.connect(self.save_selection)
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
# model = _data_model.InputModel.parse_obj(py_kwargs)

file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.AnyFile)
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setNameFilter("JSON Files (*.json);;All Files (*)")

if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
if file_path:
# Save the JSON data to the selected file
with open(file_path, "w") as json_file:
json.dump(py_kwargs, json_file, indent=4)

def load_data(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

if _use_threading:
worker = time_series_load(model)
worker.returned.connect(self.process_timeseries_layers)
worker.start()

Check warning on line 289 in src/yt_napari/_widget_reader.py

View check run for this annotation

Codecov / codecov/patch

src/yt_napari/_widget_reader.py#L287-L289

Added lines #L287 - L289 were not covered by tests
else:
_, layer_list = _model_ingestor._process_validated_model(model)
self.process_timeseries_layers(layer_list)

def process_timeseries_layers(self, layer_list):
for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
# probably can remove since the _special_loaders can be used
# if self._post_load_function is not None:
# im_arr = self._post_load_function(im_arr)
# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# first, get the pydantic args for each selection type, embed in lists
selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
Expand Down Expand Up @@ -254,29 +331,12 @@

# now ready to instantiate the base model
py_kwargs = {
"$schema": schema_name,
"timeseries": [
py_kwargs,
]
],
}

model = _data_model.InputModel.parse_obj(py_kwargs)

if _use_threading:
worker = time_series_load(model)
worker.returned.connect(self.process_timeseries_layers)
worker.start()
else:
_, layer_list = _model_ingestor._process_validated_model(model)
self.process_timeseries_layers(layer_list)

def process_timeseries_layers(self, layer_list):
for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
# probably can remove since the _special_loaders can be used
# if self._post_load_function is not None:
# im_arr = self._post_load_function(im_arr)
# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)
return py_kwargs


@thread_worker(progress=True)
Expand Down
Loading