Skip to content

Commit b775654

Browse files
committed
Add a simply unit test with a GH workflow
1 parent f4d4027 commit b775654

File tree

5 files changed

+116
-2
lines changed

5 files changed

+116
-2
lines changed

.github/requirements-old.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Ensure changes to these dependencies are reflected in pyproject.toml
2+
attrs==24.2.0
3+
cattrs==24.1.1
4+
numpy==2.3.1
5+
scipy==1.16.0
6+
matplotlib==3.10.3

.github/workflows/pytest.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: pytest
2+
on:
3+
push:
4+
branches:
5+
# Run tests for change on the main branch ...
6+
- main
7+
tags-ignore:
8+
# ... but not for tags (avoids duplicate work).
9+
- '**'
10+
pull_request:
11+
branches:
12+
# Run tests for PRs targeting the main branch.
13+
- main
14+
15+
jobs:
16+
tests:
17+
strategy:
18+
matrix:
19+
os: [ubuntu-latest, macos-latest, windows-latest]
20+
python-version: ["3.13", "3.14"]
21+
runs-on: ${{ matrix.os }}
22+
steps:
23+
- uses: actions/checkout@v4
24+
- name: Set up Python ${{ matrix.python-version }}
25+
uses: actions/setup-python@v5
26+
with:
27+
python-version: ${{ matrix.python-version }}
28+
- name: Install oldest versions of supported dependencies
29+
if: ${{ matrix.python-version == '3.13'}}
30+
run: pip install -r .github/requirements-old.txt
31+
- name: Install package with test dependencies
32+
run: pip install -e .[tests]
33+
- name: Run pytest
34+
run: pytest -vv

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424
"Topic :: Scientific/Engineering :: Physics",
2525
]
2626
dependencies = [
27+
# Ensure changes to these dependencies are reflected in .github/requirements-old.txt
2728
"attrs>=24.2.0",
2829
"cattrs>=24.1.1",
2930
"numpy>=2.3.1",
@@ -32,6 +33,11 @@ dependencies = [
3233
]
3334
dynamic = ["version"]
3435

36+
[project.optional-dependencies]
37+
tests = [
38+
"pytest",
39+
]
40+
3541
[project.urls]
3642
Documentation = "https://github.com/molmod/soapboxslide/"
3743
Issues = "https://github.com/molmod/soapboxslide/issues"

soapboxslide.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,9 @@ def to_file(self, path_npz):
692692
data = attrs.asdict(self)
693693
data["end_state"] = self.end_state.value
694694
data = {key: value for key, value in data.items() if value is not None}
695-
np.savez_compressed(path_npz, **data, allow_pickle=False)
695+
np.savez_compressed(str(path_npz), **data, allow_pickle=False)
696696

697697
@classmethod
698698
def from_file(cls, path_npz):
699-
data = np.load(path_npz)
699+
data = np.load(str(path_npz))
700700
return cls(**data)

test_soapboxslide.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Soap Box Slide is a computational take on soapbox racing.
2+
# © 2025 Toon Verstraelen
3+
#
4+
# This file is part of Soap Box Slide.
5+
#
6+
# Soap Box Slide is free software; you can redistribute it and/or
7+
# modify it under the terms of the GNU General Public License
8+
# as published by the Free Software Foundation; either version 3
9+
# of the License, or (at your option) any later version.
10+
#
11+
# Soap Box Slide is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
# GNU General Public License for more details.
15+
#
16+
# You should have received a copy of the GNU General Public License
17+
# along with this program; if not, see <http://www.gnu.org/licenses/>
18+
#
19+
# --
20+
"""Unit tests for Soap Box Slide."""
21+
22+
import attrs
23+
import numpy as np
24+
25+
from soapboxslide import EndState, Trajectory
26+
27+
28+
def test_npz_trajectory(tmpdir):
29+
traj = Trajectory(
30+
time=[0, 1, 2],
31+
mass=[1, 1.2],
32+
gamma=[0.3, 0.0],
33+
pos=[
34+
[[0, 0, 0], [1, 0, 0]],
35+
[[0, 1, 0], [1, 1, 0]],
36+
[[0, 2, 0], [1, 2, 0]],
37+
],
38+
vel=[
39+
[[0, 0, 0], [0, 0, 0]],
40+
[[1, 0, 0], [1, 0, 0]],
41+
[[1, 1, 0], [1, 1, 0]],
42+
],
43+
grad=[
44+
[[0.5, 0.1], [1.2, 0.3]],
45+
[[0.6, 0.2], [1.3, 0.4]],
46+
[[0.7, 0.3], [1.4, 0.5]],
47+
],
48+
hess=[
49+
[[0, 0, 0], [0, 0, 0]],
50+
[[1, 0, 0], [1, 0, 0]],
51+
[[1, 1, 0], [1, 1, 0]],
52+
],
53+
spring_idx=[[0, 1]],
54+
spring_par=[[100, 0.5, 1.2]],
55+
end_state=EndState.STOP,
56+
stop_time=30.0,
57+
stop_pos=[[10, 0, 0], [10, 1, 0]],
58+
stop_vel=[[0, 0, 0], [0, 0, 0]],
59+
)
60+
path = tmpdir.join("trajectory.npz")
61+
traj.to_file(path)
62+
print(path)
63+
loaded_traj = Trajectory.from_file(path)
64+
65+
for attr in attrs.fields(Trajectory):
66+
val_orig = getattr(traj, attr.name)
67+
val_loaded = getattr(loaded_traj, attr.name)
68+
assert np.array_equal(val_orig, val_loaded)

0 commit comments

Comments
 (0)