From 549b95b5736b42f3fe74daf9805303a08b1ae01d Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 28 May 2019 16:08:41 +0200 Subject: [PATCH 01/19] FGW+gromov changes --- README.md | 2 + examples/plot_fgw.py | 152 +++++++++++++++++++++ ot/bregman.py | 2 +- ot/gromov.py | 310 ++++++++++++++++++++++++++++++++++++++++--- ot/optim.py | 102 +++++++++++++- 5 files changed, 546 insertions(+), 22 deletions(-) create mode 100644 examples/plot_fgw.py diff --git a/README.md b/README.md index a22306d8e..be88f6571 100644 --- a/README.md +++ b/README.md @@ -219,3 +219,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + +[18] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py new file mode 100644 index 000000000..5c2d0e190 --- /dev/null +++ b/examples/plot_fgw.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +============================== +Plot Fused-gromov-Wasserstein +============================== + +This example illustrates the computation of FGW for 1D measures[18]. + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer +# +# License: MIT License + +import matplotlib.pyplot as pl +import numpy as np +import ot +from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein + +#%% parameters +# We create two 1D random measures +n=20 +n2=30 +sig=1 +sig2=0.1 + +np.random.seed(0) + +phi=np.arange(n)[:,None] +xs=phi+sig*np.random.randn(n,1) +ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1) + +phi2=np.arange(n2)[:,None] +xt=phi2+sig*np.random.randn(n2,1) +yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1) +yt= yt[::-1,:] + +p=ot.unif(n) +q=ot.unif(n2) + +#%% plot the distributions + +pl.close(10) +pl.figure(10,(7,7)) + +pl.subplot(2,1,1) + +pl.scatter(ys,xs,c=phi,s=70) +pl.ylabel('Feature value a',fontsize=20) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1) +pl.xticks(()) +pl.yticks(()) +pl.subplot(2,1,2) +pl.scatter(yt,xt,c=phi2,s=70) +pl.xlabel('coordinates x/y',fontsize=25) +pl.ylabel('Feature value b',fontsize=20) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1) +pl.yticks(()) +pl.tight_layout() +pl.show() + + +#%% Structure matrices and across-features distance matrix +C1=ot.dist(xs) +C2=ot.dist(xt).T +M=ot.dist(ys,yt) +w1=ot.unif(C1.shape[0]) +w2=ot.unif(C2.shape[0]) +Got=ot.emd([],[],M) + +#%% +cmap='Reds' +pl.close(10) +pl.figure(10,(5,5)) +fs=15 +l_x=[0,5,10,15] +l_y=[0,5,10,15,20,25] +gs = pl.GridSpec(5, 5) + +ax1=pl.subplot(gs[3:,:2]) + +pl.imshow(C1,cmap=cmap,interpolation='nearest') +pl.title("$C_1$",fontsize=fs) +pl.xlabel("$k$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) +pl.xticks(l_x) +pl.yticks(l_x) + +ax2=pl.subplot(gs[:3,2:]) + +pl.imshow(C2,cmap=cmap,interpolation='nearest') +pl.title("$C_2$",fontsize=fs) +pl.ylabel("$l$",fontsize=fs) +#pl.ylabel("$l$",fontsize=fs) +pl.xticks(()) +pl.yticks(l_y) +ax2.set_aspect('auto') + +ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1) +pl.imshow(M,cmap=cmap,interpolation='nearest') +pl.yticks(l_x) +pl.xticks(l_y) +pl.ylabel("$i$",fontsize=fs) +pl.title("$M_{AB}$",fontsize=fs) +pl.xlabel("$j$",fontsize=fs) +pl.tight_layout() +ax3.set_aspect('auto') +pl.show() + + +#%% Computing FGW and GW +alpha=1e-3 + +ot.tic() +Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True) +ot.toc() + +#%reload_ext WGW +Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True) + +#%% visu OT matrix +cmap='Blues' +fs=15 +pl.figure(2,(13,5)) +pl.clf() +pl.subplot(1,3,1) +pl.imshow(Got,cmap=cmap,interpolation='nearest') +#pl.xlabel("$y$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) +pl.xticks(()) + +pl.title('Wasserstein ($M$ only)') + +pl.subplot(1,3,2) +pl.imshow(Gg,cmap=cmap,interpolation='nearest') +pl.title('Gromov ($C_1,C_2$ only)') +pl.xticks(()) +pl.subplot(1,3,3) +pl.imshow(Gwg,cmap=cmap,interpolation='nearest') +pl.title('FGW ($M+C_1,C_2$)') + +pl.xlabel("$j$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) + +pl.tight_layout() +pl.show() \ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index b017c1a85..90404295e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,7 +5,7 @@ # Author: Remi Flamary # Nicolas Courty -# +# Titouan Vayer # License: MIT License import numpy as np diff --git a/ot/gromov.py b/ot/gromov.py index 0278e9921..7491664b2 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -9,17 +9,18 @@ # Author: Erwan Vautier # Nicolas Courty # Rémi Flamary -# +# Titouan Vayer # License: MIT License import numpy as np + from .bregman import sinkhorn from .utils import dist from .optim import cg -def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): +def init_matrix(C1, C2, p, q, loss_fun='square_loss'): """ Return loss matrices and tensors for Gromov-Wasserstein fast computation Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss @@ -77,16 +78,16 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): if loss_fun == 'square_loss': def f1(a): - return (a**2) / 2 + return (a**2) def f2(b): - return (b**2) / 2 + return (b**2) def h1(a): return a def h2(b): - return b + return 2*b elif loss_fun == 'kl_loss': def f1(a): return a * np.log(a + 1e-15) - a @@ -268,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): """ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) @@ -306,6 +307,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations log : bool, optional record log if True + amijo : bool, optional + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. **kwargs : dict parameters can be directly pased to the ot.optim.cg solver @@ -329,9 +333,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): """ - T = np.eye(len(p), len(q)) - - constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -342,14 +344,79 @@ def df(G): return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0, **kwargs) + return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs) + +def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs): + """ + Computes the FGW distance between two graphs see [3] + .. math:: + \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. \gamma 1 = p + \gamma^T 1= q + \gamma\geq 0 + where : + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + Parameters + ---------- + M : ndarray, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix respresentative of the structure in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix espresentative of the structure in the target space + p : ndarray, shape (ns,) + distribution in the source space + q : ndarray, shape (nt,) + distribution in the target space + loss_fun : string,optionnal + loss function used for the solver + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + amijo : bool, optional + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. + **kwargs : dict + parameters can be directly pased to the ot.optim.cg solver + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun) + + G0=p[:,None]*q[None,:] + + def f(G): + return gwloss(constC,hC1,hC2,G) + def df(G): + return gwggrad(constC,hC1,hC2,G) + + return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): """ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) @@ -387,7 +454,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations log : bool, optional record log if True - + amijo : bool, optional + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. Returns ------- gw_dist : float @@ -407,9 +476,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): """ - T = np.eye(len(p), len(q)) - - constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -418,7 +485,7 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) log['T'] = res if log: @@ -495,7 +562,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, T = np.outer(p, q) # Initialization - constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) cpt = 0 err = 1 @@ -815,3 +882,210 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, cpt += 1 return C + +def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss', + max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None): + + """ + Compute the fgw barycenter as presented eq (5) in [3]. + ---------- + N : integer + Desired number of samples of the target barycenter + Ys: list of ndarray, each element has shape (ns,d) + Features of all samples + Cs : list of ndarray, each element has shape (ns,ns) + Structure matrices of all samples + ps : list of ndarray, each element has shape (ns,) + masses of all samples + lambdas : list of float + list of the S spaces' weights + alpha : float + Alpha parameter for the fgw distance + fixed_structure : bool + Wether to fix the structure of the barycenter during the updates + fixed_features : bool + Wether to fix the feature of the barycenter during the updates + init_C : ndarray, shape (N,N), optional + initialization for the barycenters' structure matrix. If not set random init + init_X : ndarray, shape (N,d), optional + initialization for the barycenters' features. If not set random init + Returns + ---------- + X : ndarray, shape (N,d) + Barycenters' features + C : ndarray, shape (N,N) + Barycenters' structure matrix + log_: + T : list of (N,ns) transport matrices + Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + S = len(Cs) + d = Ys[0].shape[1] #dimension on the node features + if p is None: + p = np.ones(N)/N + + Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] + Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] + + lambdas = np.asarray(lambdas, dtype=np.float64) + + if fixed_structure: + if init_C is None: + C=Cs[0] + else: + C=init_C + else: + if init_C is None: + xalea = np.random.randn(N, 2) + C = dist(xalea, xalea) + else: + C = init_C + + if fixed_features: + if init_X is None: + X=Ys[0] + else : + X= init_X + else: + if init_X is None: + X=np.zeros((N,d)) + else: + X = init_X + + T=[np.outer(p,q) for q in ps] + + # X is N,d + # Ys is ns,d + Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + # Ms is N,ns + + cpt = 0 + err_feature = 1 + err_structure = 1 + + if log: + log_={} + log_['err_feature']=[] + log_['err_structure']=[] + log_['Ts_iter']=[] + + while((err_feature > tol or err_structure > tol) and cpt < max_iter): + Cprev = C + Xprev = X + + if not fixed_features: + Ys_temp=[y.T for y in Ys] + X=update_feature_matrix(lambdas,Ys_temp,T,p) + + # X must be N,d + # Ys must be ns,d + Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + + if not fixed_structure: + if loss_fun == 'square_loss': + # T must be ns,N + # Cs must be ns,ns + # p must be N,1 + T_temp=[t.T for t in T] + C = update_sructure_matrix(p, lambdas, T_temp, Cs) + + # Ys must be d,ns + # Ts must be N,ns + # p must be N,1 + # Ms is N,ns + # C is N,N + # Cs is ns,ns + # p is N,1 + # ps is ns,1 + + T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] + + # T is N,ns + + log_['Ts_iter'].append(T) + err_feature = np.linalg.norm(X - Xprev.reshape(d,N)) + err_structure = np.linalg.norm(C - Cprev) + + if log: + log_['err_feature'].append(err_feature) + log_['err_structure'].append(err_structure) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_structure)) + print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + cpt += 1 + log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys + log_['p']=p + log_['Ms']=Ms #Ms sont de tailles N,ns + + return X.T,C,log_ + + +def update_sructure_matrix(p, lambdas, T, Cs): + """ + Updates C according to the L2 Loss kernel with the S Ts couplings + calculated at each iteration + Parameters + ---------- + p : ndarray, shape (N,) + masses in the targeted barycenter + lambdas : list of float + list of the S spaces' weights + T : list of S np.ndarray(ns,N) + the S Ts couplings calculated at each iteration + Cs : list of S ndarray, shape(ns,ns) + Metric cost matrices + Returns + ---------- + C : ndarray, shape (nt,nt) + updated C matrix + """ + tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) + ppt = np.outer(p, p) + + return np.divide(tmpsum, ppt) + +def update_feature_matrix(lambdas,Ys,Ts,p): + + """ + Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3] + calculated at each iteration + Parameters + ---------- + p : ndarray, shape (N,) + masses in the targeted barycenter + lambdas : list of float + list of the S spaces' weights + Ts : list of S np.ndarray(ns,N) + the S Ts couplings calculated at each iteration + Ys : list of S ndarray, shape(d,ns) + The features + Returns + ---------- + X : ndarray, shape (d,N) + + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + p=np.diag(np.array(1/p).reshape(-1,)) + + tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))]) + + return tmpsum + + diff --git a/ot/optim.py b/ot/optim.py index f31fae2d1..a7748652b 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -4,7 +4,7 @@ """ # Author: Remi Flamary -# +# Titouan Vayer # License: MIT License import numpy as np @@ -71,9 +71,70 @@ def phi(alpha1): return alpha, fc[0], phi1 +def do_linesearch(cost,G,deltaG,Mi,f_val, + amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + The FGW cost + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + amijo : bool, optionnal + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optionnal + Structure matrix in the source domain. Only used when amijo=False + C2 : ndarray (nt,nt), optionnal + Structure matrix in the target domain. Only used when amijo=False + reg : float, optionnal + Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used when amijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [3]. Only used when amijo=False + M : ndarray (ns,nt), optionnal + Cost matrix between the features. Only used when amijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if amijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1=np.dot(C1,deltaG) + dot12=dot1.dot(C2) + a=-2*reg*np.sum(dot12*deltaG) + b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG)) + c=cost(G) + + alpha=solve_1d_linesearch_quad_funct(a,b,c) + fc=None + f_val=cost(G+alpha*deltaG) + + return alpha,fc,f_val + def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False): + stopThr=1e-9, verbose=False, log=False,**kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -116,6 +177,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True + kwargs : dict + Parameters for linesearch Returns ------- @@ -177,7 +240,7 @@ def cost(G): deltaG = Gc - G # line search - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs) G = G + alpha * deltaG @@ -339,3 +402,36 @@ def cost(G): return G, log else: return G + +def solve_1d_linesearch_quad_funct(a,b,c): + """ + Solve on 0,1 the following problem: + .. math:: + \min f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + + """ + f0=c + df0=b + f1=a+f0+df0 + + if a>0: # convex + minimum=min(1,max(0,-b/(2*a))) + #print('entrelesdeux') + return minimum + else: # non convexe donc sur les coins + if f0>f1: + #print('sur1 f(1)={}'.format(f(1))) + return 1 + else: + #print('sur0 f(0)={}'.format(f(0))) + return 0 From b1b514f5d9de009e63bd407dfd9c0a0cf6128876 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 28 May 2019 16:50:00 +0200 Subject: [PATCH 02/19] bary fgw --- examples/plot_barycenter_fgw.py | 172 ++++++++++++++++++++++++++++++++ examples/plot_fgw.py | 1 - ot/gromov.py | 15 +-- 3 files changed, 180 insertions(+), 8 deletions(-) create mode 100644 examples/plot_barycenter_fgw.py diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py new file mode 100644 index 000000000..f416629d4 --- /dev/null +++ b/examples/plot_barycenter_fgw.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +""" +================================= +Plot graphs' barycenter using FGW +================================= + +This example illustrates the computation barycenter of labeled graphs using FGW + +Requires networkx >=2 + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer +# +# License: MIT License + +#%% load libraries +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import fgw_barycenters +#%% Graph functions + +def find_thresh(C,inf=0.5,sup=3,step=10): + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + and the original matrix. + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + inf : float + The beginning of the linesearch + sup : float + The end of the linesearch + step : integer + Number of thresholds tested + """ + dist=[] + search=np.linspace(inf,sup,step) + for thresh in search: + Cprime=sp_to_adjency(C,0,thresh) + SC=shortest_path(Cprime,method='D') + SC[SC==float('inf')]=100 + dist.append(np.linalg.norm(SC-C)) + return search[np.argmin(dist)],dist + +def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + threshinf : float + The minimum value of distance from which the new value is set to 1 + threshsup : float + The maximum value of distance from which the new value is set to 1 + Returns + ------- + C : ndarray, shape (n_nodes,n_nodes) + The threshold matrix. Each element is in {0,1} + """ + H=np.zeros_like(C) + np.fill_diagonal(H,np.diagonal(C)) + C=C-H + C=np.minimum(np.maximum(C,threshinf),threshsup) + C[C==threshsup]=0 + C[C!=0]=1 + + return C + +def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None): + """ Create a noisy circular graph + """ + g=nx.Graph() + g.add_nodes_from(list(range(N))) + for i in range(N): + noise=float(np.random.normal(mu,sigma,1)) + if with_noise: + g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise) + else: + g.add_node(i,attr_name=math.sin(2*i*math.pi/N)) + g.add_edge(i,i+1) + if structure_noise: + randomint=np.random.randint(0,p) + if randomint==0: + if i<=N-3: + g.add_edge(i,i+2) + if i==N-2: + g.add_edge(i,0) + if i==N-1: + g.add_edge(i,1) + g.add_edge(N,0) + noise=float(np.random.normal(mu,sigma,1)) + if with_noise: + g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise) + else: + g.add_node(N,attr_name=math.sin(2*N*math.pi/N)) + return g + +def graph_colors(nx_graph,vmin=0,vmax=7): + cnorm = mcol.Normalize(vmin=vmin,vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis') + cpick.set_array([]) + val_map = {} + for k,v in nx.get_node_attributes(nx_graph,'attr_name').items(): + val_map[k]=cpick.to_rgba(v) + colors=[] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + +#%% create dataset +# We build a dataset of noisy circular graphs. +# Noise is added on the structures by random connections and on the features by gaussian noise. + +np.random.seed(30) +X0=[] +for k in range(9): + X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3)) + +#%% Plot dataset + +plt.figure(figsize=(8,10)) +for i in range(len(X0)): + plt.subplot(3,3,i+1) + g=X0[i] + pos=nx.kamada_kawai_layout(g) + nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100) +plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20) +plt.show() + + + +#%% +# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +# Features distances are the euclidean distances +Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0] +ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0] +Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0] +lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel() +sizebary=15 # we choose a barycenter with 15 nodes + +#%% + +A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95) + +#%% +bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0])) +for i in range(len(A.ravel())): + bary.add_node(i,attr_name=float(A.ravel()[i])) + +#%% +pos = nx.kamada_kawai_layout(bary) +nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False) +plt.suptitle('Barycenter',fontsize=20) +plt.show() + + + + diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index 5c2d0e190..bfa7fb45a 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ ============================== diff --git a/ot/gromov.py b/ot/gromov.py index 7491664b2..31bd657d4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -883,8 +883,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, return C -def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss', - max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None): +def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False, + p=None,loss_fun='square_loss',max_iter=100, tol=1e-9, + verbose=False,log=True,init_C=None,init_X=None): """ Compute the fgw barycenter as presented eq (5) in [3]. @@ -957,7 +958,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature X=np.zeros((N,d)) else: X = init_X - + T=[np.outer(p,q) for q in ps] # X is N,d @@ -981,7 +982,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if not fixed_features: Ys_temp=[y.T for y in Ys] - X=update_feature_matrix(lambdas,Ys_temp,T,p) + X=update_feature_matrix(lambdas,Ys_temp,T,p).T # X must be N,d # Ys must be ns,d @@ -1024,11 +1025,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys + log_['T']=T # from target to Ys log_['p']=p - log_['Ms']=Ms #Ms sont de tailles N,ns + log_['Ms']=Ms #Ms are N,ns - return X.T,C,log_ + return X,C,log_ def update_sructure_matrix(p, lambdas, T, Cs): From cd4b98c34f885176f33db3fab16530622f29ab42 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 28 May 2019 17:13:21 +0200 Subject: [PATCH 03/19] solve conlict --- README.md | 15 ++++++++++++++- ot/bregman.py | 1 + ot/gromov.py | 6 +++--- ot/optim.py | 2 +- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index be88f6571..13e1013d6 100644 --- a/README.md +++ b/README.md @@ -220,4 +220,17 @@ You can also post bug reports and feature requests in Github issues. Make sure t [17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). -[18] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) + +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning + +[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. + +[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 + +[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + +[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). + diff --git a/ot/bregman.py b/ot/bregman.py index 90404295e..7be67b814 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,6 +5,7 @@ # Author: Remi Flamary # Nicolas Courty +# Kilian Fatras # Titouan Vayer # License: MIT License diff --git a/ot/gromov.py b/ot/gromov.py index 31bd657d4..ad68a1cf9 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -398,7 +398,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo= log dictionary return only if log==True in parameters References ---------- - .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. @@ -921,7 +921,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) References ---------- - .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. @@ -1077,7 +1077,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p): References ---------- - .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. diff --git a/ot/optim.py b/ot/optim.py index a7748652b..9fce21e9e 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -112,7 +112,7 @@ def do_linesearch(cost,G,deltaG,Mi,f_val, The value of the cost for the next iteration References ---------- - .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. From 11c2c26ff897e5763e714546e7021cffa8d673a7 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 28 May 2019 17:19:40 +0200 Subject: [PATCH 04/19] solve 2 --- README.md | 20 +++++++++++--------- ot/bregman.py | 1 + 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 13e1013d6..995177375 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Anaconda Cloud](https://anaconda.org/conda-forge/pot/badges/version.svg)](https://anaconda.org/conda-forge/pot) [![Build Status](https://travis-ci.org/rflamary/POT.svg?branch=master)](https://travis-ci.org/rflamary/POT) [![Documentation Status](https://readthedocs.org/projects/pot/badge/?version=latest)](http://pot.readthedocs.io/en/latest/?badge=latest) +[![Downloads](https://pepy.tech/badge/pot)](https://pepy.tech/project/pot) [![Anaconda downloads](https://anaconda.org/conda-forge/pot/badges/downloads.svg)](https://anaconda.org/conda-forge/pot) [![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/rflamary/POT/blob/master/LICENSE) @@ -14,15 +15,18 @@ This open source Python library provide several solvers for optimization problem It provides the following solvers: * OT Network Flow solver for the linear program/ Earth Movers Distance [1]. -* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat). +* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2], stabilized version [9][10] and greedy Sinkhorn [22] with optional GPU implementation (requires cupy). +* Sinkhorn divergence [23] and entropic regularization OT from empirical data. * Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized Wasserstein barycenters [16] with LP solver (only small scale). -* Bregman projections for Wasserstein barycenter [3] and unmixing [4]. +* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4]. * Optimal transport for domain adaptation with group lasso regularization [5] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. * Linear OT [14] and Joint OT matrix and mapping estimation [8]. * Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). * Gromov-Wasserstein distances and barycenters ([13] and regularized [12]) +* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* Non regularized free support Wasserstein barycenters [20]. Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. @@ -77,16 +81,12 @@ Note that for easier access the module is name ot instead of pot. Some sub-modules require additional dependences which are discussed below -* **ot.dr** (Wasserstein dimensionality rediuction) depends on autograd and pymanopt that can be installed with: +* **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with: ``` pip install pymanopt autograd ``` -* **ot.gpu** (GPU accelerated OT) depends on cudamat that have to be installed with: -``` -git clone https://github.com/cudamat/cudamat.git -cd cudamat -python setup.py install --user # for user install (no root) -``` +* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). + obviously you need CUDA installed and a compatible GPU. @@ -162,6 +162,8 @@ The contributors to this library are: * [Stanislas Chambon](https://slasnista.github.io/) * [Antoine Rolet](https://arolet.github.io/) * Erwan Vautier (Gromov-Wasserstein) +* [Kilian Fatras](https://kilianfatras.github.io/) +* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/bregman.py b/ot/bregman.py index 7be67b814..ffa62022c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,6 +7,7 @@ # Nicolas Courty # Kilian Fatras # Titouan Vayer +# # License: MIT License import numpy as np From 6484c9ea301fc15ae53b4afe134941909f581ffe Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:11:48 +0200 Subject: [PATCH 05/19] Tests + contributions --- README.md | 1 + ot/gromov.py | 12 +++++--- test/test_gromov.py | 75 +++++++++++++++++++++++++++++++++++++++++++++ test/test_optim.py | 5 +++ 4 files changed, 89 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 995177375..9692344ca 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,7 @@ The contributors to this library are: * Erwan Vautier (Gromov-Wasserstein) * [Kilian Fatras](https://kilianfatras.github.io/) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) +* [Vayer Titouan](https://tvayer.github.io/) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/gromov.py b/ot/gromov.py index ad68a1cf9..297b1942b 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -926,6 +926,10 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + + class UndefinedParameter(Exception): + pass + S = len(Cs) d = Ys[0].shape[1] #dimension on the node features if p is None: @@ -938,7 +942,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_structure: if init_C is None: - C=Cs[0] + raise UndefinedParameter('If C is fixed it must be initialized') else: C=init_C else: @@ -950,7 +954,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_features: if init_X is None: - X=Ys[0] + raise UndefinedParameter('If X is fixed it must be initialized') else : X= init_X else: @@ -1004,13 +1008,13 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature # Cs is ns,ns # p is N,1 # ps is ns,1 - + T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns log_['Ts_iter'].append(T) - err_feature = np.linalg.norm(X - Xprev.reshape(d,N)) + err_feature = np.linalg.norm(X - Xprev.reshape(N,d)) err_structure = np.linalg.norm(C - Cprev) if log: diff --git a/test/test_gromov.py b/test/test_gromov.py index fb86274a1..07cd87455 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + +def test_fgw(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0],2) + yt= ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M=ot.dist(ys,yt) + M/=M.max() + + G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence fgw + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence fgw + + +def test_fgw_barycenter(): + + ns = 50 + nt = 60 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + + ys = np.random.randn(Xs.shape[0],2) + yt= np.random.randn(Xt.shape[0],2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + + n_samples = 3 + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, + fixed_structure=False,fixed_features=False, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + xalea = np.random.randn(n_samples, 2) + init_C = ot.dist(xalea, xalea) + + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5, + fixed_structure=True,init_C=init_C,fixed_features=False, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + init_X=np.random.randn(n_samples,ys.shape[1]) + + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, + fixed_structure=False,fixed_features=True, init_X=init_X, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_optim.py b/test/test_optim.py index dfefe597a..1188ef601 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -65,3 +65,8 @@ def df(G): np.testing.assert_allclose(a, G.sum(1), atol=1e-05) np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + +def test_solve_1d_linesearch_quad_funct(): + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1) From f70aabfcc11f92181e0dc987b341bad8ec030d75 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:16:23 +0200 Subject: [PATCH 06/19] pep8 --- ot/gromov.py | 124 +++++++++++++++++++++++++-------------------------- ot/optim.py | 59 ++++++++++++------------ 2 files changed, 91 insertions(+), 92 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 297b1942b..fe4fc159b 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -78,16 +78,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): if loss_fun == 'square_loss': def f1(a): - return (a**2) + return (a**2) def f2(b): - return (b**2) + return (b**2) def h1(a): return a def h2(b): - return 2*b + return 2 * b elif loss_fun == 'kl_loss': def f1(a): return a * np.log(a + 1e-15) - a @@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): """ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) @@ -344,13 +344,14 @@ def df(G): return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs) + return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs) -def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs): + +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs): """ Computes the FGW distance between two graphs see [3] .. math:: @@ -376,7 +377,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo= q : ndarray, shape (nt,) distribution in the target space loss_fun : string,optionnal - loss function used for the solver + loss function used for the solver max_iter : int, optional Max number of iterations tol : float, optional @@ -404,19 +405,20 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo= International Conference on Machine Learning (ICML). 2019. """ - constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun) - - G0=p[:,None]*q[None,:] - + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + def f(G): - return gwloss(constC,hC1,hC2,G) + return gwloss(constC, hC1, hC2, G) + def df(G): - return gwggrad(constC,hC1,hC2,G) - - return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs) + return gwggrad(constC, hC1, hC2, G) + + return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): """ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) @@ -485,7 +487,7 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) log['T'] = res if log: @@ -883,14 +885,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, return C -def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False, - p=None,loss_fun='square_loss',max_iter=100, tol=1e-9, - verbose=False,log=True,init_C=None,init_X=None): - + +def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, + verbose=False, log=True, init_C=None, init_X=None): """ Compute the fgw barycenter as presented eq (5) in [3]. ---------- - N : integer + N : integer Desired number of samples of the target barycenter Ys: list of ndarray, each element has shape (ns,d) Features of all samples @@ -906,9 +908,9 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature Wether to fix the structure of the barycenter during the updates fixed_features : bool Wether to fix the feature of the barycenter during the updates - init_C : ndarray, shape (N,N), optional + init_C : ndarray, shape (N,N), optional initialization for the barycenters' structure matrix. If not set random init - init_X : ndarray, shape (N,d), optional + init_X : ndarray, shape (N,d), optional initialization for the barycenters' features. If not set random init Returns ---------- @@ -926,14 +928,14 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - + class UndefinedParameter(Exception): pass - + S = len(Cs) - d = Ys[0].shape[1] #dimension on the node features + d = Ys[0].shape[1] # dimension on the node features if p is None: - p = np.ones(N)/N + p = np.ones(N) / N Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] @@ -944,7 +946,7 @@ class UndefinedParameter(Exception): if init_C is None: raise UndefinedParameter('If C is fixed it must be initialized') else: - C=init_C + C = init_C else: if init_C is None: xalea = np.random.randn(N, 2) @@ -954,20 +956,20 @@ class UndefinedParameter(Exception): if fixed_features: if init_X is None: - raise UndefinedParameter('If X is fixed it must be initialized') - else : - X= init_X + raise UndefinedParameter('If X is fixed it must be initialized') + else: + X = init_X else: - if init_X is None: - X=np.zeros((N,d)) + if init_X is None: + X = np.zeros((N, d)) else: X = init_X - - T=[np.outer(p,q) for q in ps] + + T = [np.outer(p, q) for q in ps] # X is N,d # Ys is ns,d - Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns cpt = 0 @@ -975,46 +977,46 @@ class UndefinedParameter(Exception): err_structure = 1 if log: - log_={} - log_['err_feature']=[] - log_['err_structure']=[] - log_['Ts_iter']=[] + log_ = {} + log_['err_feature'] = [] + log_['err_structure'] = [] + log_['Ts_iter'] = [] while((err_feature > tol or err_structure > tol) and cpt < max_iter): Cprev = C Xprev = X if not fixed_features: - Ys_temp=[y.T for y in Ys] - X=update_feature_matrix(lambdas,Ys_temp,T,p).T + Ys_temp = [y.T for y in Ys] + X = update_feature_matrix(lambdas, Ys_temp, T, p).T # X must be N,d # Ys must be ns,d - Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': # T must be ns,N # Cs must be ns,ns # p must be N,1 - T_temp=[t.T for t in T] + T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) # Ys must be d,ns # Ts must be N,ns # p must be N,1 # Ms is N,ns - # C is N,N + # C is N,N # Cs is ns,ns # p is N,1 # ps is ns,1 - - T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] - # T is N,ns + T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] + + # T is N,ns log_['Ts_iter'].append(T) - err_feature = np.linalg.norm(X - Xprev.reshape(N,d)) + err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) err_structure = np.linalg.norm(C - Cprev) if log: @@ -1029,11 +1031,11 @@ class UndefinedParameter(Exception): print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T']=T # from target to Ys - log_['p']=p - log_['Ms']=Ms #Ms are N,ns + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms # Ms are N,ns - return X,C,log_ + return X, C, log_ def update_sructure_matrix(p, lambdas, T, Cs): @@ -1060,8 +1062,8 @@ def update_sructure_matrix(p, lambdas, T, Cs): return np.divide(tmpsum, ppt) -def update_feature_matrix(lambdas,Ys,Ts,p): - + +def update_feature_matrix(lambdas, Ys, Ts, p): """ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3] calculated at each iteration @@ -1078,7 +1080,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p): Returns ---------- X : ndarray, shape (d,N) - + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -1087,10 +1089,8 @@ def update_feature_matrix(lambdas,Ys,Ts,p): International Conference on Machine Learning (ICML). 2019. """ - p=np.diag(np.array(1/p).reshape(-1,)) + p = np.diag(np.array(1 / p).reshape(-1,)) - tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))]) + tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))]) return tmpsum - - diff --git a/ot/optim.py b/ot/optim.py index 9fce21e9e..cbfb187b8 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -71,8 +71,9 @@ def phi(alpha1): return alpha, fc[0], phi1 -def do_linesearch(cost,G,deltaG,Mi,f_val, - amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None): + +def do_linesearch(cost, G, deltaG, Mi, f_val, + amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters @@ -119,22 +120,22 @@ def do_linesearch(cost,G,deltaG,Mi,f_val, """ if amijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) - else: # requires symetric matrices - dot1=np.dot(C1,deltaG) - dot12=dot1.dot(C2) - a=-2*reg*np.sum(dot12*deltaG) - b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG)) - c=cost(G) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad_funct(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) - alpha=solve_1d_linesearch_quad_funct(a,b,c) - fc=None - f_val=cost(G+alpha*deltaG) - - return alpha,fc,f_val + return alpha, fc, f_val def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False,**kwargs): + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -240,7 +241,7 @@ def cost(G): deltaG = Gc - G # line search - alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -403,11 +404,12 @@ def cost(G): else: return G -def solve_1d_linesearch_quad_funct(a,b,c): + +def solve_1d_linesearch_quad_funct(a, b, c): """ - Solve on 0,1 the following problem: + Solve on 0,1 the following problem: .. math:: - \min f(x)=a*x^{2}+b*x+c + \min f(x)=a*x^{2}+b*x+c Parameters ---------- @@ -416,22 +418,19 @@ def solve_1d_linesearch_quad_funct(a,b,c): Returns ------- - x : float + x : float The optimal value which leads to the minimal cost - + """ - f0=c - df0=b - f1=a+f0+df0 + f0 = c + df0 = b + f1 = a + f0 + df0 - if a>0: # convex - minimum=min(1,max(0,-b/(2*a))) - #print('entrelesdeux') + if a > 0: # convex + minimum = min(1, max(0, -b / (2 * a))) return minimum - else: # non convexe donc sur les coins - if f0>f1: - #print('sur1 f(1)={}'.format(f(1))) + else: # non convex + if f0 > f1: return 1 else: - #print('sur0 f(0)={}'.format(f(0))) return 0 From fa989062c17f87bd96aa58ad764fd3791ea11e22 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:00:50 +0200 Subject: [PATCH 07/19] Reame +pep8 --- README.md | 14 +++ examples/plot_barycenter_fgw.py | 150 ++++++++++++++++---------------- examples/plot_fgw.py | 138 ++++++++++++++--------------- test/test_gromov.py | 53 +++++------ test/test_optim.py | 9 +- 5 files changed, 190 insertions(+), 174 deletions(-) diff --git a/README.md b/README.md index fd27f9d6f..b6b215cb4 100644 --- a/README.md +++ b/README.md @@ -222,3 +222,17 @@ You can also post bug reports and feature requests in Github issues. Make sure t [16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) + +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning + +[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. + +[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 + +[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + +[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py index f416629d4..9eea03613 100644 --- a/examples/plot_barycenter_fgw.py +++ b/examples/plot_barycenter_fgw.py @@ -30,10 +30,11 @@ from ot.gromov import fgw_barycenters #%% Graph functions -def find_thresh(C,inf=0.5,sup=3,step=10): + +def find_thresh(C, inf=0.5, sup=3, step=10): """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected - Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. - The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix and the original matrix. Parameters ---------- @@ -43,21 +44,22 @@ def find_thresh(C,inf=0.5,sup=3,step=10): The beginning of the linesearch sup : float The end of the linesearch - step : integer - Number of thresholds tested + step : integer + Number of thresholds tested """ - dist=[] - search=np.linspace(inf,sup,step) + dist = [] + search = np.linspace(inf, sup, step) for thresh in search: - Cprime=sp_to_adjency(C,0,thresh) - SC=shortest_path(Cprime,method='D') - SC[SC==float('inf')]=100 - dist.append(np.linalg.norm(SC-C)) - return search[np.argmin(dist)],dist - -def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): - """ Thresholds the structure matrix in order to compute an adjency matrix. - All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Cprime = sp_to_adjency(C, 0, thresh) + SC = shortest_path(Cprime, method='D') + SC[SC == float('inf')] = 100 + dist.append(np.linalg.norm(SC - C)) + return search[np.argmin(dist)], dist + + +def sp_to_adjency(C, threshinf=0.2, threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 Parameters ---------- C : ndarray, shape (n_nodes,n_nodes) @@ -71,102 +73,100 @@ def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): C : ndarray, shape (n_nodes,n_nodes) The threshold matrix. Each element is in {0,1} """ - H=np.zeros_like(C) - np.fill_diagonal(H,np.diagonal(C)) - C=C-H - C=np.minimum(np.maximum(C,threshinf),threshsup) - C[C==threshsup]=0 - C[C!=0]=1 - - return C - -def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None): + H = np.zeros_like(C) + np.fill_diagonal(H, np.diagonal(C)) + C = C - H + C = np.minimum(np.maximum(C, threshinf), threshsup) + C[C == threshsup] = 0 + C[C != 0] = 1 + + return C + + +def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): """ Create a noisy circular graph """ - g=nx.Graph() + g = nx.Graph() g.add_nodes_from(list(range(N))) for i in range(N): - noise=float(np.random.normal(mu,sigma,1)) + noise = float(np.random.normal(mu, sigma, 1)) if with_noise: - g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise) + g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) else: - g.add_node(i,attr_name=math.sin(2*i*math.pi/N)) - g.add_edge(i,i+1) + g.add_node(i, attr_name=math.sin(2 * i * math.pi / N)) + g.add_edge(i, i + 1) if structure_noise: - randomint=np.random.randint(0,p) - if randomint==0: - if i<=N-3: - g.add_edge(i,i+2) - if i==N-2: - g.add_edge(i,0) - if i==N-1: - g.add_edge(i,1) - g.add_edge(N,0) - noise=float(np.random.normal(mu,sigma,1)) + randomint = np.random.randint(0, p) + if randomint == 0: + if i <= N - 3: + g.add_edge(i, i + 2) + if i == N - 2: + g.add_edge(i, 0) + if i == N - 1: + g.add_edge(i, 1) + g.add_edge(N, 0) + noise = float(np.random.normal(mu, sigma, 1)) if with_noise: - g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise) + g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) else: - g.add_node(N,attr_name=math.sin(2*N*math.pi/N)) + g.add_node(N, attr_name=math.sin(2 * N * math.pi / N)) return g -def graph_colors(nx_graph,vmin=0,vmax=7): - cnorm = mcol.Normalize(vmin=vmin,vmax=vmax) - cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis') + +def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') cpick.set_array([]) val_map = {} - for k,v in nx.get_node_attributes(nx_graph,'attr_name').items(): - val_map[k]=cpick.to_rgba(v) - colors=[] + for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): + val_map[k] = cpick.to_rgba(v) + colors = [] for node in nx_graph.nodes(): colors.append(val_map[node]) return colors - + #%% create dataset # We build a dataset of noisy circular graphs. # Noise is added on the structures by random connections and on the features by gaussian noise. + np.random.seed(30) -X0=[] +X0 = [] for k in range(9): - X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3)) - + X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) + #%% Plot dataset -plt.figure(figsize=(8,10)) +plt.figure(figsize=(8, 10)) for i in range(len(X0)): - plt.subplot(3,3,i+1) - g=X0[i] - pos=nx.kamada_kawai_layout(g) - nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100) -plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20) + plt.subplot(3, 3, i + 1) + g = X0[i] + pos = nx.kamada_kawai_layout(g) + nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) +plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) plt.show() - #%% # We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances -Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0] -ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0] -Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0] -lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel() -sizebary=15 # we choose a barycenter with 15 nodes +Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] +ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] +Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] +lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() +sizebary = 15 # we choose a barycenter with 15 nodes #%% -A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95) +A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) #%% -bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0])) +bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) for i in range(len(A.ravel())): - bary.add_node(i,attr_name=float(A.ravel()[i])) - + bary.add_node(i, attr_name=float(A.ravel()[i])) + #%% pos = nx.kamada_kawai_layout(bary) -nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False) -plt.suptitle('Barycenter',fontsize=20) +nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) +plt.suptitle('Barycenter', fontsize=20) plt.show() - - - - diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index bfa7fb45a..ae3c487a3 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -20,132 +20,132 @@ import matplotlib.pyplot as pl import numpy as np import ot -from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein +from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein #%% parameters -# We create two 1D random measures -n=20 -n2=30 -sig=1 -sig2=0.1 +# We create two 1D random measures +n = 20 +n2 = 30 +sig = 1 +sig2 = 0.1 np.random.seed(0) -phi=np.arange(n)[:,None] -xs=phi+sig*np.random.randn(n,1) -ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1) +phi = np.arange(n)[:, None] +xs = phi + sig * np.random.randn(n, 1) +ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1) -phi2=np.arange(n2)[:,None] -xt=phi2+sig*np.random.randn(n2,1) -yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1) -yt= yt[::-1,:] +phi2 = np.arange(n2)[:, None] +xt = phi2 + sig * np.random.randn(n2, 1) +yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1) +yt = yt[::-1, :] -p=ot.unif(n) -q=ot.unif(n2) +p = ot.unif(n) +q = ot.unif(n2) #%% plot the distributions pl.close(10) -pl.figure(10,(7,7)) +pl.figure(10, (7, 7)) -pl.subplot(2,1,1) +pl.subplot(2, 1, 1) -pl.scatter(ys,xs,c=phi,s=70) -pl.ylabel('Feature value a',fontsize=20) -pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1) +pl.scatter(ys, xs, c=phi, s=70) +pl.ylabel('Feature value a', fontsize=20) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1) pl.xticks(()) pl.yticks(()) -pl.subplot(2,1,2) -pl.scatter(yt,xt,c=phi2,s=70) -pl.xlabel('coordinates x/y',fontsize=25) -pl.ylabel('Feature value b',fontsize=20) -pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1) +pl.subplot(2, 1, 2) +pl.scatter(yt, xt, c=phi2, s=70) +pl.xlabel('coordinates x/y', fontsize=25) +pl.ylabel('Feature value b', fontsize=20) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1) pl.yticks(()) pl.tight_layout() pl.show() #%% Structure matrices and across-features distance matrix -C1=ot.dist(xs) -C2=ot.dist(xt).T -M=ot.dist(ys,yt) -w1=ot.unif(C1.shape[0]) -w2=ot.unif(C2.shape[0]) -Got=ot.emd([],[],M) +C1 = ot.dist(xs) +C2 = ot.dist(xt).T +M = ot.dist(ys, yt) +w1 = ot.unif(C1.shape[0]) +w2 = ot.unif(C2.shape[0]) +Got = ot.emd([], [], M) #%% -cmap='Reds' +cmap = 'Reds' pl.close(10) -pl.figure(10,(5,5)) -fs=15 -l_x=[0,5,10,15] -l_y=[0,5,10,15,20,25] +pl.figure(10, (5, 5)) +fs = 15 +l_x = [0, 5, 10, 15] +l_y = [0, 5, 10, 15, 20, 25] gs = pl.GridSpec(5, 5) -ax1=pl.subplot(gs[3:,:2]) +ax1 = pl.subplot(gs[3:, :2]) -pl.imshow(C1,cmap=cmap,interpolation='nearest') -pl.title("$C_1$",fontsize=fs) -pl.xlabel("$k$",fontsize=fs) -pl.ylabel("$i$",fontsize=fs) +pl.imshow(C1, cmap=cmap, interpolation='nearest') +pl.title("$C_1$", fontsize=fs) +pl.xlabel("$k$", fontsize=fs) +pl.ylabel("$i$", fontsize=fs) pl.xticks(l_x) pl.yticks(l_x) -ax2=pl.subplot(gs[:3,2:]) +ax2 = pl.subplot(gs[:3, 2:]) -pl.imshow(C2,cmap=cmap,interpolation='nearest') -pl.title("$C_2$",fontsize=fs) -pl.ylabel("$l$",fontsize=fs) +pl.imshow(C2, cmap=cmap, interpolation='nearest') +pl.title("$C_2$", fontsize=fs) +pl.ylabel("$l$", fontsize=fs) #pl.ylabel("$l$",fontsize=fs) pl.xticks(()) pl.yticks(l_y) ax2.set_aspect('auto') -ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1) -pl.imshow(M,cmap=cmap,interpolation='nearest') +ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1) +pl.imshow(M, cmap=cmap, interpolation='nearest') pl.yticks(l_x) pl.xticks(l_y) -pl.ylabel("$i$",fontsize=fs) -pl.title("$M_{AB}$",fontsize=fs) -pl.xlabel("$j$",fontsize=fs) +pl.ylabel("$i$", fontsize=fs) +pl.title("$M_{AB}$", fontsize=fs) +pl.xlabel("$j$", fontsize=fs) pl.tight_layout() ax3.set_aspect('auto') pl.show() #%% Computing FGW and GW -alpha=1e-3 - +alpha = 1e-3 + ot.tic() -Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True) +Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) ot.toc() -#%reload_ext WGW -Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True) - +#%reload_ext WGW +Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) + #%% visu OT matrix -cmap='Blues' -fs=15 -pl.figure(2,(13,5)) +cmap = 'Blues' +fs = 15 +pl.figure(2, (13, 5)) pl.clf() -pl.subplot(1,3,1) -pl.imshow(Got,cmap=cmap,interpolation='nearest') +pl.subplot(1, 3, 1) +pl.imshow(Got, cmap=cmap, interpolation='nearest') #pl.xlabel("$y$",fontsize=fs) -pl.ylabel("$i$",fontsize=fs) +pl.ylabel("$i$", fontsize=fs) pl.xticks(()) pl.title('Wasserstein ($M$ only)') -pl.subplot(1,3,2) -pl.imshow(Gg,cmap=cmap,interpolation='nearest') +pl.subplot(1, 3, 2) +pl.imshow(Gg, cmap=cmap, interpolation='nearest') pl.title('Gromov ($C_1,C_2$ only)') pl.xticks(()) -pl.subplot(1,3,3) -pl.imshow(Gwg,cmap=cmap,interpolation='nearest') +pl.subplot(1, 3, 3) +pl.imshow(Gwg, cmap=cmap, interpolation='nearest') pl.title('FGW ($M+C_1,C_2$)') -pl.xlabel("$j$",fontsize=fs) -pl.ylabel("$i$",fontsize=fs) +pl.xlabel("$j$", fontsize=fs) +pl.ylabel("$i$", fontsize=fs) pl.tight_layout() -pl.show() \ No newline at end of file +pl.show() diff --git a/test/test_gromov.py b/test/test_gromov.py index 43b63e154..cd180d481 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -145,7 +145,8 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - + + def test_fgw(): n_samples = 50 # nb samples @@ -155,9 +156,9 @@ def test_fgw(): xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) xt = xs[::-1].copy() - - ys = np.random.randn(xs.shape[0],2) - yt= ys[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -167,11 +168,11 @@ def test_fgw(): C1 /= C1.max() C2 /= C2.max() - - M=ot.dist(ys,yt) - M/=M.max() - G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5) + M = ot.dist(ys, yt) + M /= M.max() + + G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) # check constratints np.testing.assert_allclose( @@ -187,36 +188,36 @@ def test_fgw_barycenter(): Xs, ys = ot.datasets.make_data_classif('3gauss', ns) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) - - ys = np.random.randn(Xs.shape[0],2) - yt= np.random.randn(Xt.shape[0],2) + + ys = np.random.randn(Xs.shape[0], 2) + yt = np.random.randn(Xt.shape[0], 2) C1 = ot.dist(Xs) C2 = ot.dist(Xt) n_samples = 3 - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, - fixed_structure=False,fixed_features=False, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5, - fixed_structure=True,init_C=init_C,fixed_features=False, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, + fixed_structure=True, init_C=init_C, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) - - init_X=np.random.randn(n_samples,ys.shape[1]) - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, - fixed_structure=False,fixed_features=True, init_X=init_X, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + init_X = np.random.randn(n_samples, ys.shape[1]) + + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_optim.py b/test/test_optim.py index 1188ef601..e7ba32a59 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -65,8 +65,9 @@ def df(G): np.testing.assert_allclose(a, G.sum(1), atol=1e-05) np.testing.assert_allclose(b, G.sum(0), atol=1e-05) - + + def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1) From 103dfe0ee76e110bb9e0d1e36e3dd86109db3fce Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:10:37 +0200 Subject: [PATCH 08/19] test check --- ot/gromov.py | 2 +- ot/optim.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 33134a266..44248d149 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -348,7 +348,7 @@ def df(G): log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs) + return cg(p, q, 0, 1, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs): diff --git a/ot/optim.py b/ot/optim.py index cbfb187b8..2170c7ec3 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,7 +73,7 @@ def phi(alpha1): def do_linesearch(cost, G, deltaG, Mi, f_val, - amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + amijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters From 915d5fa4c4020536f2d41c21353b1477befa8af3 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:19:41 +0200 Subject: [PATCH 09/19] python2 divide problem --- ot/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/optim.py b/ot/optim.py index 2170c7ec3..282b30db6 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -427,7 +427,7 @@ def solve_1d_linesearch_quad_funct(a, b, c): f1 = a + f0 + df0 if a > 0: # convex - minimum = min(1, max(0, -b / (2 * a))) + minimum = min(1, max(0, np.divide(-b, 2 * a))) return minimum else: # non convex if f0 > f1: From 94d2fe5fd0b07060426e9449de0331b88ab53df4 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:32:03 +0200 Subject: [PATCH 10/19] wizard stuff --- ot/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/optim.py b/ot/optim.py index 282b30db6..b96d92095 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -427,7 +427,7 @@ def solve_1d_linesearch_quad_funct(a, b, c): f1 = a + f0 + df0 if a > 0: # convex - minimum = min(1, max(0, np.divide(-b, 2 * a))) + minimum = min(1, max(0, np.divide(-b, 2.0 * a))) return minimum else: # non convex if f0 > f1: From 9421dddd8890d4c575b593d678eb7bdf5f933f83 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:51:57 +0200 Subject: [PATCH 11/19] Doc+armijo --- ot/gromov.py | 39 ++++++++++++++++++++------------------- ot/optim.py | 22 +++++++++++----------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 44248d149..5a57dc8da 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -33,12 +33,12 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): * C2 : Metric cost matrix in the target space * T : A coupling between those two spaces - The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as : + The square-loss function L(a,b)=|a-b|^2 is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=(a^2)/2 - * f2(b)=(b^2)/2 + * f1(a)=(a^2) + * f2(b)=(b^2) * h1(a)=a - * h2(b)=b + * h2(b)=2*b The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : @@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): """ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) @@ -307,8 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs) Print information along iterations log : bool, optional record log if True - amijo : bool, optional - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. **kwargs : dict parameters can be directly pased to the ot.optim.cg solver @@ -344,14 +344,14 @@ def df(G): return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs): """ Computes the FGW distance between two graphs see [3] .. math:: @@ -363,6 +363,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - M is the (ns,nt) metric cost matrix - :math:`f` is the regularization term ( and df is its gradient) - a and b are source and target weights (sum to 1) + - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ Parameters ---------- @@ -386,8 +387,8 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Print information along iterations log : bool, optional record log if True - amijo : bool, optional - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. **kwargs : dict parameters can be directly pased to the ot.optim.cg solver @@ -415,10 +416,10 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): """ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) @@ -456,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs Print information along iterations log : bool, optional record log if True - amijo : bool, optional - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. Returns ------- @@ -487,7 +488,7 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) log['T'] = res if log: @@ -890,7 +891,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=True, init_C=None, init_X=None): """ - Compute the fgw barycenter as presented eq (5) in [3]. + Compute the fgw barycenter as presented eq (5) in [24]. ---------- N : integer Desired number of samples of the target barycenter @@ -1065,7 +1066,7 @@ def update_sructure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): """ - Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3] + Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24] calculated at each iteration Parameters ---------- diff --git a/ot/optim.py b/ot/optim.py index b96d92095..82a91bf2a 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,13 +73,13 @@ def phi(alpha1): def do_linesearch(cost, G, deltaG, Mi, f_val, - amijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters ---------- cost : method - The FGW cost + Cost in the FW for the linesearch G : ndarray, shape(ns,nt) The transport map at a given iteration of the FW deltaG : ndarray (ns,nt) @@ -88,21 +88,21 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost f_val : float Value of the cost at G - amijo : bool, optionnal - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optionnal + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. C1 : ndarray (ns,ns), optionnal - Structure matrix in the source domain. Only used when amijo=False + Structure matrix in the source domain. Only used when armijo=False C2 : ndarray (nt,nt), optionnal - Structure matrix in the target domain. Only used when amijo=False + Structure matrix in the target domain. Only used when armijo=False reg : float, optionnal - Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False + Regularization parameter. Only used when armijo=False Gc : ndarray (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used when amijo=False + Optimal map found by linearization in the FW algorithm. Only used when armijo=False constC : ndarray (ns,nt) - Constant for the gromov cost. See [3]. Only used when amijo=False + Constant for the gromov cost. See [24]. Only used when armijo=False M : ndarray (ns,nt), optionnal - Cost matrix between the features. Only used when amijo=False + Cost matrix between the features. Only used when armijo=False Returns ------- alpha : float @@ -118,7 +118,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - if amijo: + if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices dot1 = np.dot(C1, deltaG) From d4320382fa8873d15dcaec7adca3a4723c142515 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 16:10:26 +0200 Subject: [PATCH 12/19] relative+absolute loss --- ot/optim.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index 82a91bf2a..7d103e2fe 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -135,7 +135,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -173,7 +173,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -249,8 +251,9 @@ def cost(G): if it >= numItermax: loop = 0 - delta_fval = (f_val - old_fval) / abs(f_val) - if abs(delta_fval) < stopThr: + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: loop = 0 if log: @@ -259,8 +262,8 @@ def cost(G): if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval)) + 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: return G, log @@ -269,7 +272,7 @@ def cost(G): def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with the generalized conditional gradient @@ -312,7 +315,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -386,8 +391,10 @@ def cost(G): if it >= numItermax: loop = 0 - delta_fval = (f_val - old_fval) / abs(f_val) - if abs(delta_fval) < stopThr: + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + + if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: loop = 0 if log: @@ -396,8 +403,8 @@ def cost(G): if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval)) + 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: return G, log From e1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:05:38 +0200 Subject: [PATCH 13/19] code review1 --- examples/plot_barycenter_fgw.py | 30 ++++++--- examples/plot_fgw.py | 32 ++++++++-- ot/gromov.py | 108 ++++++++++++++++++++++++++++---- ot/optim.py | 31 ++++----- test/test_gromov.py | 57 +++++++++++++---- 5 files changed, 204 insertions(+), 54 deletions(-) diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py index 9eea03613..e4be447e5 100644 --- a/examples/plot_barycenter_fgw.py +++ b/examples/plot_barycenter_fgw.py @@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7): colors.append(val_map[node]) return colors -#%% create dataset +############################################################################## +# Generate data +# ------------- + +#%% circular dataset # We build a dataset of noisy circular graphs. # Noise is added on the structures by random connections and on the features by gaussian noise. @@ -135,7 +139,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7): for k in range(9): X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) -#%% Plot dataset +############################################################################## +# Plot data +# --------- + +#%% Plot graphs plt.figure(figsize=(8, 10)) for i in range(len(X0)): @@ -146,9 +154,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7): plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) plt.show() +############################################################################## +# Barycenter computation +# ---------------------- -#%% -# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] @@ -156,14 +166,16 @@ def graph_colors(nx_graph, vmin=0, vmax=7): lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() sizebary = 15 # we choose a barycenter with 15 nodes -#%% - A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) -#%% +############################################################################## +# Plot Barycenter +# ------------------------- + +#%% Create the barycenter bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) -for i in range(len(A.ravel())): - bary.add_node(i, attr_name=float(A.ravel()[i])) +for i, v in enumerate(A.ravel()): + bary.add_node(i, attr_name=v) #%% pos = nx.kamada_kawai_layout(bary) diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index ae3c487a3..43efc94be 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -22,12 +22,16 @@ import ot from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein +############################################################################## +# Generate data +# --------- + #%% parameters # We create two 1D random measures -n = 20 -n2 = 30 -sig = 1 -sig2 = 0.1 +n = 20 # number of points in the first distribution +n2 = 30 # number of points in the second distribution +sig = 1 # std of first distribution +sig2 = 0.1 # std of second distribution np.random.seed(0) @@ -43,6 +47,10 @@ p = ot.unif(n) q = ot.unif(n2) +############################################################################## +# Plot data +# --------- + #%% plot the distributions pl.close(10) @@ -64,15 +72,22 @@ pl.tight_layout() pl.show() +############################################################################## +# Create structure matrices and across-feature distance matrix +# --------- #%% Structure matrices and across-features distance matrix C1 = ot.dist(xs) -C2 = ot.dist(xt).T +C2 = ot.dist(xt) M = ot.dist(ys, yt) w1 = ot.unif(C1.shape[0]) w2 = ot.unif(C2.shape[0]) Got = ot.emd([], [], M) +############################################################################## +# Plot matrices +# --------- + #%% cmap = 'Reds' pl.close(10) @@ -112,6 +127,9 @@ ax3.set_aspect('auto') pl.show() +############################################################################## +# Compute FGW/GW +# --------- #%% Computing FGW and GW alpha = 1e-3 @@ -123,6 +141,10 @@ #%reload_ext WGW Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) +############################################################################## +# Visualize transport matrices +# --------- + #%% visu OT matrix cmap = 'Blues' fs = 15 diff --git a/ot/gromov.py b/ot/gromov.py index 5a57dc8da..53349b73e 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -10,6 +10,7 @@ # Nicolas Courty # Rémi Flamary # Titouan Vayer +# # License: MIT License import numpy as np @@ -351,9 +352,9 @@ def df(G): return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): """ - Computes the FGW distance between two graphs see [3] + Computes the FGW transport between two graphs see [24] .. math:: \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} s.t. \gamma 1 = p @@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, distribution in the source space q : ndarray, shape (nt,) distribution in the target space - loss_fun : string,optionnal + loss_fun : string,optional loss function used for the solver max_iter : int, optional Max number of iterations @@ -416,7 +417,86 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + if log: + res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + log['fgw_dist'] = log['loss'][::-1][0] + return res, log + else: + return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): + """ + Computes the FGW distance between two graphs see [24] + .. math:: + \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. \gamma 1 = p + \gamma^T 1= q + \gamma\geq 0 + where : + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + - L is a loss function to account for the misfit between the similarity matrices + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + Parameters + ---------- + M : ndarray, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix respresentative of the structure in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix espresentative of the structure in the target space + p : ndarray, shape (ns,) + distribution in the source space + q : ndarray, shape (nt,) + distribution in the target space + loss_fun : string,optional + loss function used for the solver + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + **kwargs : dict + parameters can be directly pased to the ot.optim.cg solver + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + + def f(G): + return gwloss(constC, hC1, hC2, G) + + def df(G): + return gwggrad(constC, hC1, hC2, G) + + res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + if log: + log['fgw_dist'] = log['loss'][::-1][0] + log['T'] = res + return log['fgw_dist'], log + else: + return log['fgw_dist'] def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): @@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=True, init_C=None, init_X=None): + verbose=False, log=False, init_C=None, init_X=None): """ Compute the fgw barycenter as presented eq (5) in [24]. ---------- @@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Barycenters' features C : ndarray, shape (N,N) Barycenters' structure matrix - log_: + log_: dictionary + Only returned when log=True T : list of (N,ns) transport matrices Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) References @@ -1015,14 +1096,13 @@ class UndefinedParameter(Exception): T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns - - log_['Ts_iter'].append(T) err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) err_structure = np.linalg.norm(C - Cprev) if log: log_['err_feature'].append(err_feature) log_['err_structure'].append(err_structure) + log_['Ts_iter'].append(T) if verbose: if cpt % 200 == 0: @@ -1032,11 +1112,15 @@ class UndefinedParameter(Exception): print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T'] = T # from target to Ys - log_['p'] = p - log_['Ms'] = Ms # Ms are N,ns + if log: + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms # Ms are N,ns - return X, C, log_ + if log: + return X, C, log_ + else: + return X, C def update_sructure_matrix(p, lambdas, T, Cs): diff --git a/ot/optim.py b/ot/optim.py index 7d103e2fe..4d428d9ed 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -5,6 +5,7 @@ # Author: Remi Flamary # Titouan Vayer +# # License: MIT License import numpy as np @@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost f_val : float Value of the cost at G - armijo : bool, optionnal + armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. - C1 : ndarray (ns,ns), optionnal + C1 : ndarray (ns,ns), optional Structure matrix in the source domain. Only used when armijo=False - C2 : ndarray (nt,nt), optionnal + C2 : ndarray (nt,nt), optional Structure matrix in the target domain. Only used when armijo=False - reg : float, optionnal + reg : float, optional Regularization parameter. Only used when armijo=False Gc : ndarray (ns,nt) Optimal map found by linearization in the FW algorithm. Only used when armijo=False constC : ndarray (ns,nt) Constant for the gromov cost. See [24]. Only used when armijo=False - M : ndarray (ns,nt), optionnal + M : ndarray (ns,nt), optional Cost matrix between the features. Only used when armijo=False Returns ------- @@ -223,9 +224,9 @@ def cost(G): it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -261,8 +262,8 @@ def cost(G): if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: @@ -363,9 +364,9 @@ def cost(G): it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -402,8 +403,8 @@ def cost(G): if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: diff --git a/test/test_gromov.py b/test/test_gromov.py index cd180d481..ec85abf5a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -2,6 +2,7 @@ # Author: Erwan Vautier # Nicolas Courty +# Titouan Vayer # # License: MIT License @@ -10,6 +11,8 @@ def test_gromov(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -36,6 +39,11 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov + Id = (1 / n_samples) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + G, np.flipud(Id), atol=1e-04) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) G = log['T'] @@ -50,6 +58,8 @@ def test_gromov(): def test_entropic_gromov(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -92,6 +102,7 @@ def test_entropic_gromov(): def test_gromov_barycenter(): + np.random.seed(42) ns = 50 nt = 60 @@ -120,7 +131,7 @@ def test_gromov_barycenter(): def test_gromov_entropic_barycenter(): - + np.random.seed(42) ns = 50 nt = 60 @@ -148,6 +159,8 @@ def test_gromov_entropic_barycenter(): def test_fgw(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -180,8 +193,26 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw + Id = (1 / n_samples) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + G, np.flipud(Id), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + + G = log['T'] + + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + def test_fgw_barycenter(): + np.random.seed(42) ns = 50 nt = 60 @@ -196,28 +227,28 @@ def test_fgw_barycenter(): C2 = ot.dist(Xt) n_samples = 3 - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, + fixed_structure=True, init_C=init_C, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) From 28059eb5e0aad715823ee4f6509d6a9e3d6e5db0 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:11:41 +0200 Subject: [PATCH 14/19] py2 error --- test/test_gromov.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index ec85abf5a..3ca184b87 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -13,7 +13,7 @@ def test_gromov(): np.random.seed(42) - n_samples = 50 # nb samples + n_samples = 50.0 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -161,7 +161,7 @@ def test_gromov_entropic_barycenter(): def test_fgw(): np.random.seed(42) - n_samples = 50 # nb samples + n_samples = 50.0 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) From 63093cef7af3350228251aa930872c6f30789432 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:19:13 +0200 Subject: [PATCH 15/19] n_samples float --- test/test_gromov.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index 3ca184b87..d7a12f31b 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -13,7 +13,7 @@ def test_gromov(): np.random.seed(42) - n_samples = 50.0 # nb samples + n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -39,7 +39,7 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov - Id = (1 / n_samples) * np.eye(n_samples, n_samples) + Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) @@ -161,7 +161,7 @@ def test_gromov_entropic_barycenter(): def test_fgw(): np.random.seed(42) - n_samples = 50.0 # nb samples + n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -193,7 +193,7 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw - Id = (1 / n_samples) * np.eye(n_samples, n_samples) + Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) # cf convergence gromov From 9bb7d40b563f42bf2875efca860bf0c579307161 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:52:20 +0200 Subject: [PATCH 16/19] .0 --- test/test_gromov.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index d7a12f31b..b7ede95b3 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -39,7 +39,7 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov - Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) + Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) @@ -193,7 +193,7 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw - Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) + Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) # cf convergence gromov From 89a2e0aee4353a051d924de0457f8976c26fa5d7 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 18:02:27 +0200 Subject: [PATCH 17/19] pep8 + err --- test/test_gromov.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index b7ede95b3..f218b74a7 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -39,7 +39,7 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov - Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) + Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) @@ -193,7 +193,7 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw - Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) + Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) # cf convergence gromov From ad450b0a5bb63ee9731e88d4a8e7423b16f1abd8 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 4 Jun 2019 10:32:30 +0200 Subject: [PATCH 18/19] changes forgotten coments --- ot/gromov.py | 26 +++----------------------- ot/optim.py | 32 ++++++++++++++++---------------- ot/utils.py | 8 ++++++++ test/test_optim.py | 6 +++--- 4 files changed, 30 insertions(+), 42 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 53349b73e..ca96b31d9 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -17,7 +17,7 @@ from .bregman import sinkhorn -from .utils import dist +from .utils import dist, UndefinedParameter from .optim import cg @@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ International Conference on Machine Learning (ICML). 2019. """ - class UndefinedParameter(Exception): - pass - S = len(Cs) d = Ys[0].shape[1] # dimension on the node features if p is None: @@ -1049,10 +1046,7 @@ class UndefinedParameter(Exception): T = [np.outer(p, q) for q in ps] - # X is N,d - # Ys is ns,d - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] - # Ms is N,ns + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns cpt = 0 err_feature = 1 @@ -1072,27 +1066,13 @@ class UndefinedParameter(Exception): Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - # X must be N,d - # Ys must be ns,d Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': - # T must be ns,N - # Cs must be ns,ns - # p must be N,1 T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) - # Ys must be d,ns - # Ts must be N,ns - # p must be N,1 - # Ms is N,ns - # C is N,N - # Cs is ns,ns - # p is N,1 - # ps is ns,1 - T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns @@ -1115,7 +1095,7 @@ class UndefinedParameter(Exception): if log: log_['T'] = T # from target to Ys log_['p'] = p - log_['Ms'] = Ms # Ms are N,ns + log_['Ms'] = Ms if log: return X, C, log_ diff --git a/ot/optim.py b/ot/optim.py index 4d428d9ed..f94acebb4 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,8 +73,8 @@ def phi(alpha1): return alpha, fc[0], phi1 -def do_linesearch(cost, G, deltaG, Mi, f_val, - armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): +def solve_linesearch(cost, G, deltaG, Mi, f_val, + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters @@ -93,17 +93,17 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. C1 : ndarray (ns,ns), optional - Structure matrix in the source domain. Only used when armijo=False + Structure matrix in the source domain. Only used and necessary when armijo=False C2 : ndarray (nt,nt), optional - Structure matrix in the target domain. Only used when armijo=False + Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used when armijo=False + Regularization parameter. Only used and necessary when armijo=False Gc : ndarray (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used when armijo=False + Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used when armijo=False + Constant for the gromov cost. See [24]. Only used and necessary when armijo=False M : ndarray (ns,nt), optional - Cost matrix between the features. Only used when armijo=False + Cost matrix between the features. Only used and necessary when armijo=False Returns ------- alpha : float @@ -128,7 +128,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) c = cost(G) - alpha = solve_1d_linesearch_quad_funct(a, b, c) + alpha = solve_1d_linesearch_quad(a, b, c) fc = None f_val = cost(G + alpha * deltaG) @@ -181,7 +181,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True - kwargs : dict + **kwargs : dict Parameters for linesearch Returns @@ -244,7 +244,7 @@ def cost(G): deltaG = Gc - G # line search - alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -254,7 +254,7 @@ def cost(G): abs_delta_fval = abs(f_val - old_fval) relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -395,7 +395,7 @@ def cost(G): abs_delta_fval = abs(f_val - old_fval) relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -413,11 +413,11 @@ def cost(G): return G -def solve_1d_linesearch_quad_funct(a, b, c): +def solve_1d_linesearch_quad(a, b, c): """ - Solve on 0,1 the following problem: + For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: .. math:: - \min f(x)=a*x^{2}+b*x+c + \argmin f(x)=a*x^{2}+b*x+c Parameters ---------- diff --git a/ot/utils.py b/ot/utils.py index bb21b3887..efd1288db 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -487,3 +487,11 @@ def set_params(self, **params): (key, self.__class__.__name__)) setattr(self, key, value) return self + + +class UndefinedParameter(Exception): + """ + Aim at raising an Exception when a undefined parameter is called + + """ + pass diff --git a/test/test_optim.py b/test/test_optim.py index e7ba32a59..ae31e1f6c 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -68,6 +68,6 @@ def df(G): def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) From 788a6506c9bf3b862a9652d74f65f8d07851e653 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 4 Jun 2019 11:34:46 +0200 Subject: [PATCH 19/19] seed --- test/test_gromov.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index f218b74a7..70fa83f10 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -11,14 +11,12 @@ def test_gromov(): - np.random.seed(42) - n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() @@ -58,14 +56,12 @@ def test_gromov(): def test_entropic_gromov(): - np.random.seed(42) - n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) xt = xs[::-1].copy() @@ -102,13 +98,11 @@ def test_entropic_gromov(): def test_gromov_barycenter(): - np.random.seed(42) - ns = 50 nt = 60 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -131,12 +125,11 @@ def test_gromov_barycenter(): def test_gromov_entropic_barycenter(): - np.random.seed(42) ns = 50 nt = 60 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -159,14 +152,13 @@ def test_gromov_entropic_barycenter(): def test_fgw(): - np.random.seed(42) n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) xt = xs[::-1].copy() @@ -217,8 +209,8 @@ def test_fgw_barycenter(): ns = 50 nt = 60 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) ys = np.random.randn(Xs.shape[0], 2) yt = np.random.randn(Xt.shape[0], 2)