From 0f96ab0dc62c120691f38a3bb876b0e1b007474e Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 11:27:23 +0200 Subject: [PATCH 1/7] Add batch implementation of Sinkhorn --- ot/bregman.py | 121 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 100 insertions(+), 21 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b10effd91..b0bc9decc 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -18,6 +18,7 @@ import numpy as np from scipy.optimize import fmin_l_bfgs_b +from scipy.special import logsumexp from ot.utils import unif, dist, list_to_array from .backend import get_backend @@ -1684,7 +1685,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, + numIterMax=10000, stopThr=1e-9, lazyEvaluation=False, numMaxEntries=1e8, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the @@ -1723,6 +1724,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Max number of iterations stopThr : float, optional Stop threshol on error (>0) + lazyEvaluation: boolean, optional + If True, then only calculate the cost matrix by block and return the dual vectors + If False, calculate full cost matrix and return outputs of sinkhorn function. + numMaxEntries: int, optional + Maximum number of entries in the coupling, when lazyEvaluation=True verbose : bool, optional Print information along iterations log : bool, optional @@ -1758,24 +1764,69 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' - + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) - M = dist(X_s, X_t, metric=metric) + if not lazyEvaluation: + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) - return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi + if log: + log_err = [] + + log_a, log_b = np.log(a), np.log(b) + f, g = np.zeros(ns), np.zeros(nt) + + bs, bt = min(int(numMaxEntries / nt), ns), min(int(numMaxEntries / ns), nt) # batch size + range_s, range_t = range(0, ns, bs), range(0, nt, bt) + + lse_f = np.zeros(ns) + lse_g = np.zeros(nt) + + for i_ot in range(numIterMax): + + for i in range_s: + M = dist(X_s[i:i+bs,:], X_t, metric=metric) + lse_f[i:i+bs] = logsumexp(g[None,:] - M / reg, axis=1) + f = log_a - lse_f + + for j in range_t: + M = dist(X_s, X_t[j:j+bt,:], metric=metric) + lse_g[j:j+bt] = logsumexp(f[:,None] - M / reg, axis=0) + g = log_b - lse_g + + if (i_ot + 1) % 10 == 0: + m1 = np.zeros_like(a) + for i in range_s: + M = dist(X_s[i:i+bs,:], X_t, metric=metric) + m1[i:i+bs] = np.exp(f[i:i+bs,None] + g[None,:] - M / reg).sum(1) + err = np.abs(m1 - a).sum() + if log: + log_err.append(err) + if verbose and (i_ot+1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format(i_ot+1, err)) + + if err <= stopThr: + break + + if log: + return (f, g, log_err) + else: + return (f, g) def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): + lazyEvaluation=False, numMaxEntries=1e8, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1814,6 +1865,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Max number of iterations stopThr : float, optional Stop threshol on error (>0) + lazyEvaluation: boolean, optional + If True, then only calculate the cost matrix by block and return the dual vectors + If False, calculate full cost matrix and return outputs of sinkhorn function. + numMaxEntries: int, optional + Maximum number of entries in the coupling, when lazyEvaluation=True verbose : bool, optional Print information along iterations log : bool, optional @@ -1850,22 +1906,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) - M = dist(X_s, X_t, metric=metric) + if not lazyEvaluation: + M = dist(X_s, X_t, metric=metric) - if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss, log + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss + else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss + bs = min(int(numMaxEntries / nt), ns) # batch size + range_s = range(0, ns, bs) + if log: + f, g, log_error = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, \ + lazyEvaluation=lazyEvaluation, numMaxEntries=numMaxEntries, verbose=verbose, log=log) + else: + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, \ + lazyEvaluation=lazyEvaluation, numMaxEntries=numMaxEntries, verbose=verbose, log=log) + + loss = 0 + for i in range_s: + M_block = dist(X_s[i:i+bs,:], X_t, metric=metric) + pi_block = np.exp(f[i:i+bs,None] + g[None,:] - M_block / reg) + loss += np.sum(M_block * pi_block) + + if log: + return loss, log_error + else: + return loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): From d36979d5fab625149ba0afef740644d116299660 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 12:34:53 +0200 Subject: [PATCH 2/7] Reformat to pep8 and modify parameter --- ot/bregman.py | 72 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b0bc9decc..903e2910e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1685,7 +1685,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, lazyEvaluation=False, numMaxEntries=1e8, verbose=False, + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the @@ -1724,11 +1724,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Max number of iterations stopThr : float, optional Stop threshol on error (>0) - lazyEvaluation: boolean, optional + isLazy: boolean, optional If True, then only calculate the cost matrix by block and return the dual vectors If False, calculate full cost matrix and return outputs of sinkhorn function. - numMaxEntries: int, optional - Maximum number of entries in the coupling, when lazyEvaluation=True + batchSize: int or tuple of 2 int, optional + Shape of the block of cost matrix, when isLazy=True verbose : bool, optional Print information along iterations log : bool, optional @@ -1770,7 +1770,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if b is None: b = unif(nt) - if not lazyEvaluation: + if not isLazy: M = dist(X_s, X_t, metric=metric) if log: @@ -1787,7 +1787,13 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', log_a, log_b = np.log(a), np.log(b) f, g = np.zeros(ns), np.zeros(nt) - bs, bt = min(int(numMaxEntries / nt), ns), min(int(numMaxEntries / ns), nt) # batch size + if isinstance(batchSize, int): + bs, bt = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs, bt = batchSize[0], batchSize[1] + else: + raise ValueError("Batch size must be in integer or a tuple of two integers") + range_s, range_t = range(0, ns, bs), range(0, nt, bt) lse_f = np.zeros(ns) @@ -1796,26 +1802,26 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', for i_ot in range(numIterMax): for i in range_s: - M = dist(X_s[i:i+bs,:], X_t, metric=metric) - lse_f[i:i+bs] = logsumexp(g[None,:] - M / reg, axis=1) + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) f = log_a - lse_f for j in range_t: - M = dist(X_s, X_t[j:j+bt,:], metric=metric) - lse_g[j:j+bt] = logsumexp(f[:,None] - M / reg, axis=0) + M = dist(X_s, X_t[j:j + bt, :], metric=metric) + lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) g = log_b - lse_g if (i_ot + 1) % 10 == 0: m1 = np.zeros_like(a) for i in range_s: - M = dist(X_s[i:i+bs,:], X_t, metric=metric) - m1[i:i+bs] = np.exp(f[i:i+bs,None] + g[None,:] - M / reg).sum(1) + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) err = np.abs(m1 - a).sum() if log: log_err.append(err) - if verbose and (i_ot+1) % 100 == 0: - print("Error in marginal at iteration {} = {}".format(i_ot+1, err)) + if verbose and (i_ot + 1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) if err <= stopThr: break @@ -1825,8 +1831,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: return (f, g) + def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - lazyEvaluation=False, numMaxEntries=1e8, verbose=False, log=False, **kwargs): + isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1865,11 +1872,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Max number of iterations stopThr : float, optional Stop threshol on error (>0) - lazyEvaluation: boolean, optional + isLazy: boolean, optional If True, then only calculate the cost matrix by block and return the dual vectors If False, calculate full cost matrix and return outputs of sinkhorn function. - numMaxEntries: int, optional - Maximum number of entries in the coupling, when lazyEvaluation=True + batchSize: int or tuple of 2 int, optional + Shape of the block of cost matrix, when isLazy=True verbose : bool, optional Print information along iterations log : bool, optional @@ -1912,33 +1919,39 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num if b is None: b = unif(nt) - if not lazyEvaluation: + if not isLazy: M = dist(X_s, X_t, metric=metric) if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + **kwargs) return sinkhorn_loss, log else: sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + **kwargs) return sinkhorn_loss - + else: - bs = min(int(numMaxEntries / nt), ns) # batch size + if isinstance(batchSize, int): + bs = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs = batchSize[0] + else: + raise ValueError("Batch size must be in integer or a tuple of two integers") + range_s = range(0, ns, bs) if log: - f, g, log_error = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, \ - lazyEvaluation=lazyEvaluation, numMaxEntries=numMaxEntries, verbose=verbose, log=log) + f, g, log_error = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, \ - lazyEvaluation=lazyEvaluation, numMaxEntries=numMaxEntries, verbose=verbose, log=log) + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) loss = 0 for i in range_s: - M_block = dist(X_s[i:i+bs,:], X_t, metric=metric) - pi_block = np.exp(f[i:i+bs,None] + g[None,:] - M_block / reg) + M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) + pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) loss += np.sum(M_block * pi_block) if log: @@ -1946,6 +1959,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num else: return loss + def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): r''' From 5ac7712a2e1327d8ee38d70c2cd5f0b1826c9300 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 12:39:00 +0200 Subject: [PATCH 3/7] Fix error in batch size --- ot/bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index 903e2910e..196e6b5af 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1933,7 +1933,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num else: if isinstance(batchSize, int): - bs = batchSize, batchSize + bs = batchSize elif isinstance(batchSize, tuple) and len(batchSize) == 2: bs = batchSize[0] else: From e1ed50b7e43d4ea84d27e1b2ac5b4a9a2b344e51 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 15:20:26 +0200 Subject: [PATCH 4/7] Code review and add test --- ot/bregman.py | 84 +++++++++++++++++++++----------------------- test/test_bregman.py | 45 +++++++++++++++++++++++- 2 files changed, 85 insertions(+), 44 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 196e6b5af..5d37cf400 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1725,10 +1725,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', stopThr : float, optional Stop threshol on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual vectors + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Shape of the block of cost matrix, when isLazy=True + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1770,19 +1771,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if b is None: b = unif(nt) - if not isLazy: - M = dist(X_s, X_t, metric=metric) - - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) - return pi, log - else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi - - else: + if isLazy: if log: - log_err = [] + dict_log = {"err": []} log_a, log_b = np.log(a), np.log(b) f, g = np.zeros(ns), np.zeros(nt) @@ -1818,7 +1809,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) err = np.abs(m1 - a).sum() if log: - log_err.append(err) + dict_log["err"].append(err) if verbose and (i_ot + 1) % 100 == 0: print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) @@ -1827,10 +1818,22 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', break if log: - return (f, g, log_err) + dict_log["u"] = f + dict_log["v"] = g + return (f, g, dict_log) else: return (f, g) + else: + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi + def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): @@ -1873,10 +1876,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num stopThr : float, optional Stop threshol on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual vectors + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Shape of the block of cost matrix, when isLazy=True + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1919,35 +1923,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num if b is None: b = unif(nt) - if not isLazy: - M = dist(X_s, X_t, metric=metric) - + if isLazy: if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss, log - else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss - - else: - if isinstance(batchSize, int): - bs = batchSize - elif isinstance(batchSize, tuple) and len(batchSize) == 2: - bs = batchSize[0] - else: - raise ValueError("Batch size must be in integer or a tuple of two integers") - - range_s = range(0, ns, bs) - - if log: - f, g, log_error = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) else: f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + bs = batchSize if isinstance(batchSize, int) else batchSize[0] + range_s = range(0, ns, bs) + loss = 0 for i in range_s: M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) @@ -1955,10 +1941,22 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num loss += np.sum(M_block * pi_block) if log: - return loss, log_error + return loss, dict_log else: return loss + else: + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss + def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): diff --git a/test/test_bregman.py b/test/test_bregman.py index 7c5162a9b..56c422ca9 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -301,7 +301,7 @@ def test_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) + G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, ) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) @@ -329,6 +329,49 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +def test_lazy_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='minkowski') + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1)) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True, numIterMax=numIterMax, isLazy=True, batchSize=1) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + + loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10 From 22324fcac16f9d1eef57de5a290d206be501036f Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 15:23:30 +0200 Subject: [PATCH 5/7] Fix accidental typo in test_empirical_sinkhorn --- test/test_bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 56c422ca9..8c49d2abf 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -301,7 +301,7 @@ def test_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, ) + G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) From 35aa1e5396040f219bf14e3a97f8d37867d3ac3e Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 16:07:47 +0200 Subject: [PATCH 6/7] Remove whitespace --- ot/bregman.py | 5 +++-- test/test_bregman.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 5d37cf400..105b38be6 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,6 +11,7 @@ # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko +# Quang Huy Tran # # License: MIT License @@ -1728,7 +1729,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Size of the batcheses used to compute the sinkhorn update without memory overhead. + Size of the batcheses used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations @@ -1879,7 +1880,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Size of the batcheses used to compute the sinkhorn update without memory overhead. + Size of the batcheses used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations diff --git a/test/test_bregman.py b/test/test_bregman.py index 8c49d2abf..9665229b7 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -2,6 +2,7 @@ # Author: Remi Flamary # Kilian Fatras +# Quang Huy Tran # # License: MIT License @@ -341,11 +342,11 @@ def test_lazy_empirical_sinkhorn(): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='minkowski') - f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1)) + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) - f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True, numIterMax=numIterMax, isLazy=True, batchSize=1) + f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) @@ -353,7 +354,7 @@ def test_lazy_empirical_sinkhorn(): G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) - loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1) + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) # check constratints From 28dd2fe2c461ac2287ae51e464f0b002d61f5f31 Mon Sep 17 00:00:00 2001 From: 6Ulm Date: Fri, 11 Jun 2021 16:32:34 +0200 Subject: [PATCH 7/7] Edit config.yml --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 29c9a0716..e4c71dde1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -73,6 +73,7 @@ jobs: command: | cd docs; make html; + no_output_timeout: 30m # Save the outputs - store_artifacts: