From 9ee3a6a34777cd51555e727bf614d60813b3f710 Mon Sep 17 00:00:00 2001 From: HexuanLiu Date: Tue, 5 May 2020 16:14:51 -0700 Subject: [PATCH 1/4] edit dr.py --- ot/dr.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ot/dr.py b/ot/dr.py index 11d2e101f..6b24af175 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -162,6 +162,16 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): xc = split_classes(X, y) # compute uniform weighs wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc] + + # pre-compute reg_c,c' + if P0 is not None: + regmean = np.zeros((len(xc), len(xc))) + for i, xi in enumerate(xc): + xi = np.dot(xi, P0) + for j, xj in enumerate(xc[i:]): + xj = np.dot(xj, P0) + M = dist(xi, xj) + regmean[i,j] = np.sum(M)/(len(xi)*len(xj)) def cost(P): # wda loss @@ -173,7 +183,7 @@ def cost(P): for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg, k) + G = sinkhorn(wc[i], wc[j + i], M, reg/regmean[i,j], k) if j == 0: loss_w += np.sum(G * M) else: From 57566706be75f579c8b026752fa3c7573e06cbd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 29 Oct 2021 16:55:51 +0200 Subject: [PATCH 2/4] Correct normalization + optional parameter --- ot/dr.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ot/dr.py b/ot/dr.py index 56b28cfaf..e6faa934b 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -108,7 +108,7 @@ def proj(X): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): +def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize = False): r""" Wasserstein Discriminant Analysis [11]_ @@ -138,6 +138,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): else should be a pymanopt.solvers P0 : ndarray, shape (d, p) Initial starting point for projection. + normalize : bool, optional + Normalise the Wasserstaiun distane by the average distance on P0 (default : False) verbose : int, optional Print information along iterations. @@ -164,7 +166,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc] # pre-compute reg_c,c' - if P0 is not None: + if P0 is not None and normalize: regmean = np.zeros((len(xc), len(xc))) for i, xi in enumerate(xc): xi = np.dot(xi, P0) @@ -172,6 +174,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): xj = np.dot(xj, P0) M = dist(xi, xj) regmean[i,j] = np.sum(M)/(len(xi)*len(xj)) + else: + regmean = np.ones((len(xc), len(xc))) def cost(P): # wda loss @@ -183,7 +187,7 @@ def cost(P): for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg/regmean[i,j], k) + G = sinkhorn(wc[i], wc[j + i], M, reg*regmean[i,j], k) if j == 0: loss_w += np.sum(G * M) else: From 578cfc992026db1f56bf94657b4c6c71001cbae0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 29 Oct 2021 17:00:48 +0200 Subject: [PATCH 3/4] pep8? --- ot/dr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/dr.py b/ot/dr.py index 6855b6f9a..96399f3c5 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -109,7 +109,7 @@ def proj(X): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize = False): +def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis [11]_ @@ -165,7 +165,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no xc = split_classes(X, y) # compute uniform weighs wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc] - + # pre-compute reg_c,c' if P0 is not None and normalize: regmean = np.zeros((len(xc), len(xc))) @@ -174,7 +174,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P0) M = dist(xi, xj) - regmean[i,j] = np.sum(M)/(len(xi)*len(xj)) + regmean[i, j] = np.sum(M) / (len(xi) * len(xj)) else: regmean = np.ones((len(xc), len(xc))) @@ -188,7 +188,7 @@ def cost(P): for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg*regmean[i,j], k) + G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i,j], k) if j == 0: loss_w += np.sum(G * M) else: From 39107b6e5b8b6ac23189439d74b032c105736582 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 29 Oct 2021 17:03:34 +0200 Subject: [PATCH 4/4] final! --- ot/dr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/dr.py b/ot/dr.py index 96399f3c5..74692707f 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -188,7 +188,7 @@ def cost(P): for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i,j], k) + G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: