Skip to content

Commit 1eb9eee

Browse files
committed
Remove inlining for autodiff handling
1 parent c621097 commit 1eb9eee

File tree

10 files changed

+37
-43
lines changed

10 files changed

+37
-43
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ mod llvm_enzyme {
193193
/// which becomes expanded to:
194194
/// ```
195195
/// #[rustc_autodiff]
196-
/// #[inline(never)]
197196
/// fn sin(x: &Box<f32>) -> f32 {
198197
/// f32::sin(**x)
199198
/// }
@@ -369,7 +368,7 @@ mod llvm_enzyme {
369368
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
370369
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
371370

372-
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
371+
// We're avoid duplicating the attribute `#[rustc_autodiff]`.
373372
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
374373
match (attr, item) {
375374
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@@ -382,23 +381,25 @@ mod llvm_enzyme {
382381
}
383382
}
384383

384+
let mut has_inline_never = false;
385+
385386
// Don't add it multiple times:
386387
let orig_annotatable: Annotatable = match item {
387388
Annotatable::Item(ref mut iitem) => {
388389
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
389390
iitem.attrs.push(attr);
390391
}
391-
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
392-
iitem.attrs.push(inline_never.clone());
392+
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
393+
has_inline_never = true;
393394
}
394395
Annotatable::Item(iitem.clone())
395396
}
396397
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
397398
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
398399
assoc_item.attrs.push(attr);
399400
}
400-
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
401-
assoc_item.attrs.push(inline_never.clone());
401+
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
402+
has_inline_never = true;
402403
}
403404
Annotatable::AssocItem(assoc_item.clone(), i)
404405
}
@@ -408,9 +409,8 @@ mod llvm_enzyme {
408409
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
409410
iitem.attrs.push(attr);
410411
}
411-
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
412-
{
413-
iitem.attrs.push(inline_never.clone());
412+
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
413+
has_inline_never = true;
414414
}
415415
}
416416
_ => unreachable!("stmt kind checked previously"),
@@ -431,11 +431,19 @@ mod llvm_enzyme {
431431

432432
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
433433
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
434+
435+
// If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
436+
let mut d_attrs = thin_vec![d_attr];
437+
438+
if has_inline_never {
439+
d_attrs.push(inline_never);
440+
}
441+
434442
let d_annotatable = match &item {
435443
Annotatable::AssocItem(_, _) => {
436444
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
437445
let d_fn = P(ast::AssocItem {
438-
attrs: thin_vec![d_attr],
446+
attrs: d_attrs,
439447
id: ast::DUMMY_NODE_ID,
440448
span,
441449
vis,
@@ -445,13 +453,13 @@ mod llvm_enzyme {
445453
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
446454
}
447455
Annotatable::Item(_) => {
448-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
456+
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
449457
d_fn.vis = vis;
450458

451459
Annotatable::Item(d_fn)
452460
}
453461
Annotatable::Stmt(_) => {
454-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
462+
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
455463
d_fn.vis = vis;
456464

457465
Annotatable::Stmt(P(ast::Stmt {

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ use tracing::debug;
1010
use crate::builder::{Builder, PlaceRef, UNNAMED};
1111
use crate::context::SimpleCx;
1212
use crate::declare::declare_simple_fn;
13-
use crate::llvm::AttributePlace::Function;
13+
use crate::llvm;
1414
use crate::llvm::{Metadata, True, Type};
1515
use crate::value::Value;
16-
use crate::{attributes, llvm};
1716

1817
pub(crate) fn adjust_activity_to_abi<'tcx>(
1918
tcx: TyCtxt<'tcx>,
@@ -308,11 +307,6 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
308307
enzyme_ty,
309308
);
310309

311-
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
312-
// do it's work.
313-
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
314-
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
315-
316310
let num_args = llvm::LLVMCountParams(&fn_to_diff);
317311
let mut args = Vec::with_capacity(num_args as usize + 1);
318312
args.push(fn_to_diff);

tests/codegen-llvm/autodiff/batched.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::autodiff::autodiff_forward;
1717
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
1818
#[autodiff_forward(d_square1, 4, Dual, Dual)]
1919
#[no_mangle]
20+
#[inline(never)]
2021
fn square(x: &f32) -> f32 {
2122
x * x
2223
}

tests/codegen-llvm/autodiff/generic.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use std::autodiff::autodiff_reverse;
77

88
#[autodiff_reverse(d_square, Duplicated, Active)]
9+
#[inline(never)]
910
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
1011
*x * *x
1112
}

tests/codegen-llvm/autodiff/identical_fnc.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
use std::autodiff::autodiff_reverse;
1515

1616
#[autodiff_reverse(d_square, Duplicated, Active)]
17+
#[inline(never)]
1718
fn square(x: &f64) -> f64 {
1819
x * x
1920
}
2021

2122
#[autodiff_reverse(d_square2, Duplicated, Active)]
23+
#[inline(never)]
2224
fn square2(x: &f64) -> f64 {
2325
x * x
2426
}
@@ -29,8 +31,10 @@ fn square2(x: &f64) -> f64 {
2931
// CHECK-NEXT:start:
3032
// CHECK-NOT:br
3133
// CHECK-NOT:ret
32-
// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx1)
33-
// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx2)
34+
// CHECK:; call identical_fnc::d_square
35+
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
36+
// CHECK:; call identical_fnc::d_square
37+
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
3438

3539
fn main() {
3640
let x = std::hint::black_box(3.0);

tests/codegen-llvm/autodiff/scalar.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::autodiff::autodiff_reverse;
77

88
#[autodiff_reverse(d_square, Duplicated, Active)]
99
#[no_mangle]
10+
#[inline(never)]
1011
fn square(x: &f64) -> f64 {
1112
x * x
1213
}

tests/codegen-llvm/autodiff/sret.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::autodiff::autodiff_reverse;
1313

1414
#[no_mangle]
1515
#[autodiff_reverse(df, Active, Active, Active)]
16+
#[inline(never)]
1617
fn primal(x: f32, y: f32) -> f64 {
1718
(x * x * y) as f64
1819
}

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6-
#[prelude_import]
7-
use ::std::prelude::rust_2015::*;
86
#[macro_use]
97
extern crate std;
8+
#[prelude_import]
9+
use ::std::prelude::rust_2015::*;
1010
//@ pretty-mode:expanded
1111
//@ pretty-compare-only
1212
//@ pp-exact:autodiff_forward.pp
@@ -16,7 +16,6 @@
1616
use std::autodiff::{autodiff_forward, autodiff_reverse};
1717

1818
#[rustc_autodiff]
19-
#[inline(never)]
2019
pub fn f1(x: &[f64], y: f64) -> f64 {
2120

2221

@@ -40,7 +39,6 @@
4039
::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y))
4140
}
4241
#[rustc_autodiff]
43-
#[inline(never)]
4442
pub fn f2(x: &[f64], y: f64) -> f64 {
4543
::core::panicking::panic("not implemented")
4644
}
@@ -49,7 +47,6 @@
4947
::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y))
5048
}
5149
#[rustc_autodiff]
52-
#[inline(never)]
5350
pub fn f3(x: &[f64], y: f64) -> f64 {
5451
::core::panicking::panic("not implemented")
5552
}
@@ -58,14 +55,12 @@
5855
::core::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y))
5956
}
6057
#[rustc_autodiff]
61-
#[inline(never)]
6258
pub fn f4() {}
6359
#[rustc_autodiff(Forward, 1, None)]
6460
pub fn df4() -> () {
6561
::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ())
6662
}
6763
#[rustc_autodiff]
68-
#[inline(never)]
6964
pub fn f5(x: &[f64], y: f64) -> f64 {
7065
::core::panicking::panic("not implemented")
7166
}
@@ -84,7 +79,6 @@
8479
}
8580
struct DoesNotImplDefault;
8681
#[rustc_autodiff]
87-
#[inline(never)]
8882
pub fn f6() -> DoesNotImplDefault {
8983
::core::panicking::panic("not implemented")
9084
}
@@ -93,15 +87,13 @@
9387
::core::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ())
9488
}
9589
#[rustc_autodiff]
96-
#[inline(never)]
9790
pub fn f7(x: f32) -> () {}
9891
#[rustc_autodiff(Forward, 1, Const, None)]
9992
pub fn df7(x: f32) -> () {
10093
::core::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,))
10194
}
10295
#[no_mangle]
10396
#[rustc_autodiff]
104-
#[inline(never)]
10597
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
10698
#[rustc_autodiff(Forward, 4, Dual, Dual)]
10799
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
@@ -121,7 +113,6 @@
121113
}
122114
pub fn f9() {
123115
#[rustc_autodiff]
124-
#[inline(never)]
125116
fn inner(x: f32) -> f32 { x * x }
126117
#[rustc_autodiff(Forward, 1, Dual, Dual)]
127118
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
@@ -135,7 +126,6 @@
135126
}
136127
}
137128
#[rustc_autodiff]
138-
#[inline(never)]
139129
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
140130
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
141131
pub fn d_square<T: std::ops::Mul<Output = T> +

tests/pretty/autodiff/autodiff_reverse.pp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6-
#[prelude_import]
7-
use ::std::prelude::rust_2015::*;
86
#[macro_use]
97
extern crate std;
8+
#[prelude_import]
9+
use ::std::prelude::rust_2015::*;
1010
//@ pretty-mode:expanded
1111
//@ pretty-compare-only
1212
//@ pp-exact:autodiff_reverse.pp
@@ -16,7 +16,6 @@
1616
use std::autodiff::autodiff_reverse;
1717

1818
#[rustc_autodiff]
19-
#[inline(never)]
2019
pub fn f1(x: &[f64], y: f64) -> f64 {
2120

2221
// Not the most interesting derivative, but who are we to judge
@@ -33,12 +32,10 @@
3332
::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
3433
}
3534
#[rustc_autodiff]
36-
#[inline(never)]
3735
pub fn f2() {}
3836
#[rustc_autodiff(Reverse, 1, None)]
3937
pub fn df2() { ::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, ()) }
4038
#[rustc_autodiff]
41-
#[inline(never)]
4239
pub fn f3(x: &[f64], y: f64) -> f64 {
4340
::core::panicking::panic("not implemented")
4441
}
@@ -49,14 +46,12 @@
4946
enum Foo { Reverse, }
5047
use Foo::Reverse;
5148
#[rustc_autodiff]
52-
#[inline(never)]
5349
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
5450
#[rustc_autodiff(Reverse, 1, Const, None)]
5551
pub fn df4(x: f32) {
5652
::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, (x,))
5753
}
5854
#[rustc_autodiff]
59-
#[inline(never)]
6055
pub fn f5(x: *const f32, y: &f32) {
6156
::core::panicking::panic("not implemented")
6257
}

tests/pretty/autodiff/inherent_impl.pp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6-
#[prelude_import]
7-
use ::std::prelude::rust_2015::*;
86
#[macro_use]
97
extern crate std;
8+
#[prelude_import]
9+
use ::std::prelude::rust_2015::*;
1010
//@ pretty-mode:expanded
1111
//@ pretty-compare-only
1212
//@ pp-exact:inherent_impl.pp
@@ -26,7 +26,6 @@
2626

2727
impl MyTrait for Foo {
2828
#[rustc_autodiff]
29-
#[inline(never)]
3029
fn f(&self, x: f64) -> f64 {
3130
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
3231
}

0 commit comments

Comments
 (0)