Skip to content

Commit b523901

Browse files
BUG: fix replace on non-contiguous array views (#513)
1 parent 7113aee commit b523901

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

bottleneck/src/nonreduce_template.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,20 @@ replace_DTYPE0(PyArrayObject *a, double old, double new) {
2828
BN_BEGIN_ALLOW_THREADS
2929
const npy_DTYPE0 oldf = (npy_DTYPE0)old;
3030
const npy_DTYPE0 newf = (npy_DTYPE0)new;
31+
const npy_intp stride = it.stride;
3132
if (old == old) {
3233
WHILE {
3334
npy_DTYPE0* array = PA(DTYPE0);
3435
FOR {
35-
array[it.i] = array[it.i] == oldf ? newf : array[it.i];
36+
array[it.i * stride] = array[it.i * stride] == oldf ? newf : array[it.i * stride];
3637
}
3738
NEXT
3839
}
3940
} else {
4041
WHILE {
4142
npy_DTYPE0* array = PA(DTYPE0);
4243
FOR {
43-
array[it.i] = array[it.i] != array[it.i] ? newf : array[it.i];
44+
array[it.i * stride] = array[it.i * stride] != array[it.i * stride] ? newf : array[it.i * stride];
4445
}
4546
NEXT
4647
}
@@ -69,12 +70,13 @@ replace_DTYPE0(PyArrayObject *a, double old, double new) {
6970
return NULL;
7071
}
7172
BN_BEGIN_ALLOW_THREADS
73+
const npy_intp stride = it.stride;
7274
WHILE {
7375
npy_DTYPE0* array = (npy_DTYPE0 *)it.pa;
7476
npy_intp i;
7577
// clang has a large perf regression when using the FOR macro here
7678
for (i=0; i < it.length; i++) {
77-
array[i] = array[i] == oldint ? newint : array[i];
79+
array[i * stride] = array[i * stride] == oldint ? newint : array[i * stride];
7880
}
7981
NEXT
8082
}

bottleneck/tests/nonreduce_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import bottleneck as bn
1010

11-
from .util import DTYPES, INT_DTYPES, array_order, arrays
11+
from .util import DTYPES, FLOAT_DTYPES, INT_DTYPES, array_order, arrays
1212

1313

1414
@pytest.mark.parametrize(
@@ -138,3 +138,31 @@ def test_replace_newaxis(dtype):
138138
array = np.ones((2, 2), dtype=dtype)[..., np.newaxis]
139139
result = bn.replace(array, 1, 2)
140140
assert (result == 2).all().all()
141+
142+
143+
@pytest.mark.parametrize("dtype", DTYPES)
144+
def test_replace_view(dtype):
145+
"""Test replace on non-contiguous view"""
146+
expected_array = np.arange(20, dtype=dtype)
147+
expected_view = expected_array[::2]
148+
bn.slow.replace(expected_view, 10, -1)
149+
array = np.arange(20, dtype=dtype)
150+
view = array[::2]
151+
bn.replace(view, 10, -1)
152+
assert_array_equal(view, expected_view)
153+
assert_array_equal(array, expected_array)
154+
155+
156+
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
157+
def test_replace_nan_view(dtype):
158+
"""Test replace NaN on non-contiguous view"""
159+
expected_array = np.ones((4, 3, 2), dtype=dtype)
160+
expected_array[::2, :, 0] = np.nan
161+
expected_view = expected_array[:, :, 0]
162+
bn.slow.replace(expected_view, np.nan, 0)
163+
array = np.ones((4, 3, 2), dtype=dtype)
164+
array[::2, :, 0] = np.nan
165+
view = array[:, :, 0]
166+
bn.replace(view, np.nan, 0)
167+
assert_array_equal(view, expected_view)
168+
assert_array_equal(array, expected_array)

0 commit comments

Comments
 (0)