diff --git a/fastplotlib/widgets/image_widget/__init__.py b/fastplotlib/widgets/image_widget/__init__.py index 70a1aa8a..2c217038 100644 --- a/fastplotlib/widgets/image_widget/__init__.py +++ b/fastplotlib/widgets/image_widget/__init__.py @@ -2,6 +2,7 @@ if IMGUI: from ._widget import ImageWidget + from ._array import ImageWidgetArray else: diff --git a/fastplotlib/widgets/image_widget/_array.py b/fastplotlib/widgets/image_widget/_array.py new file mode 100644 index 00000000..ad70548e --- /dev/null +++ b/fastplotlib/widgets/image_widget/_array.py @@ -0,0 +1,221 @@ +import numpy as np +from numpy.typing import NDArray +from typing import Literal, Callable +from warnings import warn + + +class ImageWidgetArray: + def __init__( + self, + data: NDArray, + rgb: bool = False, + process_function: Callable = None, + window_size: dict[str, int] = None, + n_display_dims: Literal[2, 3] = 2, + dim_names: tuple[str] = None, + ): + """ + + Parameters + ---------- + data: NDArray + array-like data, must have 2 or more dimensions + + process_function: Callable, optional + function to apply to a window of data around the current index. + The callable must take an `axis` kwarg. + + window_size: dict[str, int] + dict of window sizes for each dim, maps dim names -> window size. + Example: {"t": 5, "z": 3}. + + If a dim is not provided the window size is 0 for that dim, i.e. no window is taken along that dimension + + n_display_dims: int, 2 or 3, default 2 + number of display dimensions + + dim_names: tuple[str], optional + dimension names as a tuple of strings, ex: ("t", "z", "x", "y") + """ + self._data = data + + self._window_size = process_function + self._window_size = window_size + + self._rgb = rgb + + # default dim names for mn, tmn, and tzmn, ignore rgb dim if present + if dim_names is None: + if data.ndim == (2 + int(self.rgb)): + dim_names = ("m", "n") + + elif data.ndim == (3 + int(self.rgb)): + dim_names = ("t", "m", "n") + + elif data.ndim == (4 + int(self.rgb)): + dim_names = ("t", "z", "m", "n") + + else: + # create a tuple of str numbers for each time, ex: ("0", "1", "2", "3", "4", "5", "6") + dim_names = tuple(map(str, range(data.ndim))) + + self._dim_names = dim_names + + for k in self._window_size: + if k not in dim_names: + raise KeyError + + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be an with a value of 2 or 3") + + self._n_display_dims = n_display_dims + + @property + def data(self) -> NDArray: + return self._data + + @data.setter + def data(self, data: NDArray): + self._data = data + + @property + def rgb(self) -> bool: + return self._rgb + + @property + def ndim(self) -> int: + return self.data.ndim + + @property + def n_scrollable_dims(self) -> int: + return self.ndim - 2 - int(self.rgb) + + @property + def n_display_dims(self) -> int: + return self._n_display_dims + + @property + def dim_names(self) -> tuple[str]: + return self._dim_names + + @property + def window_function(self) -> Callable | None: + return self._window_size + + @window_function.setter + def window_function(self, func: Callable | None): + self._window_size = func + + @property + def window_size(self) -> dict | None: + """dict of window sizes for each dim""" + return self._window_size + + @window_size.setter + def window_size(self, size: dict): + for k in list(size.keys()): + if k not in self.dim_names: + raise ValueError(f"specified window key: `k` not present in array with dim names: {self.dim_names}") + + if not isinstance(size[k], int): + raise TypeError("window size values must be integers") + + if size[k] < 0: + raise ValueError(f"window size values must be greater than 2 and odd numbers") + + if size[k] == 0: + # remove key + warn(f"specified window size of 0 for dim: {k}, removing dim from windows") + size.pop(k) + + elif size[k] % 2 != 0: + # odd number, add 1 + warn(f"specified even number for window size of dim: {k}, adding one to make it even") + size[k] += 1 + + self._window_size = size + + def _apply_window_function(self, index: dict[str, int]): + if self.n_scrollable_dims == 0: + # 2D image, return full data + # TODO: would be smart to handle this in ImageWidget so + # that Texture buffer is not updated when it doesn't change!! + return self.data + + if self.window_size is None: + # for simplicity, so we can use the same for loop below to slice the array + # regardless of whether window_functions are specified or not + window_size = dict() + else: + window_size = self.window_size + + # create a slice object for every dim except the last 2, or 3 (if rgb) + multi_slice = list() + axes = list() + + for dim_number in range(self.n_scrollable_dims): + # get str name + dim_name = self.dim_names[dim_number] + + # don't go beyond max bound + max_bound = self.data.shape[dim_number] + + # check if a window is specified for this dim + if dim_name in window_size.keys(): + size = window_size[dim_name] + half_size = int((size - 1) / 2) + + # create slice obj for this dim using this window + start = max(0, index[dim_name] - half_size) # start index, min allowed value is 0 + stop = min(max_bound, index[dim_name] + half_size) + + s = slice(start, stop) + multi_slice.append(s) + + # add to axes list for window function + axes.append(dim_number) + else: + # no window size is specified for this scrollable dim, directly use integer index + multi_slice.append(index[dim_name]) + + # get sliced array + array_sliced = self.data[tuple(multi_slice)] + + if self.window_function is not None: + # apply window function + return self.window_function(array_sliced, axis=axes) + + # not window function, return sliced array + return array_sliced + + def get(self, index: dict[str, int]): + """ + Get the data at the given index, process data through the window function and frame function. + + Note that we do not use __getitem__ here since the index is a dict specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + index: dict[str, int] + Get the processed data at this index. + Example: get({"t": 1000, "z" 3}) + + """ + + if set(index.keys()) != set(self.dim_names): + raise ValueError( + f"Must specify index for every dim, you have specified an index: {index}\n" + f"All dim names are: {self.dim_names}" + ) + + window_output = self._apply_window_function(index) + + if window_output.ndim != self.n_display_dims: + raise ValueError( + f"Output of the `process_function` must match the number of display dims." + f"`process_function` returned an array with {window_output.ndim} dims, " + f"expected {self.n_display_dims} dims" + ) + + return window_output diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py index 65009795..479e4591 100644 --- a/fastplotlib/widgets/image_widget/_widget.py +++ b/fastplotlib/widgets/image_widget/_widget.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Callable from warnings import warn @@ -11,6 +10,7 @@ from ...utils import calculate_figure_shape, quick_min_max from ...tools import HistogramLUTTool from ._sliders import ImageWidgetSliders +from ._array import ImageWidgetArray # Number of dimensions that represent one image/one frame @@ -289,6 +289,7 @@ def _get_n_scrollable_dims(self, curr_arr: np.ndarray, rgb: bool) -> list[int]: def __init__( self, data: np.ndarray | list[np.ndarray], + array_types: ImageWidgetArray | list[ImageWidgetArray] = ImageWidgetArray, window_funcs: dict[str, tuple[Callable, int]] = None, frame_apply: Callable | dict[int, Callable] = None, figure_shape: tuple[int, int] = None,