summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAbhishek Varma <abhishek.varma@cerebras.net>2022-02-03 17:09:51 +0100
committerAlex Zinenko <zinenko@google.com>2022-02-03 17:13:25 +0100
commit59b23c4aecccd8a0687c71b3afc55fd233b935c2 (patch)
tree3d0b8ea44b9386dcdc30ac8aa77d2a6b7d42ba84
parent42fc05e09c38460d149e8097a3cb1e1f481e7ac2 (diff)
[MLIR][SCF] Remove loop invariant arguments of scf.while
-- This commit adds a canonicalization pattern on scf.while to remove the loop invariant arguments. -- An argument is considered loop invariant if the iteration argument value is the same as the corresponding one being yielded (at the same position) in both the before/after block of scf.while. -- For the arguments removed, their use within scf.while and their corresponding scf.while's result are replaced with their corresponding initial value. Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com> Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D116923
-rw-r--r--mlir/lib/Dialect/SCF/SCF.cpp296
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir68
2 files changed, 362 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 28c50ba6382f..0e2759fef3bb 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -2343,6 +2343,297 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
}
};
+/// Remove loop invariant arguments from `before` block of scf.while.
+/// A before block argument is considered loop invariant if :-
+/// 1. i-th yield operand is equal to the i-th while operand.
+/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
+/// condition operand AND this (k+1)-th condition operand is equal to i-th
+/// iter argument/while operand.
+/// For the arguments which are removed, their uses inside scf.while
+/// are replaced with their corresponding initial value.
+///
+/// Eg:
+/// INPUT :-
+/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
+/// ..., %argN_before = %N)
+/// {
+/// ...
+/// scf.condition(%cond) %arg1_before, %arg0_before,
+/// %arg2_before, %arg0_before, ...
+/// } do {
+/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
+/// ..., %argK_after):
+/// ...
+/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
+/// }
+///
+/// OUTPUT :-
+/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
+/// %N)
+/// {
+/// ...
+/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
+/// } do {
+/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
+/// ..., %argK_after):
+/// ...
+/// scf.yield %arg1_after, ..., %argN
+/// }
+///
+/// EXPLANATION:
+/// We iterate over each yield operand.
+/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
+/// %arg0_before, which in turn is the 0-th iter argument. So we
+/// remove 0-th before block argument and yield operand, and replace
+/// all uses of the 0-th before block argument with its initial value
+/// %a.
+/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
+/// value. So we remove this operand and the corresponding before
+/// block argument and replace all uses of 1-th before block argument
+/// with %b.
+struct RemoveLoopInvariantArgsFromBeforeBlock
+ : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ Block &afterBlock = op.getAfter().front();
+ Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
+ ConditionOp condOp = op.getConditionOp();
+ OperandRange condOpArgs = condOp.getArgs();
+ Operation *yieldOp = afterBlock.getTerminator();
+ ValueRange yieldOpArgs = yieldOp->getOperands();
+
+ bool canSimplify = false;
+ for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
+ auto index = static_cast<unsigned>(it.index());
+ Value initVal, yieldOpArg;
+ std::tie(initVal, yieldOpArg) = it.value();
+ // If i-th yield operand is equal to the i-th operand of the scf.while,
+ // the i-th before block argument is a loop invariant.
+ if (yieldOpArg == initVal) {
+ canSimplify = true;
+ break;
+ }
+ // If the i-th yield operand is k-th after block argument, then we check
+ // if the (k+1)-th condition op operand is equal to either the i-th before
+ // block argument or the initial value of i-th before block argument. If
+ // the comparison results `true`, i-th before block argument is a loop
+ // invariant.
+ auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
+ if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
+ Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
+ if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
+ canSimplify = true;
+ break;
+ }
+ }
+ }
+
+ if (!canSimplify)
+ return failure();
+
+ SmallVector<Value> newInitArgs, newYieldOpArgs;
+ DenseMap<unsigned, Value> beforeBlockInitValMap;
+ SmallVector<Location> newBeforeBlockArgLocs;
+ for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
+ auto index = static_cast<unsigned>(it.index());
+ Value initVal, yieldOpArg;
+ std::tie(initVal, yieldOpArg) = it.value();
+
+ // If i-th yield operand is equal to the i-th operand of the scf.while,
+ // the i-th before block argument is a loop invariant.
+ if (yieldOpArg == initVal) {
+ beforeBlockInitValMap.insert({index, initVal});
+ continue;
+ } else {
+ // If the i-th yield operand is k-th after block argument, then we check
+ // if the (k+1)-th condition op operand is equal to either the i-th
+ // before block argument or the initial value of i-th before block
+ // argument. If the comparison results `true`, i-th before block
+ // argument is a loop invariant.
+ auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
+ if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
+ Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
+ if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
+ beforeBlockInitValMap.insert({index, initVal});
+ continue;
+ }
+ }
+ }
+ newInitArgs.emplace_back(initVal);
+ newYieldOpArgs.emplace_back(yieldOpArg);
+ newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
+ }
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
+ }
+
+ auto newWhile =
+ rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
+
+ Block &newBeforeBlock = *rewriter.createBlock(
+ &newWhile.getBefore(), /*insertPt*/ {},
+ ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
+
+ Block &beforeBlock = op.getBefore().front();
+ SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
+ // For each i-th before block argument we find it's replacement value as :-
+ // 1. If i-th before block argument is a loop invariant, we fetch it's
+ // initial value from `beforeBlockInitValMap` by querying for key `i`.
+ // 2. Else we fetch j-th new before block argument as the replacement
+ // value of i-th before block argument.
+ for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
+ // If the index 'i' argument was a loop invariant we fetch it's initial
+ // value from `beforeBlockInitValMap`.
+ if (beforeBlockInitValMap.count(i) != 0)
+ newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
+ else
+ newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
+ }
+
+ rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
+ rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
+ newWhile.getAfter().begin());
+
+ rewriter.replaceOp(op, newWhile.getResults());
+ return success();
+ }
+};
+
+/// Remove loop invariant value from result (condition op) of scf.while.
+/// A value is considered loop invariant if the final value yielded by
+/// scf.condition is defined outside of the `before` block. We remove the
+/// corresponding argument in `after` block and replace the use with the value.
+/// We also replace the use of the corresponding result of scf.while with the
+/// value.
+///
+/// Eg:
+/// INPUT :-
+/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
+/// %argN_before = %N) {
+/// ...
+/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
+/// } do {
+/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
+/// ...
+/// some_func(%arg1_after)
+/// ...
+/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
+/// }
+///
+/// OUTPUT :-
+/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
+/// ...
+/// scf.condition(%cond) %arg0, %arg1, ..., %argM
+/// } do {
+/// ^bb0(%arg0, %arg3, ..., %argM):
+/// ...
+/// some_func(%a)
+/// ...
+/// scf.yield %arg0, %b, ..., %argN
+/// }
+///
+/// EXPLANATION:
+/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
+/// before block of scf.while, so they get removed.
+/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
+/// replaced by %b.
+/// 3. The corresponding after block argument %arg1_after's uses are
+/// replaced by %a and %arg2_after's uses are replaced by %b.
+struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp op,
+ PatternRewriter &rewriter) const override {
+ Block &beforeBlock = op.getBefore().front();
+ ConditionOp condOp = op.getConditionOp();
+ OperandRange condOpArgs = condOp.getArgs();
+
+ bool canSimplify = false;
+ for (Value condOpArg : condOpArgs) {
+ // Those values not defined within `before` block will be considered as
+ // loop invariant values. We map the corresponding `index` with their
+ // value.
+ if (condOpArg.getParentBlock() != &beforeBlock) {
+ canSimplify = true;
+ break;
+ }
+ }
+
+ if (!canSimplify)
+ return failure();
+
+ Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
+
+ SmallVector<Value> newCondOpArgs;
+ SmallVector<Type> newAfterBlockType;
+ DenseMap<unsigned, Value> condOpInitValMap;
+ SmallVector<Location> newAfterBlockArgLocs;
+ for (auto it : llvm::enumerate(condOpArgs)) {
+ auto index = static_cast<unsigned>(it.index());
+ Value condOpArg = it.value();
+ // Those values not defined within `before` block will be considered as
+ // loop invariant values. We map the corresponding `index` with their
+ // value.
+ if (condOpArg.getParentBlock() != &beforeBlock) {
+ condOpInitValMap.insert({index, condOpArg});
+ } else {
+ newCondOpArgs.emplace_back(condOpArg);
+ newAfterBlockType.emplace_back(condOpArg.getType());
+ newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
+ }
+ }
+
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(condOp);
+ rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
+ newCondOpArgs);
+ }
+
+ auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
+ op.getOperands());
+
+ Block &newAfterBlock =
+ *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
+ newAfterBlockType, newAfterBlockArgLocs);
+
+ Block &afterBlock = op.getAfter().front();
+ // Since a new scf.condition op was created, we need to fetch the new
+ // `after` block arguments which will be used while replacing operations of
+ // previous scf.while's `after` blocks. We'd also be fetching new result
+ // values too.
+ SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
+ SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
+ for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
+ Value afterBlockArg, result;
+ // If index 'i' argument was loop invariant we fetch it's value from the
+ // `condOpInitMap` map.
+ if (condOpInitValMap.count(i) != 0) {
+ afterBlockArg = condOpInitValMap[i];
+ result = afterBlockArg;
+ } else {
+ afterBlockArg = newAfterBlock.getArgument(j);
+ result = newWhile.getResult(j);
+ j++;
+ }
+ newAfterBlockArgs[i] = afterBlockArg;
+ newWhileResults[i] = result;
+ }
+
+ rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
+ rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
+ newWhile.getBefore().begin());
+
+ rewriter.replaceOp(op, newWhileResults);
+ return success();
+ }
+};
+
/// Remove WhileOp results that are also unused in 'after' block.
///
/// %0:2 = scf.while () : () -> (i32, i64) {
@@ -2552,8 +2843,9 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
- WhileUnusedArg>(context);
+ results.insert<RemoveLoopInvariantArgsFromBeforeBlock,
+ RemoveLoopInvariantValueYielded, WhileConditionTruth,
+ WhileCmpCond, WhileUnusedResult>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index a9ad591d31ef..1563349ac24b 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -870,6 +870,74 @@ func @while_unused_arg(%x : i32, %y : f64) -> i32 {
// -----
+// CHECK-LABEL: @invariant_loop_args_in_same_order
+// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
+func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+ %cst_0 = arith.constant dense<0> : tensor<i32>
+ %cst_1 = arith.constant dense<1> : tensor<i32>
+ %cst_42 = arith.constant dense<42> : tensor<i32>
+
+ %0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+ %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
+ %2 = tensor.extract %1[] : tensor<i1>
+ scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+ } do {
+ ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
+ // %arg1 here will get replaced by %cst_1
+ %1 = arith.addi %arg0, %arg1 : tensor<i32>
+ %2 = arith.addi %arg2, %arg3 : tensor<i32>
+ scf.yield %1, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+ }
+ return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+}
+// CHECK: %[[CST42:.*]] = arith.constant dense<42>
+// CHECK: %[[ONE:.*]] = arith.constant dense<1>
+// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
+// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
+// CHECK: tensor.extract %{{.*}}[]
+// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
+// CHECK: } do {
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
+// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
+// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
+// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
+// CHECK: }
+// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
+
+// CHECK-LABEL: @while_loop_invariant_argument_different_order
+func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+ %cst_0 = arith.constant dense<0> : tensor<i32>
+ %cst_1 = arith.constant dense<1> : tensor<i32>
+ %cst_42 = arith.constant dense<42> : tensor<i32>
+
+ %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+ %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
+ %2 = tensor.extract %1[] : tensor<i1>
+ scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+ } do {
+ ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors
+ %1 = arith.addi %arg0, %cst_1 : tensor<i32>
+ %2 = arith.addi %arg2, %arg3 : tensor<i32>
+ scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+ }
+ return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+}
+// CHECK: %[[CST42:.*]] = arith.constant dense<42>
+// CHECK: %[[ONE:.*]] = arith.constant dense<1>
+// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
+// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
+// CHECK: tensor.extract %{{.*}}[]
+// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
+// CHECK: } do {
+// CHECK: ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
+// CHECK: scf.yield %[[ZERO]], %[[ONE]]
+// CHECK: }
+// CHECK: return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[ONE]], %[[WHILE]]#1
+
+// -----
+
// CHECK-LABEL: @while_unused_result
func @while_unused_result() -> i32 {
%0:2 = scf.while () : () -> (i32, i64) {