diff --git a/.github/workflows/compiler_sanitizers.yml b/.github/workflows/compiler_sanitizers.yml new file mode 100644 index 000000000000..9477e0be1bd1 --- /dev/null +++ b/.github/workflows/compiler_sanitizers.yml @@ -0,0 +1,127 @@ +name: Test with compiler sanitizers + +on: + push: + branches: + - main + pull_request: + branches: + - main + - maintenance/** + +defaults: + run: + shell: bash + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +permissions: + contents: read # to fetch code (actions/checkout) + +jobs: + clang_ASAN: + # To enable this workflow on a fork, comment out: + if: github.repository == 'numpy/numpy' + runs-on: macos-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: recursive + fetch-tags: true + persist-credentials: false + - name: Set up pyenv + run: | + git clone https://github.com/pyenv/pyenv.git "$HOME/.pyenv" + PYENV_ROOT="$HOME/.pyenv" + PYENV_BIN="$PYENV_ROOT/bin" + PYENV_SHIMS="$PYENV_ROOT/shims" + echo "$PYENV_BIN" >> $GITHUB_PATH + echo "$PYENV_SHIMS" >> $GITHUB_PATH + echo "PYENV_ROOT=$PYENV_ROOT" >> $GITHUB_ENV + - name: Check pyenv is working + run: + pyenv --version + - name: Set up LLVM + run: | + brew install llvm@19 + LLVM_PREFIX=$(brew --prefix llvm@19) + echo CC="$LLVM_PREFIX/bin/clang" >> $GITHUB_ENV + echo CXX="$LLVM_PREFIX/bin/clang++" >> $GITHUB_ENV + echo LDFLAGS="-L$LLVM_PREFIX/lib" >> $GITHUB_ENV + echo CPPFLAGS="-I$LLVM_PREFIX/include" >> $GITHUB_ENV + - name: Build Python with address sanitizer + run: | + CONFIGURE_OPTS="--with-address-sanitizer" pyenv install 3.13 + pyenv global 3.13 + - name: Install dependencies + run: | + pip install -r requirements/build_requirements.txt + pip install -r requirements/ci_requirements.txt + pip install -r requirements/test_requirements.txt + # xdist captures stdout/stderr, but we want the ASAN output + pip uninstall -y pytest-xdist + - name: Build + run: + python -m spin build -j2 -- -Db_sanitize=address + - name: Test + run: | + # pass -s to pytest to see ASAN errors and warnings, otherwise pytest captures them + ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true:allocator_may_return_null=1:halt_on_error=1 \ + python -m spin test -- -v -s --timeout=600 --durations=10 + + clang_TSAN: + # To enable this workflow on a fork, comment out: + if: github.repository == 'numpy/numpy' + runs-on: macos-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: recursive + fetch-tags: true + persist-credentials: false + - name: Set up pyenv + run: | + git clone https://github.com/pyenv/pyenv.git "$HOME/.pyenv" + PYENV_ROOT="$HOME/.pyenv" + PYENV_BIN="$PYENV_ROOT/bin" + PYENV_SHIMS="$PYENV_ROOT/shims" + echo "$PYENV_BIN" >> $GITHUB_PATH + echo "$PYENV_SHIMS" >> $GITHUB_PATH + echo "PYENV_ROOT=$PYENV_ROOT" >> $GITHUB_ENV + - name: Check pyenv is working + run: + pyenv --version + - name: Set up LLVM + run: | + brew install llvm@19 + LLVM_PREFIX=$(brew --prefix llvm@19) + echo CC="$LLVM_PREFIX/bin/clang" >> $GITHUB_ENV + echo CXX="$LLVM_PREFIX/bin/clang++" >> $GITHUB_ENV + echo LDFLAGS="-L$LLVM_PREFIX/lib" >> $GITHUB_ENV + echo CPPFLAGS="-I$LLVM_PREFIX/include" >> $GITHUB_ENV + - name: Build Python with thread sanitizer support + run: | + # free-threaded Python is much more likely to trigger races + CONFIGURE_OPTS="--with-thread-sanitizer" pyenv install 3.13t + pyenv global 3.13t + - name: Install dependencies + run: | + # TODO: remove when a released cython supports free-threaded python + pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple cython + pip install -r requirements/build_requirements.txt + pip install -r requirements/ci_requirements.txt + pip install -r requirements/test_requirements.txt + # xdist captures stdout/stderr, but we want the TSAN output + pip uninstall -y pytest-xdist + - name: Build + run: + python -m spin build -j2 -- -Db_sanitize=thread + - name: Test + run: | + # These tests are slow, so only run tests in files that do "import threading" to make them count + TSAN_OPTIONS=allocator_may_return_null=1:halt_on_error=1 \ + python -m spin test \ + `find numpy -name "test*.py" | xargs grep -l "import threading" | tr '\n' ' '` \ + -- -v -s --timeout=600 --durations=10 diff --git a/numpy/_core/src/multiarray/convert_datatype.c b/numpy/_core/src/multiarray/convert_datatype.c index 1dff38a1d1ef..00251af5bf68 100644 --- a/numpy/_core/src/multiarray/convert_datatype.c +++ b/numpy/_core/src/multiarray/convert_datatype.c @@ -62,46 +62,24 @@ static PyObject * PyArray_GetObjectToGenericCastingImpl(void); -/** - * Fetch the casting implementation from one DType to another. - * - * @param from The implementation to cast from - * @param to The implementation to cast to - * - * @returns A castingimpl (PyArrayDTypeMethod *), None or NULL with an - * error set. - */ -NPY_NO_EXPORT PyObject * -PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to) +static PyObject * +create_casting_impl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to) { - PyObject *res; - if (from == to) { - res = (PyObject *)NPY_DT_SLOTS(from)->within_dtype_castingimpl; - } - else { - res = PyDict_GetItemWithError(NPY_DT_SLOTS(from)->castingimpls, (PyObject *)to); - } - if (res != NULL || PyErr_Occurred()) { - Py_XINCREF(res); - return res; - } /* - * The following code looks up CastingImpl based on the fact that anything + * Look up CastingImpl based on the fact that anything * can be cast to and from objects or structured (void) dtypes. - * - * The last part adds casts dynamically based on legacy definition */ if (from->type_num == NPY_OBJECT) { - res = PyArray_GetObjectToGenericCastingImpl(); + return PyArray_GetObjectToGenericCastingImpl(); } else if (to->type_num == NPY_OBJECT) { - res = PyArray_GetGenericToObjectCastingImpl(); + return PyArray_GetGenericToObjectCastingImpl(); } else if (from->type_num == NPY_VOID) { - res = PyArray_GetVoidToGenericCastingImpl(); + return PyArray_GetVoidToGenericCastingImpl(); } else if (to->type_num == NPY_VOID) { - res = PyArray_GetGenericToVoidCastingImpl(); + return PyArray_GetGenericToVoidCastingImpl(); } /* * Reject non-legacy dtypes. They need to use the new API to add casts and @@ -125,42 +103,105 @@ PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to) from->singleton, to->type_num); if (castfunc == NULL) { PyErr_Clear(); - /* Remember that this cast is not possible */ - if (PyDict_SetItem(NPY_DT_SLOTS(from)->castingimpls, - (PyObject *) to, Py_None) < 0) { - return NULL; - } Py_RETURN_NONE; } } - - /* PyArray_AddLegacyWrapping_CastingImpl find the correct casting level: */ - /* - * TODO: Possibly move this to the cast registration time. But if we do - * that, we have to also update the cast when the casting safety - * is registered. + /* Create a cast using the state of the legacy casting setup defined + * during the setup of the DType. + * + * Ideally we would do this when we create the DType, but legacy user + * DTypes don't have a way to signal that a DType is done setting up + * casts. Without such a mechanism, the safest way to know that a + * DType is done setting up is to register the cast lazily the first + * time a user does the cast. + * + * We *could* register the casts when we create the wrapping + * DTypeMeta, but that means the internals of the legacy user DType + * system would need to update the state of the casting safety flags + * in the cast implementations stored on the DTypeMeta. That's an + * inversion of abstractions and would be tricky to do without + * creating circular dependencies inside NumPy. */ if (PyArray_AddLegacyWrapping_CastingImpl(from, to, -1) < 0) { return NULL; } + /* castingimpls is unconditionally filled by + * AddLegacyWrapping_CastingImpl, so this won't create a recursive + * critical section + */ return PyArray_GetCastingImpl(from, to); } +} - if (res == NULL) { +static PyObject * +ensure_castingimpl_exists(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to) +{ + int return_error = 0; + PyObject *res = NULL; + + /* Need to create the cast. This might happen at runtime so we enter a + critical section to avoid races */ + + Py_BEGIN_CRITICAL_SECTION(NPY_DT_SLOTS(from)->castingimpls); + + /* check if another thread filled it while this thread was blocked on + acquiring the critical section */ + if (PyDict_GetItemRef(NPY_DT_SLOTS(from)->castingimpls, (PyObject *)to, + &res) < 0) { + return_error = 1; + } + else if (res == NULL) { + res = create_casting_impl(from, to); + if (res == NULL) { + return_error = 1; + } + else if (PyDict_SetItem(NPY_DT_SLOTS(from)->castingimpls, + (PyObject *)to, res) < 0) { + return_error = 1; + } + } + Py_END_CRITICAL_SECTION(); + if (return_error) { + Py_XDECREF(res); return NULL; } - if (from == to) { + if (from == to && res == Py_None) { PyErr_Format(PyExc_RuntimeError, "Internal NumPy error, within-DType cast missing for %S!", from); Py_DECREF(res); return NULL; } - if (PyDict_SetItem(NPY_DT_SLOTS(from)->castingimpls, - (PyObject *)to, res) < 0) { - Py_DECREF(res); + return res; +} + +/** + * Fetch the casting implementation from one DType to another. + * + * @param from The implementation to cast from + * @param to The implementation to cast to + * + * @returns A castingimpl (PyArrayDTypeMethod *), None or NULL with an + * error set. + */ +NPY_NO_EXPORT PyObject * +PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to) +{ + PyObject *res = NULL; + if (from == to) { + if ((NPY_DT_SLOTS(from)->within_dtype_castingimpl) != NULL) { + res = Py_XNewRef( + (PyObject *)NPY_DT_SLOTS(from)->within_dtype_castingimpl); + } + } + else if (PyDict_GetItemRef(NPY_DT_SLOTS(from)->castingimpls, + (PyObject *)to, &res) < 0) { return NULL; } - return res; + if (res != NULL) { + return res; + } + + return ensure_castingimpl_exists(from, to); } @@ -409,7 +450,7 @@ _get_cast_safety_from_castingimpl(PyArrayMethodObject *castingimpl, * implementations fully to have them available for doing the actual cast * later. * - * @param from The descriptor to cast from + * @param from The descriptor to cast from * @param to The descriptor to cast to (may be NULL) * @param to_dtype If `to` is NULL, must pass the to_dtype (otherwise this * is ignored). @@ -2031,6 +2072,11 @@ PyArray_AddCastingImplementation(PyBoundArrayMethodObject *meth) /** * Add a new casting implementation using a PyArrayMethod_Spec. * + * Using this function outside of module initialization without holding a + * critical section on the castingimpls dict may lead to a race to fill the + * dict. Use PyArray_GetGastingImpl to lazily register casts at runtime + * safely. + * * @param spec The specification to use as a source * @param private If private, allow slots not publicly exposed. * @return 0 on success -1 on failure diff --git a/numpy/_core/src/multiarray/dtypemeta.c b/numpy/_core/src/multiarray/dtypemeta.c index a60e6fd59fd9..0b1b0fb39192 100644 --- a/numpy/_core/src/multiarray/dtypemeta.c +++ b/numpy/_core/src/multiarray/dtypemeta.c @@ -1252,6 +1252,12 @@ dtypemeta_wrap_legacy_descriptor( return -1; } } + else { + // ensure the within dtype cast is populated for legacy user dtypes + if (PyArray_GetCastingImpl(dtype_class, dtype_class) == NULL) { + return -1; + } + } return 0; } diff --git a/numpy/_core/tests/test_multithreading.py b/numpy/_core/tests/test_multithreading.py index 2ddca57dbd0b..133268d276ee 100644 --- a/numpy/_core/tests/test_multithreading.py +++ b/numpy/_core/tests/test_multithreading.py @@ -1,10 +1,13 @@ +import concurrent.futures import threading +import string import numpy as np import pytest from numpy.testing import IS_WASM from numpy.testing._private.utils import run_threaded +from numpy._core import _rational_tests if IS_WASM: pytest.skip(allow_module_level=True, reason="no threading support in wasm") @@ -165,3 +168,106 @@ def closure(b): x = np.repeat(x0, 2, axis=0)[::2] run_threaded(closure, max_workers=10, pass_barrier=True) + + +def test_structured_advanced_indexing(): + # Test that copyswap(n) used by integer array indexing is threadsafe + # for structured datatypes, see gh-15387. This test can behave randomly. + + # Create a deeply nested dtype to make a failure more likely: + dt = np.dtype([("", "f8")]) + dt = np.dtype([("", dt)] * 2) + dt = np.dtype([("", dt)] * 2) + # The array should be large enough to likely run into threading issues + arr = np.random.uniform(size=(6000, 8)).view(dt)[:, 0] + + rng = np.random.default_rng() + + def func(arr): + indx = rng.integers(0, len(arr), size=6000, dtype=np.intp) + arr[indx] + + tpe = concurrent.futures.ThreadPoolExecutor(max_workers=8) + futures = [tpe.submit(func, arr) for _ in range(10)] + for f in futures: + f.result() + + assert arr.dtype is dt + + +def test_structured_threadsafety2(): + # Nonzero (and some other functions) should be threadsafe for + # structured datatypes, see gh-15387. This test can behave randomly. + from concurrent.futures import ThreadPoolExecutor + + # Create a deeply nested dtype to make a failure more likely: + dt = np.dtype([("", "f8")]) + dt = np.dtype([("", dt)]) + dt = np.dtype([("", dt)] * 2) + # The array should be large enough to likely run into threading issues + arr = np.random.uniform(size=(5000, 4)).view(dt)[:, 0] + + def func(arr): + arr.nonzero() + + tpe = ThreadPoolExecutor(max_workers=8) + futures = [tpe.submit(func, arr) for _ in range(10)] + for f in futures: + f.result() + + assert arr.dtype is dt + + +def test_stringdtype_multithreaded_access_and_mutation( + dtype, random_string_list): + # this test uses an RNG and may crash or cause deadlocks if there is a + # threading bug + rng = np.random.default_rng(0x4D3D3D3) + + chars = list(string.ascii_letters + string.digits) + chars = np.array(chars, dtype="U1") + ret = rng.choice(chars, size=100 * 10, replace=True) + random_string_list = ret.view("U100") + + def func(arr): + rnd = rng.random() + # either write to random locations in the array, compute a ufunc, or + # re-initialize the array + if rnd < 0.25: + num = np.random.randint(0, arr.size) + arr[num] = arr[num] + "hello" + elif rnd < 0.5: + if rnd < 0.375: + np.add(arr, arr) + else: + np.add(arr, arr, out=arr) + elif rnd < 0.75: + if rnd < 0.875: + np.multiply(arr, np.int64(2)) + else: + np.multiply(arr, np.int64(2), out=arr) + else: + arr[:] = random_string_list + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe: + arr = np.array(random_string_list, dtype=dtype) + futures = [tpe.submit(func, arr) for _ in range(500)] + + for f in futures: + f.result() + + +def test_legacy_usertype_cast_init_thread_safety(): + def closure(b): + b.wait() + np.full((10, 10), 1, _rational_tests.rational) + + try: + run_threaded(closure, 250, pass_barrier=True) + except RuntimeError: + # The 32 bit linux runner will trigger this with 250 threads. I can + # trigger it on my Linux laptop with 500 threads but the CI runner is + # more resource-constrained. + # Reducing the number of threads means the test doesn't trigger the + # bug. Better to skip on some platforms than add a useless test. + pytest.skip("Couldn't spawn enough threads to run the test") diff --git a/numpy/conftest.py b/numpy/conftest.py index b37092296005..0eb42d1103e4 100644 --- a/numpy/conftest.py +++ b/numpy/conftest.py @@ -2,6 +2,7 @@ Pytest configuration and fixtures for the Numpy test suite. """ import os +import string import sys import tempfile from contextlib import contextmanager @@ -10,9 +11,11 @@ import hypothesis import pytest import numpy +import numpy as np from numpy._core._multiarray_tests import get_fpu_mode -from numpy.testing._private.utils import NOGIL_BUILD +from numpy._core.tests._natype import pd_NA +from numpy.testing._private.utils import NOGIL_BUILD, get_stringdtype_dtype try: from scipy_doctest.conftest import dt_config @@ -204,12 +207,12 @@ def warnings_errors_and_rng(test=None): dt_config.check_namespace['StringDType'] = numpy.dtypes.StringDType # temporary skips - dt_config.skiplist = set([ + dt_config.skiplist = { 'numpy.savez', # unclosed file 'numpy.matlib.savez', 'numpy.__array_namespace_info__', 'numpy.matlib.__array_namespace_info__', - ]) + } # xfail problematic tutorials dt_config.pytest_extra_xfail = { @@ -231,3 +234,28 @@ def warnings_errors_and_rng(test=None): 'numpy/f2py/_backends/_distutils.py', ] + +@pytest.fixture +def random_string_list(): + chars = list(string.ascii_letters + string.digits) + chars = np.array(chars, dtype="U1") + ret = np.random.choice(chars, size=100 * 10, replace=True) + return ret.view("U100") + + +@pytest.fixture(params=[True, False]) +def coerce(request): + return request.param + + +@pytest.fixture( + params=["unset", None, pd_NA, np.nan, float("nan"), "__nan__"], + ids=["unset", "None", "pandas.NA", "np.nan", "float('nan')", "string nan"], +) +def na_object(request): + return request.param + + +@pytest.fixture() +def dtype(na_object, coerce): + return get_stringdtype_dtype(na_object, coerce) diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 4a97ff111cd7..01fe6327713c 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -4,6 +4,7 @@ """ import os import sys +import pathlib import platform import re import gc @@ -19,6 +20,7 @@ import sysconfig import concurrent.futures import threading +import importlib.metadata import numpy as np from numpy._core import ( @@ -26,9 +28,11 @@ from numpy import isfinite, isnan, isinf import numpy.linalg._umath_linalg from numpy._utils import _rename_parameter +from numpy._core.tests._natype import pd_NA from io import StringIO + __all__ = [ 'assert_equal', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_equal', 'assert_array_less', 'assert_string_equal', @@ -42,7 +46,7 @@ 'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare', 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON', 'IS_MUSL', 'check_support_sve', 'NOGIL_BUILD', - 'IS_EDITABLE', 'run_threaded', + 'IS_EDITABLE', 'IS_INSTALLED', 'NUMPY_ROOT', 'run_threaded', ] @@ -54,10 +58,40 @@ class KnownFailureException(Exception): KnownFailureTest = KnownFailureException # backwards compat verbose = 0 +NUMPY_ROOT = pathlib.Path(np.__file__).parent + +try: + np_dist = importlib.metadata.distribution('numpy') +except importlib.metadata.PackageNotFoundError: + IS_INSTALLED = IS_EDITABLE = False +else: + IS_INSTALLED = True + try: + if sys.version_info >= (3, 13): + IS_EDITABLE = np_dist.origin.dir_info.editable + else: + # Backport importlib.metadata.Distribution.origin + import json, types # noqa: E401 + origin = json.loads( + np_dist.read_text('direct_url.json') or '{}', + object_hook=lambda data: types.SimpleNamespace(**data), + ) + IS_EDITABLE = origin.dir_info.editable + except AttributeError: + IS_EDITABLE = False + + # spin installs numpy directly via meson, instead of using meson-python, and + # runs the module by setting PYTHONPATH. This is problematic because the + # resulting installation lacks the Python metadata (.dist-info), and numpy + # might already be installed on the environment, causing us to find its + # metadata, even though we are not actually loading that package. + # Work around this issue by checking if the numpy root matches. + if not IS_EDITABLE and np_dist.locate_file('numpy') != NUMPY_ROOT: + IS_INSTALLED = False + IS_WASM = platform.machine() in ["wasm32", "wasm64"] IS_PYPY = sys.implementation.name == 'pypy' IS_PYSTON = hasattr(sys, "pyston_version_info") -IS_EDITABLE = not bool(np.__path__) or 'editable' in np.__path__[0] HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64 @@ -101,14 +135,15 @@ def GetPerformanceAttributes(object, counter, instance=None, # thread's CPU usage is either 0 or 100). To read counters like this, # you should copy this function, but keep the counter open, and call # CollectQueryData() each time you need to know. - # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link) + # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp + #(dead link) # My older explanation for this was that the "AddCounter" process # forced the CPU to 100%, but the above makes more sense :) import win32pdh if format is None: format = win32pdh.PDH_FMT_LONG - path = win32pdh.MakeCounterPath( (machine, object, instance, None, - inum, counter)) + path = win32pdh.MakeCounterPath((machine, object, instance, None, + inum, counter)) hq = win32pdh.OpenQuery() try: hc = win32pdh.AddCounter(hq, path) @@ -166,7 +201,7 @@ def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]): l = f.readline().split(' ') return int(l[13]) except Exception: - return int(100*(time.time()-_load_time[0])) + return int(100 * (time.time() - _load_time[0])) else: # os.getpid is not in all platforms available. # Using time is safe but inaccurate, especially when process @@ -182,7 +217,7 @@ def jiffies(_load_time=[]): import time if not _load_time: _load_time.append(time.time()) - return int(100*(time.time()-_load_time[0])) + return int(100 * (time.time() - _load_time[0])) def build_err_msg(arrays, err_msg, header='Items are not equal:', @@ -190,7 +225,7 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', msg = ['\n' + header] err_msg = str(err_msg) if err_msg: - if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): + if err_msg.find('\n') == -1 and len(err_msg) < 79 - len(header): msg = [msg[0] + ' ' + err_msg] else: msg.append(err_msg) @@ -659,14 +694,14 @@ def assert_approx_equal(actual, desired, significant=7, err_msg='', # Normalized the numbers to be in range (-10.0,10.0) # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) with np.errstate(invalid='ignore'): - scale = 0.5*(np.abs(desired) + np.abs(actual)) + scale = 0.5 * (np.abs(desired) + np.abs(actual)) scale = np.power(10, np.floor(np.log10(scale))) try: - sc_desired = desired/scale + sc_desired = desired / scale except ZeroDivisionError: sc_desired = 0.0 try: - sc_actual = actual/scale + sc_actual = actual / scale except ZeroDivisionError: sc_actual = 0.0 msg = build_err_msg( @@ -687,7 +722,7 @@ def assert_approx_equal(actual, desired, significant=7, err_msg='', return except (TypeError, NotImplementedError): pass - if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)): + if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant - 1)): raise AssertionError(msg) @@ -1379,10 +1414,10 @@ def check_support_sve(__cache=[]): """ gh-22982 """ - + if __cache: return __cache[0] - + import subprocess cmd = 'lscpu' try: @@ -1543,7 +1578,7 @@ def measure(code_str, times=1, label=None): i += 1 exec(code, globs, locs) elapsed = jiffies() - elapsed - return 0.01*elapsed + return 0.01 * elapsed def _assert_valid_refcount(op): @@ -1557,7 +1592,7 @@ def _assert_valid_refcount(op): import gc import numpy as np - b = np.arange(100*100).reshape(100, 100) + b = np.arange(100 * 100).reshape(100, 100) c = b i = 1 @@ -1735,7 +1770,7 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): ax = np.abs(x) ay = np.abs(y) ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) - if not np.all(np.abs(x-y) <= ref): + if not np.all(np.abs(x - y) <= ref): if np.iscomplexobj(x) or np.iscomplexobj(y): msg = f"Arrays are not equal to {nulp} ULP" else: @@ -1851,7 +1886,7 @@ def nulp_diff(x, y, dtype=None): (x.shape, y.shape)) def _diff(rx, ry, vdt): - diff = np.asarray(rx-ry, dtype=vdt) + diff = np.asarray(rx - ry, dtype=vdt) return np.abs(diff) rx = integer_repr(x) @@ -2596,7 +2631,7 @@ def check_free_memory(free_bytes): except ValueError as exc: raise ValueError(f'Invalid environment variable {env_var}: {exc}') - msg = (f'{free_bytes/1e9} GB memory required, but environment variable ' + msg = (f'{free_bytes / 1e9} GB memory required, but environment variable ' f'NPY_AVAILABLE_MEM={env_value} set') else: mem_free = _get_mem_available() @@ -2607,7 +2642,9 @@ def check_free_memory(free_bytes): "the test.") mem_free = -1 else: - msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available' + free_bytes_gb = free_bytes / 1e9 + mem_free_gb = mem_free / 1e9 + msg = f'{free_bytes_gb} GB memory required, but {mem_free_gb} GB available' return msg if mem_free < free_bytes else None @@ -2700,8 +2737,23 @@ def run_threaded(func, max_workers=8, pass_count=False, barrier = threading.Barrier(max_workers) args.append(barrier) if pass_count: - futures = [tpe.submit(func, i, *args) for i in range(max_workers)] + all_args = [(func, i, *args) for i in range(max_workers)] else: - futures = [tpe.submit(func, *args) for _ in range(max_workers)] + all_args = [(func, *args) for i in range(max_workers)] + try: + futures = [] + for arg in all_args: + futures.append(tpe.submit(*arg)) + finally: + if len(futures) < max_workers and pass_barrier: + barrier.abort() for f in futures: f.result() + + +def get_stringdtype_dtype(na_object, coerce=True): + # explicit is check for pd_NA because != with pd_NA returns pd_NA + if na_object is pd_NA or na_object != "unset": + return np.dtypes.StringDType(na_object=na_object, coerce=coerce) + else: + return np.dtypes.StringDType(coerce=coerce) diff --git a/requirements/test_requirements.txt b/requirements/test_requirements.txt index 7ea464dadc40..93e441f61310 100644 --- a/requirements/test_requirements.txt +++ b/requirements/test_requirements.txt @@ -9,6 +9,7 @@ pytest-cov==4.1.0 meson ninja; sys_platform != "emscripten" pytest-xdist +pytest-timeout # for numpy.random.test.test_extending cffi; python_version < '3.10' # For testing types. Notes on the restrictions: