From ff1fcb51a5b00e244d88e37b89c057766434ba1a Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 28 Mar 2023 10:59:51 -0500 Subject: [PATCH] Fix infix expression `_value` and `_expr` usage Previously, `infixexpr._value` was sometimes `MatrixExpression` and sometimes `Matrix`. Scary! --- .pre-commit-config.yaml | 6 +++--- environment.yml | 1 + graphblas/core/expr.py | 31 ++++++++++++++++++++----------- graphblas/core/formatting.py | 1 + graphblas/core/infix.py | 16 ++++++++-------- graphblas/core/recorder.py | 3 ++- graphblas/tests/test_infix.py | 18 ++++++++++++++++++ pyproject.toml | 3 ++- 8 files changed, 55 insertions(+), 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab097216e..05469a926 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.12.1 + rev: v0.12.2 hooks: - id: validate-pyproject name: Validate pyproject.toml @@ -47,7 +47,7 @@ repos: - id: black - id: black-jupyter - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.257 + rev: v0.0.259 hooks: - id: ruff args: [--fix-only] @@ -75,7 +75,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.257 + rev: v0.0.259 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/environment.yml b/environment.yml index 5ffd588da..41eb3c43d 100644 --- a/environment.yml +++ b/environment.yml @@ -99,6 +99,7 @@ dependencies: # - snakeviz # - sphinx-lint # - sympy + # - tuna # - twine # - vim # - yesqa diff --git a/graphblas/core/expr.py b/graphblas/core/expr.py index affe06112..48839bcff 100644 --- a/graphblas/core/expr.py +++ b/graphblas/core/expr.py @@ -478,33 +478,34 @@ def __bool__(self): class InfixExprBase: - __slots__ = "left", "right", "_value", "__weakref__" + __slots__ = "left", "right", "_expr", "__weakref__" _is_scalar = False def __init__(self, left, right): self.left = left self.right = right - self._value = None + self._expr = None def new(self, dtype=None, *, mask=None, name=None, **opts): if ( mask is None - and self._value is not None - and (dtype is None or self._value.dtype == dtype) + and self._expr is not None + and self._expr._value is not None + and (dtype is None or self._expr._value.dtype == dtype) ): - rv = self._value + rv = self._expr._value if name is not None: rv.name = name - self._value = None + self._expr._value = None return rv expr = self._to_expr() return expr.new(dtype, mask=mask, name=name, **opts) def _to_expr(self): - if self._value is None: + if self._expr is None: # Rely on the default operator for `x @ y` - self._value = getattr(self.left, self.method_name)(self.right) - return self._value + self._expr = getattr(self.left, self.method_name)(self.right) + return self._expr def _get_value(self, attr=None, default=None): expr = self._to_expr() @@ -536,10 +537,18 @@ def __repr__(self): @property def dtype(self): - if self._value is not None: - return self._value.dtype return self._to_expr().dtype + @property + def _value(self): + if self._expr is None: + return None + return self._expr._value + + @_value.setter + def _value(self, val): + self._to_expr()._value = val + # Mistakes utils._output_types[AmbiguousAssignOrExtract] = AmbiguousAssignOrExtract diff --git a/graphblas/core/formatting.py b/graphblas/core/formatting.py index 305df05ae..52b7ed4d0 100644 --- a/graphblas/core/formatting.py +++ b/graphblas/core/formatting.py @@ -1,3 +1,4 @@ +# This file imports pandas, so it should only be imported when formatting import numpy as np from .. import backend, config, monoid, unary diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 1fc7caa95..bd1d10a92 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -16,11 +16,11 @@ def _ewise_add_to_expr(self): - if self._value is not None: - return self._value + if self._expr is not None: + return self._expr if self.left.dtype == BOOL and self.right.dtype == BOOL: - self._value = self.left.ewise_add(self.right, lor) - return self._value + self._expr = self.left.ewise_add(self.right, lor) + return self._expr raise TypeError( "Bad dtypes for `x | y`! Automatic computation of `x | y` infix expressions is only valid " f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n" @@ -30,11 +30,11 @@ def _ewise_add_to_expr(self): def _ewise_mult_to_expr(self): - if self._value is not None: - return self._value + if self._expr is not None: + return self._expr if self.left.dtype == BOOL and self.right.dtype == BOOL: - self._value = self.left.ewise_mult(self.right, land) - return self._value + self._expr = self.left.ewise_mult(self.right, land) + return self._expr raise TypeError( "Bad dtypes for `x & y`! Automatic computation of `x & y` infix expressions is only valid " f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n" diff --git a/graphblas/core/recorder.py b/graphblas/core/recorder.py index 455166544..ce79c85ff 100644 --- a/graphblas/core/recorder.py +++ b/graphblas/core/recorder.py @@ -3,7 +3,6 @@ from ..dtypes import DataType from . import base, lib from .base import _recorder -from .formatting import CSS_STYLE from .mask import Mask from .matrix import TransposedMatrix from .operator import TypedOpBase @@ -103,6 +102,8 @@ def is_recording(self): return self._token is not None and _recorder.get(base._prev_recorder) is self def _repr_base_(self): + from .formatting import CSS_STYLE + status = ( '