Skip to content

[WIP] Spectral-Grassmann OT#792

Open
thibaut-germain wants to merge 31 commits into
PythonOT:masterfrom
thibaut-germain:sgot
Open

[WIP] Spectral-Grassmann OT#792
thibaut-germain wants to merge 31 commits into
PythonOT:masterfrom
thibaut-germain:sgot

Conversation

@thibaut-germain
Copy link
Copy Markdown

Types of changes

Adding sgot file in the ot folder.

Motivation and context / Related issue

Keep track of SGOT implementation in POT.

How has this been tested (if it applies)

Not tested yet.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • [] All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title Sgot [WIP] Spactral-Gromov OT Feb 9, 2026
@rflamary rflamary changed the title [WIP] Spactral-Gromov OT [WIP] Spectral-Grassman OT Feb 9, 2026
@rflamary rflamary changed the title [WIP] Spectral-Grassman OT [WIP] Spectral-Grassmann OT Feb 9, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 11, 2026

Codecov Report

❌ Patch coverage is 96.08541% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 96.86%. Comparing base (41a4d57) to head (6711f1f).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #792      +/-   ##
==========================================
- Coverage   96.87%   96.86%   -0.02%     
==========================================
  Files         113      115       +2     
  Lines       23062    23339     +277     
==========================================
+ Hits        22342    22608     +266     
- Misses        720      731      +11     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @osheasienna and @thibaut-germain this is a nice first step.

Here are below a few comments that we can discuss together

Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py Outdated
Comment thread test/test_sgot.py Outdated
Comment thread test/test_sgot.py Outdated
Comment thread test/test_sgot.py Outdated
Comment thread RELEASES.md
Comment thread RELEASES.md Outdated
Copy link
Copy Markdown
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments from talking together

Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py Outdated
C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)

C2 = eta * C_lambda + (1.0 - eta) * C_grass
C = C2 ** (p / 2.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
C = C2 ** (p / 2.0)
C = nx.real(C2) ** (p / 2.0)

Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py
Comment thread test/test_sgot.py Outdated
Comment thread test/test_sgot.py Outdated
Comment thread test/test_sgot.py Outdated
Comment thread test/test_sgot.py Outdated


@pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"])
def test_cost_backend_consistency(backend_name):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_cost_backend_consistency(backend_name):
def test_cost_backend_consistency(nx):

Comment thread test/test_sgot.py Outdated
Comment thread ot/sgot.py Outdated
Comment thread ot/sgot.py
else:
real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1]

Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale
C_real = nx.real(Dsn)[:,None] - nx.real(Dtn)[None,:]
C_real = C_real**2
C_imag = nx.imag(Dsn)[:,None] - nx.imag(Dtn)[None,:]
C_imag = C_imag**2
prod = C_real + C_imag
return prod ** (q / 2)

Comment thread ot/sgot.py Outdated
A_norm: array-like, shape (d, n)
Column-normalized array.
"""
nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True))
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True))
nrm = nx.norm(A, axis=0, keepdims=True)

You can replace it with the function nx.norm which manages the case of complex number

Comment thread ot/sgot.py Outdated
return delta


def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300):
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Epsilon is too small for the machine precision, you can set it to 1e-12 for instance.

Comment thread ot/sgot.py Outdated
if nx is None:
nx = get_backend(delta)

delta = nx.clip(delta, 0.0, 1.0)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If delta is not in [0,1] it should raise an error, this is an issue in the computation of delta outside of this function.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@osheasienna can you handle this?

Comment thread ot/sgot.py
Copy link
Copy Markdown
Author

@thibaut-germain thibaut-germain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thibaut's review

Comment thread ot/sgot.py Outdated
# information-geometric interpretation in Germain et al. (2025).
delta2 = nx.maximum(delta**2, eps)
return -nx.log(delta2)
raise ValueError(f"Unknown grassman_metric: {grassman_metric}")
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function the power q should also be a parameter:
for any distance you can set:
result = square_ditance(delta)
then
return nx.real(result)**(q/2)
Set by default q to the same value as for eigenvalue cost

Comment thread ot/sgot.py Outdated
C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx)

delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx)
C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the power parameter q should also affect the Grassmann cost

Comment thread ot/sgot.py Outdated
C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)

C2 = eta * C_lambda + (1.0 - eta) * C_grass
C = nx.real(C2) ** (p / 2.0)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you your cost function already return a real no need for nx.real here

Comment thread ot/sgot.py Outdated
return C


def _validate_sgot_metric_inputs(Ds, Dt):
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add verifications you wrote in line 272-290 in this function and also add verifications than source and target have the same shapes.

Comment thread ot/sgot.py
Copy link
Copy Markdown
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, just a few more things to do

Comment thread examples/plot_sgot.py Outdated
representation varies under rotation. The SGOT cost and metric are used to
compare the reference and rotated systems.

[1] T. Germain; R. Flamary; V. R. Kostic; K. Lounici, A Spectral-Grassmann Wasserstein Metric for Operator Representations of Dynamical Systems, arXiv preprint arXiv:2509.24920, 2025.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the reference at the end of the readme file and use the same reference number here

Comment thread examples/plot_sgot.py Outdated
Comment thread examples/plot_sgot.py
Comment thread examples/plot_sgot.py Outdated
# Comparison across Grassmannian metrics for SGOT distance versus rotation angle
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

thetas = np.linspace(0, np.pi / 2, 10)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have something more detailed such as

Suggested change
thetas = np.linspace(0, np.pi / 2, 10)
thetas = np.linspace(0, np.pi / 2, 50)

Comment thread ot/sgot.py Outdated
if nx is None:
nx = get_backend(delta)

delta = nx.clip(delta, 0.0, 1.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@osheasienna can you handle this?

Comment thread ot/sgot.py
- :math:`p` is the exponent used in the OT ground cost and the inner
Wasserstein root,
- :math:`r` is an additional outer root applied to the Wasserstein objective.
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please ad the references with all authors and number form readme here

Comment thread ot/sgot.py Outdated

References
----------
Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detail reference here (with all informations

osheasienna and others added 4 commits May 14, 2026 18:35
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
@github-actions github-actions Bot added the CI label May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants