summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDiego Caballero <diegocaballero@google.com>2022-03-10 20:44:24 +0000
committerDiego Caballero <diegocaballero@google.com>2022-03-10 22:33:14 +0000
commitf71f9958b9845878909e005c67970e48b300f991 (patch)
treea76b5edd0ac690363b2ebdc80cfce1a6fad1a407
parent3c9e8499435a4ecd197dcae23e9de6da914057d2 (diff)
[mlir][Vector] Modernize default lowering of vector transpose
This patch removes an old recursive implementation to lower vector.transpose to extract/insert operations and replaces it with a iterative approach that leverages newer linearization/delinearization utilities. The patch should be NFC except by the order in which the extract/insert ops are generated. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D121321
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Utils/Utils.h13
-rw-r--r--mlir/include/mlir/Dialect/Utils/IndexingUtils.h13
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp85
-rw-r--r--mlir/test/Dialect/Vector/vector-transpose-lowering.mlir16
8 files changed, 63 insertions, 68 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index e53b25cc22f2..4f991588d1d4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -32,19 +32,6 @@ class LinalgDependenceGraph;
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
-/// Apply the permutation defined by `permutation` to `inVec`.
-/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
-/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
-/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
-template <typename T, unsigned N>
-void applyPermutationToVector(SmallVector<T, N> &inVec,
- ArrayRef<int64_t> permutation) {
- SmallVector<T, N> auxVec(inVec.size());
- for (const auto &en : enumerate(permutation))
- auxVec[en.index()] = inVec[en.value()];
- inVec = auxVec;
-}
-
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 4678ce064815..3f2dd00c696f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -30,6 +30,19 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
int64_t linearIndex);
+/// Apply the permutation defined by `permutation` to `inVec`.
+/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
+/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
+/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
+template <typename T, unsigned N>
+void applyPermutationToVector(SmallVector<T, N> &inVec,
+ ArrayRef<int64_t> permutation) {
+ SmallVector<T, N> auxVec(inVec.size());
+ for (const auto &en : enumerate(permutation))
+ auxVec[en.index()] = inVec[en.value()];
+ inVec = auxVec;
+}
+
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4297a83005fe..1d46657018b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index f1cd988d1fd7..907ffa3be4b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AsmState.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 452bf9d30ee4..4ce38530fe1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 2e1418c529a2..4f863298ba42 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Transforms/FoldUtils.h"
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2b22412d6fc3..ab353326093a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -300,16 +301,18 @@ public:
}
};
-/// Return the number of leftmost dimensions from the first rightmost transposed
-/// dimension found in 'transpose'.
-size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) {
+/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
+/// transposed.
+void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
+ SmallVectorImpl<int64_t> &result) {
size_t numTransposedDims = transpose.size();
for (size_t transpDim : llvm::reverse(transpose)) {
if (transpDim != numTransposedDims - 1)
break;
numTransposedDims--;
}
- return numTransposedDims;
+
+ result.append(transpose.begin(), transpose.begin() + numTransposedDims);
}
/// Progressive lowering of TransposeOp.
@@ -334,6 +337,8 @@ public:
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
+ Value input = op.vector();
+ VectorType inputType = op.getVectorType();
VectorType resType = op.getResultType();
// Set up convenience transposition table.
@@ -354,7 +359,7 @@ public:
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
auto matrix =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
+ rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
Value trans = rewriter.create<vector::FlatTransposeOp>(
@@ -365,54 +370,40 @@ public:
// Generate unrolled extract/insert ops. We do not unroll the rightmost
// (i.e., highest-order) dimensions that are not transposed and leave them
- // in vector form to improve performance.
- size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp);
-
- // The type of the extract operation will be scalar if all the dimensions
- // are unrolled. Otherwise, it will be a vector with the shape of the
- // dimensions that are not transposed.
- Type extractType =
- numLeftmostTransposedDims == transp.size()
- ? resType.getElementType()
- : VectorType::Builder(resType).setShape(
- resType.getShape().drop_front(numLeftmostTransposedDims));
-
+ // in vector form to improve performance. Therefore, we prune those
+ // dimensions from the shape/transpose data structures used to generate the
+ // extract/insert ops.
+ SmallVector<int64_t, 4> prunedTransp;
+ pruneNonTransposedDims(transp, prunedTransp);
+ size_t numPrunedDims = transp.size() - prunedTransp.size();
+ auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
+ SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
+ auto prunedInStrides = computeStrides(prunedInShape, ones);
+
+ // Generates the extract/insert operations for every scalar/vector element
+ // of the leftmost transposed dimensions. We traverse every transpose
+ // element using a linearized index that we delinearize to generate the
+ // appropriate indices for the extract/insert operations.
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
- SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0);
- SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0);
- rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0,
- numLeftmostTransposedDims, transp, lhs,
- rhs, op.vector(), result, rewriter));
+ int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
+
+ for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
+ ++linearIdx) {
+ auto extractIdxs = delinearize(prunedInStrides, linearIdx);
+ SmallVector<int64_t, 4> insertIdxs(extractIdxs);
+ applyPermutationToVector(insertIdxs, prunedTransp);
+ Value extractOp =
+ rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
+ result =
+ rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
private:
- // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
- // operations when all the ranks go over the last dimension being transposed.
- Value expandIndices(Location loc, VectorType resType, Type extractType,
- int64_t pos, int64_t numLeftmostTransposedDims,
- SmallVector<int64_t, 4> &transp,
- SmallVector<int64_t, 4> &lhs,
- SmallVector<int64_t, 4> &rhs, Value input, Value result,
- PatternRewriter &rewriter) const {
- if (pos >= numLeftmostTransposedDims) {
- auto ridx = rewriter.getI64ArrayAttr(rhs);
- auto lidx = rewriter.getI64ArrayAttr(lhs);
- Value e =
- rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx);
- return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
- }
- for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
- lhs[pos] = d;
- rhs[transp[pos]] = d;
- result = expandIndices(loc, resType, extractType, pos + 1,
- numLeftmostTransposedDims, transp, lhs, rhs, input,
- result, rewriter);
- }
- return result;
- }
-
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
};
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 4087eab4a586..245e40cda8ee 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -8,14 +8,14 @@
// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
-// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
-// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
-// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
-// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
-// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
-// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
-// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
-// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
+// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
+// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
+// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32>
+// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
+// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
// ELTWISE: return %[[T11]] : vector<3x2xf32>