Skip to content

Commit 350397b

Browse files
committed
add test for pack/expand round-trip
1 parent 4e897da commit 350397b

File tree

3 files changed

+97
-2
lines changed

3 files changed

+97
-2
lines changed

src/fitpack_core.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18837,7 +18837,7 @@ pure subroutine FP_REAL_COMM_PACK_2D(array, buffer)
1883718837
integer(FP_SIZE), parameter :: header = 2
1883818838

1883918839
if (allocated(array)) then
18840-
forall (d = 1:2) bnd(:, d) = [lbound(array, d, FP_SIZE), ubound(array, d, FP_SIZE)]
18840+
do d=1,rank(array); bnd(:,d) = [lbound(array, dim=d, kind=FP_SIZE), ubound(array, dim=d, kind=FP_SIZE)]; end do
1884118841
else
1884218842
bnd = FP_NOT_ALLOC
1884318843
end if

test/fitpack_curve_tests.f90

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ module fitpack_curve_tests
4141
public :: test_gridded_sphere
4242
public :: test_parametric_surface
4343
public :: test_fpknot_crash
44+
public :: test_curve_comm_roundtrip
4445

4546

4647
contains
@@ -501,7 +502,7 @@ logical function test_zeros() result(success)
501502

502503
success = .false.
503504

504-
! Try f(x) = x3 Ð 3x2 + 2x = x(x Ð 1)(x Ð 2), with real roots x = 0, x = 1, and x = 2.
505+
! Try f(x) = x3 3x2 + 2x = x(x 1)(x 2), with real roots x = 0, x = 1, and x = 2.
505506
x = linspace(-10.0_FP_REAL,10.0_FP_REAL,20)
506507
y = x**3-3*x**2+2*x
507508

@@ -1244,6 +1245,99 @@ logical function test_parametric_surface(iunit) result(success)
12441245

12451246
end function test_parametric_surface
12461247

1248+
!> Test that pack/expand round-trip preserves curve data
1249+
logical function test_curve_comm_roundtrip() result(success)
1250+
1251+
integer, parameter :: N = 50
1252+
type(fitpack_curve) :: curve1, curve2
1253+
real(FP_REAL) :: x(N), y(N), xtest(10), y1(10), y2(10)
1254+
real(FP_COMM), allocatable :: buffer(:)
1255+
integer(FP_SIZE) :: buf_size
1256+
integer :: ierr, i
1257+
1258+
success = .false.
1259+
1260+
! Generate test data: a simple polynomial
1261+
x = linspace(zero, pi2, N)
1262+
y = sin(x) + half * cos(2 * x)
1263+
1264+
! Create an interpolating curve
1265+
ierr = curve1%new_fit(x, y, smoothing=zero)
1266+
if (.not.FITPACK_SUCCESS(ierr)) then
1267+
print *, '[test_curve_comm_roundtrip] error creating curve: ', FITPACK_MESSAGE(ierr)
1268+
return
1269+
end if
1270+
1271+
! Get buffer size and allocate
1272+
buf_size = curve1%comm_size()
1273+
allocate(buffer(buf_size))
1274+
1275+
! Pack curve into buffer
1276+
call curve1%comm_pack(buffer)
1277+
1278+
! Expand buffer into new curve
1279+
call curve2%comm_expand(buffer)
1280+
1281+
! Generate test points (different from fitting points)
1282+
xtest = linspace(0.1_FP_REAL, pi2 - 0.1_FP_REAL, 10)
1283+
1284+
! Evaluate both curves
1285+
y1 = curve1%eval(xtest)
1286+
y2 = curve2%eval(xtest)
1287+
1288+
! Check that evaluations match
1289+
if (maxval(abs(y1 - y2)) > epsilon(one)) then
1290+
print *, '[test_curve_comm_roundtrip] evaluation mismatch after round-trip'
1291+
print *, ' max difference: ', maxval(abs(y1 - y2))
1292+
return
1293+
end if
1294+
1295+
! Check scalar members
1296+
if (curve1%m /= curve2%m) then
1297+
print *, '[test_curve_comm_roundtrip] m mismatch: ', curve1%m, ' vs ', curve2%m
1298+
return
1299+
end if
1300+
if (curve1%order /= curve2%order) then
1301+
print *, '[test_curve_comm_roundtrip] order mismatch'
1302+
return
1303+
end if
1304+
if (curve1%knots /= curve2%knots) then
1305+
print *, '[test_curve_comm_roundtrip] knots mismatch'
1306+
return
1307+
end if
1308+
if (abs(curve1%smoothing - curve2%smoothing) > epsilon(one)) then
1309+
print *, '[test_curve_comm_roundtrip] smoothing mismatch'
1310+
return
1311+
end if
1312+
if (abs(curve1%fp - curve2%fp) > epsilon(one)) then
1313+
print *, '[test_curve_comm_roundtrip] fp mismatch'
1314+
return
1315+
end if
1316+
1317+
! Check array sizes match
1318+
if (size(curve1%t) /= size(curve2%t)) then
1319+
print *, '[test_curve_comm_roundtrip] t array size mismatch'
1320+
return
1321+
end if
1322+
if (size(curve1%c) /= size(curve2%c)) then
1323+
print *, '[test_curve_comm_roundtrip] c array size mismatch'
1324+
return
1325+
end if
1326+
1327+
! Check knot and coefficient values
1328+
if (maxval(abs(curve1%t - curve2%t)) > epsilon(one)) then
1329+
print *, '[test_curve_comm_roundtrip] t values mismatch'
1330+
return
1331+
end if
1332+
if (maxval(abs(curve1%c - curve2%c)) > epsilon(one)) then
1333+
print *, '[test_curve_comm_roundtrip] c values mismatch'
1334+
return
1335+
end if
1336+
1337+
success = .true.
1338+
1339+
end function test_curve_comm_roundtrip
1340+
12471341
! ODE-style reciprocal error weight
12481342
elemental real(FP_REAL) function rewt(RTOL,ATOL,x)
12491343
real(FP_REAL), intent(in) :: RTOL,ATOL,x

test/test.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ subroutine run_interface_tests()
5656
call add_test(test_gridded_sphere())
5757
call add_test(test_parametric_surface())
5858
call add_test(test_fpknot_crash())
59+
call add_test(test_curve_comm_roundtrip())
5960

6061
end subroutine run_interface_tests
6162

0 commit comments

Comments
 (0)