diff --git a/CITATION.cff b/CITATION.cff index 5fe2a31ab..513398eb7 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -81,10 +81,17 @@ authors: - given-names: David family-names: Coeurjolly affiliation: CNRS, LIRIS + - given-names: Thibaut + family-names: Germain + affiliation: Ecole Polytechnique + - given-names: Sienna + family-names: O'Shea + affiliation: Ecole Polytechnique - given-names: Marco family-names: Corneli affiliation: Université Côte d'Azur - - given-names: Ferdinand Genans + - given-names: Ferdinand + family-names: Genans affiliation: Sorbonne Université, LPSM, CNRS identifiers: - type: url diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index aa28dd7ab..08f481b33 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -59,6 +59,8 @@ The contributors to this library are: * [Julie Delon](https://judelo.github.io/) (GMM OT) * [Samuel Boïté](https://samuelbx.github.io/) (GMM OT) * [Nathan Neike](https://github.com/nathanneike) (Sparse EMD solver) +* [Thibaut Germain](https://thibaut-germain.github.io) (SGOT) +* Sienna O'Shea (SGOT) ## Acknowledgments diff --git a/RELEASES.md b/RELEASES.md index f10d7b62b..eea94981d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -38,6 +38,7 @@ This new release adds support for sparse cost matrices and a new lazy exact OT s - Update the geomloss wrapper to the new version and API (PR #826) - Fix docstrings for `lowrank_gromov_wasserstein_samples` and `lowrank_sinkhorn` (PR #823) - Reorganize all tests per backend (PR #828) +- Update sgot cost function and example (PR #830) #### Closed issues diff --git a/examples/plot_sgot.py b/examples/others/plot_sgot.py similarity index 74% rename from examples/plot_sgot.py rename to examples/others/plot_sgot.py index 0f4ee1dee..4dc3b6a6e 100644 --- a/examples/plot_sgot.py +++ b/examples/others/plot_sgot.py @@ -66,6 +66,10 @@ theta_0 = np.pi / 4 +def rotation_matrix(theta): + return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + + def generate_data(time, tau, freq, theta): t_ = np.sin(2 * np.pi * freq[None, :] * time[:, None]) * np.exp( -tau[None, :] * time[:, None] @@ -73,26 +77,24 @@ def generate_data(time, tau, freq, theta): t_ = t_.sum(axis=1) traj_0 = np.zeros((t_.shape[0], 2)) traj_0[:, 0] = t_ - rotation_matrix = np.array( - [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] - ) - traj_0 = traj_0 @ rotation_matrix.T + R_ = rotation_matrix(theta) + traj_0 = traj_0 @ R_.T return traj_0 traj_0 = generate_data(time, tau_0, freq_0, theta_0) +traj_0_proj = traj_0 @ rotation_matrix(theta_0)[:, 0] # plot the observed signal components and their sum plt.figure(figsize=(10, 4)) -plt.plot(time, traj_0, label="base trajectory", linewidth=2) +plt.plot(time, traj_0_proj, label="projected trajectory", linewidth=2) plt.xlabel("time") plt.ylabel("amplitude") plt.legend() plt.title(r"Observed scalar signal along $\vec{e}(\theta)$") plt.show() - # %% # 2. Interpret the signal as coming from a continuous linear dynamical system # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -274,12 +276,8 @@ def augment(traj, window_length=2): # Processing Systems, 35, pp.4017-4031. -def estimator(X, Y, rank=4): - # X: (n_samples, n_features) - # Y: (n_samples, n_features) - - # estimate operator - cxx = X.T @ X +def estimator(X, Y, rank=4, eps=1e-8): + cxx = X.T @ X + eps * np.eye(X.shape[1]) U, S, Vt = np.linalg.svd(cxx) S_inv = np.divide(1, S, out=np.zeros_like(S), where=S != 0) cxx_inv_half = Vt.T @ np.diag(np.sqrt(S_inv)) @ U.T @@ -416,6 +414,24 @@ def estimator(X, Y, rank=4): # spectral atoms, taking into account both the location of eigenvalues and the # relative geometry of their eigenspaces. +# %% +# A wider delay window for the SGOT experiments below +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The window of length 4 used above is enough to identify a single reference +# operator, but the experiments below also probe signals whose two modes +# nearly coincide in frequency (e.g. :math:`\omega_2'\to\omega_1`). Telling +# such near-degenerate modes apart requires the delay embedding to span +# enough time to "see" their differing decay, so we re-embed the reference +# signal with a longer window before running the sweeps. + +sgot_window = 10 +Z = augment(traj_0, sgot_window) +_, B_0_spec_sgot = estimator(Z[:-1], Z[1:]) +D_0_sgot = np.log(B_0_spec_sgot["eig_val"]) * fs +L_0_sgot = B_0_spec_sgot["eig_vec_left"] +R_0_sgot = B_0_spec_sgot["eig_vec_right"] + # %% # SGOT distance versus rotation angle # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -434,52 +450,66 @@ def estimator(X, Y, rank=4): # this experiment isolates the effect of rotating the underlying one-dimensional # subspace in the observation plane. -thetas = np.linspace(0, np.pi / 2, 50) -lst = [] -for i, theta in enumerate(thetas): - traj = generate_data(time, tau_0, freq_0, theta) - Z = augment(traj, 4) - X = Z[:-1] - Y = Z[1:] - B, B_spec = estimator(X, Y, rank=4) - D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] - D = np.log(D) * fs - lst.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=0.01)) - -plt.figure(figsize=(8, 5)) -plt.plot(thetas, lst) -plt.xlabel("theta") -plt.ylabel("SGOT distance") -plt.title("SGOT distance vs rotation angle") +thetas = np.linspace(0, np.pi / 2, 51) +rotation_scores = [] + +for theta in thetas: + Z = augment(generate_data(time, tau_0, freq_0, theta), sgot_window) + B, B_spec = estimator(Z[:-1], Z[1:]) + D = np.log(B_spec["eig_val"]) * fs + L = B_spec["eig_vec_left"] + R = B_spec["eig_vec_right"] + rotation_scores.append( + sgot_metric( + D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric="chordal" + ) + ) + +fig, ax = plt.subplots(figsize=(7, 4)) +ax.plot(thetas, rotation_scores, linewidth=1.8) +ax.axvline(theta_0, color="gray", linestyle="--", linewidth=0.8) +ax.set_xlabel(r"Rotation angle $\theta$ (rad)") +ax.set_ylabel(r"$d_S$") +ax.set_title("SGOT distance vs. rotation angle") +fig.tight_layout() plt.show() # %% # Comparison across Grassmannian metrics for SGOT distance versus rotation angle # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -thetas = np.linspace(0, np.pi / 2, 50) -lst = [] -for i, theta in enumerate(thetas): - traj = generate_data(time, tau_0, freq_0, theta) - Z = augment(traj, 4) - X = Z[:-1] - Y = Z[1:] - B, B_spec = estimator(X, Y, rank=4) - D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] - D = np.log(D) * fs - lst1 = [] - for name in ["chordal", "martin", "geodesic", "procrustes"]: - lst1.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=0.9, grassmann_metric=name)) - lst.append(lst1) -lst2 = np.array(lst) -plt.figure(figsize=(8, 5)) -for i, name in enumerate(["chordal", "martin", "geodesic", "procrustes"]): - plt.plot(thetas, lst2[:, i], label=name) - -plt.xlabel("theta") -plt.ylabel("SGOT distance") -plt.title("SGOT distance vs rotation angle") -plt.legend() +metrics = ["chordal", "geodesic", "procrustes", "martin"] +styles = {"chordal": "-", "geodesic": "--", "procrustes": "-.", "martin": ":"} +rotation_results = {m: [] for m in metrics} + +for theta in thetas: + Z = augment(generate_data(time, tau_0, freq_0, theta), sgot_window) + B, B_spec = estimator(Z[:-1], Z[1:]) + D = np.log(B_spec["eig_val"]) * fs + L = B_spec["eig_vec_left"] + R = B_spec["eig_vec_right"] + for m in metrics: + rotation_results[m].append( + sgot_metric( + D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric=m + ) + ) + +fig, ax = plt.subplots(figsize=(7, 4)) +for m in metrics: + ax.plot(thetas, rotation_results[m], styles[m], label=m, linewidth=1.8) +ax.axvline( + theta_0, + color="gray", + linestyle="--", + linewidth=0.8, + label=r"$\theta_0 = \pi/4$ (reference)", +) +ax.set_xlabel(r"Rotation angle $\theta$ (rad)") +ax.set_ylabel(r"$d_S$") +ax.set_title("SGOT distance vs. rotation angle across Grassmannian metrics") +ax.legend() +fig.tight_layout() plt.show() # %% @@ -501,38 +531,38 @@ def estimator(X, Y, rank=4): # distance changes as a function of the perturbed frequency :math:`\omega_2'`. omegas = np.linspace(0.5, 3.0, 21) -methods = ["chordal", "martin", "geodesic", "procrustes"] -scores_omega = [] -theta = theta_0 +frequency_scores = {m: [] for m in metrics} -eta_fixed = 0.9 for omega in omegas: - freq_1 = np.array([freq_0[0], omega]) - traj = generate_data(time, tau_0, freq_1, theta) - Z = augment(traj, 4) - X = Z[:-1] - Y = Z[1:] - - B, B_spec = estimator(X, Y, rank=4) - D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] - D = np.log(D) * fs - - row = [] - for name in methods: - row.append( - sgot_metric(D_0, R_0, L_0, D, R, L, eta=eta_fixed, grassmann_metric=name) + Z = augment( + generate_data(time, tau_0, np.array([freq_0[0], omega]), theta_0), sgot_window + ) + B, B_spec = estimator(Z[:-1], Z[1:]) + D = np.log(B_spec["eig_val"]) * fs + L = B_spec["eig_vec_left"] + R = B_spec["eig_vec_right"] + for m in metrics: + frequency_scores[m].append( + sgot_metric( + D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric=m + ) ) - scores_omega.append(row) - -scores_omega = np.array(scores_omega) -plt.figure(figsize=(8, 5)) -for i, name in enumerate(methods): - plt.plot(omegas, scores_omega[:, i], label=name) -plt.xlabel("omega") -plt.ylabel("SGOT distance") -plt.title("SGOT distance vs omega") -plt.legend() +fig, ax = plt.subplots(figsize=(7, 4)) +for m in metrics: + ax.plot(omegas, frequency_scores[m], styles[m], label=m, linewidth=1.8) +ax.axvline( + freq_0[1], + color="gray", + linestyle="--", + linewidth=0.8, + label=r"$\omega_2 = 2.0$ Hz (reference)", +) +ax.set_xlabel(r"Frequency $\omega_2'$ (Hz)") +ax.set_ylabel(r"$d_S$") +ax.set_title("SGOT distance vs. frequency across Grassmannian metrics") +ax.legend() +fig.tight_layout() plt.show() # %% @@ -553,47 +583,28 @@ def estimator(X, Y, rank=4): # In this way, both modes share the same modified decay parameter # :math:`\tau`, allowing us to isolate the influence of dissipation on the SGOT # distance. -decays = np.linspace(0.1, 3.0, 20) # adjust range as needed -methods = ["chordal", "martin", "geodesic", "procrustes"] -scores_decay = [] -theta = theta_0 - -for tau in decays: - freq_1 = np.array([freq_0[0], recovered_freqs[1]]) - tau_1 = np.array([tau, tau]) # or whatever structure your generator expects - - traj = generate_data(time, tau_1, freq_1, theta) - Z = augment(traj, 4) - X = Z[:-1] - Y = Z[1:] - - B, B_spec = estimator(X, Y, rank=4) - D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"] - D = np.log(D) * fs - - row = [] - for name in methods: - row.append( +taus = np.linspace(0.1, 3.0, 21) +decay_scores = {m: [] for m in metrics} + +for tau in taus: + Z = augment(generate_data(time, np.array([tau, tau]), freq_0, theta_0), sgot_window) + B, B_spec = estimator(Z[:-1], Z[1:]) + D = np.log(B_spec["eig_val"]) * fs + L = B_spec["eig_vec_left"] + R = B_spec["eig_vec_right"] + for m in metrics: + decay_scores[m].append( sgot_metric( - D_0, - R_0, - L_0, - D, - R, - L, - eta=0.9, # keep eta fixed here - grassmann_metric=name, + D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric=m ) ) - scores_decay.append(row) -scores_decay = np.array(scores_decay) -plt.figure(figsize=(8, 5)) -for i, name in enumerate(methods): - plt.plot(decays, scores_decay[:, i], label=name) - -plt.xlabel("decay") -plt.ylabel("SGOT distance") -plt.title("SGOT distance vs decay") -plt.legend() +fig, ax = plt.subplots(figsize=(7, 4)) +for m in metrics: + ax.plot(taus, decay_scores[m], styles[m], label=m, linewidth=1.8) +ax.set_xlabel(r"Decay rate $\tau$") +ax.set_ylabel(r"$d_S$") +ax.set_title("SGOT distance vs. decay across Grassmannian metrics") +ax.legend() +fig.tight_layout() plt.show() diff --git a/ot/sgot.py b/ot/sgot.py index ff67192c5..22ce5798d 100644 --- a/ot/sgot.py +++ b/ot/sgot.py @@ -124,7 +124,7 @@ def _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=None, eps=1e-12): Ltn = _normalize_columns(Lt, nx=nx, eps=eps) Cr = nx.dot(nx.conj(Rsn).T, Rtn) - Cl = nx.dot(nx.conj(Lsn).T, Ltn) + Cl = nx.dot(Lsn.T, nx.conj(Ltn)) delta = nx.abs(Cr * Cl) delta = nx.clip(delta, 0.0, 1.0)