diff options
author | Diego Caballero <diegocaballero@google.com> | 2022-03-10 20:44:24 +0000 |
---|---|---|
committer | Diego Caballero <diegocaballero@google.com> | 2022-03-10 22:33:14 +0000 |
commit | f71f9958b9845878909e005c67970e48b300f991 (patch) | |
tree | a76b5edd0ac690363b2ebdc80cfce1a6fad1a407 | |
parent | 3c9e8499435a4ecd197dcae23e9de6da914057d2 (diff) |
[mlir][Vector] Modernize default lowering of vector transpose
This patch removes an old recursive implementation to lower vector.transpose to extract/insert operations
and replaces it with a iterative approach that leverages newer linearization/delinearization utilities.
The patch should be NFC except by the order in which the extract/insert ops are generated.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D121321
8 files changed, 63 insertions, 68 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index e53b25cc22f2..4f991588d1d4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -32,19 +32,6 @@ class LinalgDependenceGraph; /// `[0, permutation.size())`. bool isPermutation(ArrayRef<int64_t> permutation); -/// Apply the permutation defined by `permutation` to `inVec`. -/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. -/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector -/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. -template <typename T, unsigned N> -void applyPermutationToVector(SmallVector<T, N> &inVec, - ArrayRef<int64_t> permutation) { - SmallVector<T, N> auxVec(inVec.size()); - for (const auto &en : enumerate(permutation)) - auxVec[en.index()] = inVec[en.value()]; - inVec = auxVec; -} - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 4678ce064815..3f2dd00c696f 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -30,6 +30,19 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis); SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides, int64_t linearIndex); +/// Apply the permutation defined by `permutation` to `inVec`. +/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. +/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector +/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. +template <typename T, unsigned N> +void applyPermutationToVector(SmallVector<T, N> &inVec, + ArrayRef<int64_t> permutation) { + SmallVector<T, N> auxVec(inVec.size()); + for (const auto &en : enumerate(permutation)) + auxVec[en.index()] = inVec[en.value()]; + inVec = auxVec; +} + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 4297a83005fe..1d46657018b3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index f1cd988d1fd7..907ffa3be4b9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AsmState.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index 452bf9d30ee4..4ce38530fe1e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 2e1418c529a2..4f863298ba42 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/Transforms/FoldUtils.h" diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 2b22412d6fc3..ab353326093a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -300,16 +301,18 @@ public: } }; -/// Return the number of leftmost dimensions from the first rightmost transposed -/// dimension found in 'transpose'. -size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) { +/// Given a 'transpose' pattern, prune the rightmost dimensions that are not +/// transposed. +void pruneNonTransposedDims(ArrayRef<int64_t> transpose, + SmallVectorImpl<int64_t> &result) { size_t numTransposedDims = transpose.size(); for (size_t transpDim : llvm::reverse(transpose)) { if (transpDim != numTransposedDims - 1) break; numTransposedDims--; } - return numTransposedDims; + + result.append(transpose.begin(), transpose.begin() + numTransposedDims); } /// Progressive lowering of TransposeOp. @@ -334,6 +337,8 @@ public: PatternRewriter &rewriter) const override { auto loc = op.getLoc(); + Value input = op.vector(); + VectorType inputType = op.getVectorType(); VectorType resType = op.getResultType(); // Set up convenience transposition table. @@ -354,7 +359,7 @@ public: Type flattenedType = VectorType::get(resType.getNumElements(), resType.getElementType()); auto matrix = - rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector()); + rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); Value trans = rewriter.create<vector::FlatTransposeOp>( @@ -365,54 +370,40 @@ public: // Generate unrolled extract/insert ops. We do not unroll the rightmost // (i.e., highest-order) dimensions that are not transposed and leave them - // in vector form to improve performance. - size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp); - - // The type of the extract operation will be scalar if all the dimensions - // are unrolled. Otherwise, it will be a vector with the shape of the - // dimensions that are not transposed. - Type extractType = - numLeftmostTransposedDims == transp.size() - ? resType.getElementType() - : VectorType::Builder(resType).setShape( - resType.getShape().drop_front(numLeftmostTransposedDims)); - + // in vector form to improve performance. Therefore, we prune those + // dimensions from the shape/transpose data structures used to generate the + // extract/insert ops. + SmallVector<int64_t, 4> prunedTransp; + pruneNonTransposedDims(transp, prunedTransp); + size_t numPrunedDims = transp.size() - prunedTransp.size(); + auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); + SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); + auto prunedInStrides = computeStrides(prunedInShape, ones); + + // Generates the extract/insert operations for every scalar/vector element + // of the leftmost transposed dimensions. We traverse every transpose + // element using a linearized index that we delinearize to generate the + // appropriate indices for the extract/insert operations. Value result = rewriter.create<arith::ConstantOp>( loc, resType, rewriter.getZeroAttr(resType)); - SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0); - SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0); - rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0, - numLeftmostTransposedDims, transp, lhs, - rhs, op.vector(), result, rewriter)); + int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); + + for (int64_t linearIdx = 0; linearIdx < numTransposedElements; + ++linearIdx) { + auto extractIdxs = delinearize(prunedInStrides, linearIdx); + SmallVector<int64_t, 4> insertIdxs(extractIdxs); + applyPermutationToVector(insertIdxs, prunedTransp); + Value extractOp = + rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); + result = + rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); + } + + rewriter.replaceOp(op, result); return success(); } private: - // Builds the indices arrays for the lhs and rhs. Generates the extract/insert - // operations when all the ranks go over the last dimension being transposed. - Value expandIndices(Location loc, VectorType resType, Type extractType, - int64_t pos, int64_t numLeftmostTransposedDims, - SmallVector<int64_t, 4> &transp, - SmallVector<int64_t, 4> &lhs, - SmallVector<int64_t, 4> &rhs, Value input, Value result, - PatternRewriter &rewriter) const { - if (pos >= numLeftmostTransposedDims) { - auto ridx = rewriter.getI64ArrayAttr(rhs); - auto lidx = rewriter.getI64ArrayAttr(lhs); - Value e = - rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx); - return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx); - } - for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { - lhs[pos] = d; - rhs[transp[pos]] = d; - result = expandIndices(loc, resType, extractType, pos + 1, - numLeftmostTransposedDims, transp, lhs, rhs, input, - result, rewriter); - } - return result; - } - /// Options to control the vector patterns. vector::VectorTransformsOptions vectorTransformOptions; }; diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 4087eab4a586..245e40cda8ee 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -8,14 +8,14 @@ // ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> // ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> // ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> -// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> -// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> -// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> -// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> +// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> // ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> // ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> // ELTWISE: return %[[T11]] : vector<3x2xf32> |