Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2568,6 +2568,22 @@ where

pointee_info
}

fn is_adt(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Adt(..))
}

fn is_never(this: TyAndLayout<'tcx>) -> bool {
this.ty.kind() == &ty::Never
}

fn is_tuple(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Tuple(..))
}

fn is_unit(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0)
}
}

impl<'tcx> ty::Instance<'tcx> {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl<T> List<T> {
static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign);
unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) }
}

pub fn len(&self) -> usize {
self.len
}
}

impl<T: Copy> List<T> {
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_target/src/abi/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> {
"sparc" => sparc::compute_abi_info(cx, self),
"sparc64" => sparc64::compute_abi_info(cx, self),
"nvptx" => nvptx::compute_abi_info(self),
"nvptx64" => nvptx64::compute_abi_info(self),
"nvptx64" => {
if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel {
nvptx64::compute_ptx_kernel_abi_info(cx, self)
} else {
nvptx64::compute_abi_info(self)
}
}
"hexagon" => hexagon::compute_abi_info(self),
"riscv32" | "riscv64" => riscv::compute_abi_info(cx, self),
"wasm32" | "wasm64" => {
Expand Down
47 changes: 39 additions & 8 deletions compiler/rustc_target/src/abi/call/nvptx64.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
// Reference: PTX Writer's Guide to Interoperability
// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability

use crate::abi::call::{ArgAbi, FnAbi};
use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
use crate::abi::{HasDataLayout, TyAbiInterface};

fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 {
ret.make_indirect();
} else {
ret.extend_integer_width_to(64);
}
}

fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 {
arg.make_indirect();
} else {
arg.extend_integer_width_to(64);
}
}

fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
let align_bytes = arg.layout.align.abi.bytes();

let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) });
}
}

Expand All @@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
classify_arg(arg);
}
}

pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
panic!("Kernels should not return anything other than () or !");
}

for arg in &mut fn_abi.args {
if arg.is_ignore() {
continue;
}
classify_arg_kernel(cx, arg);
}
}
32 changes: 32 additions & 0 deletions compiler/rustc_target/src/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ pub trait TyAbiInterface<'a, C>: Sized {
cx: &C,
offset: Size,
) -> Option<PointeeInfo>;
fn is_adt(this: TyAndLayout<'a, Self>) -> bool;
fn is_never(this: TyAndLayout<'a, Self>) -> bool;
fn is_tuple(this: TyAndLayout<'a, Self>) -> bool;
fn is_unit(this: TyAndLayout<'a, Self>) -> bool;
}

impl<'a, Ty> TyAndLayout<'a, Ty> {
Expand Down Expand Up @@ -1291,6 +1295,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
_ => false,
}
}

pub fn is_adt<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_adt(self)
}

pub fn is_never<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_never(self)
}

pub fn is_tuple<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_tuple(self)
}

pub fn is_unit<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_unit(self)
}
}

impl<'a, Ty> TyAndLayout<'a, Ty> {
Expand Down
Loading