@@ -40,8 +40,13 @@ def __init__(
40
40
self .model = model
41
41
self .cell_graph = cell_graph
42
42
self .inverse_transform = inverse_transform
43
- self .current_accumulation_steps = 1
44
43
self .loss_func = loss_func
44
+
45
+ # Initialize gradient accumulation
46
+ self .current_accumulation_steps = 1
47
+ if self .hparams .grad_accumulation_schedule is not None :
48
+ # Get the accumulation steps for epoch 0
49
+ self .current_accumulation_steps = self .hparams .grad_accumulation_schedule .get (0 , 1 )
45
50
46
51
reg_metrics = MetricCollection (
47
52
{
@@ -315,6 +320,15 @@ def training_step(self, batch, batch_idx):
315
320
batch_size = batch ["gene" ].x .size (0 ),
316
321
sync_dist = True ,
317
322
)
323
+ # Log effective batch size when using gradient accumulation
324
+ if self .hparams .grad_accumulation_schedule is not None :
325
+ effective_batch_size = batch ["gene" ].x .size (0 ) * self .current_accumulation_steps
326
+ self .log (
327
+ "effective_batch_size" ,
328
+ effective_batch_size ,
329
+ batch_size = batch ["gene" ].x .size (0 ),
330
+ sync_dist = True ,
331
+ )
318
332
# print(f"Loss: {loss}")
319
333
return loss
320
334
@@ -446,6 +460,13 @@ def on_train_epoch_end(self):
446
460
sch .step ()
447
461
448
462
def on_train_epoch_start (self ):
463
+ # Update gradient accumulation steps based on current epoch
464
+ if self .hparams .grad_accumulation_schedule is not None :
465
+ for epoch_threshold in sorted (self .hparams .grad_accumulation_schedule .keys ()):
466
+ if self .current_epoch >= epoch_threshold :
467
+ self .current_accumulation_steps = self .hparams .grad_accumulation_schedule [epoch_threshold ]
468
+ print (f"Epoch { self .current_epoch } : Using gradient accumulation steps = { self .current_accumulation_steps } " )
469
+
449
470
# Clear sample containers at the start of epochs where we'll collect samples
450
471
if (self .current_epoch + 1 ) % self .hparams .plot_every_n_epochs == 0 :
451
472
self .train_samples = {"true_values" : [], "predictions" : [], "latents" : {}}
0 commit comments