-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
Description
Describe the issue:
The type annotations used for scalar/int binary operations like np.float32(1) * 2
imply that the scalar types are not closed under e.g. multiplication with int
:
reveal_type(np.int8(1)) # signedinteger[_8Bit]
reveal_type(np.int8(1) * np.int8(1)) # signedinteger[_8Bit]
reveal_type(np.int8(1) * 1) # signedinteger[_8Bit] | signedinteger[_32Bit | _64Bit]
As far as I can tell mixed operations with int
don't actually promote the type:
>>> np.int8(1) * 128
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
OverflowError: Python integer 128 out of bounds for int8
It comes from here:
Lines 3768 to 3776 in a7eda47
class signedinteger(integer[_NBit1]): | |
def __init__(self, value: _ConvertibleToInt = ..., /) -> None: ... | |
__add__: _SignedIntOp[_NBit1] | |
__radd__: _SignedIntOp[_NBit1] | |
__sub__: _SignedIntOp[_NBit1] | |
__rsub__: _SignedIntOp[_NBit1] | |
__mul__: _SignedIntOp[_NBit1] | |
__rmul__: _SignedIntOp[_NBit1] |
And that uses:
numpy/numpy/_typing/_callable.pyi
Lines 207 to 222 in a7eda47
@type_check_only | |
class _SignedIntOp(Protocol[_NBit1]): | |
@overload | |
def __call__(self, other: bool, /) -> signedinteger[_NBit1]: ... | |
@overload | |
def __call__(self, other: int, /) -> signedinteger[_NBit1] | int_: ... | |
@overload | |
def __call__(self, other: float, /) -> floating[_NBit1] | float64: ... | |
@overload | |
def __call__( | |
self, other: complex, / | |
) -> complexfloating[_NBit1, _NBit1] | complex128: ... | |
@overload | |
def __call__( | |
self, other: signedinteger[_NBit2], / | |
) -> signedinteger[_NBit1] | signedinteger[_NBit2]: ... |
I think that the problematic overload is:
def __call__(self, other: int, /) -> signedinteger[_NBit1] | int_: ...
Is there a reason that | int_
is needed there?
Reproduce the code example:
from __future__ import annotations
import numpy as np
from typing import Protocol, Self, reveal_type
class MultiplyWithInt(Protocol):
def __mul__(self, other: int, /) -> Self:
...
a: MultiplyWithInt = 1
b: MultiplyWithInt = 1.0
c: MultiplyWithInt = 1j
d: MultiplyWithInt = np.uint8(1)
e: MultiplyWithInt = np.uint16(1)
f: MultiplyWithInt = np.uint32(1)
g: MultiplyWithInt = np.uint64(1)
h: MultiplyWithInt = np.int8(1) # type check error
i: MultiplyWithInt = np.int16(1) # type check error
j: MultiplyWithInt = np.int32(1) # type check error
k: MultiplyWithInt = np.int64(1)
l: MultiplyWithInt = np.float32(1.0) # type check error
m: MultiplyWithInt = np.float64(1.0)
n: MultiplyWithInt = np.complex64(1) # type check error
o: MultiplyWithInt = np.complex128(1)
reveal_type(np.uint8(1)) # unsignedinteger[_8Bit]
reveal_type(np.uint8(1) * 1) # Any
reveal_type(np.uint8(1) * np.uint8(1)) # unsignedinteger[_8Bit]
reveal_type(np.int8(1)) # signedinteger[_8Bit]
reveal_type(np.int8(1) * 1) # signedinteger[_8Bit] | signedinteger[_32Bit | _64Bit]
reveal_type(np.int8(1) * np.int8(1)) # signedinteger[_8Bit]
Error message:
No response
Python and NumPy Versions:
Python 3.12
NumPy 2.2.1
Runtime Environment:
No response
Context for the issue:
I'm trying to write generically typed code with rings like:
from typing import Protocol, Self, Literal
type _PositiveInteger = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
class RingElement(Protocol):
"""Elements supporting ring operations."""
def __pos__(self) -> Self: ...
def __neg__(self) -> Self: ...
def __add__(self, other: Self, /) -> Self: ...
def __mul__(self, other: Self | int, /) -> Self: ...
def __rmul__(self, other: int, /) -> Self: ...
def __pow__(self, other: _PositiveInteger, /) -> Self: ...
The allowance for multiplication with int
is so that with this protocol you can have code like 2*x + y*2
. Both mypy and pyright think that some of numpy's scalar types are incompatible with this protocol because they are not closed under multiplication with int
.