Skip to content

[MRG] fix gpu compatibility of srGW solvers #596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 14, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.

[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
4 changes: 2 additions & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)

- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)

## 0.9.2
*December 2023*
Expand Down Expand Up @@ -671,4 +671,4 @@ It provides the following solvers:
* Optimal transport for domain adaptation with group lasso regularization
* Conditional gradient and Generalized conditional gradient for regularized OT.

Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
10 changes: 5 additions & 5 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
else:
q = nx.sum(G0, 0)
# Check first marginal of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down Expand Up @@ -363,8 +363,8 @@ def semirelaxed_fused_gromov_wasserstein(
G0 = nx.outer(p, q)
else:
q = nx.sum(G0, 0)
# Check marginals of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
# Check first marginal of G0
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down Expand Up @@ -703,7 +703,7 @@ def entropic_semirelaxed_gromov_wasserstein(
else:
q = nx.sum(G0, 0)
# Check first marginal of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down Expand Up @@ -951,7 +951,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
else:
q = nx.sum(G0, 0)
# Check first marginal of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down