Skip to content

Commit 7e547c4

Browse files
Replace some powf->powi (#3152)
* Replace some powf->powi * Fix
1 parent 3b7f6d3 commit 7e547c4

File tree

17 files changed

+35
-40
lines changed

17 files changed

+35
-40
lines changed

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,16 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
375375
grads,
376376
|grad| {
377377
let rhs = rhs_4lhs.unwrap();
378-
let value = B::float_powf_scalar(rhs, -1.0);
378+
let value = B::float_recip(rhs);
379379
let grad = B::float_mul(grad, value);
380380

381381
broadcast.backward_lhs::<B>(grad)
382382
},
383383
|grad| {
384384
let rhs = rhs_4rhs.unwrap();
385385
let lhs = lhs.unwrap();
386-
let value = B::float_div(B::float_neg(lhs), B::float_powf_scalar(rhs, 2.0));
386+
let value =
387+
B::float_div(B::float_neg(lhs), B::float_powi_scalar(rhs, 2.elem()));
387388
let grad = B::float_mul(grad, value);
388389

389390
broadcast.backward_rhs::<B>(grad)
@@ -644,7 +645,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
644645
) {
645646
let tensor = checkpointer.retrieve_node_output(ops.state);
646647
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
647-
let tmp = B::float_powf_scalar(tensor, -2.0);
648+
let tmp = B::float_powi_scalar(tensor, (-2).elem());
648649
let value = B::float_neg(tmp);
649650

650651
B::float_mul(grad, value)
@@ -1631,7 +1632,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
16311632
) {
16321633
let input = checkpointer.retrieve_node_output(ops.state);
16331634
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
1634-
let value = B::float_powf_scalar(input, -1.0);
1635+
let value = B::float_recip(input);
16351636
B::float_mul(grad, value)
16361637
});
16371638
}
@@ -1670,7 +1671,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
16701671
let input = checkpointer.retrieve_node_output(ops.state);
16711672
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
16721673
let value = B::float_add_scalar(input, 1.elem());
1673-
let value = B::float_powf_scalar(value, -1.0);
1674+
let value = B::float_recip(value);
16741675

16751676
B::float_mul(grad, value)
16761677
});
@@ -1920,7 +1921,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
19201921
let state = B::float_tanh(input);
19211922
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
19221923
let value = B::float_add_scalar(
1923-
B::float_neg(B::float_powf_scalar(state, 2.0)),
1924+
B::float_neg(B::float_powi_scalar(state, 2.elem())),
19241925
1.elem(),
19251926
);
19261927
B::float_mul(grad, value)
@@ -2068,7 +2069,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
20682069
) {
20692070
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
20702071
let ops = checkpointer.retrieve_node_output(ops.state);
2071-
let exponent = B::float_neg(B::float_powf_scalar(ops, 2.0));
2072+
let exponent = B::float_neg(B::float_powi_scalar(ops, 2.elem()));
20722073
let numerator = B::float_mul_scalar(B::float_exp(exponent), 2.0.elem());
20732074
let denominator = core::f64::consts::PI.sqrt().elem();
20742075
let value = B::float_div_scalar(numerator, denominator);

crates/burn-core/src/grad_clipping/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ impl GradientClipping {
8888
}
8989

9090
fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
91-
let squared = tensor.powf_scalar(2.0);
91+
let squared = tensor.powi_scalar(2);
9292
let sum = squared.sum();
9393
sum.sqrt()
9494
}

crates/burn-core/src/nn/loss/huber.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ impl HuberLoss {
132132
// Moreover |r| = sign(r) * r
133133
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
134134

135-
let inside = residuals.powf_scalar(2.).mul_scalar(0.5);
135+
let inside = residuals.powi_scalar(2).mul_scalar(0.5);
136136
inside.mask_where(is_large, outside)
137137
}
138138
}

crates/burn-core/src/nn/loss/mse.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl MseLoss {
4646
logits: Tensor<B, D>,
4747
targets: Tensor<B, D>,
4848
) -> Tensor<B, D> {
49-
logits.sub(targets).powf_scalar(2.0)
49+
logits.sub(targets).powi_scalar(2)
5050
}
5151
}
5252

crates/burn-core/src/nn/norm/batch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
138138
let var = input
139139
.clone()
140140
.sub(mean.clone())
141-
.powf_scalar(2.0)
141+
.powi_scalar(2)
142142
.swap_dims(0, 1)
143143
.reshape([channels, flatten_size])
144144
.mean_dim(1)

crates/burn-core/src/nn/norm/group.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ pub(crate) fn group_norm<B: Backend, const D: usize>(
170170
let mean = input.clone().sum_dim(2) / hidden_size as f64;
171171
let input = input.sub(mean);
172172

173-
let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
173+
let var = input.clone().powi_scalar(2).sum_dim(2) / hidden_size as f64;
174174
let input_normalized = input.div(var.add_scalar(epsilon).sqrt());
175175

176176
if affine {

crates/burn-core/src/nn/norm/rms.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ impl<B: Backend> RmsNorm<B> {
7171
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
7272
// Calculate the root-mean-square norm of the input tensor along the last dimension
7373
let dtype = x.dtype();
74-
let rms =
75-
(x.clone().cast(DType::F32).powf_scalar(2.0).mean_dim(D - 1) + self.epsilon).sqrt();
74+
let rms = (x.clone().cast(DType::F32).powi_scalar(2).mean_dim(D - 1) + self.epsilon).sqrt();
7675
(x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
7776
}
7877
}

crates/burn-core/src/nn/rope_encoding.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl RotaryEncodingConfig {
7979
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
8080
// This is done since burn doesn't support exponentiation of scalar to tensor
8181
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
82-
let theta_i = theta_i.powf_scalar(-1.0);
82+
let theta_i = theta_i.recip();
8383

8484
let theta_i = scaling(theta_i);
8585

crates/burn-core/src/optim/adagrad.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ impl LrDecay {
117117
lr_decay_state: Option<LrDecayState<B, D>>,
118118
) -> (Tensor<B, D>, LrDecayState<B, D>) {
119119
let state = if let Some(mut state) = lr_decay_state {
120-
state.sum = state.sum.add(grad.clone().powf_scalar(2.));
120+
state.sum = state.sum.add(grad.clone().powi_scalar(2));
121121
state.time += 1;
122122
state
123123
} else {
124-
LrDecayState::new(1, grad.clone().powf_scalar(2.))
124+
LrDecayState::new(1, grad.clone().powi_scalar(2))
125125
};
126126

127127
let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);

crates/burn-core/src/optim/adam.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::{
1010
use crate::config::Config;
1111
use crate::optim::adaptor::OptimizerAdaptor;
1212
use crate::tensor::{Tensor, backend::AutodiffBackend};
13-
use burn_tensor::{ElementConversion, backend::Backend, ops::Device};
13+
use burn_tensor::{backend::Backend, ops::Device};
1414

1515
#[cfg(not(feature = "std"))]
1616
use num_traits::Float;
@@ -140,7 +140,7 @@ impl AdaptiveMomentum {
140140
state.moment_2 = state
141141
.moment_2
142142
.mul_scalar(self.beta_2)
143-
.add(grad.powf_scalar(2.0).mul_scalar(factor));
143+
.add(grad.powi_scalar(2).mul_scalar(factor));
144144

145145
state.time += 1;
146146

@@ -150,12 +150,12 @@ impl AdaptiveMomentum {
150150
let moment_1 = grad.clone().mul_scalar(factor);
151151

152152
let factor = 1.0 - self.beta_2;
153-
let moment_2 = grad.powf_scalar(2.0).mul_scalar(factor);
153+
let moment_2 = grad.powi_scalar(2).mul_scalar(factor);
154154

155155
AdaptiveMomentumState::new(1, moment_1, moment_2)
156156
};
157157

158-
let time = (state.time as i32).elem();
158+
let time = state.time as i32;
159159
let moment_1_corrected = state
160160
.moment_1
161161
.clone()

0 commit comments

Comments
 (0)