50
50
51
51
import numpy as np
52
52
import matplotlib .pyplot as pl
53
+ import matplotlib .animation as animation
53
54
import torch
54
55
from torch import nn
55
56
import ot
@@ -112,10 +113,10 @@ def forward(self, x):
112
113
113
114
114
115
G = Generator ()
115
- optimizer = torch .optim .RMSprop (G .parameters (), lr = 0.001 )
116
+ optimizer = torch .optim .RMSprop (G .parameters (), lr = 0.00019 , eps = 1e-5 )
116
117
117
118
# number of iteration and size of the batches
118
- n_iter = 500
119
+ n_iter = 200 # set to 200 for doc build but 1000 is better ;)
119
120
size_batch = 500
120
121
121
122
# generate statis samples to see their trajectory along training
@@ -167,7 +168,7 @@ def forward(self, x):
167
168
168
169
pl .figure (3 , (10 , 10 ))
169
170
170
- ivisu = [0 , 10 , 50 , 100 , 150 , 200 , 300 , 400 , 499 ]
171
+ ivisu = [0 , 10 , 25 , 50 , 75 , 125 , 15 , 175 , 199 ]
171
172
172
173
for i in range (9 ):
173
174
pl .subplot (3 , 3 , i + 1 )
@@ -179,6 +180,37 @@ def forward(self, x):
179
180
if i == 0 :
180
181
pl .legend ()
181
182
183
+ # %%
184
+ # Animate trajectories of generated samples along iteration
185
+ # -------------------------------------------------------
186
+
187
+ pl .figure (4 , (8 , 8 ))
188
+
189
+
190
+ def _update_plot (i ):
191
+ pl .clf ()
192
+ pl .scatter (xd [:, 0 ], xd [:, 1 ], label = 'Data samples from $\mu_d$' , alpha = 0.1 )
193
+ pl .scatter (xvisu [i , :, 0 ], xvisu [i , :, 1 ], label = 'Data samples from $G\#\mu_n$' , alpha = 0.5 )
194
+ pl .xticks (())
195
+ pl .yticks (())
196
+ pl .xlim ((- 1.5 , 1.5 ))
197
+ pl .ylim ((- 1.5 , 1.5 ))
198
+ pl .title ('Iter. {}' .format (i ))
199
+ return 1
200
+
201
+
202
+ i = 0
203
+ pl .scatter (xd [:, 0 ], xd [:, 1 ], label = 'Data samples from $\mu_d$' , alpha = 0.1 )
204
+ pl .scatter (xvisu [i , :, 0 ], xvisu [i , :, 1 ], label = 'Data samples from $G\#\mu_n$' , alpha = 0.5 )
205
+ pl .xticks (())
206
+ pl .yticks (())
207
+ pl .xlim ((- 1.5 , 1.5 ))
208
+ pl .ylim ((- 1.5 , 1.5 ))
209
+ pl .title ('Iter. {}' .format (ivisu [i ]))
210
+
211
+
212
+ ani = animation .FuncAnimation (pl .gcf (), _update_plot , n_iter , interval = 100 , repeat_delay = 2000 )
213
+
182
214
# %%
183
215
# Generate and visualize data
184
216
# ---------------------------
@@ -188,7 +220,7 @@ def forward(self, x):
188
220
xn = torch .randn (size_batch , 2 )
189
221
x = G (xn ).detach ().numpy ()
190
222
191
- pl .figure (4 )
223
+ pl .figure (5 )
192
224
pl .scatter (xd [:, 0 ], xd [:, 1 ], label = 'Data samples from $\mu_d$' , alpha = 0.5 )
193
225
pl .scatter (x [:, 0 ], x [:, 1 ], label = 'Data samples from $G\#\mu_n$' , alpha = 0.5 )
194
226
pl .title ('Sources and Target distributions' )
0 commit comments