import asyncio
import gc
import os
import sys
import threading
import pytest
import numpy as np
from numpy.testing import extbuild, assert_warns, IS_WASM, IS_EDITABLE
from numpy._core.multiarray import get_handler_name
@pytest.fixture
def get_module(tmp_path):
""" Add a memory policy that returns a false pointer 64 bytes into the
actual allocation, and fill the prefix with some text. Then check at each
memory manipulation that the prefix exists, to make sure all alloc/realloc/
free/calloc go via the functions here.
"""
if sys.platform.startswith('cygwin'):
pytest.skip('link fails on cygwin')
if IS_WASM:
pytest.skip("Can't build module inside Wasm")
if IS_EDITABLE:
pytest.skip("Can't build module for editable install")
functions = [
("get_default_policy", "METH_NOARGS", """
Py_INCREF(PyDataMem_DefaultHandler);
return PyDataMem_DefaultHandler;
"""),
("set_secret_data_policy", "METH_NOARGS", """
PyObject *secret_data =
PyCapsule_New(&secret_data_handler, "mem_handler", NULL);
if (secret_data == NULL) {
return NULL;
}
PyObject *old = PyDataMem_SetHandler(secret_data);
Py_DECREF(secret_data);
return old;
"""),
("set_wrong_capsule_name_data_policy", "METH_NOARGS", """
PyObject *wrong_name_capsule =
PyCapsule_New(&secret_data_handler, "not_mem_handler", NULL);
if (wrong_name_capsule == NULL) {
return NULL;
}
PyObject *old = PyDataMem_SetHandler(wrong_name_capsule);
Py_DECREF(wrong_name_capsule);
return old;
"""),
("set_old_policy", "METH_O", """
PyObject *old;
if (args != NULL && PyCapsule_CheckExact(args)) {
old = PyDataMem_SetHandler(args);
}
else {
old = PyDataMem_SetHandler(NULL);
}
return old;
"""),
("get_array", "METH_NOARGS", """
char *buf = (char *)malloc(20);
npy_intp dims[1];
dims[0] = 20;
PyArray_Descr *descr = PyArray_DescrNewFromType(NPY_UINT8);
return PyArray_NewFromDescr(&PyArray_Type, descr, 1, dims, NULL,
buf, NPY_ARRAY_WRITEABLE, NULL);
"""),
("set_own", "METH_O", """
if (!PyArray_Check(args)) {
PyErr_SetString(PyExc_ValueError,
"need an ndarray");
return NULL;
}
PyArray_ENABLEFLAGS((PyArrayObject*)args, NPY_ARRAY_OWNDATA);
// Maybe try this too?
// PyArray_BASE(PyArrayObject *)args) = NULL;
Py_RETURN_NONE;
"""),
("get_array_with_base", "METH_NOARGS", """
char *buf = (char *)malloc(20);
npy_intp dims[1];
dims[0] = 20;
PyArray_Descr *descr = PyArray_DescrNewFromType(NPY_UINT8);
PyObject *arr = PyArray_NewFromDescr(&PyArray_Type, descr, 1, dims,
NULL, buf,
NPY_ARRAY_WRITEABLE, NULL);
if (arr == NULL) return NULL;
PyObject *obj = PyCapsule_New(buf, "buf capsule",
(PyCapsule_Destructor)&warn_on_free);
if (obj == NULL) {
Py_DECREF(arr);
return NULL;
}
if (PyArray_SetBaseObject((PyArrayObject *)arr, obj) < 0) {
Py_DECREF(arr);
Py_DECREF(obj);
return NULL;
}
return arr;
"""),
]
prologue = '''
#define NPY_TARGET_VERSION NPY_1_22_API_VERSION
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/arrayobject.h>
/*
* This struct allows the dynamic configuration of the allocator funcs
* of the `secret_data_allocator`. It is provided here for
* demonstration purposes, as a valid `ctx` use-case scenario.
*/
typedef struct {
void *(*malloc)(size_t);
void *(*calloc)(size_t, size_t);
void *(*realloc)(void *, size_t);
void (*free)(void *);
} SecretDataAllocatorFuncs;
NPY_NO_EXPORT void *
shift_alloc(void *ctx, size_t sz) {
SecretDataAllocatorFuncs *funcs = (SecretDataAllocatorFuncs *)ctx;
char *real = (char *)funcs->malloc(sz + 64);
if (real == NULL) {
return NULL;
}
snprintf(real, 64, "originally allocated %ld", (unsigned long)sz);
return (void *)(real + 64);
}
NPY_NO_EXPORT void *
shift_zero(void *ctx, size_t sz, size_t cnt) {
SecretDataAllocatorFuncs *funcs = (SecretDataAllocatorFuncs *)ctx;
char *real = (char *)funcs->calloc(sz + 64, cnt);
if (real == NULL) {
return NULL;
}
snprintf(real, 64, "originally allocated %ld via zero",
(unsigned long)sz);
return (void *)(real + 64);
}
NPY_NO_EXPORT void
shift_free(void *ctx, void * p, npy_uintp sz) {
SecretDataAllocatorFuncs *funcs = (SecretDataAllocatorFuncs *)ctx;
if (p == NULL) {
return ;
}
char *real = (char *)p - 64;
if (strncmp(real, "originally allocated", 20) != 0) {
fprintf(stdout, "uh-oh, unmatched shift_free, "
"no appropriate prefix\\n");
/* Make C runtime crash by calling free on the wrong address */
funcs->free((char *)p + 10);
/* funcs->free(real); */
}
else {
npy_uintp i = (npy_uintp)atoi(real +20);
if (i != sz) {
fprintf(stderr, "uh-oh, unmatched shift_free"
"(ptr, %ld) but allocated %ld\\n", sz, i);
/* This happens in some places, only print */
funcs->free(real);
}
else {
funcs->free(real);
}
}
}
NPY_NO_EXPORT void *
shift_realloc(void *ctx, void * p, npy_uintp sz) {
SecretDataAllocatorFuncs *funcs = (SecretDataAllocatorFuncs *)ctx;
if (p != NULL) {
char *real = (char *)p - 64;
if (strncmp(real, "originally allocated", 20) != 0) {
fprintf(stdout, "uh-oh, unmatched shift_realloc\\n");
return realloc(p, sz);
}
return (void *)((char *)funcs->realloc(real, sz + 64) + 64);
}
else {
char *real = (char *)funcs->realloc(p, sz + 64);
if (real == NULL) {
return NULL;
}
snprintf(real, 64, "originally allocated "
"%ld via realloc", (unsigned long)sz);
return (void *)(real + 64);
}
}
/* As an example, we use the standard {m|c|re}alloc/free funcs. */
static SecretDataAllocatorFuncs secret_data_handler_ctx = {
malloc,
calloc,
realloc,
free
};
static PyDataMem_Handler secret_data_handler = {
"secret_data_allocator",
1,
{
&secret_data_handler_ctx, /* ctx */
shift_alloc, /* malloc */
shift_zero, /* calloc */
shift_realloc, /* realloc */
shift_free /* free */
}
};
void warn_on_free(void *capsule) {
PyErr_WarnEx(PyExc_UserWarning, "in warn_on_free", 1);
void * obj = PyCapsule_GetPointer(capsule,
PyCapsule_GetName(capsule));
free(obj);
};
'''
more_init = "import_array();"
try:
import mem_policy
return mem_policy
except ImportError:
pass
# if it does not exist, build and load it
return extbuild.build_and_import_extension('mem_policy',
functions,
prologue=prologue,
include_dirs=[np.get_include()],
build_dir=tmp_path,
more_init=more_init)
def test_set_policy(get_module):
get_handler_name = np._core.multiarray.get_handler_name
get_handler_version = np._core.multiarray.get_handler_version
orig_policy_name = get_handler_name()
a = np.arange(10).reshape((2, 5)) # a doesn't own its own data
assert get_handler_name(a) is None
assert get_handler_version(a) is None
assert get_handler_name(a.base) == orig_policy_name
assert get_handler_version(a.base) == 1
orig_policy = get_module.set_secret_data_policy()
b = np.arange(10).reshape((2, 5)) # b doesn't own its own data
assert get_handler_name(b) is None
assert get_handler_version(b) is None
assert get_handler_name(b.base) == 'secret_data_allocator'
assert get_handler_version(b.base) == 1
if orig_policy_name == 'default_allocator':
get_module.set_old_policy(None) # tests PyDataMem_SetHandler(NULL)
assert get_handler_name() == 'default_allocator'
else:
get_module.set_old_policy(orig_policy)
assert get_handler_name() == orig_policy_name
with pytest.raises(ValueError,
match="Capsule must be named 'mem_handler'"):
get_module.set_wrong_capsule_name_data_policy()
def test_default_policy_singleton(get_module):
get_handler_name = np._core.multiarray.get_handler_name
# set the policy to default
orig_policy = get_module.set_old_policy(None)
assert get_handler_name() == 'default_allocator'
# re-set the policy to default
def_policy_1 = get_module.set_old_policy(None)
assert get_handler_name() == 'default_allocator'
# set the policy to original
def_policy_2 = get_module.set_old_policy(orig_policy)
# since default policy is a singleton,
# these should be the same object
assert def_policy_1 is def_policy_2 is get_module.get_default_policy()
def test_policy_propagation(get_module):
# The memory policy goes hand-in-hand with flags.owndata
class MyArr(np.ndarray):
pass
get_handler_name = np._core.multiarray.get_handler_name
orig_policy_name = get_handler_name()
a = np.arange(10).view(MyArr).reshape((2, 5))
assert get_handler_name(a) is None
assert a.flags.owndata is False
assert get_handler_name(a.base) is None
assert a.base.flags.owndata is False
assert get_handler_name(a.base.base) == orig_policy_name
assert a.base.base.flags.owndata is True
async def concurrent_context1(get_module, orig_policy_name, event):
if orig_policy_name == 'default_allocator':
get_module.set_secret_data_policy()
assert get_handler_name() == 'secret_data_allocator'
else:
get_module.set_old_policy(None)
assert get_handler_name() == 'default_allocator'
event.set()
async def concurrent_context2(get_module, orig_policy_name, event):
await event.wait()
# the policy is not affected by changes in parallel contexts
assert get_handler_name() == orig_policy_name
# change policy in the child context
if orig_policy_name == 'default_allocator':
get_module.set_secret_data_policy()
assert get_handler_name() == 'secret_data_allocator'
else:
get_module.set_old_policy(None)
assert get_handler_name() == 'default_allocator'
async def async_test_context_locality(get_module):
orig_policy_name = np._core.multiarray.get_handler_name()
event = asyncio.Event()
# the child contexts inherit the parent policy
concurrent_task1 = asyncio.create_task(
concurrent_context1(get_module, orig_policy_name, event))
concurrent_task2 = asyncio.create_task(
concurrent_context2(get_module, orig_policy_name, event))
await concurrent_task1
await concurrent_task2
# the parent context is not affected by child policy changes
assert np._core.multiarray.get_handler_name() == orig_policy_name
def test_context_locality(get_module):
if (sys.implementation.name == 'pypy'
and sys.pypy_version_info[:3] < (7, 3, 6)):
pytest.skip('no context-locality support in PyPy < 7.3.6')
asyncio.run(async_test_context_locality(get_module))
def concurrent_thread1(get_module, event):
get_module.set_secret_data_policy()
assert np._core.multiarray.get_handler_name() == 'secret_data_allocator'
event.set()
def concurrent_thread2(get_module, event):
event.wait()
# the policy is not affected by changes in parallel threads
assert np._core.multiarray.get_handler_name() == 'default_allocator'
# change policy in the child thread
get_module.set_secret_data_policy()
def test_thread_locality(get_module):
orig_policy_name = np._core.multiarray.get_handler_name()
event = threading.Event()
# the child threads do not inherit the parent policy
concurrent_task1 = threading.Thread(target=concurrent_thread1,
args=(get_module, event))
concurrent_task2 = threading.Thread(target=concurrent_thread2,
args=(get_module, event))
concurrent_task1.start()
concurrent_task2.start()
concurrent_task1.join()
concurrent_task2.join()
# the parent thread is not affected by child policy changes
assert np._core.multiarray.get_handler_name() == orig_policy_name
@pytest.mark.skip(reason="too slow, see gh-23975")
def test_new_policy(get_module):
a = np.arange(10)
orig_policy_name = np._core.multiarray.get_handler_name(a)
orig_policy = get_module.set_secret_data_policy()
b = np.arange(10)
assert np._core.multiarray.get_handler_name(b) == 'secret_data_allocator'
# test array manipulation. This is slow
if orig_policy_name == 'default_allocator':
# when the np._core.test tests recurse into this test, the
# policy will be set so this "if" will be false, preventing
# infinite recursion
#
# if needed, debug this by
# - running tests with -- -s (to not capture stdout/stderr
# - setting verbose=2
# - setting extra_argv=['-vv'] here
assert np._core.test('full', verbose=1, extra_argv=[])
# also try the ma tests, the pickling test is quite tricky
assert np.ma.test('full', verbose=1, extra_argv=[])
get_module.set_old_policy(orig_policy)
c = np.arange(10)
assert np._core.multiarray.get_handler_name(c) == orig_policy_name
@pytest.mark.xfail(sys.implementation.name == "pypy",
reason=("bad interaction between getenv and "
"os.environ inside pytest"))
@pytest.mark.parametrize("policy", ["0", "1", None])
def test_switch_owner(get_module, policy):
a = get_module.get_array()
assert np._core.multiarray.get_handler_name(a) is None
get_module.set_own(a)
if policy is None:
# See what we expect to be set based on the env variable
policy = os.getenv("NUMPY_WARN_IF_NO_MEM_POLICY", "0") == "1"
oldval = None
else:
policy = policy == "1"
oldval = np._core._multiarray_umath._set_numpy_warn_if_no_mem_policy(
policy)
try:
# The policy should be NULL, so we have to assume we can call
# "free". A warning is given if the policy == "1"
if policy:
with assert_warns(RuntimeWarning) as w:
del a
gc.collect()
else:
del a
gc.collect()
finally:
if oldval is not None:
np._core._multiarray_umath._set_numpy_warn_if_no_mem_policy(oldval)
def test_owner_is_base(get_module):
a = get_module.get_array_with_base()
with pytest.warns(UserWarning, match='warn_on_free'):
del a
gc.collect()
gc.collect()