Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ jobs:

- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
uses: actions/deploy-pages@v4
53 changes: 50 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,59 @@
def linkcode_resolve(domain, info):
"""Resolve GitHub links for source code.

This function is required by sphinx.ext.linkcode.
This function is required by sphinx.ext.linkcode. It uses inspect
to find the actual source file and line numbers, handling __init__.py
packages and generating correct URLs without an erroneous src/ prefix.
"""
import importlib
import inspect
import os

if domain != "py":
return None
if not info["module"]:
return None

filename = info["module"].replace(".", "/")
return f"https://github.com/{github_user}/{github_repo}/blob/{github_version}/src/{filename}.py"
modname = info["module"]
fullname = info["fullname"]

try:
mod = importlib.import_module(modname)
except ImportError:
return None

# Resolve the object from its fully qualified name
obj = mod
for part in fullname.split("."):
try:
obj = getattr(obj, part)
except AttributeError:
return None

# Unwrap decorated objects to get the original function
obj = inspect.unwrap(obj)

try:
sourcefile = inspect.getsourcefile(obj)
except TypeError:
return None

if sourcefile is None:
return None

# Get path relative to the repository root
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
try:
relpath = os.path.relpath(sourcefile, repo_root)
except ValueError:
return None

# Build the line number anchor
linespec = ""
try:
source, lineno = inspect.getsourcelines(obj)
linespec = f"#L{lineno}-L{lineno + len(source) - 1}"
except (OSError, TypeError):
pass

return f"https://github.com/{github_user}/{github_repo}/blob/{github_version}/{relpath}{linespec}"
6 changes: 3 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,9 @@
from spd_learn.models import SPDNet

model = SPDNet(
n_chans=22, # EEG channels
n_outputs=4, # Number of classes
subspacedim=16 # SPD subspace dimension
n_chans=22, # EEG channels
n_outputs=4, # Number of classes
subspacedim=16, # SPD subspace dimension
)

.. raw:: html
Expand Down
11 changes: 11 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -1048,3 +1048,14 @@ @article{yair2019parallel
year={2019},
publisher={IEEE}
}
@article{thanwerdas2023,
title = {O(n)-invariant Riemannian metrics on SPD matrices},
journal = {Linear Algebra and its Applications},
volume = {661},
pages = {163-201},
year = {2023},
issn = {0024-3795},
doi = {https://doi.org/10.1016/j.laa.2022.12.009},
url = {https://www.sciencedirect.com/science/article/pii/S0024379522004360},
author = {Yann Thanwerdas and Xavier Pennec},
}
134 changes: 83 additions & 51 deletions examples/applied_examples/plot_parallel_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def make_symmetric(n: int, batch_size: int | None = None) -> torch.Tensor:
#
# **Key insight**: On flat spaces (Euclidean), vectors don't change during
# transport. On curved manifolds like SPD with AIRM, the vector *rotates*
# as it moves. LEM and Log-Cholesky flatten the manifold, so transport
# becomes trivial (identity).
# as it moves. LEM and Log-Cholesky flatten the manifold via coordinate
# maps, so transport is trivial *in those coordinates* — but non-trivial
# when expressed in the ambient SPD tangent space.
#

######################################################################
Expand Down Expand Up @@ -221,60 +222,82 @@ def airm_inner_product(U, V, P):
print(f"\nReconstruction error: {torch.norm(V0 - V_recovered):.2e}")

######################################################################
# Why LEM and Log-Cholesky Have Identity Transport
# ------------------------------------------------
# LEM and Log-Cholesky: Non-trivial Ambient Transport
# ---------------------------------------------------
#
# Under the Log-Euclidean Metric (LEM), the SPD manifold becomes **flat**
# via the matrix logarithm diffeomorphism. In that log-domain (flat space),
# parallel transport is the identity under the canonical identification of
# tangent spaces. :cite:p:`arsigny2007geometric`
# via the matrix logarithm diffeomorphism. Transport is trivial *in the
# log-domain*, but when we work with tangent vectors in the ambient SPD
# space (symmetric matrices), the coordinate change introduces a non-trivial
# map. :cite:p:`arsigny2007geometric,thanwerdas2023`
#
# The LEM transport formula is:
#
# .. math::
#
# \Gamma_{P \to Q}^{LEM}(V) = V
# \Gamma_{P \to Q}^{LEM}(V) = D\exp(\log Q)\bigl[D\log(P)[V]\bigr]
#
# The same applies to the Log-Cholesky metric, which uses the Cholesky
# decomposition to create a flat geometry in the Cholesky-log coordinates.
# :cite:p:`lin2019riemannian`
# where :math:`D\log(P)` and :math:`D\exp(\log Q)` are Fréchet derivatives.
# This maps the tangent vector into the flat log-space (where transport is
# identity), then maps back to the ambient tangent space at Q.
#
# This is computationally efficient (O(1) transport) but means these
# metrics don't capture the same geometric structure as AIRM.
# Similarly, the Log-Cholesky metric uses the Cholesky decomposition and
# log-diagonal map to create flat coordinates.
# :cite:p:`lin2019riemannian`
# Transport is identity in those coordinates, but the pullback/pushforward
# through the Cholesky factorization makes it non-trivial in the ambient space.

# LEM transport is identity
# All three metrics give non-trivial transport for P != Q
V_lem = parallel_transport_lem(V0, P, Q)
print("LEM transport:")
print(f" Original V == Transported V: {torch.allclose(V0, V_lem)}")

# Log-Cholesky transport is also identity
V_chol = parallel_transport_log_cholesky(V0, P, Q)
print("Log-Cholesky transport:")
print(f" Original V == Transported V: {torch.allclose(V0, V_chol)}")

# Compare with AIRM (non-trivial transport)
V_airm = parallel_transport_airm(V0, P, Q)
print("AIRM transport:")
print(f" Original V == Transported V: {torch.allclose(V0, V_airm)}")
print(f" Transport difference norm: {torch.norm(V0 - V_airm):.4f}")

print("Transport results (all non-trivial for P != Q):")
print(f" AIRM - change from V: {torch.norm(V0 - V_airm):.4f}")
print(f" LEM - change from V: {torch.norm(V0 - V_lem):.4f}")
print(f" L-Chol - change from V: {torch.norm(V0 - V_chol):.4f}")

# Self-transport (P -> P) should be identity for all metrics
V_lem_self = parallel_transport_lem(V0, P, P)
V_chol_self = parallel_transport_log_cholesky(V0, P, P)
print("\nSelf-transport (P -> P) is identity:")
print(f" LEM: {torch.allclose(V0, V_lem_self, atol=1e-5)}")
print(f" Log-Cholesky: {torch.allclose(V0, V_chol_self, atol=1e-5)}")

######################################################################
# Comparing Transport Methods
# ---------------------------
#
# SPD Learn provides several transport methods with different trade-offs:
#
# +----------------+------------------+-----------+---------------------------+
# | Method | Formula | Complexity| Notes |
# +================+==================+===========+===========================+
# | AIRM | :math:`EVE^T` | O(n³) | Exact, preserves geometry |
# +----------------+------------------+-----------+---------------------------+
# | LEM | :math:`V` | O(1) | Identity (flat geometry) |
# +----------------+------------------+-----------+---------------------------+
# | Log-Cholesky | :math:`V` | O(1) | Identity (flat geometry) |
# +----------------+------------------+-----------+---------------------------+
# | Schild's ladder| Iterative | O(k·n³) | ~O(1/k²) (small steps) |
# +----------------+------------------+-----------+---------------------------+
# | Pole ladder | Single iteration | O(n³) | O(h²) (small distance) |
# +----------------+------------------+-----------+---------------------------+
# .. list-table::
# :header-rows: 1
# :widths: 16 20 11 27
#
# * - Method
# - Formula
# - Complexity
# - Notes
# * - AIRM
# - :math:`EVE^T`
# - O(n³)
# - Exact, preserves geometry
# * - LEM
# - Fréchet derivatives
# - O(n³)
# - Via D_exp(log Q)[D_log(P)]
# * - Log-Cholesky
# - Cholesky pull/push
# - O(n³)
# - Via log-diagonal map
# * - Schild's ladder
# - Iterative
# - O(k·n³)
# - ~O(1/k²) (small steps)
# * - Pole ladder
# - Single iteration
# - O(n³)
# - O(h²) (small distance)
#
# The ``transport_tangent_vector`` function provides a unified interface:

Expand All @@ -283,9 +306,10 @@ def airm_inner_product(U, V, P):
V_lem = transport_tangent_vector(V0, P, Q, metric="lem")
V_chol = transport_tangent_vector(V0, P, Q, metric="log_cholesky")

print("Transport results by metric:")
print(f" AIRM vs LEM difference: {torch.norm(V_airm - V_lem):.4f}")
print(f" LEM vs Log-Cholesky difference: {torch.norm(V_lem - V_chol):.4f}")
print("Transport results by metric (all differ from each other):")
print(f" AIRM vs LEM difference: {torch.norm(V_airm - V_lem):.4f}")
print(f" AIRM vs Log-Cholesky difference: {torch.norm(V_airm - V_chol):.4f}")
print(f" LEM vs Log-Cholesky difference: {torch.norm(V_lem - V_chol):.4f}")

######################################################################
# Numerical Approximations: Schild's and Pole Ladder
Expand Down Expand Up @@ -425,10 +449,10 @@ def benchmark(func, *args, n_runs=n_trials):
time_pole = benchmark(pole_ladder, V_time, P_time, Q_time)

print(f"\nTiming comparison ({n_timing}x{n_timing} matrices, {n_trials} runs):")
print(f" AIRM (exact): {time_airm:.3f} ms")
print(f" LEM (identity): {time_lem:.3f} ms")
print(f" Schild's (10 steps): {time_schild_10:.3f} ms")
print(f" Pole ladder: {time_pole:.3f} ms")
print(f" AIRM (congruence): {time_airm:.3f} ms")
print(f" LEM (Frechet derivs): {time_lem:.3f} ms")
print(f" Schild's (10 steps): {time_schild_10:.3f} ms")
print(f" Pole ladder: {time_pole:.3f} ms")

######################################################################
# Application: Cross-Subject EEG Transfer
Expand Down Expand Up @@ -523,14 +547,15 @@ def benchmark(func, *args, n_runs=n_trials):
#
# 1. **Need affine invariance?** → Use AIRM transport
#
# 2. **Speed critical?** → Use LEM (identity transport, O(1))
# 2. **Need consistency with pyRiemann?** → LEM and Log-Cholesky now
# match pyRiemann's non-trivial transport conventions
#
# 3. **No closed-form available?** → Use pole ladder for small distances
#
# 4. **High accuracy needed?** → Use Schild's ladder with many steps
#
# 5. **Need gradients through reference points?** → Use functional AIRM
# transport (``parallel_transport_airm``)
# 5. **Need gradients through reference points?** → All three closed-form
# methods (AIRM, LEM, Log-Cholesky) support autograd

# Gradient flow demonstration
P_grad = make_spd(n)
Expand Down Expand Up @@ -616,12 +641,16 @@ def benchmark(func, *args, n_runs=n_trials):
# In this tutorial, we covered:
#
# - **Parallel transport** moves tangent vectors while preserving geometry
# - **AIRM** has non-trivial transport (:math:`EVE^T`); LEM/Log-Cholesky
# have identity transport due to flat geometry
# - **AIRM** uses the congruence formula :math:`EVE^T`
# - **LEM** uses Fréchet derivatives of matrix log/exp
# :cite:p:`thanwerdas2023`
# - **Log-Cholesky** uses Cholesky pullback/pushforward
# :cite:p:`lin2019riemannian`
# - All three are **non-trivial** in the ambient SPD tangent space (flat
# geometry only applies in their respective coordinate systems)
# - **Numerical methods** (Schild's and pole ladder) approximate transport
# when closed-form solutions are unavailable or expensive
# - **Cross-subject transfer** is a key application for BCI domain adaptation
# - Choose your method based on accuracy, speed, and invariance requirements
#
# See Also
# --------
Expand All @@ -630,6 +659,9 @@ def benchmark(func, *args, n_runs=n_trials):
#
# - :func:`spd_learn.functional.parallel_transport_airm`
# - :func:`spd_learn.functional.parallel_transport_lem`
# - :func:`spd_learn.functional.parallel_transport_log_cholesky`
# - :func:`spd_learn.functional.frechet_derivative_log`
# - :func:`spd_learn.functional.frechet_derivative_exp`
# - :func:`spd_learn.functional.schild_ladder`
# - :func:`spd_learn.functional.pole_ladder`
# - :func:`spd_learn.functional.transport_tangent_vector`
Expand Down
4 changes: 4 additions & 0 deletions spd_learn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from .covariance import covariance, cross_covariance, real_covariance, sample_covariance
from .dropout import dropout_spd
from .frechet import frechet_derivative_exp, frechet_derivative_log
from .metrics import (
# AIRM metric
airm_distance,
Expand Down Expand Up @@ -165,6 +166,9 @@
"compute_gabor_wavelet",
# Dropout
"dropout_spd",
# Frechet derivatives
"frechet_derivative_log",
"frechet_derivative_exp",
# Autograd utilities
"modeig_backward",
"modeig_forward",
Expand Down
Loading
Loading