-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Gromov-Wasserstein closed form for linesearch and integration of Fused Gromov-Wasserstein #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
549b95b
b1b514f
cd4b98c
11c2c26
6484c9e
f70aabf
63bbeb3
Merge remote-tracking branch 'rflamary/master'
tvayer fa98906
103dfe0
915d5fa
94d2fe5
9421ddd
d432038
e1bd94b
code review1
tvayer 28059eb
63093ce
9bb7d40
89a2e0a
ad450b0
changes forgotten coments
tvayer 788a650
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
================================= | ||
Plot graphs' barycenter using FGW | ||
================================= | ||
|
||
This example illustrates the computation barycenter of labeled graphs using FGW | ||
|
||
Requires networkx >=2 | ||
|
||
.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain | ||
and Courty Nicolas | ||
"Optimal Transport for structured data with application on graphs" | ||
International Conference on Machine Learning (ICML). 2019. | ||
|
||
""" | ||
|
||
# Author: Titouan Vayer <titouan.vayer@irisa.fr> | ||
# | ||
# License: MIT License | ||
|
||
#%% load libraries | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import networkx as nx | ||
import math | ||
from scipy.sparse.csgraph import shortest_path | ||
import matplotlib.colors as mcol | ||
from matplotlib import cm | ||
from ot.gromov import fgw_barycenters | ||
#%% Graph functions | ||
|
||
|
||
def find_thresh(C, inf=0.5, sup=3, step=10): | ||
""" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected | ||
Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. | ||
The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix | ||
and the original matrix. | ||
Parameters | ||
---------- | ||
C : ndarray, shape (n_nodes,n_nodes) | ||
The structure matrix to threshold | ||
inf : float | ||
The beginning of the linesearch | ||
sup : float | ||
The end of the linesearch | ||
step : integer | ||
Number of thresholds tested | ||
""" | ||
dist = [] | ||
search = np.linspace(inf, sup, step) | ||
for thresh in search: | ||
Cprime = sp_to_adjency(C, 0, thresh) | ||
SC = shortest_path(Cprime, method='D') | ||
SC[SC == float('inf')] = 100 | ||
dist.append(np.linalg.norm(SC - C)) | ||
return search[np.argmin(dist)], dist | ||
|
||
|
||
def sp_to_adjency(C, threshinf=0.2, threshsup=1.8): | ||
""" Thresholds the structure matrix in order to compute an adjency matrix. | ||
All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 | ||
Parameters | ||
---------- | ||
C : ndarray, shape (n_nodes,n_nodes) | ||
The structure matrix to threshold | ||
threshinf : float | ||
The minimum value of distance from which the new value is set to 1 | ||
threshsup : float | ||
The maximum value of distance from which the new value is set to 1 | ||
Returns | ||
------- | ||
C : ndarray, shape (n_nodes,n_nodes) | ||
The threshold matrix. Each element is in {0,1} | ||
""" | ||
H = np.zeros_like(C) | ||
np.fill_diagonal(H, np.diagonal(C)) | ||
C = C - H | ||
C = np.minimum(np.maximum(C, threshinf), threshsup) | ||
C[C == threshsup] = 0 | ||
C[C != 0] = 1 | ||
|
||
return C | ||
|
||
|
||
def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): | ||
""" Create a noisy circular graph | ||
""" | ||
g = nx.Graph() | ||
g.add_nodes_from(list(range(N))) | ||
for i in range(N): | ||
noise = float(np.random.normal(mu, sigma, 1)) | ||
if with_noise: | ||
g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) | ||
else: | ||
g.add_node(i, attr_name=math.sin(2 * i * math.pi / N)) | ||
g.add_edge(i, i + 1) | ||
if structure_noise: | ||
randomint = np.random.randint(0, p) | ||
if randomint == 0: | ||
if i <= N - 3: | ||
g.add_edge(i, i + 2) | ||
if i == N - 2: | ||
g.add_edge(i, 0) | ||
if i == N - 1: | ||
g.add_edge(i, 1) | ||
g.add_edge(N, 0) | ||
noise = float(np.random.normal(mu, sigma, 1)) | ||
if with_noise: | ||
g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) | ||
else: | ||
g.add_node(N, attr_name=math.sin(2 * N * math.pi / N)) | ||
return g | ||
|
||
|
||
def graph_colors(nx_graph, vmin=0, vmax=7): | ||
cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) | ||
cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') | ||
cpick.set_array([]) | ||
val_map = {} | ||
for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): | ||
val_map[k] = cpick.to_rgba(v) | ||
colors = [] | ||
for node in nx_graph.nodes(): | ||
colors.append(val_map[node]) | ||
return colors | ||
|
||
############################################################################## | ||
# Generate data | ||
# ------------- | ||
|
||
#%% circular dataset | ||
# We build a dataset of noisy circular graphs. | ||
# Noise is added on the structures by random connections and on the features by gaussian noise. | ||
|
||
|
||
np.random.seed(30) | ||
X0 = [] | ||
for k in range(9): | ||
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) | ||
|
||
############################################################################## | ||
# Plot data | ||
# --------- | ||
|
||
#%% Plot graphs | ||
|
||
plt.figure(figsize=(8, 10)) | ||
for i in range(len(X0)): | ||
plt.subplot(3, 3, i + 1) | ||
g = X0[i] | ||
pos = nx.kamada_kawai_layout(g) | ||
nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) | ||
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) | ||
plt.show() | ||
|
||
############################################################################## | ||
# Barycenter computation | ||
# ---------------------- | ||
|
||
#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph | ||
# Features distances are the euclidean distances | ||
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] | ||
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] | ||
Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] | ||
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() | ||
sizebary = 15 # we choose a barycenter with 15 nodes | ||
|
||
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) | ||
|
||
############################################################################## | ||
# Plot Barycenter | ||
# ------------------------- | ||
|
||
#%% Create the barycenter | ||
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) | ||
for i, v in enumerate(A.ravel()): | ||
bary.add_node(i, attr_name=v) | ||
|
||
#%% | ||
pos = nx.kamada_kawai_layout(bary) | ||
nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) | ||
plt.suptitle('Barycenter', fontsize=20) | ||
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
============================== | ||
Plot Fused-gromov-Wasserstein | ||
============================== | ||
|
||
This example illustrates the computation of FGW for 1D measures[18]. | ||
|
||
.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain | ||
and Courty Nicolas | ||
"Optimal Transport for structured data with application on graphs" | ||
International Conference on Machine Learning (ICML). 2019. | ||
|
||
""" | ||
|
||
# Author: Titouan Vayer <titouan.vayer@irisa.fr> | ||
# | ||
# License: MIT License | ||
|
||
import matplotlib.pyplot as pl | ||
import numpy as np | ||
import ot | ||
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein | ||
|
||
############################################################################## | ||
# Generate data | ||
# --------- | ||
|
||
#%% parameters | ||
# We create two 1D random measures | ||
n = 20 # number of points in the first distribution | ||
n2 = 30 # number of points in the second distribution | ||
sig = 1 # std of first distribution | ||
sig2 = 0.1 # std of second distribution | ||
|
||
np.random.seed(0) | ||
|
||
phi = np.arange(n)[:, None] | ||
xs = phi + sig * np.random.randn(n, 1) | ||
ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1) | ||
|
||
phi2 = np.arange(n2)[:, None] | ||
xt = phi2 + sig * np.random.randn(n2, 1) | ||
yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1) | ||
yt = yt[::-1, :] | ||
|
||
p = ot.unif(n) | ||
q = ot.unif(n2) | ||
|
||
############################################################################## | ||
# Plot data | ||
# --------- | ||
|
||
#%% plot the distributions | ||
|
||
pl.close(10) | ||
pl.figure(10, (7, 7)) | ||
|
||
pl.subplot(2, 1, 1) | ||
|
||
pl.scatter(ys, xs, c=phi, s=70) | ||
pl.ylabel('Feature value a', fontsize=20) | ||
pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1) | ||
pl.xticks(()) | ||
pl.yticks(()) | ||
pl.subplot(2, 1, 2) | ||
pl.scatter(yt, xt, c=phi2, s=70) | ||
pl.xlabel('coordinates x/y', fontsize=25) | ||
pl.ylabel('Feature value b', fontsize=20) | ||
pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1) | ||
pl.yticks(()) | ||
pl.tight_layout() | ||
pl.show() | ||
|
||
############################################################################## | ||
# Create structure matrices and across-feature distance matrix | ||
# --------- | ||
|
||
#%% Structure matrices and across-features distance matrix | ||
C1 = ot.dist(xs) | ||
C2 = ot.dist(xt) | ||
M = ot.dist(ys, yt) | ||
w1 = ot.unif(C1.shape[0]) | ||
w2 = ot.unif(C2.shape[0]) | ||
Got = ot.emd([], [], M) | ||
|
||
############################################################################## | ||
# Plot matrices | ||
# --------- | ||
|
||
#%% | ||
cmap = 'Reds' | ||
pl.close(10) | ||
pl.figure(10, (5, 5)) | ||
fs = 15 | ||
l_x = [0, 5, 10, 15] | ||
l_y = [0, 5, 10, 15, 20, 25] | ||
gs = pl.GridSpec(5, 5) | ||
|
||
ax1 = pl.subplot(gs[3:, :2]) | ||
|
||
pl.imshow(C1, cmap=cmap, interpolation='nearest') | ||
pl.title("$C_1$", fontsize=fs) | ||
pl.xlabel("$k$", fontsize=fs) | ||
pl.ylabel("$i$", fontsize=fs) | ||
pl.xticks(l_x) | ||
pl.yticks(l_x) | ||
|
||
ax2 = pl.subplot(gs[:3, 2:]) | ||
|
||
pl.imshow(C2, cmap=cmap, interpolation='nearest') | ||
pl.title("$C_2$", fontsize=fs) | ||
pl.ylabel("$l$", fontsize=fs) | ||
#pl.ylabel("$l$",fontsize=fs) | ||
pl.xticks(()) | ||
pl.yticks(l_y) | ||
ax2.set_aspect('auto') | ||
|
||
ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1) | ||
pl.imshow(M, cmap=cmap, interpolation='nearest') | ||
pl.yticks(l_x) | ||
pl.xticks(l_y) | ||
pl.ylabel("$i$", fontsize=fs) | ||
pl.title("$M_{AB}$", fontsize=fs) | ||
pl.xlabel("$j$", fontsize=fs) | ||
pl.tight_layout() | ||
ax3.set_aspect('auto') | ||
pl.show() | ||
|
||
############################################################################## | ||
# Compute FGW/GW | ||
# --------- | ||
|
||
#%% Computing FGW and GW | ||
alpha = 1e-3 | ||
|
||
ot.tic() | ||
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) | ||
ot.toc() | ||
|
||
#%reload_ext WGW | ||
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) | ||
|
||
############################################################################## | ||
# Visualize transport matrices | ||
# --------- | ||
|
||
#%% visu OT matrix | ||
cmap = 'Blues' | ||
fs = 15 | ||
pl.figure(2, (13, 5)) | ||
pl.clf() | ||
pl.subplot(1, 3, 1) | ||
pl.imshow(Got, cmap=cmap, interpolation='nearest') | ||
#pl.xlabel("$y$",fontsize=fs) | ||
pl.ylabel("$i$", fontsize=fs) | ||
pl.xticks(()) | ||
|
||
pl.title('Wasserstein ($M$ only)') | ||
|
||
pl.subplot(1, 3, 2) | ||
pl.imshow(Gg, cmap=cmap, interpolation='nearest') | ||
pl.title('Gromov ($C_1,C_2$ only)') | ||
pl.xticks(()) | ||
pl.subplot(1, 3, 3) | ||
pl.imshow(Gwg, cmap=cmap, interpolation='nearest') | ||
pl.title('FGW ($M+C_1,C_2$)') | ||
|
||
pl.xlabel("$j$", fontsize=fs) | ||
pl.ylabel("$i$", fontsize=fs) | ||
|
||
pl.tight_layout() | ||
pl.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.