summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp24
1 files changed, 11 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7604d14eb7d1..86ef1210c747 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,7 +15,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
using namespace mlir::linalg;
@@ -484,15 +484,15 @@ SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
/// are used within an AffineExpr.
struct HasAffineDimExprVisitor
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
- HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
- : positions(positions) {}
+ HasAffineDimExprVisitor(llvm::SmallBitVector positions)
+ : positions(std::move(positions)) {}
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
}
bool visitDimExpr(AffineDimExpr dimExpr) {
- return positions.count(dimExpr.getPosition());
+ return positions.test(dimExpr.getPosition());
}
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
@@ -500,7 +500,7 @@ struct HasAffineDimExprVisitor
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
private:
- llvm::SmallSet<unsigned, 4> positions;
+ llvm::SmallBitVector positions;
};
LogicalResult
@@ -523,19 +523,17 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
/// From loopsToShapesMap extract the submap that represents the shape of the
/// (resultIdx, dim) needed.
- SmallVector<unsigned, 4> resultPosRange =
- llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
- resultShapesSubMapPos.second));
- AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
+ AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
+ resultShapesSubMapPos.first,
+ resultShapesSubMapPos.second - resultShapesSubMapPos.first);
AffineMap resultShapesFromInputShapesMap =
loopToResultsShapeMap.compose(getShapesToLoopsMap());
// Check that the result dim map does not contain the positions corresponding
// to the outputs.
- llvm::SmallSet<unsigned, 4> outputDims;
- llvm::for_each(resultPosRange,
- [&outputDims](unsigned dim) { outputDims.insert(dim); });
- HasAffineDimExprVisitor checkDimExpr(outputDims);
+ llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
+ outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
+ HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
Location loc = getOperation()->getLoc();
auto allResultDimValues =
applyMapToValues(b, loc, resultShapesFromInputShapesMap,