summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp11
1 files changed, 5 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 261a107c0e56..ee5a34bcc83b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -293,14 +293,13 @@ FailureOr<Value> BufferizationState::getBuffer(
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
Operation *op,
ValueRange values) {
+ assert(values.size() == op->getNumResults() &&
+ "expected one value per OpResult");
OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
+ SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
- // Skip OpResult if it has no uses.
- if (opResult.getUses().empty())
- continue;
-
Value replacement = values[opResult.getResultNumber()];
if (opResult.getType().isa<TensorType>()) {
// The OpResult is a tensor. Such values are replaced with memrefs during
@@ -315,10 +314,10 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
replacement = rewriter.create<bufferization::ToTensorOp>(
replacement.getLoc(), replacement);
}
- opResult.replaceAllUsesWith(replacement);
+ replacements.push_back(replacement);
}
- rewriter.eraseOp(op);
+ rewriter.replaceOp(op, replacements);
}
AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(