diff --git a/graphblas/core/base.py b/graphblas/core/base.py index 42a4de9a1..5658e99c1 100644 --- a/graphblas/core/base.py +++ b/graphblas/core/base.py @@ -263,23 +263,31 @@ def __call__( ) def __or__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_infix_expr, _ewise_mult_expr_types + if isinstance(other, _ewise_mult_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_add", within="__or__") def __ror__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_infix_expr, _ewise_mult_expr_types + if isinstance(other, _ewise_mult_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") def __and__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_add_expr_types, _ewise_infix_expr + if isinstance(other, _ewise_add_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__") def __rand__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_add_expr_types, _ewise_infix_expr + if isinstance(other, _ewise_add_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") def __matmul__(self, other): diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 09b6a6811..51714633c 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -126,6 +126,19 @@ class ScalarEwiseAddExpr(ScalarInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` + __or__ = Scalar.__or__ + __ror__ = Scalar.__ror__ + _ewise_add = Scalar._ewise_add + _ewise_union = Scalar._ewise_union + + # Don't allow e.g. `plus(x | y & z)` + def __and__(self, other): + raise TypeError("XXX") + + def __rand__(self, other): + raise TypeError("XXX") + class ScalarEwiseMultExpr(ScalarInfixExpr): __slots__ = () @@ -135,6 +148,18 @@ class ScalarEwiseMultExpr(ScalarInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` + __and__ = Scalar.__and__ + __rand__ = Scalar.__rand__ + _ewise_mult = Scalar._ewise_mult + + # Don't allow e.g. `plus(x | y & z)` + def __or__(self, other): + raise TypeError("XXX") + + def __ror__(self, other): + raise TypeError("XXX") + class ScalarMatMulExpr(ScalarInfixExpr): __slots__ = () @@ -239,6 +264,15 @@ class VectorEwiseAddExpr(VectorInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` + __or__ = Vector.__or__ + __ror__ = Vector.__ror__ + _ewise_add = Vector._ewise_add + _ewise_union = Vector._ewise_union + # Don't allow e.g. `plus(x | y & z)` + __and__ = ScalarEwiseAddExpr.__and__ # raises + __rand__ = ScalarEwiseAddExpr.__rand__ # raises + class VectorEwiseMultExpr(VectorInfixExpr): __slots__ = () @@ -248,6 +282,14 @@ class VectorEwiseMultExpr(VectorInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` + __and__ = Vector.__and__ + __rand__ = Vector.__rand__ + _ewise_mult = Vector._ewise_mult + # Don't allow e.g. `plus(x | y & z)` + __or__ = ScalarEwiseMultExpr.__or__ # raises + __ror__ = ScalarEwiseMultExpr.__ror__ # raises + class VectorMatMulExpr(VectorInfixExpr): __slots__ = "method_name" @@ -259,6 +301,11 @@ def __init__(self, left, right, *, method_name, size): self.method_name = method_name self._size = size + __matmul__ = Vector.__matmul__ + __rmatmul__ = Vector.__rmatmul__ + _inner = Vector._inner + _vxm = Vector._vxm + utils._output_types[VectorEwiseAddExpr] = Vector utils._output_types[VectorEwiseMultExpr] = Vector @@ -376,6 +423,15 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` + __or__ = Matrix.__or__ + __ror__ = Matrix.__ror__ + _ewise_add = Matrix._ewise_add + _ewise_union = Matrix._ewise_union + # Don't allow e.g. `plus(x | y & z)` + __and__ = VectorEwiseAddExpr.__and__ # raises + __rand__ = VectorEwiseAddExpr.__rand__ # raises + class MatrixEwiseMultExpr(MatrixInfixExpr): __slots__ = () @@ -385,6 +441,14 @@ class MatrixEwiseMultExpr(MatrixInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` + __and__ = Matrix.__and__ + __rand__ = Matrix.__rand__ + _ewise_mult = Matrix._ewise_mult + # Don't allow e.g. `plus(x | y & z)` + __or__ = VectorEwiseMultExpr.__or__ # raises + __ror__ = VectorEwiseMultExpr.__ror__ # raises + class MatrixMatMulExpr(MatrixInfixExpr): __slots__ = () @@ -397,6 +461,11 @@ def __init__(self, left, right, *, nrows, ncols): self._nrows = nrows self._ncols = ncols + __matmul__ = Matrix.__matmul__ + __rmatmul__ = Matrix.__rmatmul__ + _mxm = Matrix._mxm + _mxv = Matrix._mxv + utils._output_types[MatrixEwiseAddExpr] = Matrix utils._output_types[MatrixEwiseMultExpr] = Matrix @@ -514,5 +583,8 @@ def _matmul_infix_expr(left, right, *, within): ) +_ewise_add_expr_types = (MatrixEwiseAddExpr, VectorEwiseAddExpr, ScalarEwiseAddExpr) +_ewise_mult_expr_types = (MatrixEwiseMultExpr, VectorEwiseMultExpr, ScalarEwiseMultExpr) + # Import infixmethods, which has side effects from . import infixmethods # noqa: E402, F401 isort:skip diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 5e1a76720..34789d68d 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -10,9 +10,16 @@ from . import _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, _check_mask, call from .descriptor import lookup as descriptor_lookup -from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater +from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, InfixExprBase, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _get_typed_op_from_exprs, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -1938,17 +1945,39 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax C << monoid.max(A | B) """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): method_name = "ewise_add" - other = self._expect_type( - other, - (Matrix, TransposedMatrix, Vector), - within=method_name, - argname="other", - op=op, - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, MatrixEwiseAddExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -2006,13 +2035,39 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax C << binary.gt(A & B) """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): method_name = "ewise_mult" - other = self._expect_type( - other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseMultExpr, VectorEwiseMultExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, MatrixEwiseMultExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -2074,11 +2129,30 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax C << binary.div(A | B, left_default=1, right_default=1) """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): method_name = "ewise_union" - other = self._expect_type( - other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op - ) - temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") left_dtype = temp_op.type dtype = left_dtype if left_dtype._is_udt else None @@ -2117,8 +2191,12 @@ def ewise_union(self, other, op, left_default, right_default): else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") - op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if is_infix: + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + else: + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") if op1 is not op2: left_dtype = unify(op1.type, op2.type, is_right_scalar=True) right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) @@ -2129,6 +2207,12 @@ def ewise_union(self, other, op, left_default, right_default): if op.opclass == "Monoid": op = op.binaryop + if is_infix: + if isinstance(self, MatrixEwiseAddExpr): + self = op(self, left_default=left, right_default=right).new() + if isinstance(other, InfixExprBase): + other = op(other, left_default=left, right_default=right).new() + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 1: # Broadcast rowwise from the right @@ -2198,10 +2282,27 @@ def mxv(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ v) """ + return self._mxv(other, op) + + def _mxv(self, other, op=semiring.plus_times, is_infix=False): method_name = "mxv" - other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import MatrixMatMulExpr, VectorMatMulExpr + + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, MatrixMatMulExpr): + self = op(self).new() + if isinstance(other, VectorMatMulExpr): + other = op(other).new() + else: + other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = VectorExpression( method_name, "GrB_mxv", @@ -2241,12 +2342,33 @@ def mxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ B) """ + return self._mxm(other, op) + + def _mxm(self, other, op=semiring.plus_times, is_infix=False): method_name = "mxm" - other = self._expect_type( - other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import MatrixMatMulExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, MatrixMatMulExpr): + self = op(self).new() + if isinstance(other, MatrixMatMulExpr): + other = op(other).new() + else: + other = self._expect_type( + other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = MatrixExpression( method_name, "GrB_mxm", @@ -3862,6 +3984,12 @@ def to_dicts(self, order="rowwise"): reposition = Matrix.reposition power = Matrix.power + _ewise_add = Matrix._ewise_add + _ewise_mult = Matrix._ewise_mult + _ewise_union = Matrix._ewise_union + _mxv = Matrix._mxv + _mxm = Matrix._mxm + # Operator sugar __or__ = Matrix.__or__ __ror__ = Matrix.__ror__ diff --git a/graphblas/core/operator/__init__.py b/graphblas/core/operator/__init__.py index 509e84a04..d59c835b3 100644 --- a/graphblas/core/operator/__init__.py +++ b/graphblas/core/operator/__init__.py @@ -6,6 +6,7 @@ from .semiring import ParameterizedSemiring, Semiring from .unary import ParameterizedUnaryOp, UnaryOp from .utils import ( + _get_typed_op_from_exprs, aggregator_from_string, binary_from_string, get_semiring, diff --git a/graphblas/core/operator/base.py b/graphblas/core/operator/base.py index d66aa2f4a..59482b47d 100644 --- a/graphblas/core/operator/base.py +++ b/graphblas/core/operator/base.py @@ -111,7 +111,9 @@ def _call_op(op, left, right=None, thunk=None, **kwargs): if right is None and thunk is None: if isinstance(left, InfixExprBase): # op(A & B), op(A | B), op(A @ B) - return getattr(left.left, left.method_name)(left.right, op, **kwargs) + return getattr(left.left, f"_{left.method_name}")( + left.right, op, is_infix=True, **kwargs + ) if find_opclass(op)[1] == "Semiring": raise TypeError( f"Bad type when calling {op!r}. Got type: {type(left)}.\n" diff --git a/graphblas/core/operator/binary.py b/graphblas/core/operator/binary.py index 676ed0970..278ee3183 100644 --- a/graphblas/core/operator/binary.py +++ b/graphblas/core/operator/binary.py @@ -94,7 +94,9 @@ def __call__(self, left, right=None, *, left_default=None, right_default=None): f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " "are Vectors or Matrices, and left_default and right_default are scalars." ) - return left.left.ewise_union(left.right, self, left_default, right_default) + return left.left._ewise_union( + left.right, self, left_default, right_default, is_infix=True + ) return _call_op(self, left, right) @property diff --git a/graphblas/core/operator/utils.py b/graphblas/core/operator/utils.py index 00df31db8..cd0b82d3c 100644 --- a/graphblas/core/operator/utils.py +++ b/graphblas/core/operator/utils.py @@ -2,6 +2,7 @@ from ... import backend, binary, config, indexunary, monoid, op, select, semiring, unary from ...dtypes import UINT64, lookup_dtype, unify +from ..expr import InfixExprBase from .base import ( _SS_OPERATORS, OpBase, @@ -132,6 +133,30 @@ def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scala raise TypeError(f"Unable to get typed operator from object with type {type(op)}") +def _get_typed_op_from_exprs(op, left, right, *, kind=None): + if isinstance(left, InfixExprBase): + left_op = _get_typed_op_from_exprs(op, left.left, left.right, kind=kind) + left_dtype = left_op.type + else: + left_op = None + left_dtype = left.dtype + if isinstance(right, InfixExprBase): + right_op = _get_typed_op_from_exprs(op, right.left, right.right, kind=kind) + if right_op is left_op: + return right_op + right_dtype = right_op.type2 + else: + right_dtype = right.dtype + return get_typed_op( + op, + left_dtype, + right_dtype, + is_left_scalar=left._is_scalar, + is_right_scalar=right._is_scalar, + kind=kind, + ) + + def get_semiring(monoid, binaryop, name=None): """Get or create a Semiring object from a monoid and binaryop. diff --git a/graphblas/core/scalar.py b/graphblas/core/scalar.py index b822bd58a..9cdf3043e 100644 --- a/graphblas/core/scalar.py +++ b/graphblas/core/scalar.py @@ -629,7 +629,23 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax c << monoid.max(a | b) """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): method_name = "ewise_add" + if is_infix: + from .infix import ScalarEwiseAddExpr + + # This is a little different than how we handle ewise_add for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseAddExpr): + self = op(self).new() + if isinstance(other, ScalarEwiseAddExpr): + other = op(other).new() + if type(other) is not Scalar: dtype = self.dtype if self.dtype._is_udt else None try: @@ -683,7 +699,23 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax c << binary.gt(a & b) """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): method_name = "ewise_mult" + if is_infix: + from .infix import ScalarEwiseMultExpr + + # This is a little different than how we handle ewise_mult for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseMultExpr): + self = op(self).new() + if isinstance(other, ScalarEwiseMultExpr): + other = op(other).new() + if type(other) is not Scalar: dtype = self.dtype if self.dtype._is_udt else None try: @@ -741,7 +773,23 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax c << binary.div(a | b, left_default=1, right_default=1) """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): method_name = "ewise_union" + if is_infix: + from .infix import ScalarEwiseAddExpr + + # This is a little different than how we handle ewise_union for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseAddExpr): + self = op(self, left_default=left_default, right_default=right_default).new() + if isinstance(other, ScalarEwiseAddExpr): + other = op(other, left_default=left_default, right_default=right_default).new() + right_dtype = self.dtype dtype = right_dtype if right_dtype._is_udt else None if type(other) is not Scalar: diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index 9d19d80da..feb95ed02 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -9,9 +9,16 @@ from . import _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, _check_mask, call from .descriptor import lookup as descriptor_lookup -from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater +from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, InfixExprBase, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _get_typed_op_from_exprs, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -1038,15 +1045,41 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax w << monoid.max(u | v) """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_add" - other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, VectorEwiseAddExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1103,15 +1136,40 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax w << binary.gt(u & v) """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_mult" - other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseMultExpr, VectorEwiseMultExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, VectorEwiseMultExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1171,13 +1229,32 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax w << binary.div(u | v, left_default=1, right_default=1) """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_union" - other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") left_dtype = temp_op.type dtype = left_dtype if left_dtype._is_udt else None @@ -1216,8 +1293,12 @@ def ewise_union(self, other, op, left_default, right_default): else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") - op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if is_infix: + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + else: + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") if op1 is not op2: left_dtype = unify(op1.type, op2.type, is_right_scalar=True) right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) @@ -1228,6 +1309,12 @@ def ewise_union(self, other, op, left_default, right_default): if op.opclass == "Monoid": op = op.binaryop + if is_infix: + if isinstance(self, VectorEwiseAddExpr): + self = op(self, left_default=left, right_default=right).new() + if isinstance(other, InfixExprBase): + other = op(other, left_default=left, right_default=right).new() + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 2: # Broadcast columnwise from the left @@ -1296,14 +1383,35 @@ def vxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(v @ A) """ + return self._vxm(other, op) + + def _vxm(self, other, op=semiring.plus_times, is_infix=False): from .matrix import Matrix, TransposedMatrix method_name = "vxm" - other = self._expect_type( - other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import MatrixMatMulExpr, VectorMatMulExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, VectorMatMulExpr): + self = op(self).new() + if isinstance(other, MatrixMatMulExpr): + other = op(other).new() + else: + other = self._expect_type( + other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = VectorExpression( method_name, "GrB_vxm", @@ -1645,10 +1753,27 @@ def inner(self, other, op=semiring.plus_times): `Matrix Multiplication <../user_guide/operations.html#matrix-multiply>`__ family of functions. """ + return self._inner(other, op) + + def _inner(self, other, op=semiring.plus_times, is_infix=False): method_name = "inner" - other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import VectorMatMulExpr + + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, VectorMatMulExpr): + self = op(self).new() + if isinstance(other, VectorMatMulExpr): + other = op(other).new() + else: + other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = ScalarExpression( method_name, "GrB_vxm", diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index 72e1c8a42..e688086b9 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -1,6 +1,6 @@ import pytest -from graphblas import monoid, op +from graphblas import binary, monoid, op from graphblas.exceptions import DimensionMismatch from .conftest import autocompute @@ -367,3 +367,415 @@ def test_infix_expr_value_types(): expr._value = None assert expr._value is None assert expr._expr._value is None + + +def test_multi_infix_vector(): + D0 = Vector.from_scalar(0, 3).diag() + v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . + v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 + v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = Vector.from_scalar(3, size=3) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = monoid.min(v1 | v2 | v3).new() + expected = Vector.from_scalar(1, size=3) + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = Vector(int, size=3) + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = Vector.from_coo([1], [1], size=3) + assert result.isequal(expected) + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = Vector.from_scalar(13, size=3) + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + expected = Vector.from_scalar(13.0, size=3) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + # inner + assert op.plus_plus(v1 @ v1).new().value == 6 + assert op.plus_plus(v1 @ (v1 @ D0)).new().value == 6 + assert op.plus_plus((D0 @ v1) @ v1).new().value == 6 + # matrix-vector ewise_add + result = binary.plus((D0 | v1) | v2).new() + expected = binary.plus(binary.plus(D0 | v1).new() | v2).new() + assert result.isequal(expected) + result = binary.plus(D0 | (v1 | v2)).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | D0).new() + assert result.isequal(expected.T) + result = binary.plus(v1 | (v2 | D0)).new() + assert result.isequal(expected.T) + # matrix-vector ewise_mult + result = binary.plus((D0 & v1) & v2).new() + expected = binary.plus(binary.plus(D0 & v1).new() & v2).new() + assert result.isequal(expected) + assert result.nvals > 0 + result = binary.plus(D0 & (v1 & v2)).new() + assert result.isequal(expected) + result = binary.plus((v1 & v2) & D0).new() + assert result.isequal(expected.T) + result = binary.plus(v1 & (v2 & D0)).new() + assert result.isequal(expected.T) + # matrix-vector ewise_union + kwargs = {"left_default": 10, "right_default": 20} + result = binary.plus((D0 | v1) | v2, **kwargs).new() + expected = binary.plus(binary.plus(D0 | v1, **kwargs).new() | v2, **kwargs).new() + assert result.isequal(expected) + result = binary.plus(D0 | (v1 | v2), **kwargs).new() + expected = binary.plus(D0 | binary.plus(v1 | v2, **kwargs).new(), **kwargs).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | D0, **kwargs).new() + expected = binary.plus(binary.plus(v1 | v2, **kwargs).new() | D0, **kwargs).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | D0), **kwargs).new() + expected = binary.plus(v1 | binary.plus(v2 | D0, **kwargs).new(), **kwargs).new() + assert result.isequal(expected) + # vxm, mxv + result = op.plus_plus((D0 @ v1) @ D0).new() + assert result.isequal(v1) + result = op.plus_plus(D0 @ (v1 @ D0)).new() + assert result.isequal(v1) + result = op.plus_plus(v1 @ (D0 @ D0)).new() + assert result.isequal(v1) + result = op.plus_plus((D0 @ D0) @ v1).new() + assert result.isequal(v1) + result = op.plus_plus((v1 @ D0) @ D0).new() + assert result.isequal(v1) + result = op.plus_plus(D0 @ (D0 @ v1)).new() + assert result.isequal(v1) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_vector_auto(): + v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . + v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 + v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) + + +def test_multi_infix_matrix(): + # Adapted from test_multi_infix_vector + D0 = Vector.from_scalar(0, 3).diag() + v1 = Matrix.from_coo([0, 1], [0, 0], [1, 2], nrows=3) # 1 2 . + v2 = Matrix.from_coo([1, 2], [0, 0], [1, 2], nrows=3) # . 1 2 + v3 = Matrix.from_coo([2, 0], [0, 0], [1, 2], nrows=3) # 2 . 1 + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = Matrix.from_scalar(3, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = monoid.min(v1 | v2 | v3).new() + expected = Matrix.from_scalar(1, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | v1 | v1 | v1 | v1).new() + expected = (5 * v1).new() + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = Matrix(int, 3, 1) + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = Matrix.from_coo([1], [0], [1], nrows=3) + assert result.isequal(expected) + result = binary.plus(v1 & v1 & v1 & v1 & v1).new() + expected = (5 * v1).new() + assert result.isequal(expected) + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = Matrix.from_scalar(13, 3, 1) + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + expected = Matrix.from_scalar(13.0, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + # mxm + assert op.plus_plus(v1.T @ v1).new()[0, 0].new().value == 6 + assert op.plus_plus(v1 @ (v1.T @ D0)).new()[0, 0].new().value == 2 + assert op.plus_plus((v1.T @ D0) @ v1).new()[0, 0].new().value == 6 + assert op.plus_plus(D0 @ D0 @ D0 @ D0 @ D0).new().isequal(D0) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_matrix_auto(): + v1 = Matrix.from_coo([0, 1], [0, 0], [1, 2], nrows=3) # 1 2 . + v2 = Matrix.from_coo([1, 2], [0, 0], [1, 2], nrows=3) # . 1 2 + v3 = Matrix.from_coo([2, 0], [0, 0], [1, 2], nrows=3) # 2 . 1 + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) + + +def test_multi_infix_scalar(): + # Adapted from test_multi_infix_vector + v1 = Scalar.from_value(1) + v2 = Scalar.from_value(2) + v3 = Scalar(int) + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = 3 + assert result.isequal(expected) + result = binary.plus((1 | v2) | v3).new() + assert result.isequal(expected) + result = binary.plus((1 | v2) | 0).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | v3).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | 0).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | 0).new() + assert result.isequal(expected) + + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | 0)).new() + assert result.isequal(expected) + result = binary.plus(v1 | (2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | 0)).new() + assert result.isequal(expected) + + result = monoid.min(v1 | v2 | v3).new() + expected = 1 + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = None + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = 1 + assert result.isequal(expected) + + result = monoid.min((1 & v2) & v1).new() + assert result.isequal(expected) + result = monoid.min((1 & v2) & 1).new() + assert result.isequal(expected) + result = monoid.min((v1 & 2) & v1).new() + assert result.isequal(expected) + result = monoid.min((v1 & 2) & 1).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & 1).new() + assert result.isequal(expected) + + result = monoid.min(1 & (v2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(1 & (2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(1 & (v2 & 1)).new() + assert result.isequal(expected) + result = monoid.min(v1 & (2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(v1 & (v2 & 1)).new() + assert result.isequal(expected) + + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = 13 + assert result.isequal(expected) + result = binary.plus((1 | v2) | v3, left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | v3, left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(1 | (2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(v1 | (2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_scalar_auto(): + v1 = Scalar.from_value(1) + v2 = Scalar.from_value(2) + v3 = Scalar(int) + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 3f66e46ef..c716c97a9 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -2805,6 +2805,8 @@ def test_ss_nbytes(A): @autocompute def test_auto(A, v): + from graphblas.core.infix import MatrixEwiseMultExpr + expected = binary.land[bool](A & A).new() B = A.dup(dtype=bool) for expr in [(B & B), binary.land[bool](A & A)]: @@ -2832,12 +2834,21 @@ def test_auto(A, v): ]: # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() - val2 = getattr(expected, method)(expr) - val3 = getattr(expr, method)(expected) - val4 = getattr(expr, method)(expr) - assert val1.isequal(val2) - assert val1.isequal(val3) - assert val1.isequal(val4) + if method in {"__or__", "__ror__"} and type(expr) is MatrixEwiseMultExpr: + # Doing e.g. `plus(A & B | C)` isn't allowed--make user be explicit + with pytest.raises(TypeError): + val2 = getattr(expected, method)(expr) + with pytest.raises(TypeError): + val3 = getattr(expr, method)(expected) + with pytest.raises(TypeError): + val4 = getattr(expr, method)(expr) + else: + val2 = getattr(expected, method)(expr) + assert val1.isequal(val2) + val3 = getattr(expr, method)(expected) + assert val1.isequal(val3) + val4 = getattr(expr, method)(expr) + assert val1.isequal(val4) for method in ["reduce_rowwise", "reduce_columnwise", "reduce_scalar"]: s1 = getattr(expected, method)(monoid.lor).new() s2 = getattr(expr, method)(monoid.lor) @@ -2946,7 +2957,7 @@ def test_expr_is_like_matrix(A): "setdiag", "update", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_mxm", "_mxv"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " @@ -3011,7 +3022,7 @@ def test_index_expr_is_like_matrix(A): "resize", "setdiag", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_mxm", "_mxv"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index ba9903169..aeb19e170 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -360,7 +360,7 @@ def test_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " @@ -402,7 +402,7 @@ def test_index_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index b66bc96c9..1c9a8d38c 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -1532,6 +1532,8 @@ def test_outer(v): @autocompute def test_auto(v): + from graphblas.core.infix import VectorEwiseMultExpr + v = v.dup(dtype=bool) expected = binary.land(v & v).new() assert 0 not in expected @@ -1581,15 +1583,24 @@ def test_auto(v): ]: # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() - val2 = getattr(expected, method)(expr) - val3 = getattr(expr, method)(expected) - val4 = getattr(expr, method)(expr) - assert val1.isequal(val2) - assert val1.isequal(val3) - assert val1.isequal(val4) - assert val1.isequal(val2.new()) - assert val1.isequal(val3.new()) - assert val1.isequal(val4.new()) + if method in {"__or__", "__ror__"} and type(expr) is VectorEwiseMultExpr: + # Doing e.g. `plus(x & y | z)` isn't allowed--make user be explicit + with pytest.raises(TypeError): + val2 = getattr(expected, method)(expr) + with pytest.raises(TypeError): + val3 = getattr(expr, method)(expected) + with pytest.raises(TypeError): + val4 = getattr(expr, method)(expr) + else: + val2 = getattr(expected, method)(expr) + assert val1.isequal(val2) + assert val1.isequal(val2.new()) + val3 = getattr(expr, method)(expected) + assert val1.isequal(val3) + assert val1.isequal(val3.new()) + val4 = getattr(expr, method)(expr) + assert val1.isequal(val4) + assert val1.isequal(val4.new()) s1 = expected.reduce(monoid.lor).new() s2 = expr.reduce(monoid.lor) assert s1.isequal(s2.new()) @@ -1653,7 +1664,7 @@ def test_expr_is_like_vector(v): "resize", "update", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_inner", "_vxm"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` " @@ -1702,7 +1713,7 @@ def test_index_expr_is_like_vector(v): "from_values", "resize", } - ignore = {"__sizeof__"} + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_inner", "_vxm"} assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` "