diff options
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
-rw-r--r-- | mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 261a107c0e56..ee5a34bcc83b 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -293,14 +293,13 @@ FailureOr<Value> BufferizationState::getBuffer( void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values) { + assert(values.size() == op->getNumResults() && + "expected one value per OpResult"); OpBuilder::InsertionGuard g(rewriter); // Replace all OpResults with the given values. + SmallVector<Value> replacements; for (OpResult opResult : op->getOpResults()) { - // Skip OpResult if it has no uses. - if (opResult.getUses().empty()) - continue; - Value replacement = values[opResult.getResultNumber()]; if (opResult.getType().isa<TensorType>()) { // The OpResult is a tensor. Such values are replaced with memrefs during @@ -315,10 +314,10 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, replacement = rewriter.create<bufferization::ToTensorOp>( replacement.getLoc(), replacement); } - opResult.replaceAllUsesWith(replacement); + replacements.push_back(replacement); } - rewriter.eraseOp(op); + rewriter.replaceOp(op, replacements); } AlwaysCopyBufferizationState::AlwaysCopyBufferizationState( |