Skip to content

Commit 49cdbd9

Browse files
committed
deps: SPEC-0 set py3.12+
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 1f72f0e commit 49cdbd9

File tree

9 files changed

+437
-857
lines changed

9 files changed

+437
-857
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
strategy:
2424
fail-fast: false
2525
matrix:
26-
python-version: ["3.11", "3.12", "3.13"]
26+
python-version: ["3.12", "3.13", "3.14"]
2727
runs-on: [ubuntu-latest, macos-latest, windows-latest]
2828

2929
steps:
@@ -50,7 +50,7 @@ jobs:
5050
strategy:
5151
fail-fast: false
5252
matrix:
53-
python-version: ["3.11"]
53+
python-version: ["3.12"]
5454
runs-on: [ubuntu-latest]
5555

5656
steps:

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.11
1+
3.12

pyproject.toml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "quaxed"
33
dynamic = ["version"]
44
description = "Pre-quaxed libraries for multiple dispatch over abstract array types in JAX"
55
readme = "README.md"
6-
requires-python = ">=3.11"
6+
requires-python = ">=3.12"
77
authors = [
88
{ name = "Nathaniel Starkman", email = "nstarman@users.noreply.github.com" },
99
]
@@ -16,9 +16,9 @@ classifiers = [
1616
"Programming Language :: Python",
1717
"Programming Language :: Python :: 3",
1818
"Programming Language :: Python :: 3 :: Only",
19-
"Programming Language :: Python :: 3.11",
2019
"Programming Language :: Python :: 3.12",
2120
"Programming Language :: Python :: 3.13",
21+
"Programming Language :: Python :: 3.14",
2222
"Topic :: Scientific/Engineering",
2323
"Typing :: Typed",
2424
]
@@ -76,7 +76,7 @@ nox = [
7676
test = [
7777
"optional-dependencies>=0.4.0",
7878
"pytest>=8.3",
79-
"pytest-cov>= 6.2.1",
79+
"pytest-cov>=6.2.1",
8080
"pytest-env>=1.1.5",
8181
"pytest-github-actions-annotate-failures>=0.3.0", # only applies to GHA
8282
"sybil[pytest]>=9.2.0",
@@ -110,7 +110,7 @@ port.exclude_lines = [
110110

111111
[tool.mypy]
112112
files = ["src"]
113-
python_version = "3.11"
113+
python_version = "3.12"
114114
warn_unused_configs = true
115115
strict = true
116116
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
@@ -130,7 +130,7 @@ warn_return_any = false
130130

131131

132132
[tool.pylint]
133-
py-version = "3.11"
133+
py-version = "3.12"
134134
ignore-paths = [".*/_version.py"]
135135
reports.output-format = "colorized"
136136
similarities.ignore-imports = "yes"
@@ -208,14 +208,22 @@ convention = "numpy"
208208
[tool.uv]
209209
constraint-dependencies = [
210210
"appnope>=0.1.2",
211+
"backcall>=0.2.0",
211212
"bleach>6.0",
212213
"cffi>=1.14",
213214
"decorator>=5.1.1",
215+
"future>=1.0.0",
216+
"iniconfig>=2.0.0",
214217
"matplotlib>=3.7.1",
215218
"matplotlib-inline>=0.1.6",
219+
"nest-asyncio>=1.5.0",
216220
"opt-einsum>=3.2.1",
217221
"pickleshare>=0.7.5",
222+
"ply>3.11",
218223
"psutil>=5.9.0",
224+
"pycparser>=2.20",
219225
"pyparsing>=3.0.0",
220-
"pyzmq>=25.0",
226+
"pyzmq>=26.0",
227+
"scipy>=1.11.2",
228+
"wcwidth>=0.2.0"
221229
]

src/quaxed/_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
)
1111

1212
from collections.abc import Callable, Hashable
13-
from typing import Any, TypeAlias
13+
from typing import Any
1414

1515
import jax
1616
from quax import quaxify
1717

18-
AxisName: TypeAlias = Hashable
18+
type AxisName = Hashable
1919

2020

2121
# =============================================================================

src/quaxed/_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
"""Utility functions for quaxed."""
22

3-
from typing import TypeVar
4-
53
import quax
64

7-
T = TypeVar("T")
8-
95

10-
def quaxify(func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T:
6+
def quaxify[T](func: T, *, filter_spec: bool | tuple[bool, ...] = True) -> T:
117
"""Quaxify, but makes mypy happy."""
128
return quax.quaxify(func, filter_spec=filter_spec)

src/quaxed/numpy/_higher_order.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import functools
66
import warnings
77
from collections.abc import Callable, Collection
8-
from typing import Any, TypeVar
8+
from typing import Any
99

1010
import equinox as eqx
1111
import jax
@@ -19,10 +19,8 @@
1919

2020
from ._core import asarray, expand_dims as _expand_dims, squeeze
2121

22-
T = TypeVar("T")
2322

24-
25-
def expand_dims(a: T, axis: int | tuple[int, ...]) -> T:
23+
def expand_dims[T](a: T, axis: int | tuple[int, ...]) -> T:
2624
dynamic, static = eqx.partition(a, eqx.is_array_like)
2725
expanded_dynamic = jax.tree.map(lambda x: _expand_dims(x, axis), dynamic)
2826
return eqx.combine(expanded_dynamic, static)

tests/myarray.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44
from dataclasses import replace
5-
from typing import Any, Self
5+
from typing import Any, Final, Self
66

77
import equinox as eqx
88
import jax
@@ -18,6 +18,7 @@
1818
from quaxed._types import DType
1919

2020
JAX_VERSION = packaging.version.parse(jax.__version__)
21+
JAX_VERSION_LT_8: Final = packaging.version.Version("0.8.0") > JAX_VERSION
2122

2223

2324
class MyArray(ArrayValue):
@@ -1163,8 +1164,8 @@ def reduce_prod_p(x: MyArray, /, **kw) -> MyArray:
11631164

11641165

11651166
@register(lax.reduce_sum_p)
1166-
def reduce_sum_p(x: MyArray, *, axes: tuple[int, ...]) -> MyArray:
1167-
return replace(x, array=lax.reduce_sum_p.bind(x.array, axes=axes))
1167+
def reduce_sum_p(x: MyArray, **kw) -> MyArray:
1168+
return replace(x, array=lax.reduce_sum_p.bind(x.array, **kw))
11681169

11691170

11701171
# ==============================================================================

tests/test_lax/test_myarray.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Test with JAX inputs."""
22

3-
from typing import TypeAlias
4-
53
import jax.numpy as jnp
64
import jax.tree as jtu
75
import pytest
@@ -13,7 +11,7 @@
1311
from ..conftest import OptDeps
1412
from ..myarray import MyArray
1513

16-
AnyTuple: TypeAlias = tuple[object, ...]
14+
type AnyTuple = tuple[object, ...]
1715

1816
mark_todo = pytest.mark.skip(reason="TODO")
1917
mark_deprecated_jax7 = (

0 commit comments

Comments
 (0)