Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/scripts/SPM_compare_particle_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Compare different discretisations in the particle
#
import argparse
from typing import Any
import numpy as np
import pybamm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -48,7 +49,7 @@
disc.process_model(model)

# solve model
solutions = [None] * len(models)
solutions: list[Any] = [None] * len(models)
t_eval = np.linspace(0, 3600, 100)
for i, model in enumerate(models):
solutions[i] = model.default_solver.solve(model, t_eval)
Expand Down
21 changes: 11 additions & 10 deletions examples/scripts/SPMe_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
time += dt

# plot
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
if step_solution is not None:
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
3 changes: 2 additions & 1 deletion examples/scripts/heat_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pybamm
import numpy as np
import matplotlib.pyplot as plt
import numpy.typing as npt

# Numerical solution ----------------------------------------------------------

Expand Down Expand Up @@ -106,7 +107,7 @@ def T_exact(x, t):
# Plot ------------------------------------------------------------------------
x_nodes = mesh["rod"].nodes # numerical gridpoints
xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution
plot_times = np.linspace(0, 1, 5)
plot_times: npt.NDArray = np.linspace(0, 1, 5)

plt.figure(figsize=(15, 8))
cmap = plt.get_cmap("inferno")
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/minimal_example_of_lookup_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def process_2D(name, data):
D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df)


def D_s_n(sto, T):
def D_s_n_func(sto, T):
name, (x, y) = D_s_n_data
return pybamm.Interpolant(x, y, [T, sto], name)


parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n
parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func

k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"]

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ concurrency = ["multiprocessing"]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
strict = false
warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude = "^(build/|docs/conf\\.py)$"

[[tool.mypy.overrides]]
module = [
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str | tuple[str] | BaseStep],
operating_conditions: list[
str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep
],
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.value = pybamm.Interpolant(
t,
y,
pybamm.t - pybamm.InputParameter("start time"),
pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type]
name="Drive Cycle",
)
self.period = np.diff(t).min()
Expand Down
47 changes: 43 additions & 4 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -35,7 +34,7 @@ def _preprocess_binary(
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): # type: ignore[redundant-expr]
raise NotImplementedError(
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
)
Expand Down Expand Up @@ -114,6 +113,9 @@ def __str__(self):
right_str = f"{self.right!s}"
return f"{left_str} {self.name} {right_str}"

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return self.__class__(self.name, left, right) # pragma: no cover

def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
Expand All @@ -128,7 +130,7 @@ def create_copy(
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
out = self._new_instance(children[0], children[1])
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
Expand Down Expand Up @@ -225,6 +227,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("**", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Power(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -274,6 +279,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("+", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Addition(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) + self.right.diff(variable)
Expand Down Expand Up @@ -301,6 +309,9 @@ def __init__(

super().__init__("-", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Subtraction(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) - self.right.diff(variable)
Expand Down Expand Up @@ -330,6 +341,9 @@ def __init__(

super().__init__("*", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Multiplication(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -370,6 +384,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("@", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return MatrixMultiplication(left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# We shouldn't need this
Expand Down Expand Up @@ -419,6 +436,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("/", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Division(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply quotient rule
Expand Down Expand Up @@ -467,6 +487,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("inner product", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Inner(left, right) # pragma: no cover

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -544,6 +567,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("==", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Equality(left, right)

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# Equality should always be multiplied by something else so hopefully don't
Expand Down Expand Up @@ -601,6 +627,10 @@ def __init__(
):
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__(name, left, right)
self.name = name

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return _Heaviside(self.name, left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
Expand Down Expand Up @@ -679,6 +709,9 @@ def __init__(
):
super().__init__("%", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Modulo(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -721,6 +754,9 @@ def __init__(
):
super().__init__("minimum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Minimum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"minimum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -765,6 +801,9 @@ def __init__(
):
super().__init__("maximum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Maximum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"maximum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -1539,7 +1578,7 @@ def source(
corresponding to a source term in the bulk.
"""
# Broadcast if left is number
if isinstance(left, numbers.Number):
if isinstance(left, (int, float)):
left = pybamm.PrimaryBroadcast(left, "current collector")

# force type cast for mypy
Expand Down
17 changes: 14 additions & 3 deletions src/pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _from_json(cls, snippet):
)

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)
pass # pragma: no cover


class PrimaryBroadcast(Broadcast):
Expand Down Expand Up @@ -191,6 +190,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class PrimaryBroadcastToEdges(PrimaryBroadcast):
"""A primary broadcast onto the edges of the domain."""
Expand Down Expand Up @@ -321,6 +324,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class SecondaryBroadcastToEdges(SecondaryBroadcast):
"""A secondary broadcast onto the edges of a domain."""
Expand Down Expand Up @@ -438,6 +445,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
raise NotImplementedError

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class TertiaryBroadcastToEdges(TertiaryBroadcast):
"""A tertiary broadcast onto the edges of a domain."""
Expand All @@ -463,7 +474,7 @@ def __init__(
self,
child_input: Numeric | pybamm.Symbol,
broadcast_domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
auxiliary_domains: AuxiliaryDomainType | str = None,
broadcast_domains: DomainsType = None,
name: str | None = None,
):
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def __init__(self, *children, name: Optional[str] = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
name = intersect(children[0].name, children[1].name) or ""
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
Expand Down Expand Up @@ -515,7 +515,7 @@ def substrings(s: str):
yield s[i : j + 1]


def intersect(s1: str, s2: str):
def intersect(s1: str, s2: str) -> str:
# find all the common strings between two strings
all_intersects = set(substrings(s1)) & set(substrings(s2))
# intersect is the longest such intercept
Expand All @@ -526,7 +526,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children, name: Optional[str] = None):
def simplified_concatenation(*children, name=None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.typing as npt
from scipy import special
import sympy
from typing import Callable
from typing import Callable, cast
from collections.abc import Sequence
from typing_extensions import TypeVar

Expand All @@ -33,7 +33,7 @@ class Function(pybamm.Symbol):
def __init__(
self,
function: Callable,
*children: pybamm.Symbol,
*children: pybamm.Symbol | float | int,
name: str | None = None,
differentiated_function: Callable | None = None,
):
Expand All @@ -43,6 +43,7 @@ def __init__(
if isinstance(child, (float, int, np.number)):
children[idx] = pybamm.Scalar(child)

children = cast(Sequence[pybamm.Symbol], children)
if name is not None:
self.name = name
else:
Expand Down
Loading
Loading