From dd502e4a9cc0226c7ef60b4cf59a8b5cbb0fb6b5 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 28 Jun 2025 21:11:20 -0400 Subject: [PATCH 1/4] start separating iw plotting and array logic --- fastplotlib/widgets/image_widget/__init__.py | 1 + fastplotlib/widgets/image_widget/_array.py | 79 ++++++++++++++++++++ fastplotlib/widgets/image_widget/_widget.py | 3 +- 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 fastplotlib/widgets/image_widget/_array.py diff --git a/fastplotlib/widgets/image_widget/__init__.py b/fastplotlib/widgets/image_widget/__init__.py index 70a1aa8ae..2c217038e 100644 --- a/fastplotlib/widgets/image_widget/__init__.py +++ b/fastplotlib/widgets/image_widget/__init__.py @@ -2,6 +2,7 @@ if IMGUI: from ._widget import ImageWidget + from ._array import ImageWidgetArray else: diff --git a/fastplotlib/widgets/image_widget/_array.py b/fastplotlib/widgets/image_widget/_array.py new file mode 100644 index 000000000..bfc8c8c92 --- /dev/null +++ b/fastplotlib/widgets/image_widget/_array.py @@ -0,0 +1,79 @@ +import numpy as np +from numpy.typing import NDArray +from typing import Literal, Callable + + +class ImageWidgetArray: + def __init__( + self, + data: NDArray, + window_functions: dict = None, + frame_apply: Callable = None, + display_dims: Literal[2, 3] = 2, + dim_names: str = "tzxy", + ): + self._data = data + self._window_functions = window_functions + self._frame_apply = frame_apply + self._dim_names = dim_names + + for k in self._window_functions: + if k not in dim_names: + raise KeyError + + self._display_dims = display_dims + + @property + def data(self) -> NDArray: + return self._data + + @data.setter + def data(self, data: NDArray): + self._data = data + + @property + def window_functions(self) -> dict | None: + return self._window_functions + + @window_functions.setter + def window_functions(self, wf: dict | None): + self._window_functions = wf + + @property + def frame_apply(self, fa: Callable | None): + self._frame_apply = fa + + @frame_apply.setter + def frame_apply(self) -> Callable | None: + return self._frame_apply + + def _apply_window_functions(self, array: NDArray, key): + if self.window_functions is not None: + for dim_name in self._window_functions.keys(): + dim_index = self._dim_names.index(dim_name) + + window_size = self.window_functions[dim_name][1] + half_window_size = int((window_size - 1) / 2) + + max_bound = self._data.shape[dim_index] + + window_indices = range() + + else: + array = array[key] + + return array + + def __getitem__(self, key): + data = self._data + + + data = self._apply_window_functions(data, key) + + if self.frame_apply is not None: + data = self.frame_apply(data) + + if data.ndim != self._display_dims: + raise ValueError + + return data diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py index 650097951..479e45914 100644 --- a/fastplotlib/widgets/image_widget/_widget.py +++ b/fastplotlib/widgets/image_widget/_widget.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Callable from warnings import warn @@ -11,6 +10,7 @@ from ...utils import calculate_figure_shape, quick_min_max from ...tools import HistogramLUTTool from ._sliders import ImageWidgetSliders +from ._array import ImageWidgetArray # Number of dimensions that represent one image/one frame @@ -289,6 +289,7 @@ def _get_n_scrollable_dims(self, curr_arr: np.ndarray, rgb: bool) -> list[int]: def __init__( self, data: np.ndarray | list[np.ndarray], + array_types: ImageWidgetArray | list[ImageWidgetArray] = ImageWidgetArray, window_funcs: dict[str, tuple[Callable, int]] = None, frame_apply: Callable | dict[int, Callable] = None, figure_shape: tuple[int, int] = None, From 4f1fcd9a963d5e0e63c610958989a229a4fc6f8f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 11 Aug 2025 02:24:04 -0400 Subject: [PATCH 2/4] some more basics down --- fastplotlib/widgets/image_widget/_array.py | 220 +++++++++++++++++---- 1 file changed, 185 insertions(+), 35 deletions(-) diff --git a/fastplotlib/widgets/image_widget/_array.py b/fastplotlib/widgets/image_widget/_array.py index bfc8c8c92..54d26fa57 100644 --- a/fastplotlib/widgets/image_widget/_array.py +++ b/fastplotlib/widgets/image_widget/_array.py @@ -1,27 +1,75 @@ import numpy as np from numpy.typing import NDArray from typing import Literal, Callable +from warnings import warn class ImageWidgetArray: def __init__( self, data: NDArray, - window_functions: dict = None, - frame_apply: Callable = None, - display_dims: Literal[2, 3] = 2, - dim_names: str = "tzxy", + rgb: bool = False, + window_function: Callable = None, + window_size: dict[str, int] = None, + frame_function: Callable = None, + n_display_dims: Literal[2, 3] = 2, + dim_names: tuple[str] = None, ): + """ + + Parameters + ---------- + data: NDArray + array-like data, must have 2 or more dimensions + + window_function: Callable, optional + function to apply to a window of data around the current index. + The callable must take an `axis` kwarg. + + window_size: dict[str, int] + dict of window sizes for each dim, maps dim names -> window size. + Example: {"t": 5, "z": 3}. + + If a dim is not provided the window size is 0 for that dim, i.e. no window is taken along that dimension + + frame_function + n_display_dims + dim_names + """ self._data = data - self._window_functions = window_functions - self._frame_apply = frame_apply + + self._window_size = window_function + self._window_size = window_size + + self._frame_function = frame_function + + self._rgb = rgb + + # default dim names for mn, tmn, and tzmn, ignore rgb dim if present + if dim_names is None: + if data.ndim == (2 + int(self.rgb)): + dim_names = ("m", "n") + + elif data.ndim == (3 + int(self.rgb)): + dim_names = ("t", "m", "n") + + elif data.ndim == (4 + int(self.rgb)): + dim_names = ("t", "z", "m", "n") + + else: + # create a tuple of str numbers for each time, ex: ("0", "1", "2", "3", "4", "5", "6") + dim_names = tuple(map(str, range(data.ndim))) + self._dim_names = dim_names - for k in self._window_functions: + for k in self._window_size: if k not in dim_names: raise KeyError - self._display_dims = display_dims + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be an with a value of 2 or 3") + + self._n_display_dims = n_display_dims @property def data(self) -> NDArray: @@ -32,48 +80,150 @@ def data(self, data: NDArray): self._data = data @property - def window_functions(self) -> dict | None: - return self._window_functions + def rgb(self) -> bool: + return self._rgb + + @property + def ndim(self) -> int: + return self.data.ndim + + @property + def n_scrollable_dims(self) -> int: + return self.ndim - 2 - int(self.rgb) + + @property + def n_display_dims(self) -> int: + return self._n_display_dims + + @property + def dim_names(self) -> tuple[str]: + return self._dim_names + + @property + def window_function(self) -> Callable | None: + return self._window_size - @window_functions.setter - def window_functions(self, wf: dict | None): - self._window_functions = wf + @window_function.setter + def window_function(self, func: Callable | None): + self._window_size = func @property - def frame_apply(self, fa: Callable | None): - self._frame_apply = fa + def window_size(self) -> dict | None: + """dict of window sizes for each dim""" + return self._window_size - @frame_apply.setter - def frame_apply(self) -> Callable | None: - return self._frame_apply + @window_size.setter + def window_size(self, size: dict): + for k in list(size.keys()): + if k not in self.dim_names: + raise ValueError(f"specified window key: `k` not present in array with dim names: {self.dim_names}") - def _apply_window_functions(self, array: NDArray, key): - if self.window_functions is not None: - for dim_name in self._window_functions.keys(): - dim_index = self._dim_names.index(dim_name) + if not isinstance(size[k], int): + raise TypeError("window size values must be integers") - window_size = self.window_functions[dim_name][1] - half_window_size = int((window_size - 1) / 2) + if size[k] < 0: + raise ValueError(f"window size values must be greater than 2 and odd numbers") - max_bound = self._data.shape[dim_index] + if size[k] == 0: + # remove key + warn(f"specified window size of 0 for dim: {k}, removing dim from windows") + size.pop(k) - window_indices = range() + elif size[k] % 2 != 0: + # odd number, add 1 + warn(f"specified even number for window size of dim: {k}, adding one to make it even") + size[k] += 1 + self._window_size = size + + @property + def frame_function(self) -> Callable | None: + return self._frame_function + + @frame_function.setter + def frame_function(self, fa: Callable | None): + self._frame_function = fa + + def _apply_window_function(self, index: dict[str, int]): + if self.n_scrollable_dims == 0: + # 2D image, return full data + # TODO: would be smart to handle this in ImageWidget so + # that Texture buffer is not updated when it doesn't change!! + return self.data + + if self.window_size is None: + window_size = dict() else: - array = array[key] + window_size = self.window_size + + # create a slice object for every dim except the last 2, or 3 (if rgb) + multi_slice = list() + axes = list() + + for dim_number in range(self.n_scrollable_dims): + # get str name + dim_name = self.dim_names[dim_number] + + # don't go beyond max bound + max_bound = self.data.shape[dim_number] - return array + # check if a window is specific for this dim + if dim_name in window_size.keys(): + size = window_size[dim_name] + half_size = int((size - 1) / 2) - def __getitem__(self, key): - data = self._data + # create slice obj for this dim using this window + start = max(0, index[dim_name] - half_size) # start index, min allowed value is 0 + stop = min(max_bound, index[dim_name] + half_size) + + s = slice(start, stop) + multi_slice.append(s) + # add to axes list for window function + axes.append(dim_number) + else: + # no window size is specified for this scrollable dim, directly use integer index + multi_slice.append(index[dim_name]) - data = self._apply_window_functions(data, key) + # get sliced array + array_sliced = self.data[tuple(multi_slice)] - if self.frame_apply is not None: - data = self.frame_apply(data) + if self.window_function is not None: + # apply window function + return self.window_function(array_sliced, axis=axes) + + # not window function, return sliced array + return array_sliced + + def get(self, index: dict[str, int]): + """ + Get the data at the given index, process data through the window function and frame function. + + Note that we do not use __getitem__ here since the index is a dict specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + index: dict[str, int] + Get the processed data at this index. + Example: get({"t": 1000, "z" 3}) + + """ + + if set(index.keys()) != set(self.dim_names): + raise ValueError( + f"Must specify index for every dim, you have specified an index: {index}\n" + f"All dim names are: {self.dim_names}" + ) + + window_output = self._apply_window_function(index) + + if self.frame_function is not None: + frame_output = self.frame_function(window_output) + else: + frame_output = window_output - if data.ndim != self._display_dims: + if frame_output.ndim != self.n_display_dims: raise ValueError - return data + return frame_output From 20f1878533de8dc09f42d0cebe6a2de6675bc126 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 11 Aug 2025 02:29:06 -0400 Subject: [PATCH 3/4] comment --- fastplotlib/widgets/image_widget/_array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fastplotlib/widgets/image_widget/_array.py b/fastplotlib/widgets/image_widget/_array.py index 54d26fa57..fb4f4ae3a 100644 --- a/fastplotlib/widgets/image_widget/_array.py +++ b/fastplotlib/widgets/image_widget/_array.py @@ -152,6 +152,8 @@ def _apply_window_function(self, index: dict[str, int]): return self.data if self.window_size is None: + # for simplicity, so we can use the same for loop below to slice the array + # regardless of whether window_functions are specified or not window_size = dict() else: window_size = self.window_size @@ -167,7 +169,7 @@ def _apply_window_function(self, index: dict[str, int]): # don't go beyond max bound max_bound = self.data.shape[dim_number] - # check if a window is specific for this dim + # check if a window is specified for this dim if dim_name in window_size.keys(): size = window_size[dim_name] half_size = int((size - 1) / 2) @@ -175,7 +177,7 @@ def _apply_window_function(self, index: dict[str, int]): # create slice obj for this dim using this window start = max(0, index[dim_name] - half_size) # start index, min allowed value is 0 stop = min(max_bound, index[dim_name] + half_size) - + s = slice(start, stop) multi_slice.append(s) From 330f7f03349464810d6451687ee61ad1f1008ee3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 17 Aug 2025 01:11:21 -0400 Subject: [PATCH 4/4] collapse into just having a window function, no frame_function --- fastplotlib/widgets/image_widget/_array.py | 40 ++++++++-------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/fastplotlib/widgets/image_widget/_array.py b/fastplotlib/widgets/image_widget/_array.py index fb4f4ae3a..ad70548e6 100644 --- a/fastplotlib/widgets/image_widget/_array.py +++ b/fastplotlib/widgets/image_widget/_array.py @@ -9,9 +9,8 @@ def __init__( self, data: NDArray, rgb: bool = False, - window_function: Callable = None, + process_function: Callable = None, window_size: dict[str, int] = None, - frame_function: Callable = None, n_display_dims: Literal[2, 3] = 2, dim_names: tuple[str] = None, ): @@ -22,7 +21,7 @@ def __init__( data: NDArray array-like data, must have 2 or more dimensions - window_function: Callable, optional + process_function: Callable, optional function to apply to a window of data around the current index. The callable must take an `axis` kwarg. @@ -32,17 +31,17 @@ def __init__( If a dim is not provided the window size is 0 for that dim, i.e. no window is taken along that dimension - frame_function - n_display_dims - dim_names + n_display_dims: int, 2 or 3, default 2 + number of display dimensions + + dim_names: tuple[str], optional + dimension names as a tuple of strings, ex: ("t", "z", "x", "y") """ self._data = data - self._window_size = window_function + self._window_size = process_function self._window_size = window_size - self._frame_function = frame_function - self._rgb = rgb # default dim names for mn, tmn, and tzmn, ignore rgb dim if present @@ -136,14 +135,6 @@ def window_size(self, size: dict): self._window_size = size - @property - def frame_function(self) -> Callable | None: - return self._frame_function - - @frame_function.setter - def frame_function(self, fa: Callable | None): - self._frame_function = fa - def _apply_window_function(self, index: dict[str, int]): if self.n_scrollable_dims == 0: # 2D image, return full data @@ -220,12 +211,11 @@ def get(self, index: dict[str, int]): window_output = self._apply_window_function(index) - if self.frame_function is not None: - frame_output = self.frame_function(window_output) - else: - frame_output = window_output - - if frame_output.ndim != self.n_display_dims: - raise ValueError + if window_output.ndim != self.n_display_dims: + raise ValueError( + f"Output of the `process_function` must match the number of display dims." + f"`process_function` returned an array with {window_output.ndim} dims, " + f"expected {self.n_display_dims} dims" + ) - return frame_output + return window_output