Skip to content

Commit 3e3fb86

Browse files
author
mjvolk3
committed
update for grad acc
1 parent 35acc1c commit 3e3fb86

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

torchcell/trainers/int_hetero_cell.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ def __init__(
4040
self.model = model
4141
self.cell_graph = cell_graph
4242
self.inverse_transform = inverse_transform
43-
self.current_accumulation_steps = 1
4443
self.loss_func = loss_func
44+
45+
# Initialize gradient accumulation
46+
self.current_accumulation_steps = 1
47+
if self.hparams.grad_accumulation_schedule is not None:
48+
# Get the accumulation steps for epoch 0
49+
self.current_accumulation_steps = self.hparams.grad_accumulation_schedule.get(0, 1)
4550

4651
reg_metrics = MetricCollection(
4752
{
@@ -315,6 +320,15 @@ def training_step(self, batch, batch_idx):
315320
batch_size=batch["gene"].x.size(0),
316321
sync_dist=True,
317322
)
323+
# Log effective batch size when using gradient accumulation
324+
if self.hparams.grad_accumulation_schedule is not None:
325+
effective_batch_size = batch["gene"].x.size(0) * self.current_accumulation_steps
326+
self.log(
327+
"effective_batch_size",
328+
effective_batch_size,
329+
batch_size=batch["gene"].x.size(0),
330+
sync_dist=True,
331+
)
318332
# print(f"Loss: {loss}")
319333
return loss
320334

@@ -446,6 +460,13 @@ def on_train_epoch_end(self):
446460
sch.step()
447461

448462
def on_train_epoch_start(self):
463+
# Update gradient accumulation steps based on current epoch
464+
if self.hparams.grad_accumulation_schedule is not None:
465+
for epoch_threshold in sorted(self.hparams.grad_accumulation_schedule.keys()):
466+
if self.current_epoch >= epoch_threshold:
467+
self.current_accumulation_steps = self.hparams.grad_accumulation_schedule[epoch_threshold]
468+
print(f"Epoch {self.current_epoch}: Using gradient accumulation steps = {self.current_accumulation_steps}")
469+
449470
# Clear sample containers at the start of epochs where we'll collect samples
450471
if (self.current_epoch + 1) % self.hparams.plot_every_n_epochs == 0:
451472
self.train_samples = {"true_values": [], "predictions": [], "latents": {}}

0 commit comments

Comments
 (0)