Skip to content

TYP: Backport typing fixes from main (3) #28534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 22 additions & 46 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ from numpy.lib._histograms_impl import (
)

from numpy.lib._index_tricks_impl import (
ndenumerate,
ndindex,
ravel_multi_index,
unravel_index,
mgrid,
Expand Down Expand Up @@ -1196,8 +1198,21 @@ __future_scalars__: Final[set[L["bytes", "str", "object"]]] = ...
__array_api_version__: Final[L["2023.12"]] = "2023.12"
test: Final[PytestTester] = ...

@type_check_only
class _DTypeMeta(type):
@property
def type(cls, /) -> type[generic] | None: ...
@property
def _abstract(cls, /) -> bool: ...
@property
def _is_numeric(cls, /) -> bool: ...
@property
def _parametric(cls, /) -> bool: ...
@property
def _legacy(cls, /) -> bool: ...

@final
class dtype(Generic[_SCT_co]):
class dtype(Generic[_SCT_co], metaclass=_DTypeMeta):
names: None | tuple[builtins.str, ...]
def __hash__(self) -> int: ...

Expand Down Expand Up @@ -3977,7 +3992,7 @@ bool_ = bool
# NOTE: Because mypy has some long-standing bugs related to `__new__`, `object_` can't
# be made generic.
@final
class object_(_RealMixin, generic):
class object_(_RealMixin, generic[Any]):
@overload
def __new__(cls, nothing_to_see_here: None = None, /) -> None: ... # type: ignore[misc]
@overload
Expand All @@ -3991,6 +4006,8 @@ class object_(_RealMixin, generic):
@overload # catch-all
def __new__(cls, value: Any = ..., /) -> object | NDArray[Self]: ... # type: ignore[misc]
def __init__(self, value: object = ..., /) -> None: ...
def __hash__(self, /) -> int: ...
def __call__(self, /, *args: object, **kwargs: object) -> Any: ...

if sys.version_info >= (3, 12):
def __release_buffer__(self, buffer: memoryview, /) -> None: ...
Expand Down Expand Up @@ -4453,6 +4470,9 @@ class timedelta64(_IntegralMixin, generic[_TD64ItemT_co], Generic[_TD64ItemT_co]
@overload
def __init__(self, value: _ConvertibleToTD64, format: _TimeUnitSpec = ..., /) -> None: ...

# inherited at runtime from `signedinteger`
def __class_getitem__(cls, type_arg: type | object, /) -> GenericAlias: ...

# NOTE: Only a limited number of units support conversion
# to builtin scalar types: `Y`, `M`, `ns`, `ps`, `fs`, `as`
def __int__(self: timedelta64[int], /) -> int: ...
Expand Down Expand Up @@ -4946,50 +4966,6 @@ class errstate:
) -> None: ...
def __call__(self, func: _CallableT) -> _CallableT: ...

class ndenumerate(Generic[_SCT_co]):
@property
def iter(self) -> flatiter[NDArray[_SCT_co]]: ...

@overload
def __new__(
cls, arr: _FiniteNestedSequence[_SupportsArray[dtype[_SCT]]],
) -> ndenumerate[_SCT]: ...
@overload
def __new__(cls, arr: str | _NestedSequence[str]) -> ndenumerate[str_]: ...
@overload
def __new__(cls, arr: bytes | _NestedSequence[bytes]) -> ndenumerate[bytes_]: ...
@overload
def __new__(cls, arr: builtins.bool | _NestedSequence[builtins.bool]) -> ndenumerate[np.bool]: ...
@overload
def __new__(cls, arr: int | _NestedSequence[int]) -> ndenumerate[int_]: ...
@overload
def __new__(cls, arr: float | _NestedSequence[float]) -> ndenumerate[float64]: ...
@overload
def __new__(cls, arr: complex | _NestedSequence[complex]) -> ndenumerate[complex128]: ...
@overload
def __new__(cls, arr: object) -> ndenumerate[object_]: ...

# The first overload is a (semi-)workaround for a mypy bug (tested with v1.10 and v1.11)
@overload
def __next__(
self: ndenumerate[np.bool | datetime64 | timedelta64 | number[Any] | flexible],
/,
) -> tuple[_Shape, _SCT_co]: ...
@overload
def __next__(self: ndenumerate[object_], /) -> tuple[_Shape, Any]: ...
@overload
def __next__(self, /) -> tuple[_Shape, _SCT_co]: ...

def __iter__(self) -> Self: ...

class ndindex:
@overload
def __init__(self, shape: tuple[SupportsIndex, ...], /) -> None: ...
@overload
def __init__(self, *shape: SupportsIndex) -> None: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _Shape: ...

# TODO: The type of each `__next__` and `iters` return-type depends
# on the length and dtype of `args`; we can't describe this behavior yet
# as we lack variadics (PEP 646).
Expand Down
72 changes: 40 additions & 32 deletions numpy/_core/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ from collections.abc import Sequence
from typing import (
Any,
Literal,
NoReturn,
Protocol,
SupportsIndex,
TypeAlias,
TypeVar,
overload,
type_check_only,
)

from _typeshed import Incomplete
from typing_extensions import Never, deprecated

import numpy as np
Expand Down Expand Up @@ -551,9 +552,6 @@ def ravel(
@overload
def ravel(a: ArrayLike, order: _OrderKACF = "C") -> np.ndarray[tuple[int], np.dtype[Any]]: ...

@overload
def nonzero(a: np.generic | np.ndarray[tuple[()], Any]) -> NoReturn: ...
@overload
def nonzero(a: _ArrayLike[Any]) -> tuple[NDArray[intp], ...]: ...

# this prevents `Any` from being returned with Pyright
Expand Down Expand Up @@ -813,7 +811,7 @@ def all(
keepdims: _BoolLike_co | _NoValueType = ...,
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> np.bool | NDArray[np.bool]: ...
) -> Incomplete: ...
@overload
def all(
a: ArrayLike,
Expand Down Expand Up @@ -850,7 +848,7 @@ def any(
keepdims: _BoolLike_co | _NoValueType = ...,
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> np.bool | NDArray[np.bool]: ...
) -> Incomplete: ...
@overload
def any(
a: ArrayLike,
Expand Down Expand Up @@ -1443,10 +1441,10 @@ def mean(
keepdims: Literal[False] | _NoValueType = ...,
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> complexfloating[Any, Any]: ...
) -> complexfloating[Any]: ...
@overload
def mean(
a: _ArrayLikeTD64_co,
a: _ArrayLike[np.timedelta64],
axis: None = ...,
dtype: None = ...,
out: None = ...,
Expand All @@ -1457,23 +1455,33 @@ def mean(
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: _ShapeLike | None = ...,
dtype: None = ...,
out: None = ...,
axis: _ShapeLike | None,
dtype: DTypeLike,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> Any: ...
) -> _ArrayT: ...
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: _ShapeLike | None = ...,
dtype: DTypeLike | None = ...,
*,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> _ArrayT: ...
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: None,
dtype: _DTypeLike[_SCT],
out: None = ...,
keepdims: bool | _NoValueType = ...,
keepdims: Literal[False] | _NoValueType = ...,
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> _SCT | NDArray[_SCT]: ...
) -> _SCT: ...
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
Expand All @@ -1487,43 +1495,43 @@ def mean(
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: None = ...,
*,
axis: _ShapeLike | None,
dtype: _DTypeLike[_SCT],
out: None = ...,
keepdims: bool | _NoValueType = ...,
out: None,
keepdims: Literal[True, 1],
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> _SCT | NDArray[_SCT]: ...
) -> NDArray[_SCT]: ...
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: _ShapeLike | None = ...,
dtype: DTypeLike = ...,
axis: _ShapeLike | None,
dtype: _DTypeLike[_SCT],
out: None = ...,
keepdims: bool | _NoValueType = ...,
*,
keepdims: bool | _NoValueType = ...,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> Any: ...
) -> _SCT | NDArray[_SCT]: ...
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: _ShapeLike | None,
dtype: DTypeLike,
out: _ArrayT,
keepdims: bool | _NoValueType = ...,
axis: _ShapeLike | None = ...,
*,
dtype: _DTypeLike[_SCT],
out: None = ...,
keepdims: bool | _NoValueType = ...,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> _ArrayT: ...
) -> _SCT | NDArray[_SCT]: ...
@overload
def mean(
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
axis: _ShapeLike | None = ...,
dtype: DTypeLike = ...,
*,
out: _ArrayT,
dtype: DTypeLike | None = ...,
out: None = ...,
keepdims: bool | _NoValueType = ...,
*,
where: _ArrayLikeBool_co | _NoValueType = ...,
) -> _ArrayT: ...
) -> Incomplete: ...

@overload
def std(
Expand Down
45 changes: 23 additions & 22 deletions numpy/dtypes.pyi
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from typing import (
Any,
Final,
Generic,
Literal as L,
NoReturn,
TypeAlias,
final,
type_check_only,
)
# ruff: noqa: ANN401
from types import MemberDescriptorType
from typing import Any, ClassVar, Generic, NoReturn, TypeAlias, final, type_check_only
from typing import Literal as L

from typing_extensions import LiteralString, Self, TypeVar

import numpy as np

__all__ = [
__all__ = [ # noqa: RUF022
'BoolDType',
'Int8DType',
'ByteDType',
Expand Down Expand Up @@ -53,7 +48,7 @@ __all__ = [
_SCT_co = TypeVar("_SCT_co", bound=np.generic, covariant=True)

@type_check_only
class _SimpleDType(Generic[_SCT_co], np.dtype[_SCT_co]): # type: ignore[misc]
class _SimpleDType(np.dtype[_SCT_co], Generic[_SCT_co]): # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues]
names: None # pyright: ignore[reportIncompatibleVariableOverride]
def __new__(cls, /) -> Self: ...
def __getitem__(self, key: Any, /) -> NoReturn: ...
Expand All @@ -73,7 +68,7 @@ class _SimpleDType(Generic[_SCT_co], np.dtype[_SCT_co]): # type: ignore[misc]
def subdtype(self) -> None: ...

@type_check_only
class _LiteralDType(Generic[_SCT_co], _SimpleDType[_SCT_co]): # type: ignore[misc]
class _LiteralDType(_SimpleDType[_SCT_co], Generic[_SCT_co]): # type: ignore[misc]
@property
def flags(self) -> L[0]: ...
@property
Expand Down Expand Up @@ -234,10 +229,11 @@ class UInt64DType( # type: ignore[misc]
def str(self) -> L["<u8", ">u8"]: ...

# Standard C-named version/alias:
ByteDType: Final = Int8DType
UByteDType: Final = UInt8DType
ShortDType: Final = Int16DType
UShortDType: Final = UInt16DType
# NOTE: Don't make these `Final`: it will break stubtest
ByteDType = Int8DType
UByteDType = UInt8DType
ShortDType = Int16DType
UShortDType = UInt16DType

@final
class IntDType( # type: ignore[misc]
Expand Down Expand Up @@ -419,11 +415,11 @@ class ObjectDType( # type: ignore[misc]

@final
class BytesDType( # type: ignore[misc]
Generic[_ItemSize_co],
_TypeCodes[L["S"], L["S"], L[18]],
_NoOrder,
_NBit[L[1],_ItemSize_co],
_SimpleDType[np.bytes_],
Generic[_ItemSize_co],
):
def __new__(cls, size: _ItemSize_co, /) -> BytesDType[_ItemSize_co]: ...
@property
Expand All @@ -435,11 +431,11 @@ class BytesDType( # type: ignore[misc]

@final
class StrDType( # type: ignore[misc]
Generic[_ItemSize_co],
_TypeCodes[L["U"], L["U"], L[19]],
_NativeOrder,
_NBit[L[4],_ItemSize_co],
_SimpleDType[np.str_],
Generic[_ItemSize_co],
):
def __new__(cls, size: _ItemSize_co, /) -> StrDType[_ItemSize_co]: ...
@property
Expand All @@ -451,11 +447,11 @@ class StrDType( # type: ignore[misc]

@final
class VoidDType( # type: ignore[misc]
Generic[_ItemSize_co],
_TypeCodes[L["V"], L["V"], L[20]],
_NoOrder,
_NBit[L[1], _ItemSize_co],
np.dtype[np.void],
np.dtype[np.void], # pyright: ignore[reportGeneralTypeIssues]
Generic[_ItemSize_co],
):
# NOTE: `VoidDType(...)` raises a `TypeError` at the moment
def __new__(cls, length: _ItemSize_co, /) -> NoReturn: ...
Expand Down Expand Up @@ -578,8 +574,13 @@ class StringDType( # type: ignore[misc]
_NativeOrder,
_NBit[L[8], L[16]],
# TODO: Replace the (invalid) `str` with the scalar type, once implemented
np.dtype[str], # type: ignore[type-var]
np.dtype[str], # type: ignore[type-var] # pyright: ignore[reportGeneralTypeIssues,reportInvalidTypeArguments]
):
@property
def coerce(self) -> L[True]: ...
na_object: ClassVar[MemberDescriptorType] # does not get instantiated

#
def __new__(cls, /) -> StringDType: ...
def __getitem__(self, key: Any, /) -> NoReturn: ...
@property
Expand Down
Loading
Loading