diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 483 |
1 files changed, 321 insertions, 162 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 53c00affd70e..b98ac635e00d 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -949,6 +949,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MULHU, MVT::v8i16, Legal); setOperationAction(ISD::MULHS, MVT::v8i16, Legal); setOperationAction(ISD::MUL, MVT::v8i16, Legal); + setOperationAction(ISD::AVGCEILU, MVT::v16i8, Legal); + setOperationAction(ISD::AVGCEILU, MVT::v8i16, Legal); setOperationAction(ISD::SMULO, MVT::v16i8, Custom); setOperationAction(ISD::UMULO, MVT::v16i8, Custom); @@ -1285,13 +1287,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, if (VT == MVT::v4i64) continue; setOperationAction(ISD::ROTL, VT, Custom); setOperationAction(ISD::ROTR, VT, Custom); + setOperationAction(ISD::FSHL, VT, Custom); + setOperationAction(ISD::FSHR, VT, Custom); } - setOperationAction(ISD::FSHL, MVT::v32i8, Custom); - setOperationAction(ISD::FSHR, MVT::v32i8, Custom); - setOperationAction(ISD::FSHL, MVT::v8i32, Custom); - setOperationAction(ISD::FSHR, MVT::v8i32, Custom); - // These types need custom splitting if their input is a 128-bit vector. setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom); @@ -1353,6 +1352,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MULHS, MVT::v16i16, HasInt256 ? Legal : Custom); setOperationAction(ISD::MULHU, MVT::v32i8, Custom); setOperationAction(ISD::MULHS, MVT::v32i8, Custom); + setOperationAction(ISD::AVGCEILU, MVT::v16i16, HasInt256 ? Legal : Custom); + setOperationAction(ISD::AVGCEILU, MVT::v32i8, HasInt256 ? Legal : Custom); setOperationAction(ISD::SMULO, MVT::v32i8, Custom); setOperationAction(ISD::UMULO, MVT::v32i8, Custom); @@ -1652,6 +1653,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MULHU, MVT::v32i16, HasBWI ? Legal : Custom); setOperationAction(ISD::MULHS, MVT::v64i8, Custom); setOperationAction(ISD::MULHU, MVT::v64i8, Custom); + setOperationAction(ISD::AVGCEILU, MVT::v32i16, HasBWI ? Legal : Custom); + setOperationAction(ISD::AVGCEILU, MVT::v64i8, HasBWI ? Legal : Custom); setOperationAction(ISD::SMULO, MVT::v64i8, Custom); setOperationAction(ISD::UMULO, MVT::v64i8, Custom); @@ -25700,6 +25703,89 @@ static SDValue getTargetVShiftByConstNode(unsigned Opc, const SDLoc &dl, MVT VT, DAG.getTargetConstant(ShiftAmt, dl, MVT::i8)); } +/// Handle vector element shifts by a splat shift amount +static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, + SDValue SrcOp, SDValue ShAmt, int ShAmtIdx, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + MVT AmtVT = ShAmt.getSimpleValueType(); + assert(AmtVT.isVector() && "Vector shift type mismatch"); + assert(0 <= ShAmtIdx && ShAmtIdx < (int)AmtVT.getVectorNumElements() && + "Illegal vector splat index"); + + // Move the splat element to the bottom element. + if (ShAmtIdx != 0) { + SmallVector<int> Mask(AmtVT.getVectorNumElements(), -1); + Mask[0] = ShAmtIdx; + ShAmt = DAG.getVectorShuffle(AmtVT, dl, ShAmt, DAG.getUNDEF(AmtVT), Mask); + } + + // See if we can mask off the upper elements using the existing source node. + // The shift uses the entire lower 64-bits of the amount vector, so no need to + // do this for vXi64 types. + bool IsMasked = false; + if (AmtVT.getScalarSizeInBits() < 64) { + if (ShAmt.getOpcode() == ISD::BUILD_VECTOR || + ShAmt.getOpcode() == ISD::SCALAR_TO_VECTOR) { + // If the shift amount has come from a scalar, then zero-extend the scalar + // before moving to the vector. + ShAmt = DAG.getZExtOrTrunc(ShAmt.getOperand(0), dl, MVT::i32); + ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, ShAmt); + ShAmt = DAG.getNode(X86ISD::VZEXT_MOVL, dl, MVT::v4i32, ShAmt); + AmtVT = MVT::v4i32; + IsMasked = true; + } else if (ShAmt.getOpcode() == ISD::AND) { + // See if the shift amount is already masked (e.g. for rotation modulo), + // then we can zero-extend it by setting all the other mask elements to + // zero. + SmallVector<SDValue> MaskElts( + AmtVT.getVectorNumElements(), + DAG.getConstant(0, dl, AmtVT.getScalarType())); + MaskElts[0] = DAG.getAllOnesConstant(dl, AmtVT.getScalarType()); + SDValue Mask = DAG.getBuildVector(AmtVT, dl, MaskElts); + if ((Mask = DAG.FoldConstantArithmetic(ISD::AND, dl, AmtVT, + {ShAmt.getOperand(1), Mask}))) { + ShAmt = DAG.getNode(ISD::AND, dl, AmtVT, ShAmt.getOperand(0), Mask); + IsMasked = true; + } + } + } + + // Extract if the shift amount vector is larger than 128-bits. + if (AmtVT.getSizeInBits() > 128) { + ShAmt = extract128BitVector(ShAmt, 0, DAG, dl); + AmtVT = ShAmt.getSimpleValueType(); + } + + // Zero-extend bottom element to v2i64 vector type, either by extension or + // shuffle masking. + if (!IsMasked && AmtVT.getScalarSizeInBits() < 64) { + if (Subtarget.hasSSE41()) + ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt), + MVT::v2i64, ShAmt); + else { + SDValue ByteShift = DAG.getTargetConstant( + (128 - AmtVT.getScalarSizeInBits()) / 8, SDLoc(ShAmt), MVT::i8); + ShAmt = DAG.getBitcast(MVT::v16i8, ShAmt); + ShAmt = DAG.getNode(X86ISD::VSHLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt, + ByteShift); + ShAmt = DAG.getNode(X86ISD::VSRLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt, + ByteShift); + } + } + + // Change opcode to non-immediate version. + Opc = getTargetVShiftUniformOpcode(Opc, true); + + // The return type has to be a 128-bit type with the same element + // type as the input type. + MVT EltVT = VT.getVectorElementType(); + MVT ShVT = MVT::getVectorVT(EltVT, 128 / EltVT.getSizeInBits()); + + ShAmt = DAG.getBitcast(ShVT, ShAmt); + return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt); +} + /// Handle vector element shifts where the shift amount may or may not be a /// constant. Takes immediate version of shift as input. /// TODO: Replace with vector + (splat) idx to avoid extract_element nodes. @@ -26444,6 +26530,8 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case VSHIFT: { SDValue SrcOp = Op.getOperand(1); SDValue ShAmt = Op.getOperand(2); + assert(ShAmt.getValueType() == MVT::i32 && + "Unexpected VSHIFT amount type"); // Catch shift-by-constant. if (auto *CShAmt = dyn_cast<ConstantSDNode>(ShAmt)) @@ -26451,8 +26539,9 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getSimpleValueType(), SrcOp, CShAmt->getZExtValue(), DAG); + ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, ShAmt); return getTargetVShiftNode(IntrData->Opc0, dl, Op.getSimpleValueType(), - SrcOp, ShAmt, Subtarget, DAG); + SrcOp, ShAmt, 0, Subtarget, DAG); } case COMPRESS_EXPAND_IN_REG: { SDValue Mask = Op.getOperand(3); @@ -28394,6 +28483,21 @@ static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget, return SDValue(); } +static SDValue LowerAVG(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + MVT VT = Op.getSimpleValueType(); + + // For AVX1 cases, split to use legal ops (everything but v4i64). + if (VT.is256BitVector() && !Subtarget.hasInt256()) + return splitVectorIntBinary(Op, DAG); + + if (VT == MVT::v32i16 || VT == MVT::v64i8) + return splitVectorIntBinary(Op, DAG); + + // Default to expand. + return SDValue(); +} + static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); @@ -29843,8 +29947,8 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget, {Op0, Op1, Amt}, DAG, Subtarget); } assert((VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 || - VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v8i32 || - VT == MVT::v16i32) && + VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v4i32 || + VT == MVT::v8i32 || VT == MVT::v16i32) && "Unexpected funnel shift type!"); // fshl(x,y,z) -> unpack(y,x) << (z & (bw-1))) >> bw. @@ -29867,7 +29971,7 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget, // Split 256-bit integers on XOP/pre-AVX2 targets. // Split 512-bit integers on non 512-bit BWI targets. - if ((VT.is256BitVector() && ((Subtarget.hasXOP() && EltSizeInBits < 32) || + if ((VT.is256BitVector() && ((Subtarget.hasXOP() && EltSizeInBits < 16) || !Subtarget.hasAVX2())) || (VT.is512BitVector() && !Subtarget.useBWIRegs() && EltSizeInBits < 32)) { @@ -29878,18 +29982,18 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget, // Attempt to fold scalar shift as unpack(y,x) << zext(splat(z)) if (supportedVectorShiftWithBaseAmnt(ExtVT, Subtarget, ShiftOpc)) { - if (SDValue ScalarAmt = DAG.getSplatValue(AmtMod)) { + int ScalarAmtIdx = -1; + if (SDValue ScalarAmt = DAG.getSplatSourceVector(AmtMod, ScalarAmtIdx)) { // Uniform vXi16 funnel shifts can be efficiently handled by default. if (EltSizeInBits == 16) return SDValue(); SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0)); SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0)); - ScalarAmt = DAG.getZExtOrTrunc(ScalarAmt, DL, MVT::i32); - Lo = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Lo, ScalarAmt, Subtarget, - DAG); - Hi = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Hi, ScalarAmt, Subtarget, - DAG); + Lo = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Lo, ScalarAmt, + ScalarAmtIdx, Subtarget, DAG); + Hi = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Hi, ScalarAmt, + ScalarAmtIdx, Subtarget, DAG); return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR); } } @@ -30082,15 +30186,15 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, // TODO: Handle vXi16 cases on all targets. if (EltSizeInBits == 8 || EltSizeInBits == 32 || (IsROTL && EltSizeInBits == 16 && !Subtarget.hasAVX())) { - if (SDValue BaseRotAmt = DAG.getSplatValue(AmtMod)) { + int BaseRotAmtIdx = -1; + if (SDValue BaseRotAmt = DAG.getSplatSourceVector(AmtMod, BaseRotAmtIdx)) { unsigned ShiftX86Opc = IsROTL ? X86ISD::VSHLI : X86ISD::VSRLI; SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R)); SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R)); - BaseRotAmt = DAG.getZExtOrTrunc(BaseRotAmt, DL, MVT::i32); Lo = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Lo, BaseRotAmt, - Subtarget, DAG); + BaseRotAmtIdx, Subtarget, DAG); Hi = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Hi, BaseRotAmt, - Subtarget, DAG); + BaseRotAmtIdx, Subtarget, DAG); return getPack(DAG, Subtarget, DL, VT, Lo, Hi, IsROTL); } } @@ -31712,6 +31816,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::UMAX: case ISD::UMIN: return LowerMINMAX(Op, DAG); case ISD::ABS: return LowerABS(Op, Subtarget, DAG); + case ISD::AVGCEILU: return LowerAVG(Op, Subtarget, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, Subtarget, DAG); case ISD::MLOAD: return LowerMLOAD(Op, Subtarget, DAG); case ISD::MSTORE: return LowerMSTORE(Op, Subtarget, DAG); @@ -31807,9 +31912,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(Res); return; } - case X86ISD::VPMADDWD: - case X86ISD::AVG: { - // Legalize types for X86ISD::AVG/VPMADDWD by widening. + case X86ISD::VPMADDWD: { + // Legalize types for X86ISD::VPMADDWD by widening. assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); EVT VT = N->getValueType(0); @@ -33041,7 +33145,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(SCALEF_RND) NODE_NAME_CASE(SCALEFS) NODE_NAME_CASE(SCALEFS_RND) - NODE_NAME_CASE(AVG) NODE_NAME_CASE(MULHRS) NODE_NAME_CASE(SINT_TO_FP_RND) NODE_NAME_CASE(UINT_TO_FP_RND) @@ -33222,7 +33325,6 @@ bool X86TargetLowering::isBinOp(unsigned Opcode) const { bool X86TargetLowering::isCommutativeBinOp(unsigned Opcode) const { switch (Opcode) { // TODO: Add more X86ISD opcodes once we have test coverage. - case X86ISD::AVG: case X86ISD::PCMPEQ: case X86ISD::PMULDQ: case X86ISD::PMULUDQ: @@ -40632,7 +40734,6 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode( case X86ISD::UNPCKH: case X86ISD::BLENDI: // Integer ops. - case X86ISD::AVG: case X86ISD::PACKSS: case X86ISD::PACKUS: // Horizontal Ops. @@ -43123,6 +43224,104 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). +// This is more or less the reverse of combineBitcastvxi1. +static SDValue combineToExtendBoolVectorInReg( + unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND && + Opcode != ISD::ANY_EXTEND) + return SDValue(); + if (!DCI.isBeforeLegalizeOps()) + return SDValue(); + if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + return SDValue(); + + EVT SVT = VT.getScalarType(); + EVT InSVT = N0.getValueType().getScalarType(); + unsigned EltSizeInBits = SVT.getSizeInBits(); + + // Input type must be extending a bool vector (bit-casted from a scalar + // integer) to legal integer types. + if (!VT.isVector()) + return SDValue(); + if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8) + return SDValue(); + if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + EVT SclVT = N00.getValueType(); + if (!SclVT.isScalarInteger()) + return SDValue(); + + SDValue Vec; + SmallVector<int> ShuffleMask; + unsigned NumElts = VT.getVectorNumElements(); + assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size"); + + // Broadcast the scalar integer to the vector elements. + if (NumElts > EltSizeInBits) { + // If the scalar integer is greater than the vector element size, then we + // must split it down into sub-sections for broadcasting. For example: + // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections. + // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections. + assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale"); + unsigned Scale = NumElts / EltSizeInBits; + EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); + Vec = DAG.getBitcast(VT, Vec); + + for (unsigned i = 0; i != Scale; ++i) + ShuffleMask.append(EltSizeInBits, i); + Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); + } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits && + (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) { + // If we have register broadcast instructions, use the scalar size as the + // element type for the shuffle. Then cast to the wider element type. The + // widened bits won't be used, and this might allow the use of a broadcast + // load. + assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale"); + unsigned Scale = EltSizeInBits / NumElts; + EVT BroadcastVT = + EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); + ShuffleMask.append(NumElts * Scale, 0); + Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask); + Vec = DAG.getBitcast(VT, Vec); + } else { + // For smaller scalar integers, we can simply any-extend it to the vector + // element size (we don't care about the upper bits) and broadcast it to all + // elements. + SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl); + ShuffleMask.append(NumElts, 0); + Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); + } + + // Now, mask the relevant bit in each element. + SmallVector<SDValue, 32> Bits; + for (unsigned i = 0; i != NumElts; ++i) { + int BitIdx = (i % EltSizeInBits); + APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1); + Bits.push_back(DAG.getConstant(Bit, DL, SVT)); + } + SDValue BitMask = DAG.getBuildVector(VT, DL, Bits); + Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask); + + // Compare against the bitmask and extend the result. + EVT CCVT = VT.changeVectorElementType(MVT::i1); + Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ); + Vec = DAG.getSExtOrTrunc(Vec, DL, VT); + + // For SEXT, this is now done, otherwise shift the result down for + // zero-extension. + if (Opcode == ISD::SIGN_EXTEND) + return Vec; + return DAG.getNode(ISD::SRL, DL, VT, Vec, + DAG.getConstant(EltSizeInBits - 1, DL, VT)); +} + /// If a vector select has an operand that is -1 or 0, try to simplify the /// select to a bitwise logic operation. /// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()? @@ -43340,19 +43539,17 @@ static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) { /// This function will also call SimplifyDemandedBits on already created /// BLENDV to perform additional simplifications. static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { SDValue Cond = N->getOperand(0); if ((N->getOpcode() != ISD::VSELECT && N->getOpcode() != X86ISD::BLENDV) || ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) return SDValue(); - // Don't optimize before the condition has been transformed to a legal type - // and don't ever optimize vector selects that map to AVX512 mask-registers. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned BitWidth = Cond.getScalarValueSizeInBits(); - if (BitWidth < 8 || BitWidth > 64) - return SDValue(); + EVT VT = N->getValueType(0); // We can only handle the cases where VSELECT is directly legal on the // subtarget. We custom lower VSELECT nodes with constant conditions and @@ -43364,8 +43561,6 @@ static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG, // Potentially, we should combine constant-condition vselect nodes // pre-legalization into shuffles and not mark as many types as custom // lowered. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT VT = N->getValueType(0); if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT)) return SDValue(); // FIXME: We don't support i16-element blends currently. We could and @@ -43383,6 +43578,22 @@ static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG, if (VT.is512BitVector()) return SDValue(); + // PreAVX512, without mask-registers, attempt to sign-extend bool vectors to + // allow us to use BLENDV. + if (!Subtarget.hasAVX512() && BitWidth == 1) { + EVT CondVT = VT.changeVectorElementTypeToInteger(); + if (SDValue ExtCond = combineToExtendBoolVectorInReg( + ISD::SIGN_EXTEND, SDLoc(N), CondVT, Cond, DAG, DCI, Subtarget)) { + return DAG.getNode(X86ISD::BLENDV, SDLoc(N), VT, ExtCond, + N->getOperand(1), N->getOperand(2)); + } + } + + // Don't optimize before the condition has been transformed to a legal type + // and don't ever optimize vector selects that map to AVX512 mask-registers. + if (BitWidth < 8 || BitWidth > 64) + return SDValue(); + auto OnlyUsedAsSelectCond = [](SDValue Cond) { for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end(); UI != UE; ++UI) @@ -46876,30 +47087,44 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, // If either operand is a constant mask, then only the elements that aren't // zero are actually demanded by the other operand. - auto SimplifyUndemandedElts = [&](SDValue Op, SDValue OtherOp) { + auto GetDemandedMasks = [&](SDValue Op) { APInt UndefElts; SmallVector<APInt> EltBits; int NumElts = VT.getVectorNumElements(); int EltSizeInBits = VT.getScalarSizeInBits(); - if (!getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts, EltBits)) - return false; - - APInt DemandedBits = APInt::getZero(EltSizeInBits); - APInt DemandedElts = APInt::getZero(NumElts); - for (int I = 0; I != NumElts; ++I) - if (!EltBits[I].isZero()) { - DemandedBits |= EltBits[I]; - DemandedElts.setBit(I); - } - - return TLI.SimplifyDemandedVectorElts(OtherOp, DemandedElts, DCI) || - TLI.SimplifyDemandedBits(OtherOp, DemandedBits, DemandedElts, DCI); + APInt DemandedBits = APInt::getAllOnes(EltSizeInBits); + APInt DemandedElts = APInt::getAllOnes(NumElts); + if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts, + EltBits)) { + DemandedBits.clearAllBits(); + DemandedElts.clearAllBits(); + for (int I = 0; I != NumElts; ++I) + if (!EltBits[I].isZero()) { + DemandedBits |= EltBits[I]; + DemandedElts.setBit(I); + } + } + return std::make_pair(DemandedBits, DemandedElts); }; - if (SimplifyUndemandedElts(N0, N1) || SimplifyUndemandedElts(N1, N0)) { + std::pair<APInt, APInt> Demand0 = GetDemandedMasks(N1); + std::pair<APInt, APInt> Demand1 = GetDemandedMasks(N0); + + if (TLI.SimplifyDemandedVectorElts(N0, Demand0.second, DCI) || + TLI.SimplifyDemandedVectorElts(N1, Demand1.second, DCI) || + TLI.SimplifyDemandedBits(N0, Demand0.first, Demand0.second, DCI) || + TLI.SimplifyDemandedBits(N1, Demand1.first, Demand1.second, DCI)) { if (N->getOpcode() != ISD::DELETED_NODE) DCI.AddToWorklist(N); return SDValue(N, 0); } + + SDValue NewN0 = TLI.SimplifyMultipleUseDemandedBits(N0, Demand0.first, + Demand0.second, DAG); + SDValue NewN1 = TLI.SimplifyMultipleUseDemandedBits(N1, Demand1.first, + Demand1.second, DAG); + if (NewN0 || NewN1) + return DAG.getNode(ISD::AND, dl, VT, NewN0 ? NewN0 : N0, + NewN1 ? NewN1 : N1); } // Attempt to combine a scalar bitmask AND with an extracted shuffle. @@ -47679,7 +47904,7 @@ static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL, /// This function detects the AVG pattern between vectors of unsigned i8/i16, /// which is c = (a + b + 1) / 2, and replace this operation with the efficient -/// X86ISD::AVG instruction. +/// ISD::AVGCEILU (AVG) instruction. static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL) { @@ -47742,7 +47967,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { - return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops); + return DAG.getNode(ISD::AVGCEILU, DL, Ops[0].getValueType(), Ops); }; auto AVGSplitter = [&](std::array<SDValue, 2> Ops) { @@ -50113,26 +50338,62 @@ static SDValue combineCVTP2I_CVTTP2I(SDNode *N, SelectionDAG &DAG, static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); MVT VT = N->getSimpleValueType(0); // ANDNP(0, x) -> x - if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode())) - return N->getOperand(1); + if (ISD::isBuildVectorAllZeros(N0.getNode())) + return N1; // ANDNP(x, 0) -> 0 - if (ISD::isBuildVectorAllZeros(N->getOperand(1).getNode())) + if (ISD::isBuildVectorAllZeros(N1.getNode())) return DAG.getConstant(0, SDLoc(N), VT); // Turn ANDNP back to AND if input is inverted. - if (SDValue Not = IsNOT(N->getOperand(0), DAG)) - return DAG.getNode(ISD::AND, SDLoc(N), VT, DAG.getBitcast(VT, Not), - N->getOperand(1)); + if (SDValue Not = IsNOT(N0, DAG)) + return DAG.getNode(ISD::AND, SDLoc(N), VT, DAG.getBitcast(VT, Not), N1); // Attempt to recursively combine a bitmask ANDNP with shuffles. if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) { SDValue Op(N, 0); if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget)) return Res; + + // If either operand is a constant mask, then only the elements that aren't + // zero are actually demanded by the other operand. + auto GetDemandedMasks = [&](SDValue Op, bool Invert = false) { + APInt UndefElts; + SmallVector<APInt> EltBits; + int NumElts = VT.getVectorNumElements(); + int EltSizeInBits = VT.getScalarSizeInBits(); + APInt DemandedBits = APInt::getAllOnes(EltSizeInBits); + APInt DemandedElts = APInt::getAllOnes(NumElts); + if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts, + EltBits)) { + DemandedBits.clearAllBits(); + DemandedElts.clearAllBits(); + for (int I = 0; I != NumElts; ++I) + if ((Invert && !EltBits[I].isAllOnes()) || + (!Invert && !EltBits[I].isZero())) { + DemandedBits |= Invert ? ~EltBits[I] : EltBits[I]; + DemandedElts.setBit(I); + } + } + return std::make_pair(DemandedBits, DemandedElts); + }; + std::pair<APInt, APInt> Demand0 = GetDemandedMasks(N1); + std::pair<APInt, APInt> Demand1 = GetDemandedMasks(N0, true); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.SimplifyDemandedVectorElts(N0, Demand0.second, DCI) || + TLI.SimplifyDemandedVectorElts(N1, Demand1.second, DCI) || + TLI.SimplifyDemandedBits(N0, Demand0.first, Demand0.second, DCI) || + TLI.SimplifyDemandedBits(N1, Demand1.first, Demand1.second, DCI)) { + if (N->getOpcode() != ISD::DELETED_NODE) + DCI.AddToWorklist(N); + return SDValue(N, 0); + } } return SDValue(); @@ -50420,110 +50681,6 @@ static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) { return Res; } -// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). -// This is more or less the reverse of combineBitcastvxi1. -static SDValue -combineToExtendBoolVectorInReg(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - unsigned Opcode = N->getOpcode(); - if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND && - Opcode != ISD::ANY_EXTEND) - return SDValue(); - if (!DCI.isBeforeLegalizeOps()) - return SDValue(); - if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) - return SDValue(); - - SDValue N0 = N->getOperand(0); - EVT VT = N->getValueType(0); - EVT SVT = VT.getScalarType(); - EVT InSVT = N0.getValueType().getScalarType(); - unsigned EltSizeInBits = SVT.getSizeInBits(); - - // Input type must be extending a bool vector (bit-casted from a scalar - // integer) to legal integer types. - if (!VT.isVector()) - return SDValue(); - if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8) - return SDValue(); - if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST) - return SDValue(); - - SDValue N00 = N0.getOperand(0); - EVT SclVT = N0.getOperand(0).getValueType(); - if (!SclVT.isScalarInteger()) - return SDValue(); - - SDLoc DL(N); - SDValue Vec; - SmallVector<int, 32> ShuffleMask; - unsigned NumElts = VT.getVectorNumElements(); - assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size"); - - // Broadcast the scalar integer to the vector elements. - if (NumElts > EltSizeInBits) { - // If the scalar integer is greater than the vector element size, then we - // must split it down into sub-sections for broadcasting. For example: - // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections. - // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections. - assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale"); - unsigned Scale = NumElts / EltSizeInBits; - EVT BroadcastVT = - EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits); - Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); - Vec = DAG.getBitcast(VT, Vec); - - for (unsigned i = 0; i != Scale; ++i) - ShuffleMask.append(EltSizeInBits, i); - Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); - } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits && - (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) { - // If we have register broadcast instructions, use the scalar size as the - // element type for the shuffle. Then cast to the wider element type. The - // widened bits won't be used, and this might allow the use of a broadcast - // load. - assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale"); - unsigned Scale = EltSizeInBits / NumElts; - EVT BroadcastVT = - EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale); - Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); - ShuffleMask.append(NumElts * Scale, 0); - Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask); - Vec = DAG.getBitcast(VT, Vec); - } else { - // For smaller scalar integers, we can simply any-extend it to the vector - // element size (we don't care about the upper bits) and broadcast it to all - // elements. - SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT); - Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl); - ShuffleMask.append(NumElts, 0); - Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); - } - - // Now, mask the relevant bit in each element. - SmallVector<SDValue, 32> Bits; - for (unsigned i = 0; i != NumElts; ++i) { - int BitIdx = (i % EltSizeInBits); - APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1); - Bits.push_back(DAG.getConstant(Bit, DL, SVT)); - } - SDValue BitMask = DAG.getBuildVector(VT, DL, Bits); - Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask); - - // Compare against the bitmask and extend the result. - EVT CCVT = VT.changeVectorElementType(MVT::i1); - Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ); - Vec = DAG.getSExtOrTrunc(Vec, DL, VT); - - // For SEXT, this is now done, otherwise shift the result down for - // zero-extension. - if (Opcode == ISD::SIGN_EXTEND) - return Vec; - return DAG.getNode(ISD::SRL, DL, VT, Vec, - DAG.getConstant(EltSizeInBits - 1, DL, VT)); -} - // Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm // result type. static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG, @@ -50603,7 +50760,8 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineExtSetcc(N, DAG, Subtarget)) return V; - if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget)) + if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), DL, VT, N0, + DAG, DCI, Subtarget)) return V; if (VT.isVector()) { @@ -50757,7 +50915,8 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineExtSetcc(N, DAG, Subtarget)) return V; - if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget)) + if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), dl, VT, N0, + DAG, DCI, Subtarget)) return V; if (VT.isVector()) |