@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
1048
1048
MVT::v32i32, MVT::v64i32, MVT::v128i32},
1049
1049
Custom);
1050
1050
1051
- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1052
- // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
1053
- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::i128 , Custom);
1051
+ // Enable custom lowering for the following:
1052
+ // * MVT::i128 - clusterlaunchcontrol
1053
+ // * MVT::i32 - prmt
1054
+ // * MVT::Other - internal.addrspace.wrap
1055
+ setOperationAction (ISD::INTRINSIC_WO_CHAIN, {MVT::i32 , MVT::i128 , MVT::Other},
1056
+ Custom);
1054
1057
}
1055
1058
1056
1059
const char *NVPTXTargetLowering::getTargetNodeName (unsigned Opcode) const {
@@ -2060,6 +2063,13 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2060
2063
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2061
2064
}
2062
2065
2066
+ static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
2067
+ SelectionDAG &DAG,
2068
+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2069
+ return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
2070
+ {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
2071
+ }
2072
+
2063
2073
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2064
2074
// Handle bitcasting from v2i8 without hitting the default promotion
2065
2075
// strategy which goes through stack memory.
@@ -2111,15 +2121,13 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2111
2121
L = DAG.getAnyExtOrTrunc (L, DL, MVT::i32 );
2112
2122
R = DAG.getAnyExtOrTrunc (R, DL, MVT::i32 );
2113
2123
}
2114
- return DAG.getNode (
2115
- NVPTXISD::PRMT, DL, MVT::v4i8,
2116
- {L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ),
2117
- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2124
+ return getPRMT (L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ), DL,
2125
+ DAG);
2118
2126
};
2119
2127
auto PRMT__10 = GetPRMT (Op->getOperand (0 ), Op->getOperand (1 ), true , 0x3340 );
2120
2128
auto PRMT__32 = GetPRMT (Op->getOperand (2 ), Op->getOperand (3 ), true , 0x3340 );
2121
2129
auto PRMT3210 = GetPRMT (PRMT__10, PRMT__32, false , 0x5410 );
2122
- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT3210);
2130
+ return DAG.getBitcast ( VT, PRMT3210);
2123
2131
}
2124
2132
2125
2133
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2184,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2176
2184
SDValue Selector = DAG.getNode (ISD::OR, DL, MVT::i32 ,
2177
2185
DAG.getZExtOrTrunc (Index, DL, MVT::i32 ),
2178
2186
DAG.getConstant (0x7770 , DL, MVT::i32 ));
2179
- SDValue PRMT = DAG.getNode (
2180
- NVPTXISD::PRMT, DL, MVT::i32 ,
2181
- {DAG.getBitcast (MVT::i32 , Vector), DAG.getConstant (0 , DL, MVT::i32 ),
2182
- Selector, DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2183
- return DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2187
+ SDValue PRMT = getPRMT (DAG.getBitcast (MVT::i32 , Vector),
2188
+ DAG.getConstant (0 , DL, MVT::i32 ), Selector, DL, DAG);
2189
+ SDValue Ext = DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2190
+ SDNodeFlags Flags;
2191
+ Flags.setNoSignedWrap (Ext.getScalarValueSizeInBits () > 8 );
2192
+ Flags.setNoUnsignedWrap (Ext.getScalarValueSizeInBits () >= 8 );
2193
+ Ext->setFlags (Flags);
2194
+ return Ext;
2184
2195
}
2185
2196
2186
2197
// Constant index will be matched by tablegen.
@@ -2242,9 +2253,10 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2242
2253
}
2243
2254
2244
2255
SDLoc DL (Op);
2245
- return DAG.getNode (NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2246
- DAG.getConstant (Selector, DL, MVT::i32 ),
2247
- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 ));
2256
+ SDValue PRMT =
2257
+ getPRMT (DAG.getBitcast (MVT::i32 , V1), DAG.getBitcast (MVT::i32 , V2),
2258
+ DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG);
2259
+ return DAG.getBitcast (Op.getValueType (), PRMT);
2248
2260
}
2249
2261
// / LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2250
2262
// / 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2741,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2729
2741
{TryCancelResponse0, TryCancelResponse1});
2730
2742
}
2731
2743
2744
+ static SDValue lowerPrmtIntrinsic (SDValue Op, SelectionDAG &DAG) {
2745
+ const unsigned Mode = [&]() {
2746
+ switch (Op->getConstantOperandVal (0 )) {
2747
+ case Intrinsic::nvvm_prmt:
2748
+ return NVPTX::PTXPrmtMode::NONE;
2749
+ case Intrinsic::nvvm_prmt_b4e:
2750
+ return NVPTX::PTXPrmtMode::B4E;
2751
+ case Intrinsic::nvvm_prmt_ecl:
2752
+ return NVPTX::PTXPrmtMode::ECL;
2753
+ case Intrinsic::nvvm_prmt_ecr:
2754
+ return NVPTX::PTXPrmtMode::ECR;
2755
+ case Intrinsic::nvvm_prmt_f4e:
2756
+ return NVPTX::PTXPrmtMode::F4E;
2757
+ case Intrinsic::nvvm_prmt_rc16:
2758
+ return NVPTX::PTXPrmtMode::RC16;
2759
+ case Intrinsic::nvvm_prmt_rc8:
2760
+ return NVPTX::PTXPrmtMode::RC8;
2761
+ default :
2762
+ llvm_unreachable (" unsupported/unhandled intrinsic" );
2763
+ }
2764
+ }();
2765
+ SDLoc DL (Op);
2766
+ SDValue A = Op->getOperand (1 );
2767
+ SDValue B = Op.getNumOperands () == 4 ? Op.getOperand (2 )
2768
+ : DAG.getConstant (0 , DL, MVT::i32 );
2769
+ SDValue Selector = (Op->op_end () - 1 )->get ();
2770
+ return getPRMT (A, B, Selector, DL, DAG, Mode);
2771
+ }
2732
2772
static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
2733
2773
switch (Op->getConstantOperandVal (0 )) {
2734
2774
default :
2735
2775
return Op;
2776
+ case Intrinsic::nvvm_prmt:
2777
+ case Intrinsic::nvvm_prmt_b4e:
2778
+ case Intrinsic::nvvm_prmt_ecl:
2779
+ case Intrinsic::nvvm_prmt_ecr:
2780
+ case Intrinsic::nvvm_prmt_f4e:
2781
+ case Intrinsic::nvvm_prmt_rc16:
2782
+ case Intrinsic::nvvm_prmt_rc8:
2783
+ return lowerPrmtIntrinsic (Op, DAG);
2736
2784
case Intrinsic::nvvm_internal_addrspace_wrap:
2737
2785
return Op.getOperand (1 );
2738
2786
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5823,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5775
5823
SDLoc DL (N);
5776
5824
auto &DAG = DCI.DAG ;
5777
5825
5778
- auto PRMT = DAG.getNode (
5779
- NVPTXISD::PRMT, DL, MVT::v4i8,
5780
- {Op0, Op1, DAG.getConstant ((Op1Bytes << 8 ) | Op0Bytes, DL, MVT::i32 ),
5781
- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
5782
- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT);
5826
+ auto PRMT = getPRMT (
5827
+ DAG.getBitcast (MVT::i32 , Op0), DAG.getBitcast (MVT::i32 , Op1),
5828
+ DAG.getConstant ((Op1Bytes << 8 ) | Op0Bytes, DL, MVT::i32 ), DL, DAG);
5829
+ return DAG.getBitcast (VT, PRMT);
5783
5830
}
5784
5831
5785
5832
static SDValue combineADDRSPACECAST (SDNode *N,
@@ -5797,47 +5844,116 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5797
5844
return SDValue ();
5798
5845
}
5799
5846
5847
+ static APInt getPRMTSelector (APInt Selector, unsigned Mode) {
5848
+ if (Mode == NVPTX::PTXPrmtMode::NONE)
5849
+ return Selector;
5850
+
5851
+ unsigned V = Selector.trunc (2 ).getZExtValue ();
5852
+
5853
+ const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
5854
+ unsigned S3) {
5855
+ return APInt (32 , S0 | (S1 << 4 ) | (S2 << 8 ) | (S3 << 12 ));
5856
+ };
5857
+
5858
+ switch (Mode) {
5859
+ case NVPTX::PTXPrmtMode::F4E:
5860
+ return GetSelector (V, V + 1 , V + 2 , V + 3 );
5861
+ case NVPTX::PTXPrmtMode::B4E:
5862
+ return GetSelector (V, (V - 1 ) & 7 , (V - 2 ) & 7 , (V - 3 ) & 7 );
5863
+ case NVPTX::PTXPrmtMode::RC8:
5864
+ return GetSelector (V, V, V, V);
5865
+ case NVPTX::PTXPrmtMode::ECL:
5866
+ return GetSelector (V, std::max (V, 1U ), std::max (V, 2U ), 3U );
5867
+ case NVPTX::PTXPrmtMode::ECR:
5868
+ return GetSelector (0 , std::min (V, 1U ), std::min (V, 2U ), V);
5869
+ case NVPTX::PTXPrmtMode::RC16: {
5870
+ unsigned V1 = (V & 1 ) << 1 ;
5871
+ return GetSelector (V1, V1 + 1 , V1, V1 + 1 );
5872
+ }
5873
+ default :
5874
+ llvm_unreachable (" Invalid PRMT mode" );
5875
+ }
5876
+ }
5877
+
5878
+ static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5879
+ // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
5880
+ APInt BitField = B.concat (A);
5881
+ APInt SelectorVal = getPRMTSelector (Selector, Mode);
5882
+ APInt Result (32 , 0 );
5883
+ for (unsigned I : llvm::seq (4U )) {
5884
+ APInt Sel = SelectorVal.extractBits (4 , I * 4 );
5885
+ unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
5886
+ unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
5887
+ APInt Byte = BitField.extractBits (8 , Idx * 8 );
5888
+ if (Sign)
5889
+ Byte = Byte.ashr (8 );
5890
+ Result.insertBits (Byte, I * 8 );
5891
+ }
5892
+ return Result;
5893
+ }
5894
+
5895
+ static SDValue combinePRMT (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5896
+ CodeGenOptLevel OptLevel) {
5897
+ if (OptLevel == CodeGenOptLevel::None)
5898
+ return SDValue ();
5899
+
5900
+ // Constant fold PRMT
5901
+ if (isa<ConstantSDNode>(N->getOperand (0 )) &&
5902
+ isa<ConstantSDNode>(N->getOperand (1 )) &&
5903
+ isa<ConstantSDNode>(N->getOperand (2 )))
5904
+ return DCI.DAG .getConstant (computePRMT (N->getConstantOperandAPInt (0 ),
5905
+ N->getConstantOperandAPInt (1 ),
5906
+ N->getConstantOperandAPInt (2 ),
5907
+ N->getConstantOperandVal (3 )),
5908
+ SDLoc (N), N->getValueType (0 ));
5909
+
5910
+ return SDValue ();
5911
+ }
5912
+
5800
5913
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5801
5914
DAGCombinerInfo &DCI) const {
5802
5915
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
5803
5916
switch (N->getOpcode ()) {
5804
- default : break ;
5805
- case ISD::ADD:
5806
- return PerformADDCombine (N, DCI, OptLevel);
5807
- case ISD::FADD:
5808
- return PerformFADDCombine (N, DCI, OptLevel);
5809
- case ISD::MUL:
5810
- return PerformMULCombine (N, DCI, OptLevel);
5811
- case ISD::SHL:
5812
- return PerformSHLCombine (N, DCI, OptLevel);
5813
- case ISD::AND:
5814
- return PerformANDCombine (N, DCI);
5815
- case ISD::UREM:
5816
- case ISD::SREM:
5817
- return PerformREMCombine (N, DCI, OptLevel);
5818
- case ISD::SETCC:
5819
- return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5820
- case ISD::LOAD:
5821
- case NVPTXISD::LoadParamV2:
5822
- case NVPTXISD::LoadV2:
5823
- case NVPTXISD::LoadV4:
5824
- return combineUnpackingMovIntoLoad (N, DCI);
5825
- case NVPTXISD::StoreParam:
5826
- case NVPTXISD::StoreParamV2:
5827
- case NVPTXISD::StoreParamV4:
5828
- return PerformStoreParamCombine (N, DCI);
5829
- case ISD::STORE:
5830
- case NVPTXISD::StoreV2:
5831
- case NVPTXISD::StoreV4:
5832
- return PerformStoreCombine (N, DCI);
5833
- case ISD::EXTRACT_VECTOR_ELT:
5834
- return PerformEXTRACTCombine (N, DCI);
5835
- case ISD::VSELECT:
5836
- return PerformVSELECTCombine (N, DCI);
5837
- case ISD::BUILD_VECTOR:
5838
- return PerformBUILD_VECTORCombine (N, DCI);
5839
- case ISD::ADDRSPACECAST:
5840
- return combineADDRSPACECAST (N, DCI);
5917
+ default :
5918
+ break ;
5919
+ case ISD::ADD:
5920
+ return PerformADDCombine (N, DCI, OptLevel);
5921
+ case ISD::FADD:
5922
+ return PerformFADDCombine (N, DCI, OptLevel);
5923
+ case ISD::MUL:
5924
+ return PerformMULCombine (N, DCI, OptLevel);
5925
+ case ISD::SHL:
5926
+ return PerformSHLCombine (N, DCI, OptLevel);
5927
+ case ISD::AND:
5928
+ return PerformANDCombine (N, DCI);
5929
+ case ISD::UREM:
5930
+ case ISD::SREM:
5931
+ return PerformREMCombine (N, DCI, OptLevel);
5932
+ case ISD::SETCC:
5933
+ return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5934
+ case ISD::LOAD:
5935
+ case NVPTXISD::LoadParamV2:
5936
+ case NVPTXISD::LoadV2:
5937
+ case NVPTXISD::LoadV4:
5938
+ return combineUnpackingMovIntoLoad (N, DCI);
5939
+ case NVPTXISD::StoreParam:
5940
+ case NVPTXISD::StoreParamV2:
5941
+ case NVPTXISD::StoreParamV4:
5942
+ return PerformStoreParamCombine (N, DCI);
5943
+ case ISD::STORE:
5944
+ case NVPTXISD::StoreV2:
5945
+ case NVPTXISD::StoreV4:
5946
+ return PerformStoreCombine (N, DCI);
5947
+ case ISD::EXTRACT_VECTOR_ELT:
5948
+ return PerformEXTRACTCombine (N, DCI);
5949
+ case ISD::VSELECT:
5950
+ return PerformVSELECTCombine (N, DCI);
5951
+ case ISD::BUILD_VECTOR:
5952
+ return PerformBUILD_VECTORCombine (N, DCI);
5953
+ case ISD::ADDRSPACECAST:
5954
+ return combineADDRSPACECAST (N, DCI);
5955
+ case NVPTXISD::PRMT:
5956
+ return combinePRMT (N, DCI, OptLevel);
5841
5957
}
5842
5958
return SDValue ();
5843
5959
}
@@ -6385,7 +6501,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6385
6501
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand (2 ));
6386
6502
unsigned Mode = Op.getConstantOperandVal (3 );
6387
6503
6388
- if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6504
+ if (!Selector)
6389
6505
return ;
6390
6506
6391
6507
KnownBits AKnown = DAG.computeKnownBits (A, Depth);
@@ -6394,7 +6510,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6394
6510
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6395
6511
KnownBits BitField = BKnown.concat (AKnown);
6396
6512
6397
- APInt SelectorVal = Selector->getAPIntValue ();
6513
+ APInt SelectorVal = getPRMTSelector ( Selector->getAPIntValue (), Mode );
6398
6514
for (unsigned I : llvm::seq (std::min (4U , Known.getBitWidth () / 8 ))) {
6399
6515
APInt Sel = SelectorVal.extractBits (4 , I * 4 );
6400
6516
unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
0 commit comments