Skip to content

[VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations #148350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 21, 2025
Merged
107 changes: 77 additions & 30 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class VectorCombine {
bool foldInsExtFNeg(Instruction &I);
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
bool foldBitOpOfBitcasts(Instruction &I);
bool foldBitOpOfCastops(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
Expand Down Expand Up @@ -808,48 +808,87 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
return true;
}

bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
Value *LHSSrc, *RHSSrc;
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
m_BitCast(m_Value(RHSSrc)))))
/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
/// Supports: bitcast, trunc, sext, zext
bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
// Check if this is a bitwise logic operation
auto *BinOp = dyn_cast<BinaryOperator>(&I);
if (!BinOp || !BinOp->isBitwiseLogicOp())
return false;

// Get the cast instructions
auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
if (!LHSCast || !RHSCast) {
LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
return false;
}

// Both casts must be the same type
Instruction::CastOps CastOpcode = LHSCast->getOpcode();
if (CastOpcode != RHSCast->getOpcode())
return false;

// Only handle supported cast operations
switch (CastOpcode) {
case Instruction::BitCast:
case Instruction::Trunc:
case Instruction::SExt:
case Instruction::ZExt:
break;
default:
return false;
}

Value *LHSSrc = LHSCast->getOperand(0);
Value *RHSSrc = RHSCast->getOperand(0);

// Source types must match
if (LHSSrc->getType() != RHSSrc->getType())
return false;
if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
return false;

// Only handle vector types
// Only handle vector types with integer elements
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
if (!SrcVecTy || !DstVecTy)
return false;

// Same total bit width
assert(SrcVecTy->getPrimitiveSizeInBits() ==
DstVecTy->getPrimitiveSizeInBits() &&
"Bitcast should preserve total bit width");
if (!SrcVecTy->getScalarType()->isIntegerTy() ||
!DstVecTy->getScalarType()->isIntegerTy())
return false;

// Cost Check :
// OldCost = bitlogic + 2*bitcasts
// NewCost = bitlogic + bitcast
auto *BinOp = cast<BinaryOperator>(&I);
// OldCost = bitlogic + 2*casts
// NewCost = bitlogic + cast

// Calculate specific costs for each cast with instruction context
InstructionCost LHSCastCost =
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None, CostKind, LHSCast);
InstructionCost RHSCastCost =
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None, CostKind, RHSCast);

InstructionCost OldCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
TTI::CastContextHint::None) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
TTI::CastContextHint::None);
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy, CostKind) +
LHSCastCost + RHSCastCost;

// For new cost, we can't provide an instruction (it doesn't exist yet)
InstructionCost GenericCastCost = TTI.getCastInstrCost(
CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind);

InstructionCost NewCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy, CostKind) +
GenericCastCost;

LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
// Account for multi-use casts using specific costs
if (!LHSCast->hasOneUse())
NewCost += LHSCastCost;
if (!RHSCast->hasOneUse())
NewCost += RHSCastCost;

LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
<< " NewCost=" << NewCost << "\n");

if (NewCost > OldCost)
return false;
Expand All @@ -862,8 +901,16 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {

Worklist.pushValue(NewOp);

// Bitcast the result back
Value *Result = Builder.CreateBitCast(NewOp, I.getType());
// Create the cast operation directly to ensure we get a new instruction
Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());

// Preserve cast instruction flags
NewCast->copyIRFlags(LHSCast);
NewCast->andIRFlags(RHSCast);

// Insert the new instruction
Value *Result = Builder.Insert(NewCast);

replaceValue(I, *Result);
return true;
}
Expand Down Expand Up @@ -3773,7 +3820,7 @@ bool VectorCombine::run() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
MadeChange |= foldBitOpOfBitcasts(I);
MadeChange |= foldBitOpOfCastops(I);
break;
default:
MadeChange |= shrinkType(I);
Expand Down
Loading