summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp483
1 files changed, 321 insertions, 162 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 53c00affd70e..b98ac635e00d 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -949,6 +949,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHU, MVT::v8i16, Legal);
setOperationAction(ISD::MULHS, MVT::v8i16, Legal);
setOperationAction(ISD::MUL, MVT::v8i16, Legal);
+ setOperationAction(ISD::AVGCEILU, MVT::v16i8, Legal);
+ setOperationAction(ISD::AVGCEILU, MVT::v8i16, Legal);
setOperationAction(ISD::SMULO, MVT::v16i8, Custom);
setOperationAction(ISD::UMULO, MVT::v16i8, Custom);
@@ -1285,13 +1287,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
if (VT == MVT::v4i64) continue;
setOperationAction(ISD::ROTL, VT, Custom);
setOperationAction(ISD::ROTR, VT, Custom);
+ setOperationAction(ISD::FSHL, VT, Custom);
+ setOperationAction(ISD::FSHR, VT, Custom);
}
- setOperationAction(ISD::FSHL, MVT::v32i8, Custom);
- setOperationAction(ISD::FSHR, MVT::v32i8, Custom);
- setOperationAction(ISD::FSHL, MVT::v8i32, Custom);
- setOperationAction(ISD::FSHR, MVT::v8i32, Custom);
-
// These types need custom splitting if their input is a 128-bit vector.
setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom);
setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom);
@@ -1353,6 +1352,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHS, MVT::v16i16, HasInt256 ? Legal : Custom);
setOperationAction(ISD::MULHU, MVT::v32i8, Custom);
setOperationAction(ISD::MULHS, MVT::v32i8, Custom);
+ setOperationAction(ISD::AVGCEILU, MVT::v16i16, HasInt256 ? Legal : Custom);
+ setOperationAction(ISD::AVGCEILU, MVT::v32i8, HasInt256 ? Legal : Custom);
setOperationAction(ISD::SMULO, MVT::v32i8, Custom);
setOperationAction(ISD::UMULO, MVT::v32i8, Custom);
@@ -1652,6 +1653,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHU, MVT::v32i16, HasBWI ? Legal : Custom);
setOperationAction(ISD::MULHS, MVT::v64i8, Custom);
setOperationAction(ISD::MULHU, MVT::v64i8, Custom);
+ setOperationAction(ISD::AVGCEILU, MVT::v32i16, HasBWI ? Legal : Custom);
+ setOperationAction(ISD::AVGCEILU, MVT::v64i8, HasBWI ? Legal : Custom);
setOperationAction(ISD::SMULO, MVT::v64i8, Custom);
setOperationAction(ISD::UMULO, MVT::v64i8, Custom);
@@ -25700,6 +25703,89 @@ static SDValue getTargetVShiftByConstNode(unsigned Opc, const SDLoc &dl, MVT VT,
DAG.getTargetConstant(ShiftAmt, dl, MVT::i8));
}
+/// Handle vector element shifts by a splat shift amount
+static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT,
+ SDValue SrcOp, SDValue ShAmt, int ShAmtIdx,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ MVT AmtVT = ShAmt.getSimpleValueType();
+ assert(AmtVT.isVector() && "Vector shift type mismatch");
+ assert(0 <= ShAmtIdx && ShAmtIdx < (int)AmtVT.getVectorNumElements() &&
+ "Illegal vector splat index");
+
+ // Move the splat element to the bottom element.
+ if (ShAmtIdx != 0) {
+ SmallVector<int> Mask(AmtVT.getVectorNumElements(), -1);
+ Mask[0] = ShAmtIdx;
+ ShAmt = DAG.getVectorShuffle(AmtVT, dl, ShAmt, DAG.getUNDEF(AmtVT), Mask);
+ }
+
+ // See if we can mask off the upper elements using the existing source node.
+ // The shift uses the entire lower 64-bits of the amount vector, so no need to
+ // do this for vXi64 types.
+ bool IsMasked = false;
+ if (AmtVT.getScalarSizeInBits() < 64) {
+ if (ShAmt.getOpcode() == ISD::BUILD_VECTOR ||
+ ShAmt.getOpcode() == ISD::SCALAR_TO_VECTOR) {
+ // If the shift amount has come from a scalar, then zero-extend the scalar
+ // before moving to the vector.
+ ShAmt = DAG.getZExtOrTrunc(ShAmt.getOperand(0), dl, MVT::i32);
+ ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, ShAmt);
+ ShAmt = DAG.getNode(X86ISD::VZEXT_MOVL, dl, MVT::v4i32, ShAmt);
+ AmtVT = MVT::v4i32;
+ IsMasked = true;
+ } else if (ShAmt.getOpcode() == ISD::AND) {
+ // See if the shift amount is already masked (e.g. for rotation modulo),
+ // then we can zero-extend it by setting all the other mask elements to
+ // zero.
+ SmallVector<SDValue> MaskElts(
+ AmtVT.getVectorNumElements(),
+ DAG.getConstant(0, dl, AmtVT.getScalarType()));
+ MaskElts[0] = DAG.getAllOnesConstant(dl, AmtVT.getScalarType());
+ SDValue Mask = DAG.getBuildVector(AmtVT, dl, MaskElts);
+ if ((Mask = DAG.FoldConstantArithmetic(ISD::AND, dl, AmtVT,
+ {ShAmt.getOperand(1), Mask}))) {
+ ShAmt = DAG.getNode(ISD::AND, dl, AmtVT, ShAmt.getOperand(0), Mask);
+ IsMasked = true;
+ }
+ }
+ }
+
+ // Extract if the shift amount vector is larger than 128-bits.
+ if (AmtVT.getSizeInBits() > 128) {
+ ShAmt = extract128BitVector(ShAmt, 0, DAG, dl);
+ AmtVT = ShAmt.getSimpleValueType();
+ }
+
+ // Zero-extend bottom element to v2i64 vector type, either by extension or
+ // shuffle masking.
+ if (!IsMasked && AmtVT.getScalarSizeInBits() < 64) {
+ if (Subtarget.hasSSE41())
+ ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt),
+ MVT::v2i64, ShAmt);
+ else {
+ SDValue ByteShift = DAG.getTargetConstant(
+ (128 - AmtVT.getScalarSizeInBits()) / 8, SDLoc(ShAmt), MVT::i8);
+ ShAmt = DAG.getBitcast(MVT::v16i8, ShAmt);
+ ShAmt = DAG.getNode(X86ISD::VSHLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt,
+ ByteShift);
+ ShAmt = DAG.getNode(X86ISD::VSRLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt,
+ ByteShift);
+ }
+ }
+
+ // Change opcode to non-immediate version.
+ Opc = getTargetVShiftUniformOpcode(Opc, true);
+
+ // The return type has to be a 128-bit type with the same element
+ // type as the input type.
+ MVT EltVT = VT.getVectorElementType();
+ MVT ShVT = MVT::getVectorVT(EltVT, 128 / EltVT.getSizeInBits());
+
+ ShAmt = DAG.getBitcast(ShVT, ShAmt);
+ return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt);
+}
+
/// Handle vector element shifts where the shift amount may or may not be a
/// constant. Takes immediate version of shift as input.
/// TODO: Replace with vector + (splat) idx to avoid extract_element nodes.
@@ -26444,6 +26530,8 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case VSHIFT: {
SDValue SrcOp = Op.getOperand(1);
SDValue ShAmt = Op.getOperand(2);
+ assert(ShAmt.getValueType() == MVT::i32 &&
+ "Unexpected VSHIFT amount type");
// Catch shift-by-constant.
if (auto *CShAmt = dyn_cast<ConstantSDNode>(ShAmt))
@@ -26451,8 +26539,9 @@ SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Op.getSimpleValueType(), SrcOp,
CShAmt->getZExtValue(), DAG);
+ ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, ShAmt);
return getTargetVShiftNode(IntrData->Opc0, dl, Op.getSimpleValueType(),
- SrcOp, ShAmt, Subtarget, DAG);
+ SrcOp, ShAmt, 0, Subtarget, DAG);
}
case COMPRESS_EXPAND_IN_REG: {
SDValue Mask = Op.getOperand(3);
@@ -28394,6 +28483,21 @@ static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
return SDValue();
}
+static SDValue LowerAVG(SDValue Op, const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ MVT VT = Op.getSimpleValueType();
+
+ // For AVX1 cases, split to use legal ops (everything but v4i64).
+ if (VT.is256BitVector() && !Subtarget.hasInt256())
+ return splitVectorIntBinary(Op, DAG);
+
+ if (VT == MVT::v32i16 || VT == MVT::v64i8)
+ return splitVectorIntBinary(Op, DAG);
+
+ // Default to expand.
+ return SDValue();
+}
+
static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
@@ -29843,8 +29947,8 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
{Op0, Op1, Amt}, DAG, Subtarget);
}
assert((VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 ||
- VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v8i32 ||
- VT == MVT::v16i32) &&
+ VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v4i32 ||
+ VT == MVT::v8i32 || VT == MVT::v16i32) &&
"Unexpected funnel shift type!");
// fshl(x,y,z) -> unpack(y,x) << (z & (bw-1))) >> bw.
@@ -29867,7 +29971,7 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
// Split 256-bit integers on XOP/pre-AVX2 targets.
// Split 512-bit integers on non 512-bit BWI targets.
- if ((VT.is256BitVector() && ((Subtarget.hasXOP() && EltSizeInBits < 32) ||
+ if ((VT.is256BitVector() && ((Subtarget.hasXOP() && EltSizeInBits < 16) ||
!Subtarget.hasAVX2())) ||
(VT.is512BitVector() && !Subtarget.useBWIRegs() &&
EltSizeInBits < 32)) {
@@ -29878,18 +29982,18 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
// Attempt to fold scalar shift as unpack(y,x) << zext(splat(z))
if (supportedVectorShiftWithBaseAmnt(ExtVT, Subtarget, ShiftOpc)) {
- if (SDValue ScalarAmt = DAG.getSplatValue(AmtMod)) {
+ int ScalarAmtIdx = -1;
+ if (SDValue ScalarAmt = DAG.getSplatSourceVector(AmtMod, ScalarAmtIdx)) {
// Uniform vXi16 funnel shifts can be efficiently handled by default.
if (EltSizeInBits == 16)
return SDValue();
SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0));
SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0));
- ScalarAmt = DAG.getZExtOrTrunc(ScalarAmt, DL, MVT::i32);
- Lo = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Lo, ScalarAmt, Subtarget,
- DAG);
- Hi = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Hi, ScalarAmt, Subtarget,
- DAG);
+ Lo = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Lo, ScalarAmt,
+ ScalarAmtIdx, Subtarget, DAG);
+ Hi = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Hi, ScalarAmt,
+ ScalarAmtIdx, Subtarget, DAG);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR);
}
}
@@ -30082,15 +30186,15 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// TODO: Handle vXi16 cases on all targets.
if (EltSizeInBits == 8 || EltSizeInBits == 32 ||
(IsROTL && EltSizeInBits == 16 && !Subtarget.hasAVX())) {
- if (SDValue BaseRotAmt = DAG.getSplatValue(AmtMod)) {
+ int BaseRotAmtIdx = -1;
+ if (SDValue BaseRotAmt = DAG.getSplatSourceVector(AmtMod, BaseRotAmtIdx)) {
unsigned ShiftX86Opc = IsROTL ? X86ISD::VSHLI : X86ISD::VSRLI;
SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R));
SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R));
- BaseRotAmt = DAG.getZExtOrTrunc(BaseRotAmt, DL, MVT::i32);
Lo = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Lo, BaseRotAmt,
- Subtarget, DAG);
+ BaseRotAmtIdx, Subtarget, DAG);
Hi = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Hi, BaseRotAmt,
- Subtarget, DAG);
+ BaseRotAmtIdx, Subtarget, DAG);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, IsROTL);
}
}
@@ -31712,6 +31816,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::UMAX:
case ISD::UMIN: return LowerMINMAX(Op, DAG);
case ISD::ABS: return LowerABS(Op, Subtarget, DAG);
+ case ISD::AVGCEILU: return LowerAVG(Op, Subtarget, DAG);
case ISD::FSINCOS: return LowerFSINCOS(Op, Subtarget, DAG);
case ISD::MLOAD: return LowerMLOAD(Op, Subtarget, DAG);
case ISD::MSTORE: return LowerMSTORE(Op, Subtarget, DAG);
@@ -31807,9 +31912,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(Res);
return;
}
- case X86ISD::VPMADDWD:
- case X86ISD::AVG: {
- // Legalize types for X86ISD::AVG/VPMADDWD by widening.
+ case X86ISD::VPMADDWD: {
+ // Legalize types for X86ISD::VPMADDWD by widening.
assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
EVT VT = N->getValueType(0);
@@ -33041,7 +33145,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SCALEF_RND)
NODE_NAME_CASE(SCALEFS)
NODE_NAME_CASE(SCALEFS_RND)
- NODE_NAME_CASE(AVG)
NODE_NAME_CASE(MULHRS)
NODE_NAME_CASE(SINT_TO_FP_RND)
NODE_NAME_CASE(UINT_TO_FP_RND)
@@ -33222,7 +33325,6 @@ bool X86TargetLowering::isBinOp(unsigned Opcode) const {
bool X86TargetLowering::isCommutativeBinOp(unsigned Opcode) const {
switch (Opcode) {
// TODO: Add more X86ISD opcodes once we have test coverage.
- case X86ISD::AVG:
case X86ISD::PCMPEQ:
case X86ISD::PMULDQ:
case X86ISD::PMULUDQ:
@@ -40632,7 +40734,6 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
case X86ISD::UNPCKH:
case X86ISD::BLENDI:
// Integer ops.
- case X86ISD::AVG:
case X86ISD::PACKSS:
case X86ISD::PACKUS:
// Horizontal Ops.
@@ -43123,6 +43224,104 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
+// This is more or less the reverse of combineBitcastvxi1.
+static SDValue combineToExtendBoolVectorInReg(
+ unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) {
+ if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
+ Opcode != ISD::ANY_EXTEND)
+ return SDValue();
+ if (!DCI.isBeforeLegalizeOps())
+ return SDValue();
+ if (!Subtarget.hasSSE2() || Subtarget.hasAVX512())
+ return SDValue();
+
+ EVT SVT = VT.getScalarType();
+ EVT InSVT = N0.getValueType().getScalarType();
+ unsigned EltSizeInBits = SVT.getSizeInBits();
+
+ // Input type must be extending a bool vector (bit-casted from a scalar
+ // integer) to legal integer types.
+ if (!VT.isVector())
+ return SDValue();
+ if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
+ return SDValue();
+ if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ EVT SclVT = N00.getValueType();
+ if (!SclVT.isScalarInteger())
+ return SDValue();
+
+ SDValue Vec;
+ SmallVector<int> ShuffleMask;
+ unsigned NumElts = VT.getVectorNumElements();
+ assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");
+
+ // Broadcast the scalar integer to the vector elements.
+ if (NumElts > EltSizeInBits) {
+ // If the scalar integer is greater than the vector element size, then we
+ // must split it down into sub-sections for broadcasting. For example:
+ // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
+ // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
+ assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
+ unsigned Scale = NumElts / EltSizeInBits;
+ EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
+ Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
+ Vec = DAG.getBitcast(VT, Vec);
+
+ for (unsigned i = 0; i != Scale; ++i)
+ ShuffleMask.append(EltSizeInBits, i);
+ Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
+ } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits &&
+ (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) {
+ // If we have register broadcast instructions, use the scalar size as the
+ // element type for the shuffle. Then cast to the wider element type. The
+ // widened bits won't be used, and this might allow the use of a broadcast
+ // load.
+ assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale");
+ unsigned Scale = EltSizeInBits / NumElts;
+ EVT BroadcastVT =
+ EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale);
+ Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
+ ShuffleMask.append(NumElts * Scale, 0);
+ Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask);
+ Vec = DAG.getBitcast(VT, Vec);
+ } else {
+ // For smaller scalar integers, we can simply any-extend it to the vector
+ // element size (we don't care about the upper bits) and broadcast it to all
+ // elements.
+ SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT);
+ Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
+ ShuffleMask.append(NumElts, 0);
+ Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
+ }
+
+ // Now, mask the relevant bit in each element.
+ SmallVector<SDValue, 32> Bits;
+ for (unsigned i = 0; i != NumElts; ++i) {
+ int BitIdx = (i % EltSizeInBits);
+ APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
+ Bits.push_back(DAG.getConstant(Bit, DL, SVT));
+ }
+ SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
+ Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);
+
+ // Compare against the bitmask and extend the result.
+ EVT CCVT = VT.changeVectorElementType(MVT::i1);
+ Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
+ Vec = DAG.getSExtOrTrunc(Vec, DL, VT);
+
+ // For SEXT, this is now done, otherwise shift the result down for
+ // zero-extension.
+ if (Opcode == ISD::SIGN_EXTEND)
+ return Vec;
+ return DAG.getNode(ISD::SRL, DL, VT, Vec,
+ DAG.getConstant(EltSizeInBits - 1, DL, VT));
+}
+
/// If a vector select has an operand that is -1 or 0, try to simplify the
/// select to a bitwise logic operation.
/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
@@ -43340,19 +43539,17 @@ static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) {
/// This function will also call SimplifyDemandedBits on already created
/// BLENDV to perform additional simplifications.
static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
SDValue Cond = N->getOperand(0);
if ((N->getOpcode() != ISD::VSELECT &&
N->getOpcode() != X86ISD::BLENDV) ||
ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()))
return SDValue();
- // Don't optimize before the condition has been transformed to a legal type
- // and don't ever optimize vector selects that map to AVX512 mask-registers.
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
unsigned BitWidth = Cond.getScalarValueSizeInBits();
- if (BitWidth < 8 || BitWidth > 64)
- return SDValue();
+ EVT VT = N->getValueType(0);
// We can only handle the cases where VSELECT is directly legal on the
// subtarget. We custom lower VSELECT nodes with constant conditions and
@@ -43364,8 +43561,6 @@ static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG,
// Potentially, we should combine constant-condition vselect nodes
// pre-legalization into shuffles and not mark as many types as custom
// lowered.
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- EVT VT = N->getValueType(0);
if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
return SDValue();
// FIXME: We don't support i16-element blends currently. We could and
@@ -43383,6 +43578,22 @@ static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG,
if (VT.is512BitVector())
return SDValue();
+ // PreAVX512, without mask-registers, attempt to sign-extend bool vectors to
+ // allow us to use BLENDV.
+ if (!Subtarget.hasAVX512() && BitWidth == 1) {
+ EVT CondVT = VT.changeVectorElementTypeToInteger();
+ if (SDValue ExtCond = combineToExtendBoolVectorInReg(
+ ISD::SIGN_EXTEND, SDLoc(N), CondVT, Cond, DAG, DCI, Subtarget)) {
+ return DAG.getNode(X86ISD::BLENDV, SDLoc(N), VT, ExtCond,
+ N->getOperand(1), N->getOperand(2));
+ }
+ }
+
+ // Don't optimize before the condition has been transformed to a legal type
+ // and don't ever optimize vector selects that map to AVX512 mask-registers.
+ if (BitWidth < 8 || BitWidth > 64)
+ return SDValue();
+
auto OnlyUsedAsSelectCond = [](SDValue Cond) {
for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end();
UI != UE; ++UI)
@@ -46876,30 +47087,44 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
// If either operand is a constant mask, then only the elements that aren't
// zero are actually demanded by the other operand.
- auto SimplifyUndemandedElts = [&](SDValue Op, SDValue OtherOp) {
+ auto GetDemandedMasks = [&](SDValue Op) {
APInt UndefElts;
SmallVector<APInt> EltBits;
int NumElts = VT.getVectorNumElements();
int EltSizeInBits = VT.getScalarSizeInBits();
- if (!getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts, EltBits))
- return false;
-
- APInt DemandedBits = APInt::getZero(EltSizeInBits);
- APInt DemandedElts = APInt::getZero(NumElts);
- for (int I = 0; I != NumElts; ++I)
- if (!EltBits[I].isZero()) {
- DemandedBits |= EltBits[I];
- DemandedElts.setBit(I);
- }
-
- return TLI.SimplifyDemandedVectorElts(OtherOp, DemandedElts, DCI) ||
- TLI.SimplifyDemandedBits(OtherOp, DemandedBits, DemandedElts, DCI);
+ APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
+ APInt DemandedElts = APInt::getAllOnes(NumElts);
+ if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts,
+ EltBits)) {
+ DemandedBits.clearAllBits();
+ DemandedElts.clearAllBits();
+ for (int I = 0; I != NumElts; ++I)
+ if (!EltBits[I].isZero()) {
+ DemandedBits |= EltBits[I];
+ DemandedElts.setBit(I);
+ }
+ }
+ return std::make_pair(DemandedBits, DemandedElts);
};
- if (SimplifyUndemandedElts(N0, N1) || SimplifyUndemandedElts(N1, N0)) {
+ std::pair<APInt, APInt> Demand0 = GetDemandedMasks(N1);
+ std::pair<APInt, APInt> Demand1 = GetDemandedMasks(N0);
+
+ if (TLI.SimplifyDemandedVectorElts(N0, Demand0.second, DCI) ||
+ TLI.SimplifyDemandedVectorElts(N1, Demand1.second, DCI) ||
+ TLI.SimplifyDemandedBits(N0, Demand0.first, Demand0.second, DCI) ||
+ TLI.SimplifyDemandedBits(N1, Demand1.first, Demand1.second, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
return SDValue(N, 0);
}
+
+ SDValue NewN0 = TLI.SimplifyMultipleUseDemandedBits(N0, Demand0.first,
+ Demand0.second, DAG);
+ SDValue NewN1 = TLI.SimplifyMultipleUseDemandedBits(N1, Demand1.first,
+ Demand1.second, DAG);
+ if (NewN0 || NewN1)
+ return DAG.getNode(ISD::AND, dl, VT, NewN0 ? NewN0 : N0,
+ NewN1 ? NewN1 : N1);
}
// Attempt to combine a scalar bitmask AND with an extracted shuffle.
@@ -47679,7 +47904,7 @@ static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL,
/// This function detects the AVG pattern between vectors of unsigned i8/i16,
/// which is c = (a + b + 1) / 2, and replace this operation with the efficient
-/// X86ISD::AVG instruction.
+/// ISD::AVGCEILU (AVG) instruction.
static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
const SDLoc &DL) {
@@ -47742,7 +47967,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
auto AVGBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
- return DAG.getNode(X86ISD::AVG, DL, Ops[0].getValueType(), Ops);
+ return DAG.getNode(ISD::AVGCEILU, DL, Ops[0].getValueType(), Ops);
};
auto AVGSplitter = [&](std::array<SDValue, 2> Ops) {
@@ -50113,26 +50338,62 @@ static SDValue combineCVTP2I_CVTTP2I(SDNode *N, SelectionDAG &DAG,
static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
MVT VT = N->getSimpleValueType(0);
// ANDNP(0, x) -> x
- if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode()))
- return N->getOperand(1);
+ if (ISD::isBuildVectorAllZeros(N0.getNode()))
+ return N1;
// ANDNP(x, 0) -> 0
- if (ISD::isBuildVectorAllZeros(N->getOperand(1).getNode()))
+ if (ISD::isBuildVectorAllZeros(N1.getNode()))
return DAG.getConstant(0, SDLoc(N), VT);
// Turn ANDNP back to AND if input is inverted.
- if (SDValue Not = IsNOT(N->getOperand(0), DAG))
- return DAG.getNode(ISD::AND, SDLoc(N), VT, DAG.getBitcast(VT, Not),
- N->getOperand(1));
+ if (SDValue Not = IsNOT(N0, DAG))
+ return DAG.getNode(ISD::AND, SDLoc(N), VT, DAG.getBitcast(VT, Not), N1);
// Attempt to recursively combine a bitmask ANDNP with shuffles.
if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) {
SDValue Op(N, 0);
if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
return Res;
+
+ // If either operand is a constant mask, then only the elements that aren't
+ // zero are actually demanded by the other operand.
+ auto GetDemandedMasks = [&](SDValue Op, bool Invert = false) {
+ APInt UndefElts;
+ SmallVector<APInt> EltBits;
+ int NumElts = VT.getVectorNumElements();
+ int EltSizeInBits = VT.getScalarSizeInBits();
+ APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
+ APInt DemandedElts = APInt::getAllOnes(NumElts);
+ if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts,
+ EltBits)) {
+ DemandedBits.clearAllBits();
+ DemandedElts.clearAllBits();
+ for (int I = 0; I != NumElts; ++I)
+ if ((Invert && !EltBits[I].isAllOnes()) ||
+ (!Invert && !EltBits[I].isZero())) {
+ DemandedBits |= Invert ? ~EltBits[I] : EltBits[I];
+ DemandedElts.setBit(I);
+ }
+ }
+ return std::make_pair(DemandedBits, DemandedElts);
+ };
+ std::pair<APInt, APInt> Demand0 = GetDemandedMasks(N1);
+ std::pair<APInt, APInt> Demand1 = GetDemandedMasks(N0, true);
+
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (TLI.SimplifyDemandedVectorElts(N0, Demand0.second, DCI) ||
+ TLI.SimplifyDemandedVectorElts(N1, Demand1.second, DCI) ||
+ TLI.SimplifyDemandedBits(N0, Demand0.first, Demand0.second, DCI) ||
+ TLI.SimplifyDemandedBits(N1, Demand1.first, Demand1.second, DCI)) {
+ if (N->getOpcode() != ISD::DELETED_NODE)
+ DCI.AddToWorklist(N);
+ return SDValue(N, 0);
+ }
}
return SDValue();
@@ -50420,110 +50681,6 @@ static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) {
return Res;
}
-// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
-// This is more or less the reverse of combineBitcastvxi1.
-static SDValue
-combineToExtendBoolVectorInReg(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const X86Subtarget &Subtarget) {
- unsigned Opcode = N->getOpcode();
- if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
- Opcode != ISD::ANY_EXTEND)
- return SDValue();
- if (!DCI.isBeforeLegalizeOps())
- return SDValue();
- if (!Subtarget.hasSSE2() || Subtarget.hasAVX512())
- return SDValue();
-
- SDValue N0 = N->getOperand(0);
- EVT VT = N->getValueType(0);
- EVT SVT = VT.getScalarType();
- EVT InSVT = N0.getValueType().getScalarType();
- unsigned EltSizeInBits = SVT.getSizeInBits();
-
- // Input type must be extending a bool vector (bit-casted from a scalar
- // integer) to legal integer types.
- if (!VT.isVector())
- return SDValue();
- if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
- return SDValue();
- if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
- return SDValue();
-
- SDValue N00 = N0.getOperand(0);
- EVT SclVT = N0.getOperand(0).getValueType();
- if (!SclVT.isScalarInteger())
- return SDValue();
-
- SDLoc DL(N);
- SDValue Vec;
- SmallVector<int, 32> ShuffleMask;
- unsigned NumElts = VT.getVectorNumElements();
- assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");
-
- // Broadcast the scalar integer to the vector elements.
- if (NumElts > EltSizeInBits) {
- // If the scalar integer is greater than the vector element size, then we
- // must split it down into sub-sections for broadcasting. For example:
- // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
- // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
- assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
- unsigned Scale = NumElts / EltSizeInBits;
- EVT BroadcastVT =
- EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
- Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
- Vec = DAG.getBitcast(VT, Vec);
-
- for (unsigned i = 0; i != Scale; ++i)
- ShuffleMask.append(EltSizeInBits, i);
- Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
- } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits &&
- (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) {
- // If we have register broadcast instructions, use the scalar size as the
- // element type for the shuffle. Then cast to the wider element type. The
- // widened bits won't be used, and this might allow the use of a broadcast
- // load.
- assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale");
- unsigned Scale = EltSizeInBits / NumElts;
- EVT BroadcastVT =
- EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale);
- Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
- ShuffleMask.append(NumElts * Scale, 0);
- Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask);
- Vec = DAG.getBitcast(VT, Vec);
- } else {
- // For smaller scalar integers, we can simply any-extend it to the vector
- // element size (we don't care about the upper bits) and broadcast it to all
- // elements.
- SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT);
- Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
- ShuffleMask.append(NumElts, 0);
- Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
- }
-
- // Now, mask the relevant bit in each element.
- SmallVector<SDValue, 32> Bits;
- for (unsigned i = 0; i != NumElts; ++i) {
- int BitIdx = (i % EltSizeInBits);
- APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
- Bits.push_back(DAG.getConstant(Bit, DL, SVT));
- }
- SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
- Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);
-
- // Compare against the bitmask and extend the result.
- EVT CCVT = VT.changeVectorElementType(MVT::i1);
- Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
- Vec = DAG.getSExtOrTrunc(Vec, DL, VT);
-
- // For SEXT, this is now done, otherwise shift the result down for
- // zero-extension.
- if (Opcode == ISD::SIGN_EXTEND)
- return Vec;
- return DAG.getNode(ISD::SRL, DL, VT, Vec,
- DAG.getConstant(EltSizeInBits - 1, DL, VT));
-}
-
// Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm
// result type.
static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG,
@@ -50603,7 +50760,8 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineExtSetcc(N, DAG, Subtarget))
return V;
- if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget))
+ if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), DL, VT, N0,
+ DAG, DCI, Subtarget))
return V;
if (VT.isVector()) {
@@ -50757,7 +50915,8 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineExtSetcc(N, DAG, Subtarget))
return V;
- if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget))
+ if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), dl, VT, N0,
+ DAG, DCI, Subtarget))
return V;
if (VT.isVector())