-
Notifications
You must be signed in to change notification settings - Fork 653
Closed
Description
match build_optimizer(self.config.optimizer(), &self.model) {
AnyOptimizer::Sgd(optim) => {
if self.config.metric().is_some() {
self.train_with_sdg_config(train_loader, val_loader, scheduler, device, optim);
} else {
self.train_with_sdg(train_loader, val_loader, scheduler, device, optim);
}
}
AnyOptimizer::Adam(optim) => {
if self.config.metric().is_some() {
self.train_with_adam_config(train_loader, val_loader, scheduler, device, optim);
} else {
self.train_with_adam(train_loader, val_loader, scheduler, device, optim);
}
}
AnyOptimizer::AdamW(optim) => {
if self.config.metric().is_some() {
self.train_with_adamw_config(train_loader, val_loader, scheduler, device, optim);
} else {
self.train_with_adamw(train_loader, val_loader, scheduler, device, optim);
}
}
AnyOptimizer::AdaGrad(optim) => {
if self.config.metric().is_some() {
self.train_with_adagrad_config(train_loader, val_loader, scheduler, device, optim);
} else {
self.train_with_adagrad(train_loader, val_loader, scheduler, device, optim);
}
}
}
The above is my code. I hope to configure it more configurably and flexibly adapt to Optimizers such as sdg adam. Due to the LearnerBuilder type problem, do I need to repeatedly implement multiple LearnerBuilder::new(self.directory.clone())?
Metadata
Metadata
Assignees
Labels
No labels