Skip to content

[MRG] Gromov-Wasserstein closed form for linesearch and integration of Fused Gromov-Wasserstein #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 4, 2019
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ The contributors to this library are:
* Erwan Vautier (Gromov-Wasserstein)
* [Kilian Fatras](https://kilianfatras.github.io/)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
* [Vayer Titouan](https://tvayer.github.io/)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down Expand Up @@ -233,3 +234,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31

[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018

[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).
184 changes: 184 additions & 0 deletions examples/plot_barycenter_fgw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
"""
=================================
Plot graphs' barycenter using FGW
=================================

This example illustrates the computation barycenter of labeled graphs using FGW

Requires networkx >=2

.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.

"""

# Author: Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License

#%% load libraries
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import math
from scipy.sparse.csgraph import shortest_path
import matplotlib.colors as mcol
from matplotlib import cm
from ot.gromov import fgw_barycenters
#%% Graph functions


def find_thresh(C, inf=0.5, sup=3, step=10):
""" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
and the original matrix.
Parameters
----------
C : ndarray, shape (n_nodes,n_nodes)
The structure matrix to threshold
inf : float
The beginning of the linesearch
sup : float
The end of the linesearch
step : integer
Number of thresholds tested
"""
dist = []
search = np.linspace(inf, sup, step)
for thresh in search:
Cprime = sp_to_adjency(C, 0, thresh)
SC = shortest_path(Cprime, method='D')
SC[SC == float('inf')] = 100
dist.append(np.linalg.norm(SC - C))
return search[np.argmin(dist)], dist


def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
""" Thresholds the structure matrix in order to compute an adjency matrix.
All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
Parameters
----------
C : ndarray, shape (n_nodes,n_nodes)
The structure matrix to threshold
threshinf : float
The minimum value of distance from which the new value is set to 1
threshsup : float
The maximum value of distance from which the new value is set to 1
Returns
-------
C : ndarray, shape (n_nodes,n_nodes)
The threshold matrix. Each element is in {0,1}
"""
H = np.zeros_like(C)
np.fill_diagonal(H, np.diagonal(C))
C = C - H
C = np.minimum(np.maximum(C, threshinf), threshsup)
C[C == threshsup] = 0
C[C != 0] = 1

return C


def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
""" Create a noisy circular graph
"""
g = nx.Graph()
g.add_nodes_from(list(range(N)))
for i in range(N):
noise = float(np.random.normal(mu, sigma, 1))
if with_noise:
g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
else:
g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
g.add_edge(i, i + 1)
if structure_noise:
randomint = np.random.randint(0, p)
if randomint == 0:
if i <= N - 3:
g.add_edge(i, i + 2)
if i == N - 2:
g.add_edge(i, 0)
if i == N - 1:
g.add_edge(i, 1)
g.add_edge(N, 0)
noise = float(np.random.normal(mu, sigma, 1))
if with_noise:
g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
else:
g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
return g


def graph_colors(nx_graph, vmin=0, vmax=7):
cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
cpick.set_array([])
val_map = {}
for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
val_map[k] = cpick.to_rgba(v)
colors = []
for node in nx_graph.nodes():
colors.append(val_map[node])
return colors

##############################################################################
# Generate data
# -------------

#%% circular dataset
# We build a dataset of noisy circular graphs.
# Noise is added on the structures by random connections and on the features by gaussian noise.


np.random.seed(30)
X0 = []
for k in range(9):
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))

##############################################################################
# Plot data
# ---------

#%% Plot graphs

plt.figure(figsize=(8, 10))
for i in range(len(X0)):
plt.subplot(3, 3, i + 1)
g = X0[i]
pos = nx.kamada_kawai_layout(g)
nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
plt.show()

##############################################################################
# Barycenter computation
# ----------------------

#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
# Features distances are the euclidean distances
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
sizebary = 15 # we choose a barycenter with 15 nodes

A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)

##############################################################################
# Plot Barycenter
# -------------------------

#%% Create the barycenter
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
for i, v in enumerate(A.ravel()):
bary.add_node(i, attr_name=v)

#%%
pos = nx.kamada_kawai_layout(bary)
nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
plt.suptitle('Barycenter', fontsize=20)
plt.show()
173 changes: 173 additions & 0 deletions examples/plot_fgw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
"""
==============================
Plot Fused-gromov-Wasserstein
==============================

This example illustrates the computation of FGW for 1D measures[18].

.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.

"""

# Author: Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License

import matplotlib.pyplot as pl
import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein

##############################################################################
# Generate data
# ---------

#%% parameters
# We create two 1D random measures
n = 20 # number of points in the first distribution
n2 = 30 # number of points in the second distribution
sig = 1 # std of first distribution
sig2 = 0.1 # std of second distribution

np.random.seed(0)

phi = np.arange(n)[:, None]
xs = phi + sig * np.random.randn(n, 1)
ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)

phi2 = np.arange(n2)[:, None]
xt = phi2 + sig * np.random.randn(n2, 1)
yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
yt = yt[::-1, :]

p = ot.unif(n)
q = ot.unif(n2)

##############################################################################
# Plot data
# ---------

#%% plot the distributions

pl.close(10)
pl.figure(10, (7, 7))

pl.subplot(2, 1, 1)

pl.scatter(ys, xs, c=phi, s=70)
pl.ylabel('Feature value a', fontsize=20)
pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
pl.xticks(())
pl.yticks(())
pl.subplot(2, 1, 2)
pl.scatter(yt, xt, c=phi2, s=70)
pl.xlabel('coordinates x/y', fontsize=25)
pl.ylabel('Feature value b', fontsize=20)
pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
pl.yticks(())
pl.tight_layout()
pl.show()

##############################################################################
# Create structure matrices and across-feature distance matrix
# ---------

#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
C2 = ot.dist(xt)
M = ot.dist(ys, yt)
w1 = ot.unif(C1.shape[0])
w2 = ot.unif(C2.shape[0])
Got = ot.emd([], [], M)

##############################################################################
# Plot matrices
# ---------

#%%
cmap = 'Reds'
pl.close(10)
pl.figure(10, (5, 5))
fs = 15
l_x = [0, 5, 10, 15]
l_y = [0, 5, 10, 15, 20, 25]
gs = pl.GridSpec(5, 5)

ax1 = pl.subplot(gs[3:, :2])

pl.imshow(C1, cmap=cmap, interpolation='nearest')
pl.title("$C_1$", fontsize=fs)
pl.xlabel("$k$", fontsize=fs)
pl.ylabel("$i$", fontsize=fs)
pl.xticks(l_x)
pl.yticks(l_x)

ax2 = pl.subplot(gs[:3, 2:])

pl.imshow(C2, cmap=cmap, interpolation='nearest')
pl.title("$C_2$", fontsize=fs)
pl.ylabel("$l$", fontsize=fs)
#pl.ylabel("$l$",fontsize=fs)
pl.xticks(())
pl.yticks(l_y)
ax2.set_aspect('auto')

ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
pl.imshow(M, cmap=cmap, interpolation='nearest')
pl.yticks(l_x)
pl.xticks(l_y)
pl.ylabel("$i$", fontsize=fs)
pl.title("$M_{AB}$", fontsize=fs)
pl.xlabel("$j$", fontsize=fs)
pl.tight_layout()
ax3.set_aspect('auto')
pl.show()

##############################################################################
# Compute FGW/GW
# ---------

#%% Computing FGW and GW
alpha = 1e-3

ot.tic()
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
ot.toc()

#%reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)

##############################################################################
# Visualize transport matrices
# ---------

#%% visu OT matrix
cmap = 'Blues'
fs = 15
pl.figure(2, (13, 5))
pl.clf()
pl.subplot(1, 3, 1)
pl.imshow(Got, cmap=cmap, interpolation='nearest')
#pl.xlabel("$y$",fontsize=fs)
pl.ylabel("$i$", fontsize=fs)
pl.xticks(())

pl.title('Wasserstein ($M$ only)')

pl.subplot(1, 3, 2)
pl.imshow(Gg, cmap=cmap, interpolation='nearest')
pl.title('Gromov ($C_1,C_2$ only)')
pl.xticks(())
pl.subplot(1, 3, 3)
pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
pl.title('FGW ($M+C_1,C_2$)')

pl.xlabel("$j$", fontsize=fs)
pl.ylabel("$i$", fontsize=fs)

pl.tight_layout()
pl.show()
1 change: 1 addition & 0 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License

Expand Down
Loading