diff options
author | Christian Sigg <csigg@google.com> | 2022-01-11 18:00:44 +0100 |
---|---|---|
committer | Christian Sigg <csigg@google.com> | 2022-01-12 20:56:40 +0100 |
commit | be1aeb818cd9d4f329428a035604bebdd0c2f6e1 (patch) | |
tree | 8c0ddfd4b71a71661516695b4e1b650033f54f92 | |
parent | bf9c8636f2cd1c5e6338402b67de06f9ce74cdd9 (diff) |
Remove NaN constant from arith.minf, arith.maxf expansion
If any of the operands is NaN, return the operand instead of a new constant.
When the rhs operand is a constant, the second arith.cmpf+select ops will be folded away.
https://reviews.llvm.org/D117010 marks the two ops commutative, which will place the constant on the rhs.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D117011
-rw-r--r-- | mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp | 19 | ||||
-rw-r--r-- | mlir/test/Dialect/Arithmetic/expand-ops.mlir | 22 |
2 files changed, 17 insertions, 24 deletions
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp index d06c3043664d..d836ae5c84f5 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -156,19 +156,16 @@ public: Value rhs = op.getRhs(); Location loc = op.getLoc(); + // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). + static_assert(pred == arith::CmpFPredicate::UGT || + pred == arith::CmpFPredicate::ULT); Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs); - auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); + // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, - lhs, rhs); - - Value nan = rewriter.create<arith::ConstantFloatOp>( - loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); - if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>()) - nan = rewriter.create<SplatOp>(loc, vectorType, nan); - - rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select); + rhs, rhs); + rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, rhs, select); return success(); } }; @@ -226,8 +223,8 @@ void mlir::arith::populateArithmeticExpandOpsPatterns( CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter, - MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>, - MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>, + MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>, + MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>, MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>, MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>, diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir index 2f14178e88f2..f4a557a02b20 100644 --- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -154,11 +154,10 @@ func @maxf(%a: f32, %b: f32) -> f32 { return %result : f32 } // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 // ----- @@ -169,12 +168,10 @@ func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> { return %result : vector<4xf16> } // CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16> // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16 -// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16> -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]] +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] // CHECK-NEXT: return %[[RESULT]] : vector<4xf16> // ----- @@ -185,11 +182,10 @@ func @minf(%a: f32, %b: f32) -> f32 { return %result : f32 } // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 |