Skip to content

Commit 982510e

Browse files
authored
[MRG] Update example GAN to avoid the 10 minute CircleCI limit (#258)
* shortened example GAN * pep8 and typo * awesome animation * small eror pep8 * add animation to doc * better timing animation * tune step
1 parent 221e04b commit 982510e

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

docs/source/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,16 @@ def __getattr__(cls, name):
337337
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
338338
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
339339
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
340-
'matplotlib': ('http://matplotlib.org/', None)}
340+
'matplotlib': ('http://matplotlib.org/', None),
341+
'torch': ('https://pytorch.org/docs/stable/', None)}
341342

342343
sphinx_gallery_conf = {
343344
'examples_dirs': ['../../examples', '../../examples/da'],
344345
'gallery_dirs': 'auto_examples',
345346
'backreferences_dir': 'gen_modules/backreferences',
346347
'inspect_global_variables' : True,
347348
'doc_module' : ('ot','numpy','scipy','pylab'),
349+
'matplotlib_animations': True,
348350
'reference_url': {
349351
'ot': None}
350352
}

examples/backends/plot_wass2_gan_torch.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
import numpy as np
5252
import matplotlib.pyplot as pl
53+
import matplotlib.animation as animation
5354
import torch
5455
from torch import nn
5556
import ot
@@ -112,10 +113,10 @@ def forward(self, x):
112113

113114

114115
G = Generator()
115-
optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001)
116+
optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
116117

117118
# number of iteration and size of the batches
118-
n_iter = 500
119+
n_iter = 200 # set to 200 for doc build but 1000 is better ;)
119120
size_batch = 500
120121

121122
# generate statis samples to see their trajectory along training
@@ -167,7 +168,7 @@ def forward(self, x):
167168

168169
pl.figure(3, (10, 10))
169170

170-
ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499]
171+
ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
171172

172173
for i in range(9):
173174
pl.subplot(3, 3, i + 1)
@@ -179,6 +180,37 @@ def forward(self, x):
179180
if i == 0:
180181
pl.legend()
181182

183+
# %%
184+
# Animate trajectories of generated samples along iteration
185+
# -------------------------------------------------------
186+
187+
pl.figure(4, (8, 8))
188+
189+
190+
def _update_plot(i):
191+
pl.clf()
192+
pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
193+
pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
194+
pl.xticks(())
195+
pl.yticks(())
196+
pl.xlim((-1.5, 1.5))
197+
pl.ylim((-1.5, 1.5))
198+
pl.title('Iter. {}'.format(i))
199+
return 1
200+
201+
202+
i = 0
203+
pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1)
204+
pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
205+
pl.xticks(())
206+
pl.yticks(())
207+
pl.xlim((-1.5, 1.5))
208+
pl.ylim((-1.5, 1.5))
209+
pl.title('Iter. {}'.format(ivisu[i]))
210+
211+
212+
ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
213+
182214
# %%
183215
# Generate and visualize data
184216
# ---------------------------
@@ -188,7 +220,7 @@ def forward(self, x):
188220
xn = torch.randn(size_batch, 2)
189221
x = G(xn).detach().numpy()
190222

191-
pl.figure(4)
223+
pl.figure(5)
192224
pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5)
193225
pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5)
194226
pl.title('Sources and Target distributions')

0 commit comments

Comments
 (0)