diff --git a/examples/desktop/scatter/scatter_size.py b/examples/desktop/scatter/scatter_size.py new file mode 100644 index 000000000..5b6987b7c --- /dev/null +++ b/examples/desktop/scatter/scatter_size.py @@ -0,0 +1,56 @@ +""" +Scatter Plot +============ +Example showing point size change for scatter plot. +""" + +# test_example = true +import numpy as np +import fastplotlib as fpl + +# grid with 2 rows and 3 columns +grid_shape = (2,1) + +# pan-zoom controllers for each view +# views are synced if they have the +# same controller ID +controllers = [ + [0], + [0] +] + + +# you can give string names for each subplot within the gridplot +names = [ + ["scalar_size"], + ["array_size"] +] + +# Create the grid plot +plot = fpl.GridPlot( + shape=grid_shape, + controllers=controllers, + names=names, + size=(1000, 1000) +) + +# get y_values using sin function +angles = np.arange(0, 20*np.pi+0.001, np.pi / 20) +y_values = 30*np.sin(angles) # 1 thousand points +x_values = np.array([x for x in range(len(y_values))], dtype=np.float32) + +data = np.column_stack([x_values, y_values]) + +plot["scalar_size"].add_scatter(data=data, sizes=5, colors="blue") # add a set of scalar sizes + +non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5 +plot["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, colors="red") + +for graph in plot: + graph.auto_scale(maintain_aspect=True) + +plot.show() + +if __name__ == "__main__": + print(__doc__) + fpl.run() \ No newline at end of file diff --git a/examples/desktop/screenshots/scatter_size.png b/examples/desktop/screenshots/scatter_size.png new file mode 100644 index 000000000..db637d270 --- /dev/null +++ b/examples/desktop/screenshots/scatter_size.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4cefd4cf57e54e1ef7883edea54806dfde57939d0a395c5a7758124e41b8beb +size 63485 diff --git a/examples/notebooks/scatter_sizes_animation.ipynb b/examples/notebooks/scatter_sizes_animation.ipynb new file mode 100644 index 000000000..061f444d6 --- /dev/null +++ b/examples/notebooks/scatter_sizes_animation.ipynb @@ -0,0 +1,71 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from time import time\n", + "\n", + "import numpy as np\n", + "import fastplotlib as fpl\n", + "\n", + "plot = fpl.Plot()\n", + "\n", + "points = np.array([[-1,0,1],[-1,0,1]], dtype=np.float32).swapaxes(0,1)\n", + "size_delta_scales = np.array([10, 40, 100], dtype=np.float32)\n", + "min_sizes = 6\n", + "\n", + "def update_positions():\n", + " current_time = time()\n", + " newPositions = points + np.sin(((current_time / 4) % 1)*np.pi)\n", + " plot.graphics[0].data = newPositions\n", + " plot.camera.width = 4*np.max(newPositions[0,:])\n", + " plot.camera.height = 4*np.max(newPositions[1,:])\n", + "\n", + "def update_sizes():\n", + " current_time = time()\n", + " sin_sample = np.sin(((current_time / 4) % 1)*np.pi)\n", + " size_delta = sin_sample*size_delta_scales\n", + " plot.graphics[0].sizes = min_sizes + size_delta\n", + "\n", + "points = np.array([[0,0], \n", + " [1,1], \n", + " [2,2]])\n", + "scatter = plot.add_scatter(points, colors=[\"red\", \"green\", \"blue\"], sizes=12)\n", + "plot.add_animations(update_positions, update_sizes)\n", + "plot.show(autoscale=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fastplotlib-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/notebooks/scatter_sizes_grid.ipynb b/examples/notebooks/scatter_sizes_grid.ipynb new file mode 100644 index 000000000..ff64184f7 --- /dev/null +++ b/examples/notebooks/scatter_sizes_grid.ipynb @@ -0,0 +1,86 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "Scatter Plot\n", + "============\n", + "Example showing point size change for scatter plot.\n", + "\"\"\"\n", + "\n", + "# test_example = true\n", + "import numpy as np\n", + "import fastplotlib as fpl\n", + "\n", + "# grid with 2 rows and 3 columns\n", + "grid_shape = (2,1)\n", + "\n", + "# pan-zoom controllers for each view\n", + "# views are synced if they have the \n", + "# same controller ID\n", + "controllers = [\n", + " [0],\n", + " [0]\n", + "]\n", + "\n", + "\n", + "# you can give string names for each subplot within the gridplot\n", + "names = [\n", + " [\"scalar_size\"],\n", + " [\"array_size\"]\n", + "]\n", + "\n", + "# Create the grid plot\n", + "plot = fpl.GridPlot(\n", + " shape=grid_shape,\n", + " controllers=controllers,\n", + " names=names,\n", + " size=(1000, 1000)\n", + ")\n", + "\n", + "# get y_values using sin function\n", + "angles = np.arange(0, 20*np.pi+0.001, np.pi / 20)\n", + "y_values = 30*np.sin(angles) # 1 thousand points\n", + "x_values = np.array([x for x in range(len(y_values))], dtype=np.float32)\n", + "\n", + "data = np.column_stack([x_values, y_values])\n", + "\n", + "plot[\"scalar_size\"].add_scatter(data=data, sizes=5, colors=\"blue\") # add a set of scalar sizes\n", + "\n", + "non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5\n", + "plot[\"array_size\"].add_scatter(data=data, sizes=non_scalar_sizes, colors=\"red\")\n", + "\n", + "for graph in plot:\n", + " graph.auto_scale(maintain_aspect=True)\n", + "\n", + "plot.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fastplotlib-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fastplotlib/graphics/_features/__init__.py b/fastplotlib/graphics/_features/__init__.py index 8e78a6260..a6ce9c3a3 100644 --- a/fastplotlib/graphics/_features/__init__.py +++ b/fastplotlib/graphics/_features/__init__.py @@ -1,5 +1,6 @@ from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature +from ._sizes import PointsSizesFeature from ._present import PresentFeature from ._thickness import ThicknessFeature from ._base import GraphicFeature, GraphicFeatureIndexable, FeatureEvent, to_gpu_supported_dtype @@ -11,6 +12,7 @@ "ImageCmapFeature", "HeatmapCmapFeature", "PointsDataFeature", + "PointsSizesFeature", "ImageDataFeature", "HeatmapDataFeature", "PresentFeature", diff --git a/fastplotlib/graphics/_features/_sizes.py b/fastplotlib/graphics/_features/_sizes.py index e69de29bb..377052918 100644 --- a/fastplotlib/graphics/_features/_sizes.py +++ b/fastplotlib/graphics/_features/_sizes.py @@ -0,0 +1,108 @@ +from typing import Any + +import numpy as np + +import pygfx + +from ._base import ( + GraphicFeatureIndexable, + cleanup_slice, + FeatureEvent, + to_gpu_supported_dtype, + cleanup_array_slice, +) + + +class PointsSizesFeature(GraphicFeatureIndexable): + """ + Access to the vertex buffer data shown in the graphic. + Supports fancy indexing if the data array also supports it. + """ + + def __init__(self, parent, sizes: Any, collection_index: int = None): + sizes = self._fix_sizes(sizes, parent) + super(PointsSizesFeature, self).__init__( + parent, sizes, collection_index=collection_index + ) + + @property + def buffer(self) -> pygfx.Buffer: + return self._parent.world_object.geometry.sizes + + def __getitem__(self, item): + return self.buffer.data[item] + + def _fix_sizes(self, sizes, parent): + graphic_type = parent.__class__.__name__ + + n_datapoints = parent.data().shape[0] + if not isinstance(sizes, (list, tuple, np.ndarray)): + sizes = np.full(n_datapoints, sizes, dtype=np.float32) # force it into a float to avoid weird gpu errors + elif not isinstance(sizes, np.ndarray): # if it's not a ndarray already, make it one + sizes = np.array(sizes, dtype=np.float32) # read it in as a numpy.float32 + if (sizes.ndim != 1) or (sizes.size != parent.data().shape[0]): + raise ValueError( + f"sequence of `sizes` must be 1 dimensional with " + f"the same length as the number of datapoints" + ) + + sizes = to_gpu_supported_dtype(sizes) + + if any(s < 0 for s in sizes): + raise ValueError("All sizes must be positive numbers greater than or equal to 0.0.") + + if sizes.ndim == 1: + if graphic_type == "ScatterGraphic": + sizes = np.array(sizes) + else: + raise ValueError(f"Sizes must be an array of shape (n,) where n == the number of data points provided.\ + Received shape={sizes.shape}.") + + return np.array(sizes) + + def __setitem__(self, key, value): + if isinstance(key, np.ndarray): + # make sure 1D array of int or boolean + key = cleanup_array_slice(key, self._upper_bound) + + # put sizes into right shape if they're only indexing datapoints + if isinstance(key, (slice, int, np.ndarray, np.integer)): + value = self._fix_sizes(value, self._parent) + # otherwise assume that they have the right shape + # numpy will throw errors if it can't broadcast + + if value.size != self.buffer.data[key].size: + raise ValueError(f"{value.size} is not equal to buffer size {self.buffer.data[key].size}.\ + If you want to set size to a non-scalar value, make sure it's the right length!") + + self.buffer.data[key] = value + self._update_range(key) + # avoid creating dicts constantly if there are no events to handle + if len(self._event_handlers) > 0: + self._feature_changed(key, value) + + def _update_range(self, key): + self._update_range_indices(key) + + def _feature_changed(self, key, new_data): + if key is not None: + key = cleanup_slice(key, self._upper_bound) + if isinstance(key, (int, np.integer)): + indices = [key] + elif isinstance(key, slice): + indices = range(key.start, key.stop, key.step) + elif isinstance(key, np.ndarray): + indices = key + elif key is None: + indices = None + + pick_info = { + "index": indices, + "collection-index": self._collection_index, + "world_object": self._parent.world_object, + "new_data": new_data, + } + + event_data = FeatureEvent(type="sizes", pick_info=pick_info) + + self._call_event_handlers(event_data) \ No newline at end of file diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index 9e162c57a..141db2af3 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -5,16 +5,16 @@ from ..utils import parse_cmap_values from ._base import Graphic -from ._features import PointsDataFeature, ColorFeature, CmapFeature +from ._features import PointsDataFeature, ColorFeature, CmapFeature, PointsSizesFeature class ScatterGraphic(Graphic): - feature_events = ("data", "colors", "cmap", "present") + feature_events = ("data", "sizes", "colors", "cmap", "present") def __init__( self, data: np.ndarray, - sizes: Union[int, np.ndarray, list] = 1, + sizes: Union[int, float, np.ndarray, list] = 1, colors: np.ndarray = "w", alpha: float = 1.0, cmap: str = None, @@ -86,24 +86,11 @@ def __init__( self, self.colors(), cmap_name=cmap, cmap_values=cmap_values ) - if isinstance(sizes, int): - sizes = np.full(self.data().shape[0], sizes, dtype=np.float32) - elif isinstance(sizes, np.ndarray): - if (sizes.ndim != 1) or (sizes.size != self.data().shape[0]): - raise ValueError( - f"numpy array of `sizes` must be 1 dimensional with " - f"the same length as the number of datapoints" - ) - elif isinstance(sizes, list): - if len(sizes) != self.data().shape[0]: - raise ValueError( - "list of `sizes` must have the same length as the number of datapoints" - ) - + self.sizes = PointsSizesFeature(self, sizes) super(ScatterGraphic, self).__init__(*args, **kwargs) world_object = pygfx.Points( - pygfx.Geometry(positions=self.data(), sizes=sizes, colors=self.colors()), + pygfx.Geometry(positions=self.data(), sizes=self.sizes(), colors=self.colors()), material=pygfx.PointsMaterial(vertex_colors=True, vertex_sizes=True), )