Skip to content

Commit 32dc797

Browse files
MaanasAroracharris
authored andcommitted
BUG: Any dtype should call square on arr ** 2 (#29392)
* BUG: update fast_scalar_power to handle special-case squaring for any array type except object arrays * BUG: fix missing declaration * TST: add test to ensure `arr**2` calls square for structured dtypes * STY: remove whitespace * BUG: replace new variable `is_square` with direct op comparison in `fast_scalar_power` function
1 parent 20d973c commit 32dc797

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

numpy/_core/src/multiarray/number.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ static int
332332
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
333333
{
334334
PyObject *fastop = NULL;
335+
335336
if (PyLong_CheckExact(o2)) {
336337
int overflow = 0;
337338
long exp = PyLong_AsLongAndOverflow(o2, &overflow);
@@ -363,7 +364,12 @@ fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
363364
}
364365

365366
PyArrayObject *a1 = (PyArrayObject *)o1;
366-
if (!(PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) {
367+
if (PyArray_ISOBJECT(a1)) {
368+
return 1;
369+
}
370+
if (fastop != n_ops.square && !PyArray_ISFLOAT(a1) && !PyArray_ISCOMPLEX(a1)) {
371+
// we special-case squaring for any array type
372+
// gh-29388
367373
return 1;
368374
}
369375

numpy/_core/tests/test_multiarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4211,6 +4211,13 @@ def pow_for(exp, arr):
42114211
assert_equal(obj_arr ** -1, pow_for(-1, obj_arr))
42124212
assert_equal(obj_arr ** 2, pow_for(2, obj_arr))
42134213

4214+
def test_pow_calls_square_structured_dtype(self):
4215+
# gh-29388
4216+
dt = np.dtype([('a', 'i4'), ('b', 'i4')])
4217+
a = np.array([(1, 2), (3, 4)], dtype=dt)
4218+
with pytest.raises(TypeError, match="ufunc 'square' not supported"):
4219+
a ** 2
4220+
42144221
def test_pos_array_ufunc_override(self):
42154222
class A(np.ndarray):
42164223
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

0 commit comments

Comments
 (0)