summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
blob: 2475b44b42a3d3baecacc126581192c70daf9d1b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
//===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all the uses of LDS within non-kernel functions by
// corresponding pointer counter-parts.
//
// The main motivation behind this pass is - to *avoid* subsequent LDS lowering
// pass from directly packing LDS (assume large LDS) into a struct type which
// would otherwise cause allocating huge memory for struct instance within every
// kernel.
//
// Brief sketch of the algorithm implemented in this pass is as below:
//
//   1. Collect all the LDS defined in the module which qualify for pointer
//      replacement, say it is, LDSGlobals set.
//
//   2. Collect all the reachable callees for each kernel defined in the module,
//      say it is, KernelToCallees map.
//
//   3. FOR (each global GV from LDSGlobals set) DO
//        LDSUsedNonKernels = Collect all non-kernel functions which use GV.
//        FOR (each kernel K in KernelToCallees map) DO
//           ReachableCallees = KernelToCallees[K]
//           ReachableAndLDSUsedCallees =
//              SetIntersect(LDSUsedNonKernels, ReachableCallees)
//           IF (ReachableAndLDSUsedCallees is not empty) THEN
//             Pointer = Create a pointer to point-to GV if not created.
//             Initialize Pointer to point-to GV within kernel K.
//           ENDIF
//        ENDFOR
//        Replace all uses of GV within non kernel functions by Pointer.
//      ENFOR
//
// LLVM IR example:
//
//    Input IR:
//
//    @lds = internal addrspace(3) global [4 x i32] undef, align 16
//
//    define internal void @f0() {
//    entry:
//      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
//             i32 0, i32 0
//      ret void
//    }
//
//    define protected amdgpu_kernel void @k0() {
//    entry:
//      call void @f0()
//      ret void
//    }
//
//    Output IR:
//
//    @lds = internal addrspace(3) global [4 x i32] undef, align 16
//    @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
//
//    define internal void @f0() {
//    entry:
//      %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
//      %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
//      %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
//      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
//             i32 0, i32 0
//      ret void
//    }
//
//    define protected amdgpu_kernel void @k0() {
//    entry:
//      store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
//            i16 addrspace(3)* @lds.ptr, align 2
//      call void @f0()
//      ret void
//    }
//
//===----------------------------------------------------------------------===//

#include "AMDGPU.h"
#include "GCNSubtarget.h"
#include "Utils/AMDGPUBaseInfo.h"
#include "Utils/AMDGPULDSUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <algorithm>
#include <vector>

#define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"

using namespace llvm;

namespace {

namespace AMDGPU {
/// Collect all the instructions where user \p U belongs to. \p U could be
/// instruction itself or it could be a constant expression which is used within
/// an instruction. If \p CollectKernelInsts is true, collect instructions only
/// from kernels, otherwise collect instructions only from non-kernel functions.
DenseMap<Function *, SmallPtrSet<Instruction *, 8>>
getFunctionToInstsMap(User *U, bool CollectKernelInsts);

SmallPtrSet<Function *, 8> collectNonKernelAccessorsOfLDS(GlobalVariable *GV);

} // namespace AMDGPU

class ReplaceLDSUseImpl {
  Module &M;
  LLVMContext &Ctx;
  const DataLayout &DL;
  Constant *LDSMemBaseAddr;

  DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
  DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
  DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
  DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
  DenseMap<Function *, BasicBlock *> KernelToInitBB;
  DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
      FunctionToLDSToReplaceInst;

  // Collect LDS which requires their uses to be replaced by pointer.
  std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
    // Collect LDS which requires module lowering.
    std::vector<GlobalVariable *> LDSGlobals =
        llvm::AMDGPU::findVariablesToLower(M);

    // Remove LDS which don't qualify for replacement.
    llvm::erase_if(LDSGlobals, [&](GlobalVariable *GV) {
      return shouldIgnorePointerReplacement(GV);
    });

    return LDSGlobals;
  }

  // Returns true if uses of given LDS global within non-kernel functions should
  // be keep as it is without pointer replacement.
  bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
    // LDS whose size is very small and doesn't exceed pointer size is not worth
    // replacing.
    if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
      return true;

    // LDS which is not used from non-kernel function scope or it is used from
    // global scope does not qualify for replacement.
    LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
    return LDSToNonKernels[GV].empty();

    // FIXME: When GV is used within all (or within most of the kernels), then
    // it does not make sense to create a pointer for it.
  }

  // Insert new global LDS pointer which points to LDS.
  GlobalVariable *createLDSPointer(GlobalVariable *GV) {
    // LDS pointer which points to LDS is already created? Return it.
    auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
    if (!PointerEntry.second)
      return PointerEntry.first->second;

    // We need to create new LDS pointer which points to LDS.
    //
    // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
    // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
    auto *I16Ty = Type::getInt16Ty(Ctx);
    GlobalVariable *LDSPointer = new GlobalVariable(
        M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
        GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
        AMDGPUAS::LOCAL_ADDRESS);

    LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
    LDSPointer->setAlignment(llvm::AMDGPU::getAlign(DL, LDSPointer));

    // Mark that an associated LDS pointer is created for LDS.
    LDSToPointer[GV] = LDSPointer;

    return LDSPointer;
  }

  // Split entry basic block in such a way that only lane 0 of each wave does
  // the LDS pointer initialization, and return newly created basic block.
  BasicBlock *activateLaneZero(Function *K) {
    // If the entry basic block of kernel K is already split, then return
    // newly created basic block.
    auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
    if (!BasicBlockEntry.second)
      return BasicBlockEntry.first->second;

    // Split entry basic block of kernel K.
    auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt()));
    IRBuilder<> Builder(EI);

    Value *Mbcnt =
        Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
                                {Builder.getInt32(-1), Builder.getInt32(0)});
    Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
    Instruction *WB = cast<Instruction>(
        Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));

    BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();

    // Mark that the entry basic block of kernel K is split.
    KernelToInitBB[K] = NBB;

    return NBB;
  }

  // Within given kernel, initialize given LDS pointer to point to given LDS.
  void initializeLDSPointer(Function *K, GlobalVariable *GV,
                            GlobalVariable *LDSPointer) {
    // If LDS pointer is already initialized within K, then nothing to do.
    auto PointerEntry = KernelToLDSPointers.insert(
        std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
    if (!PointerEntry.second)
      if (PointerEntry.first->second.contains(LDSPointer))
        return;

    // Insert instructions at EI which initialize LDS pointer to point-to LDS
    // within kernel K.
    //
    // That is, convert pointer type of GV to i16, and then store this converted
    // i16 value within LDSPointer which is of type i16*.
    auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
    IRBuilder<> Builder(EI);
    Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
                        LDSPointer);

    // Mark that LDS pointer is initialized within kernel K.
    KernelToLDSPointers[K].insert(LDSPointer);
  }

  // We have created an LDS pointer for LDS, and initialized it to point-to LDS
  // within all relevant kernels. Now replace all the uses of LDS within
  // non-kernel functions by LDS pointer.
  void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
    SmallVector<User *, 8> LDSUsers(GV->users());
    for (auto *U : LDSUsers) {
      // When `U` is a constant expression, it is possible that same constant
      // expression exists within multiple instructions, and within multiple
      // non-kernel functions. Collect all those non-kernel functions and all
      // those instructions within which `U` exist.
      auto FunctionToInsts =
          AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);

      for (const auto &FunctionToInst : FunctionToInsts) {
        Function *F = FunctionToInst.first;
        auto &Insts = FunctionToInst.second;
        for (auto *I : Insts) {
          // If `U` is a constant expression, then we need to break the
          // associated instruction into a set of separate instructions by
          // converting constant expressions into instructions.
          SmallPtrSet<Instruction *, 8> UserInsts;

          if (U == I) {
            // `U` is an instruction, conversion from constant expression to
            // set of instructions is *not* required.
            UserInsts.insert(I);
          } else {
            // `U` is a constant expression, convert it into corresponding set
            // of instructions.
            auto *CE = cast<ConstantExpr>(U);
            convertConstantExprsToInstructions(I, CE, &UserInsts);
          }

          // Go through all the user instructions, if LDS exist within them as
          // an operand, then replace it by replace instruction.
          for (auto *II : UserInsts) {
            auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
            II->replaceUsesOfWith(GV, ReplaceInst);
          }
        }
      }
    }
  }

  // Create a set of replacement instructions which together replace LDS within
  // non-kernel function F by accessing LDS indirectly using LDS pointer.
  Value *getReplacementInst(Function *F, GlobalVariable *GV,
                            GlobalVariable *LDSPointer) {
    // If the instruction which replaces LDS within F is already created, then
    // return it.
    auto LDSEntry = FunctionToLDSToReplaceInst.insert(
        std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
    if (!LDSEntry.second) {
      auto ReplaceInstEntry =
          LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
      if (!ReplaceInstEntry.second)
        return ReplaceInstEntry.first->second;
    }

    // Get the instruction insertion point within the beginning of the entry
    // block of current non-kernel function.
    auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
    IRBuilder<> Builder(EI);

    // Insert required set of instructions which replace LDS within F.
    auto *V = Builder.CreateBitCast(
        Builder.CreateGEP(
            Builder.getInt8Ty(), LDSMemBaseAddr,
            Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
        GV->getType());

    // Mark that the replacement instruction which replace LDS within F is
    // created.
    FunctionToLDSToReplaceInst[F][GV] = V;

    return V;
  }

public:
  ReplaceLDSUseImpl(Module &M)
      : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
    LDSMemBaseAddr = Constant::getIntegerValue(
        PointerType::get(Type::getInt8Ty(M.getContext()),
                         AMDGPUAS::LOCAL_ADDRESS),
        APInt(32, 0));
  }

  // Entry-point function which interface ReplaceLDSUseImpl with outside of the
  // class.
  bool replaceLDSUse();

private:
  // For a given LDS from collected LDS globals set, replace its non-kernel
  // function scope uses by pointer.
  bool replaceLDSUse(GlobalVariable *GV);
};

// For given LDS from collected LDS globals set, replace its non-kernel function
// scope uses by pointer.
bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
  // Holds all those non-kernel functions within which LDS is being accessed.
  SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];

  // The LDS pointer which points to LDS and replaces all the uses of LDS.
  GlobalVariable *LDSPointer = nullptr;

  // Traverse through each kernel K, check and if required, initialize the
  // LDS pointer to point to LDS within K.
  for (const auto &KernelToCallee : KernelToCallees) {
    Function *K = KernelToCallee.first;
    SmallPtrSet<Function *, 8> Callees = KernelToCallee.second;

    // Compute reachable and LDS used callees for kernel K.
    set_intersect(Callees, LDSAccessors);

    // None of the LDS accessing non-kernel functions are reachable from
    // kernel K. Hence, no need to initialize LDS pointer within kernel K.
    if (Callees.empty())
      continue;

    // We have found reachable and LDS used callees for kernel K, and we need to
    // initialize LDS pointer within kernel K, and we need to replace LDS use
    // within those callees by LDS pointer.
    //
    // But, first check if LDS pointer is already created, if not create one.
    LDSPointer = createLDSPointer(GV);

    // Initialize LDS pointer to point to LDS within kernel K.
    initializeLDSPointer(K, GV, LDSPointer);
  }

  // We have not found reachable and LDS used callees for any of the kernels,
  // and hence we have not created LDS pointer.
  if (!LDSPointer)
    return false;

  // We have created an LDS pointer for LDS, and initialized it to point-to LDS
  // within all relevant kernels. Now replace all the uses of LDS within
  // non-kernel functions by LDS pointer.
  replaceLDSUseByPointer(GV, LDSPointer);

  return true;
}

namespace AMDGPU {

// An helper class for collecting all reachable callees for each kernel defined
// within the module.
class CollectReachableCallees {
  Module &M;
  CallGraph CG;
  SmallPtrSet<CallGraphNode *, 8> AddressTakenFunctions;

  // Collect all address taken functions within the module.
  void collectAddressTakenFunctions() {
    auto *ECNode = CG.getExternalCallingNode();

    for (const auto &GI : *ECNode) {
      auto *CGN = GI.second;
      auto *F = CGN->getFunction();
      if (!F || F->isDeclaration() || llvm::AMDGPU::isKernelCC(F))
        continue;
      AddressTakenFunctions.insert(CGN);
    }
  }

  // For given kernel, collect all its reachable non-kernel functions.
  SmallPtrSet<Function *, 8> collectReachableCallees(Function *K) {
    SmallPtrSet<Function *, 8> ReachableCallees;

    // Call graph node which represents this kernel.
    auto *KCGN = CG[K];

    // Go through all call graph nodes reachable from the node representing this
    // kernel, visit all their call sites, if the call site is direct, add
    // corresponding callee to reachable callee set, if it is indirect, resolve
    // the indirect call site to potential reachable callees, add them to
    // reachable callee set, and repeat the process for the newly added
    // potential callee nodes.
    //
    // FIXME: Need to handle bit-casted function pointers.
    //
    SmallVector<CallGraphNode *, 8> CGNStack(depth_first(KCGN));
    SmallPtrSet<CallGraphNode *, 8> VisitedCGNodes;
    while (!CGNStack.empty()) {
      auto *CGN = CGNStack.pop_back_val();

      if (!VisitedCGNodes.insert(CGN).second)
        continue;

      // Ignore call graph node which does not have associated function or
      // associated function is not a definition.
      if (!CGN->getFunction() || CGN->getFunction()->isDeclaration())
        continue;

      for (const auto &GI : *CGN) {
        auto *RCB = cast<CallBase>(GI.first.getValue());
        auto *RCGN = GI.second;

        if (auto *DCallee = RCGN->getFunction()) {
          ReachableCallees.insert(DCallee);
        } else if (RCB->isIndirectCall()) {
          auto *RCBFTy = RCB->getFunctionType();
          for (auto *ACGN : AddressTakenFunctions) {
            auto *ACallee = ACGN->getFunction();
            if (ACallee->getFunctionType() == RCBFTy) {
              ReachableCallees.insert(ACallee);
              CGNStack.append(df_begin(ACGN), df_end(ACGN));
            }
          }
        }
      }
    }

    return ReachableCallees;
  }

public:
  explicit CollectReachableCallees(Module &M) : M(M), CG(CallGraph(M)) {
    // Collect address taken functions.
    collectAddressTakenFunctions();
  }

  void collectReachableCallees(
      DenseMap<Function *, SmallPtrSet<Function *, 8>> &KernelToCallees) {
    // Collect reachable callee set for each kernel defined in the module.
    for (Function &F : M.functions()) {
      if (!llvm::AMDGPU::isKernelCC(&F))
        continue;
      Function *K = &F;
      KernelToCallees[K] = collectReachableCallees(K);
    }
  }
};

/// Collect reachable callees for each kernel defined in the module \p M and
/// return collected callees at \p KernelToCallees.
void collectReachableCallees(
    Module &M,
    DenseMap<Function *, SmallPtrSet<Function *, 8>> &KernelToCallees) {
  CollectReachableCallees CRC{M};
  CRC.collectReachableCallees(KernelToCallees);
}

/// For the given LDS global \p GV, visit all its users and collect all
/// non-kernel functions within which \p GV is used and return collected list of
/// such non-kernel functions.
SmallPtrSet<Function *, 8> collectNonKernelAccessorsOfLDS(GlobalVariable *GV) {
  SmallPtrSet<Function *, 8> LDSAccessors;
  SmallVector<User *, 8> UserStack(GV->users());
  SmallPtrSet<User *, 8> VisitedUsers;

  while (!UserStack.empty()) {
    auto *U = UserStack.pop_back_val();

    // `U` is already visited? continue to next one.
    if (!VisitedUsers.insert(U).second)
      continue;

    // `U` is a global variable which is initialized with LDS. Ignore LDS.
    if (isa<GlobalValue>(U))
      return SmallPtrSet<Function *, 8>();

    // Recursively explore constant users.
    if (isa<Constant>(U)) {
      append_range(UserStack, U->users());
      continue;
    }

    // `U` should be an instruction, if it belongs to a non-kernel function F,
    // then collect F.
    Function *F = cast<Instruction>(U)->getFunction();
    if (!llvm::AMDGPU::isKernelCC(F))
      LDSAccessors.insert(F);
  }

  return LDSAccessors;
}

DenseMap<Function *, SmallPtrSet<Instruction *, 8>>
getFunctionToInstsMap(User *U, bool CollectKernelInsts) {
  DenseMap<Function *, SmallPtrSet<Instruction *, 8>> FunctionToInsts;
  SmallVector<User *, 8> UserStack;
  SmallPtrSet<User *, 8> VisitedUsers;

  UserStack.push_back(U);

  while (!UserStack.empty()) {
    auto *UU = UserStack.pop_back_val();

    if (!VisitedUsers.insert(UU).second)
      continue;

    if (isa<GlobalValue>(UU))
      continue;

    if (isa<Constant>(UU)) {
      append_range(UserStack, UU->users());
      continue;
    }

    auto *I = cast<Instruction>(UU);
    Function *F = I->getFunction();
    if (CollectKernelInsts) {
      if (!llvm::AMDGPU::isKernelCC(F)) {
        continue;
      }
    } else {
      if (llvm::AMDGPU::isKernelCC(F)) {
        continue;
      }
    }

    FunctionToInsts.insert(std::make_pair(F, SmallPtrSet<Instruction *, 8>()));
    FunctionToInsts[F].insert(I);
  }

  return FunctionToInsts;
}

} // namespace AMDGPU

// Entry-point function which interface ReplaceLDSUseImpl with outside of the
// class.
bool ReplaceLDSUseImpl::replaceLDSUse() {
  // Collect LDS which requires their uses to be replaced by pointer.
  std::vector<GlobalVariable *> LDSGlobals =
      collectLDSRequiringPointerReplace();

  // No LDS to pointer-replace. Nothing to do.
  if (LDSGlobals.empty())
    return false;

  // Collect reachable callee set for each kernel defined in the module.
  AMDGPU::collectReachableCallees(M, KernelToCallees);

  if (KernelToCallees.empty()) {
    // Either module does not have any kernel definitions, or none of the kernel
    // has a call to non-kernel functions, or we could not resolve any of the
    // call sites to proper non-kernel functions, because of the situations like
    // inline asm calls. Nothing to replace.
    return false;
  }

  // For every LDS from collected LDS globals set, replace its non-kernel
  // function scope use by pointer.
  bool Changed = false;
  for (auto *GV : LDSGlobals)
    Changed |= replaceLDSUse(GV);

  return Changed;
}

class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
public:
  static char ID;

  AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
    initializeAMDGPUReplaceLDSUseWithPointerPass(
        *PassRegistry::getPassRegistry());
  }

  bool runOnModule(Module &M) override;

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetPassConfig>();
  }
};

} // namespace

char AMDGPUReplaceLDSUseWithPointer::ID = 0;
char &llvm::AMDGPUReplaceLDSUseWithPointerID =
    AMDGPUReplaceLDSUseWithPointer::ID;

INITIALIZE_PASS_BEGIN(
    AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
    "Replace within non-kernel function use of LDS with pointer",
    false /*only look at the cfg*/, false /*analysis pass*/)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(
    AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
    "Replace within non-kernel function use of LDS with pointer",
    false /*only look at the cfg*/, false /*analysis pass*/)

bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
  ReplaceLDSUseImpl LDSUseReplacer{M};
  return LDSUseReplacer.replaceLDSUse();
}

ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
  return new AMDGPUReplaceLDSUseWithPointer();
}

PreservedAnalyses
AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
  ReplaceLDSUseImpl LDSUseReplacer{M};
  LDSUseReplacer.replaceLDSUse();
  return PreservedAnalyses::all();
}