diff --git a/src/napari_activelearning/_acquisition.py b/src/napari_activelearning/_acquisition.py index 7bab5c7..fb4f7e5 100644 --- a/src/napari_activelearning/_acquisition.py +++ b/src/napari_activelearning/_acquisition.py @@ -283,15 +283,10 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun, n_samples = 0 img_sampling_positions = [] - pred_spatial_axes = list(dataset_metadata["images"]["axes"]) - if "C" in pred_spatial_axes: - pred_spatial_axes.remove("C") - pred_spatial_axes = "".join(pred_spatial_axes) - if "masks" in dataset_metadata: mask_axes = dataset_metadata["masks"]["source_axes"] else: - mask_axes = pred_spatial_axes + mask_axes = spatial_axes pred_sel = tuple( slice(None) if ax in pred_spatial_axes else None @@ -303,9 +298,9 @@ def compute_acquisition(self, dataset_metadata, acquisition_fun, if USING_TORCH: pos = pos[0].numpy() img = img[0].numpy() - img_sp = img_sp[0, ..., 0].numpy() + img_sp = img_sp.numpy().squeeze() else: - img_sp = img_sp[..., 0] + img_sp = img_sp.squeeze() pos = { ax: slice(pos_ax[0], pos_ax[1]) diff --git a/src/napari_activelearning/_utils.py b/src/napari_activelearning/_utils.py index 1004f8b..0887bbe 100644 --- a/src/napari_activelearning/_utils.py +++ b/src/napari_activelearning/_utils.py @@ -98,12 +98,11 @@ def _compute_transform(self, image): labels_dim = np.arange(cols * rows).reshape(rows, cols) labels = resize(labels_dim, - (image.shape[self.axes.index("Y")], - image.shape[self.axes.index("X")]), + (image.shape[self.ax_Y], image.shape[self.ax_X]), order=0) if image_channels > 1: - labels = np.expand_dims(labels, self.axes.index("C")) + labels = np.expand_dims(labels, self._ax_C) else: labels = labels[..., None]