Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions crates/burn-train/src/checkpoint/file.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::{Path, PathBuf};

use super::{Checkpointer, CheckpointerError};
use burn_core::{
record::{FileRecorder, Record},
Expand All @@ -6,7 +8,7 @@ use burn_core::{

/// The file checkpointer.
pub struct FileCheckpointer<FR> {
directory: String,
directory: PathBuf,
name: String,
recorder: FR,
}
Expand All @@ -19,17 +21,19 @@ impl<FR> FileCheckpointer<FR> {
/// * `recorder` - The file recorder.
/// * `directory` - The directory to save the checkpoints.
/// * `name` - The name of the checkpoint.
pub fn new(recorder: FR, directory: &str, name: &str) -> Self {
pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
let directory = directory.as_ref();
std::fs::create_dir_all(directory).ok();

Self {
directory: directory.to_string(),
directory: directory.to_path_buf(),
name: name.to_string(),
recorder,
}
}
fn path_for_epoch(&self, epoch: usize) -> String {
format!("{}/{}-{}", self.directory, self.name, epoch)

fn path_for_epoch(&self, epoch: usize) -> PathBuf {
self.directory.join(format!("{}-{}", self.name, epoch))
}
}

Expand All @@ -41,28 +45,36 @@ where
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Saving checkpoint {} to {}", epoch, file_path);
log::info!("Saving checkpoint {} to {}", epoch, file_path.display());

self.recorder
.record(record, file_path.into())
.record(record, file_path)
.map_err(CheckpointerError::RecorderError)?;

Ok(())
}

fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
let file_path = self.path_for_epoch(epoch);
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
log::info!(
"Restoring checkpoint {} from {}",
epoch,
file_path.display()
);
let record = self
.recorder
.load(file_path.into(), device)
.load(file_path, device)
.map_err(CheckpointerError::RecorderError)?;

Ok(record)
}

fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),);
let file_to_remove = format!(
"{}.{}",
self.path_for_epoch(epoch).display(),
FR::file_extension(),
);

if std::path::Path::new(&file_to_remove).exists() {
log::info!("Removing checkpoint {}", file_to_remove);
Expand Down
18 changes: 10 additions & 8 deletions crates/burn-train/src/learner/application_logger.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::path::Path;
use std::path::{Path, PathBuf};
use tracing_core::{Level, LevelFilter};
use tracing_subscriber::filter::filter_fn;
use tracing_subscriber::prelude::*;
Expand All @@ -12,14 +12,14 @@ pub trait ApplicationLoggerInstaller {

/// This struct is used to install a local file application logger to output logs to a given file path.
pub struct FileApplicationLoggerInstaller {
path: String,
path: PathBuf,
}

impl FileApplicationLoggerInstaller {
/// Create a new file application logger.
pub fn new(path: &str) -> Self {
pub fn new(path: impl AsRef<Path>) -> Self {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to have AsRef<Path> or Into<PatBuf>? If &str also implements Into<PathBuf> I think it might be more flexible, otherwise maybe AsRef<Path> is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. I don't know which one is better. In this situation I think either one will work just fine.

Copy link
Collaborator

@Luni-4 Luni-4 Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the path usage, if a path must be only read then AsRef<Path> would be better, otherwise a Into<PathBuf> is preferred. We can have a look at how a path is treated internally in my opinion

Copy link
Contributor Author

@varonroy varonroy Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that this PR is still open. Are there any changes you would like me to make before merging it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use AsRef as the default choice for most file system operations in the codebase. This provides flexibility and efficiency for reading paths. We should only use Into<PathBuf> when we specifically need to own or modify the path data.

Self {
path: path.to_string(),
path: path.as_ref().to_path_buf(),
}
}
}
Expand All @@ -29,8 +29,9 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
let path = Path::new(&self.path);
let writer = tracing_appender::rolling::never(
path.parent().unwrap_or_else(|| Path::new(".")),
path.file_name()
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
path.file_name().unwrap_or_else(|| {
panic!("The path '{}' to point to a file.", self.path.display())
}),
);
let layer = tracing_subscriber::fmt::layer()
.with_ansi(false)
Expand All @@ -51,13 +52,14 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
}

let hook = std::panic::take_hook();
let file_path: String = self.path.to_owned();
let file_path = self.path.to_owned();

std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
'{}'\n=============",
file_path.display()
);
hook(info);
}));
Expand Down
41 changes: 15 additions & 26 deletions crates/burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::rc::Rc;

use super::Learner;
Expand Down Expand Up @@ -45,7 +46,7 @@ where
)>,
num_epochs: usize,
checkpoint: Option<usize>,
directory: String,
directory: PathBuf,
grad_accumulation: Option<usize>,
devices: Vec<B::Device>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
Expand Down Expand Up @@ -74,20 +75,22 @@ where
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
pub fn new(directory: &str) -> Self {
pub fn new(directory: impl AsRef<Path>) -> Self {
let directory = directory.as_ref().to_path_buf();
let experiment_log_file = directory.join("experiment.log");
Self {
num_epochs: 1,
checkpoint: None,
checkpointers: None,
directory: directory.to_string(),
directory,
grad_accumulation: None,
devices: vec![B::Device::default()],
metrics: Metrics::default(),
event_store: LogEventStore::default(),
renderer: None,
interrupter: TrainingInterrupter::new(),
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
format!("{}/experiment.log", directory).as_str(),
experiment_log_file,
))),
num_loggers: 0,
checkpointer_strategy: Box::new(
Expand Down Expand Up @@ -256,21 +259,12 @@ where
M::Record: 'static,
S::Record: 'static,
{
let checkpointer_model = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"model",
);
let checkpointer_optimizer = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"optim",
);
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
recorder,
format!("{}/checkpoint", self.directory).as_str(),
"scheduler",
);
let checkpoint_dir = self.directory.join("checkpoint");
let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
let checkpointer_optimizer =
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
let checkpointer_scheduler: FileCheckpointer<FR> =
FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");

self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_model),
Expand Down Expand Up @@ -325,17 +319,12 @@ where
let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
});
let directory = &self.directory;

if self.num_loggers == 0 {
self.event_store
.register_logger_train(FileMetricLogger::new(
format!("{directory}/train").as_str(),
));
.register_logger_train(FileMetricLogger::new(self.directory.join("train")));
self.event_store
.register_logger_valid(FileMetricLogger::new(
format!("{directory}/valid").as_str(),
));
.register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
}

let event_store = Rc::new(EventStoreClient::new(self.event_store));
Expand Down
25 changes: 16 additions & 9 deletions crates/burn-train/src/learner/summary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::cmp::Ordering;
use std::{fmt::Display, path::Path};
use std::{
fmt::Display,
path::{Path, PathBuf},
};

use crate::{
logger::FileMetricLogger,
Expand Down Expand Up @@ -73,16 +76,20 @@ impl LearnerSummary {
///
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
/// * `metrics` - The list of metrics to collect for the summary.
pub fn new<S: AsRef<str>>(directory: &str, metrics: &[S]) -> Result<Self, String> {
let directory_path = Path::new(directory);
if !directory_path.exists() {
return Err(format!("Artifact directory does not exist at: {directory}"));
pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
let directory = directory.as_ref();
if !directory.exists() {
return Err(format!(
"Artifact directory does not exist at: {}",
directory.display()
));
}
let train_dir = directory_path.join("train");
let valid_dir = directory_path.join("valid");
let train_dir = directory.join("train");
let valid_dir = directory.join("valid");
if !train_dir.exists() & !valid_dir.exists() {
return Err(format!(
"No training or validation artifacts found at: {directory}"
"No training or validation artifacts found at: {}",
directory.display()
));
}

Expand Down Expand Up @@ -219,7 +226,7 @@ impl Display for LearnerSummary {
}

pub(crate) struct LearnerSummaryConfig {
pub(crate) directory: String,
pub(crate) directory: PathBuf,
pub(crate) metrics: Vec<String>,
}

Expand Down
12 changes: 9 additions & 3 deletions crates/burn-train/src/logger/file.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Logger;
use std::{fs::File, io::Write};
use std::{fs::File, io::Write, path::Path};

/// File logger.
pub struct FileLogger {
Expand All @@ -16,14 +16,20 @@ impl FileLogger {
/// # Returns
///
/// The file logger.
pub fn new(path: &str) -> Self {
pub fn new(path: &Path) -> Self {
let mut options = std::fs::File::options();
let file = options
.write(true)
.truncate(true)
.create(true)
.open(path)
.unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}"));
.unwrap_or_else(|err| {
panic!(
"Should be able to create the new file '{}': {}",
path.display(),
err
)
});

Self { file }
}
Expand Down
25 changes: 16 additions & 9 deletions crates/burn-train/src/logger/metric.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::{MetricEntry, NumericEntry};
use std::{collections::HashMap, fs};
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};

const EPOCH_PREFIX: &str = "epoch-";

Expand All @@ -27,7 +31,7 @@ pub trait MetricLogger: Send {
/// The file metric logger.
pub struct FileMetricLogger {
loggers: HashMap<String, AsyncLogger<String>>,
directory: String,
directory: PathBuf,
epoch: usize,
}

Expand All @@ -41,10 +45,10 @@ impl FileMetricLogger {
/// # Returns
///
/// The file metric logger.
pub fn new(directory: &str) -> Self {
pub fn new(directory: impl AsRef<Path>) -> Self {
Self {
loggers: HashMap::new(),
directory: directory.to_string(),
directory: directory.as_ref().to_path_buf(),
epoch: 1,
}
}
Expand Down Expand Up @@ -76,15 +80,18 @@ impl FileMetricLogger {
max_epoch
}

fn epoch_directory(&self, epoch: usize) -> String {
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
fn epoch_directory(&self, epoch: usize) -> PathBuf {
let name = format!("{}{}", EPOCH_PREFIX, epoch);
self.directory.join(name)
}
fn file_path(&self, name: &str, epoch: usize) -> String {

fn file_path(&self, name: &str, epoch: usize) -> PathBuf {
let directory = self.epoch_directory(epoch);
let name = name.replace(' ', "_");

format!("{directory}/{name}.log")
let name = format!("{name}.log");
directory.join(name)
}

fn create_directory(&self, epoch: usize) {
let directory = self.epoch_directory(epoch);
std::fs::create_dir_all(directory).ok();
Expand Down