Skip to content

[RISCV] Refactor GPRF64 register class to make it usable for Zacas. #77408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ unsigned RISCVAsmParser::checkTargetMatchPredicate(MCInst &Inst) {
const MCInstrDesc &MCID = MII.get(Inst.getOpcode());

for (unsigned I = 0; I < MCID.NumOperands; ++I) {
if (MCID.operands()[I].RegClass == RISCV::GPRPF64RegClassID) {
if (MCID.operands()[I].RegClass == RISCV::GPRPairRegClassID) {
const auto &Op = Inst.getOperand(I);
assert(Op.isReg());

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/Disassembler/RISCVDisassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ static DecodeStatus DecodeGPRCRegisterClass(MCInst &Inst, uint32_t RegNo,
return MCDisassembler::Success;
}

static DecodeStatus DecodeGPRPF64RegisterClass(MCInst &Inst, uint32_t RegNo,
static DecodeStatus DecodeGPRPairRegisterClass(MCInst &Inst, uint32_t RegNo,
uint64_t Address,
const MCDisassembler *Decoder) {
if (RegNo >= 32 || RegNo & 1)
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ bool RISCVExpandPseudo::expandRV32ZdinxStore(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
DebugLoc DL = MBBI->getDebugLoc();
const TargetRegisterInfo *TRI = STI->getRegisterInfo();
Register Lo = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32);
Register Hi = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32_hi);
Register Lo =
TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_even);
Register Hi =
TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_odd);
BuildMI(MBB, MBBI, DL, TII->get(RISCV::SW))
.addReg(Lo, getKillRegState(MBBI->getOperand(0).isKill()))
.addReg(MBBI->getOperand(1).getReg())
Expand Down Expand Up @@ -334,8 +336,10 @@ bool RISCVExpandPseudo::expandRV32ZdinxLoad(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
DebugLoc DL = MBBI->getDebugLoc();
const TargetRegisterInfo *TRI = STI->getRegisterInfo();
Register Lo = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32);
Register Hi = TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_32_hi);
Register Lo =
TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_even);
Register Hi =
TRI->getSubReg(MBBI->getOperand(0).getReg(), RISCV::sub_gpr_odd);

// If the register of operand 1 is equal to the Lo register, then swap the
// order of loading the Lo and Hi statements.
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit())
addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
else
addRegisterClass(MVT::f64, &RISCV::GPRPF64RegClass);
addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
}

static const MVT::SimpleValueType BoolVecVTs[] = {
Expand Down Expand Up @@ -16345,7 +16345,7 @@ static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
Register SrcReg = MI.getOperand(2).getReg();

const TargetRegisterClass *SrcRC = MI.getOpcode() == RISCV::SplitF64Pseudo_INX
? &RISCV::GPRPF64RegClass
? &RISCV::GPRPairRegClass
: &RISCV::FPR64RegClass;
int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);

Expand Down Expand Up @@ -16384,7 +16384,7 @@ static MachineBasicBlock *emitBuildPairF64Pseudo(MachineInstr &MI,
Register HiReg = MI.getOperand(2).getReg();

const TargetRegisterClass *DstRC =
MI.getOpcode() == RISCV::BuildPairF64Pseudo_INX ? &RISCV::GPRPF64RegClass
MI.getOpcode() == RISCV::BuildPairF64Pseudo_INX ? &RISCV::GPRPairRegClass
: &RISCV::FPR64RegClass;
int FI = MF.getInfo<RISCVMachineFunctionInfo>()->getMoveF64FrameIndex(MF);

Expand Down Expand Up @@ -18751,7 +18751,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32RegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRPF64RegClass);
return std::make_pair(0U, &RISCV::GPRPairRegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
Expand Down Expand Up @@ -18933,7 +18933,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
// Subtarget into account.
if (Res.second == &RISCV::GPRF16RegClass ||
Res.second == &RISCV::GPRF32RegClass ||
Res.second == &RISCV::GPRPF64RegClass)
Res.second == &RISCV::GPRPairRegClass)
return std::make_pair(Res.first, &RISCV::GPRRegClass);

return Res;
Expand Down
17 changes: 9 additions & 8 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,16 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}

if (RISCV::GPRPF64RegClass.contains(DstReg, SrcReg)) {
// Emit an ADDI for both parts of GPRPF64.
if (RISCV::GPRPairRegClass.contains(DstReg, SrcReg)) {
// Emit an ADDI for both parts of GPRPair.
BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
TRI->getSubReg(DstReg, RISCV::sub_32))
.addReg(TRI->getSubReg(SrcReg, RISCV::sub_32), getKillRegState(KillSrc))
TRI->getSubReg(DstReg, RISCV::sub_gpr_even))
.addReg(TRI->getSubReg(SrcReg, RISCV::sub_gpr_even),
getKillRegState(KillSrc))
.addImm(0);
BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
TRI->getSubReg(DstReg, RISCV::sub_32_hi))
.addReg(TRI->getSubReg(SrcReg, RISCV::sub_32_hi),
TRI->getSubReg(DstReg, RISCV::sub_gpr_odd))
.addReg(TRI->getSubReg(SrcReg, RISCV::sub_gpr_odd),
getKillRegState(KillSrc))
.addImm(0);
return;
Expand Down Expand Up @@ -607,7 +608,7 @@ void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
RISCV::SW : RISCV::SD;
IsScalableVector = false;
} else if (RISCV::GPRPF64RegClass.hasSubClassEq(RC)) {
} else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
Opcode = RISCV::PseudoRV32ZdinxSD;
IsScalableVector = false;
} else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
Expand Down Expand Up @@ -690,7 +691,7 @@ void RISCVInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
RISCV::LW : RISCV::LD;
IsScalableVector = false;
} else if (RISCV::GPRPF64RegClass.hasSubClassEq(RC)) {
} else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
Opcode = RISCV::PseudoRV32ZdinxLD;
IsScalableVector = false;
} else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
Expand Down
16 changes: 8 additions & 8 deletions llvm/lib/Target/RISCV/RISCVInstrInfoD.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def AddrRegImmINX : ComplexPattern<iPTR, 2, "SelectAddrRegImmINX">;

// Zdinx

def GPRPF64AsFPR : AsmOperandClass {
let Name = "GPRPF64AsFPR";
def GPRPairAsFPR : AsmOperandClass {
let Name = "GPRPairAsFPR";
let ParserMethod = "parseGPRAsFPR";
let PredicateMethod = "isGPRAsFPR";
let RenderMethod = "addRegOperands";
Expand All @@ -52,8 +52,8 @@ def FPR64INX : RegisterOperand<GPR> {
let DecoderMethod = "DecodeGPRRegisterClass";
}

def FPR64IN32X : RegisterOperand<GPRPF64> {
let ParserMatchClass = GPRPF64AsFPR;
def FPR64IN32X : RegisterOperand<GPRPair> {
let ParserMatchClass = GPRPairAsFPR;
}

def DExt : ExtInfo<"", "", [HasStdExtD], f64, FPR64, FPR32, FPR64, ?>;
Expand Down Expand Up @@ -515,15 +515,15 @@ def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>;

/// Loads
let isCall = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in
def PseudoRV32ZdinxLD : Pseudo<(outs GPRPF64:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def PseudoRV32ZdinxLD : Pseudo<(outs GPRPair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def : Pat<(f64 (load (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12))),
(PseudoRV32ZdinxLD GPR:$rs1, simm12:$imm12)>;

/// Stores
let isCall = 0, mayLoad = 0, mayStore = 1, Size = 8, isCodeGenOnly = 1 in
def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPF64:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
def : Pat<(store (f64 GPRPF64:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
(PseudoRV32ZdinxSD GPRPF64:$rs2, GPR:$rs1, simm12:$imm12)>;
def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
def : Pat<(store (f64 GPRPair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
(PseudoRV32ZdinxSD GPRPair:$rs2, GPR:$rs1, simm12:$imm12)>;

/// Pseudo-instructions needed for the soft-float ABI with RV32D

Expand Down
45 changes: 27 additions & 18 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def sub_vrm1_5 : ComposedSubRegIndex<sub_vrm2_2, sub_vrm1_1>;
def sub_vrm1_6 : ComposedSubRegIndex<sub_vrm2_3, sub_vrm1_0>;
def sub_vrm1_7 : ComposedSubRegIndex<sub_vrm2_3, sub_vrm1_1>;

def sub_32_hi : SubRegIndex<32, 32>;
// GPR sizes change with HwMode.
// FIXME: Support HwMode in SubRegIndex?
def sub_gpr_even : SubRegIndex<-1>;
def sub_gpr_odd : SubRegIndex<-1, -1>;
} // Namespace = "RISCV"

// Integer registers
Expand Down Expand Up @@ -118,6 +121,8 @@ def XLenVT : ValueTypeByHwMode<[RV32, RV64],
// Allow f64 in GPR for ZDINX on RV64.
def XLenFVT : ValueTypeByHwMode<[RV64],
[f64]>;
def XLenPairFVT : ValueTypeByHwMode<[RV32],
[f64]>;
def XLenRI : RegInfoByHwMode<
[RV32, RV64],
[RegInfo<32,32,32>, RegInfo<64,64,64>]>;
Expand Down Expand Up @@ -546,33 +551,37 @@ def DUMMY_REG_PAIR_WITH_X0 : RISCVReg<0, "0">;
def GPRAll : GPRRegisterClass<(add GPR, DUMMY_REG_PAIR_WITH_X0)>;

let RegAltNameIndices = [ABIRegAltName] in {
def X0_PD : RISCVRegWithSubRegs<0, X0.AsmName,
[X0, DUMMY_REG_PAIR_WITH_X0],
X0.AltNames> {
let SubRegIndices = [sub_32, sub_32_hi];
def X0_Pair : RISCVRegWithSubRegs<0, X0.AsmName,
[X0, DUMMY_REG_PAIR_WITH_X0],
X0.AltNames> {
let SubRegIndices = [sub_gpr_even, sub_gpr_odd];
let CoveredBySubRegs = 1;
}
foreach I = 1-15 in {
defvar Index = !shl(I, 1);
defvar IndexP1 = !add(Index, 1);
defvar Reg = !cast<Register>("X"#Index);
defvar RegP1 = !cast<Register>("X"#!add(Index,1));
def X#Index#_PD : RISCVRegWithSubRegs<Index, Reg.AsmName,
[Reg, RegP1],
Reg.AltNames> {
let SubRegIndices = [sub_32, sub_32_hi];
defvar RegP1 = !cast<Register>("X"#IndexP1);
def "X" # Index #"_X" # IndexP1 : RISCVRegWithSubRegs<Index,
Reg.AsmName,
[Reg, RegP1],
Reg.AltNames> {
let SubRegIndices = [sub_gpr_even, sub_gpr_odd];
let CoveredBySubRegs = 1;
}
}
}

let RegInfos = RegInfoByHwMode<[RV64], [RegInfo<64, 64, 64>]> in
def GPRPF64 : RegisterClass<"RISCV", [f64], 64, (add
X10_PD, X12_PD, X14_PD, X16_PD,
X6_PD,
X28_PD, X30_PD,
X8_PD,
X18_PD, X20_PD, X22_PD, X24_PD, X26_PD,
X0_PD, X2_PD, X4_PD
let RegInfos = RegInfoByHwMode<[RV32, RV64],
[RegInfo<64, 64, 64>, RegInfo<128, 128, 128>]>,
DecoderMethod = "DecodeGPRPairRegisterClass" in
def GPRPair : RegisterClass<"RISCV", [XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X16_X17,
X6_X7,
X28_X29, X30_X31,
X8_X9,
X18_X19, X20_X21, X22_X23, X24_X25, X26_X27,
X0_Pair, X2_X3, X4_X5
)>;

// The register class is added for inline assembly for vector mask types.
Expand Down