-
Notifications
You must be signed in to change notification settings - Fork 700
Add burn::linalg::{vector_norm,l2_norm} #3131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
07b3ef8
Add burn::linalg::{vector_norm,l2_norm}
crutcher 8092f63
fmt
crutcher fa64bf3
doc
crutcher 593c8a0
freshen up inf
crutcher 799620a
l1 norm
crutcher f51b5cd
aliases
crutcher 753667b
docs
crutcher 139788b
close for l2
crutcher aa11290
approx_eq
crutcher 596e1a7
long
crutcher 159e975
improve tolerance
crutcher e0ba924
review
crutcher 7db3d57
vector_normalize
crutcher f09a8f1
Refactor for review
crutcher d4d2cfe
typing; make .sqrt() work for l2_norm
crutcher 9474a6a
eps type conversion
crutcher b94ab10
test integer tensors, workaround new bug: https://github.com/tracel-a…
crutcher 1271c18
generalize
crutcher 86cf09f
remove empty where clause
crutcher 20318f4
Remove workaround for #3119
crutcher 8622f98
Update vector_norm.rs
crutcher File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mod vector_norm; | ||
pub use vector_norm::*; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
use crate::backend::Backend; | ||
use crate::tensor::{BasicOps, Tensor}; | ||
use crate::{ElementConversion, Numeric}; | ||
|
||
/// Specifies the type of norm to compute. | ||
#[derive(Debug, Clone, Copy)] | ||
pub enum Norm { | ||
/// L0 norm (count of non-zero elements) | ||
L0, | ||
|
||
/// L1 norm (sum of absolute values) | ||
L1, | ||
|
||
/// L2 norm (Euclidean norm) | ||
L2, | ||
|
||
/// L:INFINITY norm (maximum absolute value) | ||
LInf, | ||
|
||
/// L:NEG_INFINITY norm (minimum absolute value) | ||
LNegInf, | ||
|
||
/// Lp norm (generalized norm) | ||
Lp(f64), | ||
} | ||
|
||
impl From<i32> for Norm { | ||
fn from(value: i32) -> Self { | ||
match value { | ||
0 => Norm::L0, | ||
1 => Norm::L1, | ||
2 => Norm::L2, | ||
_ => Norm::Lp(value as f64), | ||
} | ||
} | ||
} | ||
|
||
impl From<f32> for Norm { | ||
fn from(value: f32) -> Self { | ||
match value { | ||
0.0 => Norm::L0, | ||
1.0 => Norm::L1, | ||
2.0 => Norm::L2, | ||
f32::INFINITY => Norm::LInf, | ||
f32::NEG_INFINITY => Norm::LNegInf, | ||
_ => Norm::Lp(value as f64), | ||
} | ||
} | ||
} | ||
|
||
impl From<f64> for Norm { | ||
fn from(value: f64) -> Self { | ||
match value { | ||
0.0 => Norm::L0, | ||
1.0 => Norm::L1, | ||
2.0 => Norm::L2, | ||
f64::INFINITY => Norm::LInf, | ||
f64::NEG_INFINITY => Norm::LNegInf, | ||
_ => Norm::Lp(value), | ||
} | ||
} | ||
} | ||
|
||
/// Computes the vector norm of a tensor along a specified dimension. | ||
/// | ||
/// Generic dispatch wrapper over specialized / optimized norms. | ||
/// | ||
/// See: | ||
/// - https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html | ||
/// - https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `norm` - The selected norm. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The vector norm of the input tensor. | ||
pub fn vector_norm<B: Backend, const D: usize>( | ||
x: Tensor<B, D>, | ||
norm: impl Into<Norm>, | ||
dim: usize, | ||
) -> Tensor<B, D> { | ||
let norm = norm.into(); | ||
match norm { | ||
Norm::L0 => l0_norm(x, dim), | ||
Norm::L1 => l1_norm(x, dim), | ||
Norm::L2 => l2_norm(x, dim), | ||
Norm::LInf => max_abs_norm(x, dim), | ||
Norm::LNegInf => min_abs_norm(x, dim), | ||
Norm::Lp(p) => lp_norm(x, p, dim), | ||
} | ||
} | ||
|
||
/// Normalize a tensor versus its `vector_norm`. | ||
/// | ||
/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `norm` - The selected norm. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// * `eps` - The epsilon for the norm. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The normalized tensor. | ||
pub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>( | ||
x: Tensor<B, D>, | ||
norm: impl Into<Norm>, | ||
dim: usize, | ||
eps: E, | ||
) -> Tensor<B, D> { | ||
x.clone() / vector_norm(x, norm, dim).clamp_min(eps) | ||
} | ||
|
||
/// Computes the L0 norm of a tensor along a specified dimension. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The L0 norm of the input tensor. | ||
pub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> | ||
where | ||
K: BasicOps<B> + Numeric<B>, | ||
{ | ||
x.zeros_like() | ||
.mask_fill(x.not_equal_elem(0), 1) | ||
.sum_dim(dim) | ||
} | ||
|
||
/// Computes the L1 norm of a tensor along a specified dimension. | ||
/// | ||
/// This is a convenience function that wraps `vector_norm` with `p = 1.0`. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The L1 norm of the input tensor. | ||
pub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K> | ||
where | ||
K: BasicOps<B> + Numeric<B>, | ||
{ | ||
x.abs().sum_dim(dim) | ||
} | ||
|
||
/// Computes the L2 norm of a tensor along a specified dimension. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The L2 norm of the input tensor. | ||
pub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> { | ||
x.abs().powi_scalar(2).sum_dim(dim).sqrt() | ||
} | ||
|
||
/// Computes the general ``L(p)`` norm of a tensor along a specified dimension. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `p` - The exponent of the Lp norm. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The ``L(p)`` norm of the input tensor. | ||
pub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> { | ||
x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p) | ||
} | ||
|
||
/// Computes the L:INFINITY norm of a tensor along a specified dimension. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The L:INFINITY norm of the input tensor. | ||
pub fn max_abs_norm<B: Backend, const D: usize, K>( | ||
x: Tensor<B, D, K>, | ||
dim: usize, | ||
) -> Tensor<B, D, K> | ||
where | ||
K: BasicOps<B> + Numeric<B>, | ||
{ | ||
x.max_abs_dim(dim) | ||
} | ||
|
||
/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `x` - The input tensor. | ||
/// * `dim` - The dimension to compute the norm over. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The L:NEG_INFINITY norm of the input tensor. | ||
pub fn min_abs_norm<B: Backend, const D: usize, K>( | ||
x: Tensor<B, D, K>, | ||
dim: usize, | ||
) -> Tensor<B, D, K> | ||
where | ||
K: BasicOps<B> + Numeric<B>, | ||
{ | ||
x.abs().min_dim(dim) | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub(crate) mod vector_norm; |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better naming than l_inf_norm 😅
Is it called the max abs norm or just max norm typically? My linear algebra terminology is a little dusty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, it really is called the
l-infinity-norm
:https://mathworld.wolfram.com/L-Infinity-Norm.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes of course but there are often other names used interchangeably (e.g., euclidean norm or l2 norm).
Max norm seems to also be used, but not sure if it is that common. Anyway, either
max_abs_norm
,max_norm
orinfinity_norm
(without theL
prefix) should be suitable.