Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ strategies.
| `grid::meshgrid(tensors, GridIndexing::Matrix)` | `torch.meshgrid(tensors, indexing="ij") |
| `grid::meshgrid(tensors, GridIndexing::Cartesian)` | `torch.meshgrid(tensors, indexing="xy") |

## Linalg Functions

| Burn API | PyTorch Equivalent |
|----------------------------------------|-------------------------------------------|
| `linalg::vector_norm(tensors, p, dim)` | `torch.linalg.vector_norm(tensor, p, dim) |

## Displaying Tensor Details

Burn provides flexible options for displaying tensor information, allowing you to control the level
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-tensor/src/tensor/linalg/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod vector_norm;
pub use vector_norm::*;
225 changes: 225 additions & 0 deletions crates/burn-tensor/src/tensor/linalg/vector_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
use crate::backend::Backend;
use crate::tensor::{BasicOps, Tensor};
use crate::{ElementConversion, Numeric};

/// Specifies the type of norm to compute.
#[derive(Debug, Clone, Copy)]
pub enum Norm {
/// L0 norm (count of non-zero elements)
L0,

/// L1 norm (sum of absolute values)
L1,

/// L2 norm (Euclidean norm)
L2,

/// L:INFINITY norm (maximum absolute value)
LInf,

/// L:NEG_INFINITY norm (minimum absolute value)
LNegInf,

/// Lp norm (generalized norm)
Lp(f64),
}

impl From<i32> for Norm {
fn from(value: i32) -> Self {
match value {
0 => Norm::L0,
1 => Norm::L1,
2 => Norm::L2,
_ => Norm::Lp(value as f64),
}
}
}

impl From<f32> for Norm {
fn from(value: f32) -> Self {
match value {
0.0 => Norm::L0,
1.0 => Norm::L1,
2.0 => Norm::L2,
f32::INFINITY => Norm::LInf,
f32::NEG_INFINITY => Norm::LNegInf,
_ => Norm::Lp(value as f64),
}
}
}

impl From<f64> for Norm {
fn from(value: f64) -> Self {
match value {
0.0 => Norm::L0,
1.0 => Norm::L1,
2.0 => Norm::L2,
f64::INFINITY => Norm::LInf,
f64::NEG_INFINITY => Norm::LNegInf,
_ => Norm::Lp(value),
}
}
}

/// Computes the vector norm of a tensor along a specified dimension.
///
/// Generic dispatch wrapper over specialized / optimized norms.
///
/// See:
/// - https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
/// - https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `norm` - The selected norm.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The vector norm of the input tensor.
pub fn vector_norm<B: Backend, const D: usize>(
x: Tensor<B, D>,
norm: impl Into<Norm>,
dim: usize,
) -> Tensor<B, D> {
let norm = norm.into();
match norm {
Norm::L0 => l0_norm(x, dim),
Norm::L1 => l1_norm(x, dim),
Norm::L2 => l2_norm(x, dim),
Norm::LInf => max_abs_norm(x, dim),
Norm::LNegInf => min_abs_norm(x, dim),
Norm::Lp(p) => lp_norm(x, p, dim),
}
}

/// Normalize a tensor versus its `vector_norm`.
///
/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `norm` - The selected norm.
/// * `dim` - The dimension to compute the norm over.
/// * `eps` - The epsilon for the norm.
///
/// # Returns
///
/// The normalized tensor.
pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(
x: Tensor<B, D>,
norm: impl Into<Norm>,
dim: usize,
eps: E,
) -> Tensor<B, D> {
x.clone() / vector_norm(x, norm, dim).clamp_min(eps)
}

/// Computes the L0 norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L0 norm of the input tensor.
pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
where
K: BasicOps<B> + Numeric<B>,
{
x.zeros_like()
.mask_fill(x.not_equal_elem(0), 1)
.sum_dim(dim)
}

/// Computes the L1 norm of a tensor along a specified dimension.
///
/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L1 norm of the input tensor.
pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>
where
K: BasicOps<B> + Numeric<B>,
{
x.abs().sum_dim(dim)
}

/// Computes the L2 norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L2 norm of the input tensor.
pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
x.abs().powi_scalar(2).sum_dim(dim).sqrt()
}

/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `p` - The exponent of the Lp norm.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The ``L(p)`` norm of the input tensor.
pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {
x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)
}

/// Computes the L:INFINITY norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L:INFINITY norm of the input tensor.
pub fn max_abs_norm<B: Backend, const D: usize, K>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better naming than l_inf_norm 😅

Is it called the max abs norm or just max norm typically? My linear algebra terminology is a little dusty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, it really is called the l-infinity-norm:
https://mathworld.wolfram.com/L-Infinity-Norm.html

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes of course but there are often other names used interchangeably (e.g., euclidean norm or l2 norm).

Max norm seems to also be used, but not sure if it is that common. Anyway, either max_abs_norm, max_norm or infinity_norm (without the L prefix) should be suitable.

x: Tensor<B, D, K>,
dim: usize,
) -> Tensor<B, D, K>
where
K: BasicOps<B> + Numeric<B>,
{
x.max_abs_dim(dim)
}

/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.
///
/// # Arguments
///
/// * `x` - The input tensor.
/// * `dim` - The dimension to compute the norm over.
///
/// # Returns
///
/// The L:NEG_INFINITY norm of the input tensor.
pub fn min_abs_norm<B: Backend, const D: usize, K>(
x: Tensor<B, D, K>,
dim: usize,
) -> Tensor<B, D, K>
where
K: BasicOps<B> + Numeric<B>,
{
x.abs().min_dim(dim)
}
3 changes: 3 additions & 0 deletions crates/burn-tensor/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub mod container;
/// The grid module.
pub mod grid;

/// The linalg module.
pub mod linalg;

/// The loss module.
pub mod loss;

Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/linalg/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod vector_norm;
Loading
Loading