[WIP] Spectral-Grassmann OT#792
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #792 +/- ##
==========================================
- Coverage 96.87% 96.86% -0.02%
==========================================
Files 113 115 +2
Lines 23062 23339 +277
==========================================
+ Hits 22342 22608 +266
- Misses 720 731 +11 🚀 New features to boost your workflow:
|
rflamary
left a comment
There was a problem hiding this comment.
Hello @osheasienna and @thibaut-germain this is a nice first step.
Here are below a few comments that we can discuss together
rflamary
left a comment
There was a problem hiding this comment.
A few comments from talking together
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = C2 ** (p / 2.0) |
There was a problem hiding this comment.
| C = C2 ** (p / 2.0) | |
| C = nx.real(C2) ** (p / 2.0) |
|
|
||
|
|
||
| @pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) | ||
| def test_cost_backend_consistency(backend_name): |
There was a problem hiding this comment.
| def test_cost_backend_consistency(backend_name): | |
| def test_cost_backend_consistency(nx): |
| else: | ||
| real_scale, imag_scale = eigen_scaling[0], eigen_scaling[1] | ||
|
|
||
| Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale |
There was a problem hiding this comment.
| Dsn = nx.real(Ds) * real_scale + 1j * nx.imag(Ds) * imag_scale | |
| C_real = nx.real(Dsn)[:,None] - nx.real(Dtn)[None,:] | |
| C_real = C_real**2 | |
| C_imag = nx.imag(Dsn)[:,None] - nx.imag(Dtn)[None,:] | |
| C_imag = C_imag**2 | |
| prod = C_real + C_imag | |
| return prod ** (q / 2) |
| A_norm: array-like, shape (d, n) | ||
| Column-normalized array. | ||
| """ | ||
| nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) |
There was a problem hiding this comment.
| nrm = nx.sqrt(nx.sum(A * nx.conj(A), axis=0, keepdims=True)) | |
| nrm = nx.norm(A, axis=0, keepdims=True) |
You can replace it with the function nx.norm which manages the case of complex number
| return delta | ||
|
|
||
|
|
||
| def _grassmann_distance_squared(delta, grassman_metric="chordal", nx=None, eps=1e-300): |
There was a problem hiding this comment.
Epsilon is too small for the machine precision, you can set it to 1e-12 for instance.
| if nx is None: | ||
| nx = get_backend(delta) | ||
|
|
||
| delta = nx.clip(delta, 0.0, 1.0) |
There was a problem hiding this comment.
If delta is not in [0,1] it should raise an error, this is an issue in the computation of delta outside of this function.
| # information-geometric interpretation in Germain et al. (2025). | ||
| delta2 = nx.maximum(delta**2, eps) | ||
| return -nx.log(delta2) | ||
| raise ValueError(f"Unknown grassman_metric: {grassman_metric}") |
There was a problem hiding this comment.
In this function the power q should also be a parameter:
for any distance you can set:
result = square_ditance(delta)
then
return nx.real(result)**(q/2)
Set by default q to the same value as for eigenvalue cost
| C_lambda = eigenvalue_cost_matrix(Ds, Dt, q=q, eigen_scaling=eigen_scaling, nx=nx) | ||
|
|
||
| delta = _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=nx) | ||
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) |
There was a problem hiding this comment.
the power parameter q should also affect the Grassmann cost
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = nx.real(C2) ** (p / 2.0) |
There was a problem hiding this comment.
you your cost function already return a real no need for nx.real here
| return C | ||
|
|
||
|
|
||
| def _validate_sgot_metric_inputs(Ds, Dt): |
There was a problem hiding this comment.
You can add verifications you wrote in line 272-290 in this function and also add verifications than source and target have the same shapes.
rflamary
left a comment
There was a problem hiding this comment.
Looking good, just a few more things to do
| representation varies under rotation. The SGOT cost and metric are used to | ||
| compare the reference and rotated systems. | ||
|
|
||
| [1] T. Germain; R. Flamary; V. R. Kostic; K. Lounici, A Spectral-Grassmann Wasserstein Metric for Operator Representations of Dynamical Systems, arXiv preprint arXiv:2509.24920, 2025. |
There was a problem hiding this comment.
Add the reference at the end of the readme file and use the same reference number here
| # Comparison across Grassmannian metrics for SGOT distance versus rotation angle | ||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| thetas = np.linspace(0, np.pi / 2, 10) |
There was a problem hiding this comment.
can we have something more detailed such as
| thetas = np.linspace(0, np.pi / 2, 10) | |
| thetas = np.linspace(0, np.pi / 2, 50) |
| if nx is None: | ||
| nx = get_backend(delta) | ||
|
|
||
| delta = nx.clip(delta, 0.0, 1.0) |
| - :math:`p` is the exponent used in the OT ground cost and the inner | ||
| Wasserstein root, | ||
| - :math:`r` is an additional outer root applied to the Wasserstein objective. | ||
| """ |
There was a problem hiding this comment.
please ad the references with all authors and number form readme here
|
|
||
| References | ||
| ---------- | ||
| Germain et al., *Spectral-Grassmann Optimal Transport* (SGOT). |
There was a problem hiding this comment.
detail reference here (with all informations
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Types of changes
Adding sgot file in the ot folder.
Motivation and context / Related issue
Keep track of SGOT implementation in POT.
How has this been tested (if it applies)
Not tested yet.
PR checklist