summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
blob: 4297a83005fe5ac51bcff5628474cb96f6e47195 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
//===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements linalg fusion on tensors
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"

using namespace mlir;
using namespace linalg;

//===----------------------------------------------------------------------===//
// StructuredOp specific helpers.
//===----------------------------------------------------------------------===//

/// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
/// The slice defines a hyper rectangular iteration space and fusing the
/// producer is always possible. However, depending on the consumer indexing
/// map, not all slice elements may be consumed and the tiles may overlap. In
/// these cases, fusion introduces redundant computation.
static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
                                              ArrayRef<int64_t> tiledLoopDims) {
  // Get the consumer operand indexing map.
  LinalgOp consumerOp = consumerOperand->getOwner();
  AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);

  // Search the slice dimensions tiled by a tile loop dimension.
  DenseSet<int64_t> tiledSliceDimIndices;
  for (const auto &en : enumerate(indexingMap.getResults())) {
    for (auto tiledLoopDim : tiledLoopDims) {
      if (en.value().isFunctionOfDim(tiledLoopDim))
        tiledSliceDimIndices.insert(en.index());
    }
  }
  return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
}

/// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
/// of the producer result slice returns the tiled producer loop dimensions.
/// Example:
/// ```
/// %res = linalg.fill(%cst, %input)
/// scf.for %i
///   scf.for %j
///     %slice = tensor.extract_slice %res[%i, %j]
/// ```
/// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
static SmallVector<int64_t>
getTiledProducerLoops(OpResult producerResult,
                      ArrayRef<int64_t> tiledSliceDimIndices) {
  LinalgOp producerOp = producerResult.getOwner();

  // Get the indexing map of the `producerOp` output operand that matches
  // ´producerResult´.
  AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
      producerOp.getOutputOperand(producerResult.getResultNumber()));

  // Keep only the tiled result slice dimensions of `producerIndexingMap`.
  AffineMap tiledProducerIndexingSubMap =
      producerIndexingMap.getSubMap(SmallVector<unsigned>(
          tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));

  // Compute the producer loop indices mapped to the tiled result slice
  // dimensions. As the output indexing map of structured operations are
  // projected permutations, `tiledProducerIndexingSubMap` has to be a
  // projected permutation as well. We can thus obtain the producer loop indices
  // by getting the positions of the result dimensions.
  // Example:
  // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
  assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
         "expect slice and producer loop dimensions map one-to-one");
  SmallVector<int64_t> tiledProducerLoopIndices;
  transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
            std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
              return tiledProducerIndexingSubMap.getDimPosition(idx);
            });

  return tiledProducerLoopIndices;
}

/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
/// along the `tiledSliceDimIndices` and clone the producer. Consider the case
/// of fusion of an output tensor:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = consumer ins(...) outs(%1)
/// ```
/// When consumer is tiled, %1 appears in the loop iter_args:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
///   %t1 = tensor.extract_slice %bbarg[..]
///   %t2 = consumer ins(...) outs(%t1)
///   %r = tensor.insert_slice %t2, %bbarg[...]
/// }
/// ```
/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
/// ```
/// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
///   %t0 = tensor.extract_slice %bbarg[..]
///   %t1 = producer ins(...) outs(%t0)
///   %t2 = consumer ins(...) outs(%t1)
///   %r = tensor.insert_slice %t2, %bbarg[...]
/// }
/// ```
/// This transformation is only valid if %bbarg is exclusively used by the
/// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
/// `fuseProducer` method.
/// TODO: instead of check and failure, insert new iter_args each time a
/// producer is fused into a consumer and fold away unused iter_args.
static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
                                 tensor::ExtractSliceOp sliceOp,
                                 ArrayRef<int64_t> tiledSliceDimIndices,
                                 ArrayRef<int64_t> tiledProducerLoopIndices,
                                 OpOperand *iterArg) {
  // Clone the producer after `sliceOp` since the slice may be reused to pass in
  // the producer result.
  OpBuilder::InsertionGuard guard(b);
  b.setInsertionPointAfter(sliceOp);

  // Get the producer.
  LinalgOp producerOp = producerResult.getOwner();
  Location loc = producerOp.getLoc();

  // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
  SmallVector<Value> producerLoopBounds;
  transform(producerOp.createLoopRanges(b, loc),
            std::back_inserter(producerLoopBounds),
            [](Range range) { return range.size; });
  SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);

  // Tile the producer operands given the `sliceOp` ranges. Iterate the
  // `tiledSliceDimIndices` and store the tile offset and size for the tiled
  // slice dimension.
  auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
  SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
  SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
  SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
  for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
    int64_t tiledSliceDim = std::get<0>(it);
    int64_t tiledProducerLoop = std::get<1>(it);
    tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
    tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
    allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
  }
  erase_value(tileIvs, nullptr);
  SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
  tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
                                  tileSizes, producerLoopBounds);

  // Output fusion has to update the iteration arguments of the tile loop nest.
  // In particular, the iteration argument of the outermost tile loop needs to
  // be set to the producer output instead of the producer result and `clonedOp`
  // shall use the existing `sliceOp` result instead of the tiled producer
  // output operand.
  if (iterArg) {
    OpOperand *outputOperand =
        producerOp.getOutputOperand(producerResult.getResultNumber());
    iterArg->set(outputOperand->get());
    tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
  }

  // Clone the producer using the tiled producer operands.
  TypeRange resultTypes = ValueRange(tiledOperands)
                              .take_back(producerOp.getNumOutputs())
                              .getTypes();
  LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);

  // Shift all IndexOp results by the tile offset.
  addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);

  return clonedOp;
}

//===----------------------------------------------------------------------===//
// TileLoopNest specific helpers.
//===----------------------------------------------------------------------===//

bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }

bool TileLoopNest::isValid() {
  // Check if `rootOp` has been tiled at least once.
  if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
    return false;

  // Check if the number of loop operations and dimensions match.
  if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
    return false;

  // Check if the innermost tile loop is the parent of `tiledOp`.
  if (rootOp->getParentOp() != tileLoopOps.back())
    return false;

  // Check if the tile loops are directly nested.
  return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
                            [](Operation *op1, Operation *op2) {
                              return op1 != op2->getParentOp();
                            }) == tileLoopOps.end();
}

SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
  assert(bbArg && "expect the block argument to be non-zero");
  SmallVector<BlockArgument> bbArgs;

  // Search all tile loop block arguments from inner to outer.
  for (auto tileLoop : reverse(tileLoopOps)) {
    if (bbArg.getOwner()->getParentOp() != tileLoop)
      return {};
    bbArgs.push_back(bbArg);
    OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
    bbArg = iterArg->get().dyn_cast<BlockArgument>();
  }

  // Reverse the block arguments to order them from outer to inner.
  return {bbArgs.rbegin(), bbArgs.rend()};
}

OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
  // Search all block arguments and return the matching iteration argument.
  SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
  if (bbArgs.size() != tileLoopOps.size())
    return nullptr;
  return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
}

bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
                                tensor::ExtractSliceOp sliceOp) {
  // Check the innermost block argument is either used by the ExtractSliceOp
  // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
  // conservatively.
  for (Operation *op : bbArg.getUsers()) {
    if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
      return false;
    if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
      if (extractSliceOp != sliceOp)
        return false;
    }
    if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
      SetVector<Operation *> backwardSlice;
      getBackwardSlice(insertSliceOp.source(), &backwardSlice,
                       [](Operation *op) {
                         return isa<LinalgOp, tensor::InsertSliceOp>(op);
                       });
      if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
        return false;
    }
  }

  // Check the block arguments, except for the innermost one, have one use.
  SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
  return !all_of(bbArgs, [&](BlockArgument bbArg) {
    return bbArg.hasOneUse() || bbArg == bbArgs.back();
  });
}

LogicalResult TileLoopNest::tileRootOp(
    OpBuilder &b, ArrayRef<int64_t> tileSizes,
    ArrayRef<int64_t> tileInterchange,
    Optional<LinalgLoopDistributionOptions> tileDistribution) {
  // Exit if all tile sizes are zero.
  if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
    return success();

  // Tile the root operation.
  LinalgTilingOptions tilingOptions;
  tilingOptions = tilingOptions
                      .setInterchange(SmallVector<unsigned>(
                          tileInterchange.begin(), tileInterchange.end()))
                      .setTileSizes(tileSizes)
                      .setLoopType(LinalgTilingLoopType::Loops);
  if (tileDistribution)
    tilingOptions =
        tilingOptions.setDistributionOptions(tileDistribution.getValue());

  // TODO: Propagate RewriterBase everywhere.
  IRRewriter rewriter(b);
  FailureOr<TiledLinalgOp> tiledRootOp =
      tileLinalgOp(rewriter, rootOp, tilingOptions);

  // Exit if tiling the root operation fails.
  if (failed(tiledRootOp))
    return failure();

  // Replace all uses of the root operation if it has been tiled before. All
  // uses of the original untiled root operation are updated by the calling pass
  // or pattern.
  if (!isEmpty())
    rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);

  // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
  if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
    tiledRootAndFusedOpsLoops[tiledRootOp->op] =
        tiledRootAndFusedOpsLoops[rootOp];
  }

  // Update the root operation and append the loops and tile loop dimensions.
  rootOp = tiledRootOp->op;
  tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
  for (const auto &en : enumerate(tileSizes)) {
    // Copy only the tiled loop dimensions with non-zero tile size.
    if (en.value() == 0)
      continue;
    tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
  }
  assert(isValid() && "expect tile loop nest to be valid after tiling");
  return success();
}

FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
                                               OpOperand *consumerOpOperand) {
  // Check if the consumer has been tiled before. For example, it may not have
  // been tiled if the outermost tile loop is a reduction loop.
  if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0)
    return failure();

  assert(this->isValid() &&
         "expect the tile loop nest to satisfy all invariants");

  // Check the tile loop nest is non-empty.
  if (isEmpty())
    return failure();

  // Check `consumerOpOperand` is defined by an ExtractSliceOp.
  auto sliceOp =
      consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
  if (!sliceOp)
    return failure();

  // Check `sliceOp` and `consumerOp` are in the same block.
  LinalgOp consumerOp = consumerOpOperand->getOwner();
  if (sliceOp->getBlock() != rootOp->getBlock() ||
      consumerOp->getBlock() != rootOp->getBlock())
    return failure();

  // Check if the producer is a LinalgOp possibly passed by iteration argument.
  OpOperand *iterArg = nullptr;
  auto producerResult = sliceOp.source().dyn_cast<OpResult>();
  if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
    iterArg = getTiedIterArg(bbArg);
    // Check the iteration argument may be used to pass in the producer output.
    if (!iterArg || hasOtherUses(bbArg, sliceOp))
      return failure();
    producerResult = iterArg->get().dyn_cast<OpResult>();
  }
  if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
    return failure();

  // Compute the tiled producer slice dimensions given the tiled consumer loops.
  SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
      consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
  if (tiledSliceDimIndices.empty())
    return failure();

  // Compute the tiled producer loop indices.
  SmallVector<int64_t> tiledProducerLoopIndices =
      getTiledProducerLoops(producerResult, tiledSliceDimIndices);

  // Tile the producer operands and clone the producer in place of `sliceOp`.
  LinalgOp clonedOp =
      getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
                       tiledProducerLoopIndices, iterArg);
  tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;

  // Cast the `clonedOp` result to gap type mismatches before canonicalization.
  Type consumerOperandType = consumerOpOperand->get().getType();
  Value newResult = clonedOp->getResult(producerResult.getResultNumber());
  if (newResult.getType() != consumerOperandType) {
    OpBuilder::InsertionGuard guard(b);
    b.setInsertionPointAfter(clonedOp);
    newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
                                         consumerOperandType, newResult);
  }

  // Replace the `sliceOp` uses except for the `clonedOp` output uses.
  sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
  return clonedOp;
}

ValueRange TileLoopNest::getRootOpReplacementResults() {
  assert(!isEmpty() && "expect tile loop nest to be non-empty");
  return tileLoopOps.front()->getOpResults();
}

SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
  SmallVector<LinalgOp> result;
  for (const auto &kvp : tiledRootAndFusedOpsLoops) {
    auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst());
    assert(linalgOp &&
           "expect all tiled and fused operations are linalg operations");
    result.push_back(linalgOp);
  }
  return result;
}

//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//

FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
    OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
    ArrayRef<int64_t> tileInterchange,
    const Optional<LinalgLoopDistributionOptions> &tileDistribution) {
  assert(tileSizes.size() == tileInterchange.size() &&
         "expect the number of tile sizes and interchange dims to match");
  assert(isPermutation(tileInterchange) &&
         "expect tile interchange is a permutation");

  // Create an empty tile loop nest.
  TileLoopNest tileLoopNest(consumerOp);

  // Search the number of outer parallel loops to separate them from possible
  // inner reduction dimensions.
  SmallVector<StringAttr> iterTypes =
      llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
  applyPermutationToVector(iterTypes, tileInterchange);
  auto *it = find_if(iterTypes, [&](StringAttr iterType) {
    return !isParallelIterator(iterType);
  });
  int64_t split = std::distance(iterTypes.begin(), it);

  // Helper to fuse the producers greedily using a queue of fusion candidates.
  auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
    SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
    while (!candidates.empty()) {
      FailureOr<LinalgOp> fusedProducer =
          tileLoopNest.fuseProducer(b, candidates.pop_back_val());
      if (failed(fusedProducer))
        continue;
      candidates.append(fusedProducer->getInputAndOutputOperands());
    }
  };

  // Tile the outer parallel loops and fuse the output operands.
  SmallVector<int64_t> outerTileSizes;
  outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
  outerTileSizes.append(tileSizes.size() - split, 0);
  if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
                                     tileDistribution)))
    return failure();
  fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());

  // Tile the remaining loops and fuse the input operands.
  SmallVector<int64_t> innerTileSizes;
  innerTileSizes.append(split, 0);
  innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
  if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
                                     tileDistribution)))
    return failure();
  fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());

  // Exit if the tile loop nest is empty since all tile sizes are zero.
  if (tileLoopNest.isEmpty())
    return failure();

  return tileLoopNest;
}