Skip to content

Commit d6a1028

Browse files
committed
more tests + reviews
1 parent c4d0c9c commit d6a1028

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

src/csrc/dtype.c

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,10 @@ quadprec_argmax(char *data, npy_intp n, npy_intp *max_ind, void *arr)
472472
if (descr->backend == BACKEND_SLEEF) {
473473
// Find first non-NaN value as initial max
474474
npy_intp start = 0;
475+
Sleef_quad max_val;
475476
for (start = 0; start < n; start++) {
476-
Sleef_quad val = *(Sleef_quad *)(data + start * elsize);
477-
if (!Sleef_iunordq1(val, val)) {
477+
max_val = *(Sleef_quad *)(data + start * elsize);
478+
if (!Sleef_iunordq1(max_val, max_val)) {
478479
*max_ind = start;
479480
break;
480481
}
@@ -489,24 +490,25 @@ quadprec_argmax(char *data, npy_intp n, npy_intp *max_ind, void *arr)
489490
// Find maximum
490491
for (npy_intp i = start + 1; i < n; i++) {
491492
Sleef_quad val = *(Sleef_quad *)(data + i * elsize);
492-
Sleef_quad max_val = *(Sleef_quad *)(data + (*max_ind) * elsize);
493493

494494
// Skip NaN values
495495
if (Sleef_iunordq1(val, val)) {
496496
continue;
497497
}
498498

499499
if (Sleef_icmpgtq1(val, max_val)) {
500+
max_val = val;
500501
*max_ind = i;
501502
}
502503
}
503504
}
504505
else {
505506
// Find first non-NaN value as initial max
506507
npy_intp start = 0;
508+
long double max_val;
507509
for (start = 0; start < n; start++) {
508-
long double val = *(long double *)(data + start * elsize);
509-
if (!isnan(val)) {
510+
max_val = *(long double *)(data + start * elsize);
511+
if (!isnan(max_val)) {
510512
*max_ind = start;
511513
break;
512514
}
@@ -521,14 +523,14 @@ quadprec_argmax(char *data, npy_intp n, npy_intp *max_ind, void *arr)
521523
// Find maximum
522524
for (npy_intp i = start + 1; i < n; i++) {
523525
long double val = *(long double *)(data + i * elsize);
524-
long double max_val = *(long double *)(data + (*max_ind) * elsize);
525526

526527
// Skip NaN values
527528
if (isnan(val)) {
528529
continue;
529530
}
530531

531532
if (val > max_val) {
533+
max_val = val;
532534
*max_ind = i;
533535
}
534536
}
@@ -554,9 +556,10 @@ quadprec_argmin(char *data, npy_intp n, npy_intp *min_ind, void *arr)
554556
if (descr->backend == BACKEND_SLEEF) {
555557
// Find first non-NaN value as initial min
556558
npy_intp start = 0;
559+
Sleef_quad min_val;
557560
for (start = 0; start < n; start++) {
558-
Sleef_quad val = *(Sleef_quad *)(data + start * elsize);
559-
if (!Sleef_iunordq1(val, val)) {
561+
min_val = *(Sleef_quad *)(data + start * elsize);
562+
if (!Sleef_iunordq1(min_val, min_val)) {
560563
*min_ind = start;
561564
break;
562565
}
@@ -571,24 +574,25 @@ quadprec_argmin(char *data, npy_intp n, npy_intp *min_ind, void *arr)
571574
// Find minimum
572575
for (npy_intp i = start + 1; i < n; i++) {
573576
Sleef_quad val = *(Sleef_quad *)(data + i * elsize);
574-
Sleef_quad min_val = *(Sleef_quad *)(data + (*min_ind) * elsize);
575577

576578
// Skip NaN values
577579
if (Sleef_iunordq1(val, val)) {
578580
continue;
579581
}
580582

581583
if (Sleef_icmpltq1(val, min_val)) {
584+
min_val = val;
582585
*min_ind = i;
583586
}
584587
}
585588
}
586589
else {
587590
// Find first non-NaN value as initial min
588591
npy_intp start = 0;
592+
long double min_val;
589593
for (start = 0; start < n; start++) {
590-
long double val = *(long double *)(data + start * elsize);
591-
if (!isnan(val)) {
594+
min_val = *(long double *)(data + start * elsize);
595+
if (!isnan(min_val)) {
592596
*min_ind = start;
593597
break;
594598
}
@@ -603,14 +607,14 @@ quadprec_argmin(char *data, npy_intp n, npy_intp *min_ind, void *arr)
603607
// Find minimum
604608
for (npy_intp i = start + 1; i < n; i++) {
605609
long double val = *(long double *)(data + i * elsize);
606-
long double min_val = *(long double *)(data + (*min_ind) * elsize);
607610

608611
// Skip NaN values
609612
if (isnan(val)) {
610613
continue;
611614
}
612615

613616
if (val < min_val) {
617+
min_val = val;
614618
*min_ind = i;
615619
}
616620
}

tests/test_quaddtype.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5772,5 +5772,15 @@ def test_argmax_argmin(backend):
57725772
# 2D with axis
57735773
x = np.array([[1, 5, 3], [4, 2, 6]], dtype=QuadPrecDType(backend=backend))
57745774
assert np.argmax(x) == 5 # flattened
5775+
assert np.argmin(x) == 0 # flattened
57755776
np.testing.assert_array_equal(np.argmax(x, axis=0), [1, 0, 1])
5776-
np.testing.assert_array_equal(np.argmax(x, axis=1), [1, 2])
5777+
np.testing.assert_array_equal(np.argmin(x, axis=0), [0, 1, 0])
5778+
np.testing.assert_array_equal(np.argmax(x, axis=1), [1, 2])
5779+
np.testing.assert_array_equal(np.argmin(x, axis=1), [0, 1])
5780+
5781+
# Empty array raises ValueError
5782+
x = np.array([], dtype=QuadPrecDType(backend=backend))
5783+
with pytest.raises(ValueError):
5784+
np.argmax(x)
5785+
with pytest.raises(ValueError):
5786+
np.argmin(x)

0 commit comments

Comments
 (0)