-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ObjectFifo] Create a new pass to split L2 buffers
-- This commit introduces a new pass `--iree-amdaie-split-buffers` to split L2 buffers for dealing with Matmul+Elementwise. -- It addresses sub-action 2 as well from #644 Signed-off-by: Abhishek Varma <[email protected]>
- Loading branch information
1 parent
883ee07
commit b2d08c7
Showing
7 changed files
with
253 additions
and
1 deletion.
There are no files selected for viewing
145 changes: 145 additions & 0 deletions
145
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIESplitBuffers.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/AMDAIEDmaUtils.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
#include "iree-amd-aie/Transforms/Transforms.h" | ||
#include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/SCF/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/Iterators.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-split-buffers" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
namespace { | ||
|
||
class AMDAIESplitBuffersPass | ||
: public impl::AMDAIESplitBuffersBase<AMDAIESplitBuffersPass> { | ||
public: | ||
using AMDAIESplitBuffersBase::AMDAIESplitBuffersBase; | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect>(); | ||
} | ||
void runOnOperation() override; | ||
}; | ||
|
||
void AMDAIESplitBuffersPass::runOnOperation() { | ||
ModuleOp moduleOp = getOperation(); | ||
IRRewriter rewriter(moduleOp.getContext()); | ||
|
||
SmallVector<AMDAIE::DmaCpyNdOp> l2ToL1DmaOps; | ||
// We are currently walking through CoreOps gathering 3rd Input DmaOp (if | ||
// applicable) from them. | ||
// TODO: We will generalize this later. | ||
moduleOp.walk([&](AMDAIE::CoreOp coreOp) { | ||
SmallVector<Value> inputDmas = coreOp.getInputDmas(); | ||
if (inputDmas.size() < 3) return WalkResult::skip(); | ||
l2ToL1DmaOps.push_back(inputDmas[2].getDefiningOp<AMDAIE::DmaCpyNdOp>()); | ||
return WalkResult::advance(); | ||
}); | ||
|
||
DenseSet<Operation *> toBeErased; | ||
for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) { | ||
LogicalObjectFifoFromMemrefOp sourceObjectFifo = | ||
l2ToL1DmaOp.getSourceObjectFifo(); | ||
auto sourceAllocOp = | ||
sourceObjectFifo.getMemref().getDefiningOp<memref::AllocOp>(); | ||
uint64_t sourceMemrefSpace = sourceObjectFifo.getMemorySpaceAsUInt(); | ||
if (!sourceAllocOp || sourceMemrefSpace != 1) continue; | ||
LogicalObjectFifoFromMemrefOp targetObjectFifo = | ||
l2ToL1DmaOp.getTargetObjectFifo(); | ||
Value targetAllocOp = targetObjectFifo.getMemref(); | ||
|
||
// Now we'll create a narrowed L2 buffer. | ||
rewriter.setInsertionPoint(sourceAllocOp); | ||
auto oldSourceMemRefType = cast<MemRefType>(sourceAllocOp.getType()); | ||
auto targetMemRefType = cast<MemRefType>(targetAllocOp.getType()); | ||
MemRefType newAllocType = MemRefType::get( | ||
targetMemRefType.getNumElements(), targetMemRefType.getElementType(), | ||
MemRefLayoutAttrInterface{}, oldSourceMemRefType.getMemorySpace()); | ||
auto newAllocOp = rewriter.create<memref::AllocOp>(rewriter.getUnknownLoc(), | ||
newAllocType); | ||
auto newDeallocOp = rewriter.create<memref::DeallocOp>( | ||
rewriter.getUnknownLoc(), newAllocOp); | ||
newDeallocOp->moveBefore(&newAllocOp->getBlock()->back()); | ||
|
||
// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target. | ||
AMDAIE::DmaCpyNdOp l3ToL2DmaOp; | ||
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) { | ||
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp); | ||
dmaOp.getTargetObjectFifo() == sourceObjectFifo) { | ||
l3ToL2DmaOp = dmaOp; | ||
toBeErased.insert(dmaOp); | ||
break; | ||
} | ||
} | ||
toBeErased.insert(sourceAllocOp); | ||
toBeErased.insert(sourceObjectFifo); | ||
|
||
auto type = cast<MemRefType>(newAllocOp.getType()); | ||
// Create new logicalobjectfifo.from_memref for the newly created L2 buffer. | ||
rewriter.setInsertionPoint(l2ToL1DmaOp.getSourceObjectFifo()); | ||
auto source = rewriter.create<AMDAIE::LogicalObjectFifoFromMemrefOp>( | ||
rewriter.getUnknownLoc(), LogicalObjectFifoType::get(type), | ||
newAllocOp.getResult(), sourceObjectFifo.getTiles()); | ||
|
||
// Create new L3 -> L2 Dma Op. | ||
rewriter.setInsertionPoint(l3ToL2DmaOp); | ||
rewriter.create<AMDAIE::DmaCpyNdOp>( | ||
l3ToL2DmaOp.getLoc(), source, l3ToL2DmaOp.getTargetMixedOffsets(), | ||
l3ToL2DmaOp.getTargetMixedSizes(), l3ToL2DmaOp.getTargetMixedStrides(), | ||
l3ToL2DmaOp.getSource(), l3ToL2DmaOp.getSourceMixedOffsets(), | ||
l3ToL2DmaOp.getSourceMixedSizes(), l3ToL2DmaOp.getSourceMixedStrides()); | ||
|
||
// Create new L2 -> L1 Input DmaOp. | ||
rewriter.setInsertionPoint(l2ToL1DmaOp); | ||
auto newL2ToL1DmaOp = rewriter.create<AMDAIE::DmaCpyNdOp>( | ||
l2ToL1DmaOp.getLoc(), l2ToL1DmaOp.getTarget(), | ||
l2ToL1DmaOp.getTargetMixedOffsets(), l2ToL1DmaOp.getTargetMixedSizes(), | ||
l2ToL1DmaOp.getTargetMixedStrides(), source, | ||
l2ToL1DmaOp.getSourceMixedOffsets(), l2ToL1DmaOp.getSourceMixedSizes(), | ||
l2ToL1DmaOp.getSourceMixedStrides()); | ||
rewriter.replaceOp(l2ToL1DmaOp, newL2ToL1DmaOp); | ||
// We have to discard non-zero offsets as subview has been replaced by a | ||
// dedicated allocated memref. | ||
SmallVector<int64_t> allocShape(type.getShape()); | ||
(void)discardAllNonZeroOffsets<CopyOpOperateOn::Source>( | ||
rewriter, | ||
cast<AMDAIE::DoublyStridedOpInterface>(newL2ToL1DmaOp.getOperation()), | ||
allocShape); | ||
|
||
// Remove old dealloc. | ||
memref::DeallocOp oldDeallocOp; | ||
for (Operation *userOp : sourceAllocOp->getUsers()) { | ||
if (auto deallocUser = dyn_cast<memref::DeallocOp>(userOp)) { | ||
oldDeallocOp = deallocUser; | ||
} | ||
} | ||
if (oldDeallocOp) { | ||
rewriter.eraseOp(oldDeallocOp); | ||
} | ||
} | ||
|
||
for (Operation *op : toBeErased) { | ||
op->dropAllUses(); | ||
rewriter.eraseOp(op); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> createAMDAIESplitBuffersPass() { | ||
return std::make_unique<AMDAIESplitBuffersPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::AMDAIE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/split_buffers.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// RUN: iree-opt --pass-pipeline="builtin.module(iree-amdaie-split-buffers,cse)" --split-input-file --verify-diagnostics %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @split_l2_buffer | ||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index | ||
// CHECK-DAG: %[[L3_ALLOC:.*]] = memref.alloc() : memref<128x128xi32> | ||
// CHECK-DAG: %[[L2_ALLOC:.*]] = memref.alloc() : memref<1024xi32, 1 : i32> | ||
// CHECK-DAG: %[[L1_ALLOC:.*]] = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32> | ||
// CHECK: %[[TILE:.*]] = amdaie.tile(%[[C1]], %[[C3]]) | ||
// CHECK: %[[L2_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L2_ALLOC]], {%[[TILE]]} : | ||
// CHECK-SAME: memref<1024xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1024xi32, 1 : i32>> | ||
// CHECK: %[[L3_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L3_ALLOC]], {%[[TILE]]} : | ||
// CHECK-SAME: memref<128x128xi32> -> !amdaie.logicalobjectfifo<memref<128x128xi32>> | ||
// CHECK: scf.forall | ||
// CHECK: %[[DMA_CPY_ND_L3_TO_L2:.*]] = amdaie.dma_cpy_nd(%[[L2_OBJECTFIFO]] | ||
// CHECK-SAME: %[[L3_OBJECTFIFO]] | ||
// CHECK: amdaie.logicalobjectfifo.from_memref | ||
// CHECK: amdaie.logicalobjectfifo.from_memref | ||
// CHECK: amdaie.dma_cpy_nd | ||
// CHECK: amdaie.dma_cpy_nd | ||
// CHECK: %[[L1_OBJECTFIFO:.*]] = amdaie.logicalobjectfifo.from_memref %[[L1_ALLOC]] | ||
// CHECK: %[[DMA_CPY_ND_L2_TO_L1:.*]] = amdaie.dma_cpy_nd(%[[L1_OBJECTFIFO]] | ||
// CHECK-SAME: %[[L2_OBJECTFIFO]] | ||
// CHECK: amdaie.core(%[[TILE]], in : [%{{.*}}, %{{.*}}, %[[DMA_CPY_ND_L2_TO_L1]]], out : | ||
// CHECK: linalg.generic | ||
// CHECK: } | ||
// CHECK: memref.dealloc %[[L2_ALLOC]] : memref<1024xi32, 1 : i32> | ||
#map = affine_map<(d0) -> (d0 * 64)> | ||
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> | ||
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> | ||
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> | ||
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> | ||
module { | ||
func.func @split_l2_buffer(%arg0: !amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>>, %arg1: !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, %arg2: !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>, %arg3: !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>) { | ||
%c3 = arith.constant 3 : index | ||
%c16 = arith.constant 16 : index | ||
%c8 = arith.constant 8 : index | ||
%c4 = arith.constant 4 : index | ||
%c128 = arith.constant 128 : index | ||
%c2048 = arith.constant 2048 : index | ||
%c256 = arith.constant 256 : index | ||
%c1024 = arith.constant 1024 : index | ||
%c4096 = arith.constant 4096 : index | ||
%c32 = arith.constant 32 : index | ||
%c2 = arith.constant 2 : index | ||
%c1 = arith.constant 1 : index | ||
%c0 = arith.constant 0 : index | ||
%alloc = memref.alloc() : memref<2x1x32x32xi32, 1 : i32> | ||
%alloc_0 = memref.alloc() : memref<1x2x32x32xi32, 1 : i32> | ||
%alloc_1 = memref.alloc() : memref<2x2x32x32xi32, 1 : i32> | ||
%alloc_2 = memref.alloc() : memref<128x128xi32> | ||
%alloc_3 = memref.alloc() : memref<1x1x8x8x4x4xi32, 2 : i32> | ||
%tile = amdaie.tile(%c1, %c3) | ||
%0 = amdaie.logicalobjectfifo.from_memref %alloc_1, {%tile} : memref<2x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>> | ||
%1 = amdaie.logicalobjectfifo.from_memref %alloc_2, {%tile} : memref<128x128xi32> -> !amdaie.logicalobjectfifo<memref<128x128xi32>> | ||
scf.forall (%arg4, %arg5) in (2, 2) { | ||
%2 = affine.apply #map(%arg5) | ||
%3 = affine.apply #map(%arg4) | ||
%4 = amdaie.dma_cpy_nd(%0[%c0, %c0, %c0, %c0] [%c2, %c2, %c32, %c32] [%c2048, %c1024, %c32, %c1], %1[%c0, %c0, %3, %2] [%c2, %c2, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (!amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<128x128xi32>>) | ||
%tile_4 = amdaie.tile(%c1, %c3) | ||
%5 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile} : memref<2x1x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<2x1x32x32xi32, 1 : i32>> | ||
%6 = amdaie.logicalobjectfifo.from_memref %alloc_0, {%tile} : memref<1x2x32x32xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>> | ||
%7 = amdaie.dma_cpy_nd(%arg0[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c256, %c32, %c8, %c1], %5[%c1, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c8, %c128, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2x1x32x32xi32, 1 : i32>>) | ||
%8 = amdaie.dma_cpy_nd(%arg1[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c1024, %c1024, %c128, %c32, %c4, %c1], %6[%c0, %c1, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c1024, %c4, %c256, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<1x2x32x32xi32, 1 : i32>>) | ||
%9 = amdaie.logicalobjectfifo.from_memref %alloc_3, {%tile} : memref<1x1x8x8x4x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> | ||
%10 = amdaie.dma_cpy_nd(%9[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c8, %c4, %c4] [%c1024, %c1024, %c128, %c16, %c4, %c1], %0[%c1, %c1, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c8, %c4, %c4] [%c2048, %c1024, %c4, %c128, %c32, %c1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>, !amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>) | ||
%11 = amdaie.dma_cpy_nd(%arg3[%c1, %c1, %c0, %c0] [%c1, %c1, %c32, %c32] [%c2048, %c1024, %c32, %c1], %arg2[%c0, %c0, %c0, %c0] [%c8, %c4, %c8, %c4] [%c16, %c4, %c128, %c1]) : (!amdaie.logicalobjectfifo<memref<2x2x32x32xi32, 1 : i32>>, !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>>) | ||
%12 = amdaie.core(%tile_4, in : [%7, %8, %10], out : [%11]) { | ||
%13 = amdaie.logicalobjectfifo.access(%arg0, Read) : !amdaie.logicalobjectfifo<memref<1x1x4x8x4x8xi32, 2 : i32>> -> memref<1x1x4x8x4x8xi32, 2 : i32> | ||
%14 = amdaie.logicalobjectfifo.access(%arg1, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>> -> memref<1x1x8x4x8x4xi32, 2 : i32> | ||
%15 = amdaie.logicalobjectfifo.access(%arg2, None) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32> | ||
linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%13, %14 : memref<1x1x4x8x4x8xi32, 2 : i32>, memref<1x1x8x4x8x4xi32, 2 : i32>) outs(%15 : memref<1x1x8x8x4x4xi32, 2 : i32>) { | ||
^bb0(%in: i32, %in_5: i32, %out: i32): | ||
%18 = arith.muli %in, %in_5 : i32 | ||
%19 = arith.addi %out, %18 : i32 | ||
linalg.yield %19 : i32 | ||
} | ||
%16 = amdaie.logicalobjectfifo.access(%arg2, Read) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32> | ||
%17 = amdaie.logicalobjectfifo.access(%arg2, Write) : !amdaie.logicalobjectfifo<memref<1x1x8x8x4x4xi32, 2 : i32>> -> memref<1x1x8x8x4x4xi32, 2 : i32> | ||
linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%15, %16 : memref<1x1x8x8x4x4xi32, 2 : i32>, memref<1x1x8x8x4x4xi32, 2 : i32>) outs(%17 : memref<1x1x8x8x4x4xi32, 2 : i32>) { | ||
^bb0(%in: i32, %in_5: i32, %out: i32): | ||
%18 = arith.addi %in, %in_5 : i32 | ||
linalg.yield %18 : i32 | ||
} | ||
amdaie.end | ||
} | ||
} {mapping = [#gpu.block<y>, #gpu.block<x>]} | ||
memref.dealloc %alloc : memref<2x1x32x32xi32, 1 : i32> | ||
memref.dealloc %alloc_3 : memref<1x1x8x8x4x4xi32, 2 : i32> | ||
memref.dealloc %alloc_0 : memref<1x2x32x32xi32, 1 : i32> | ||
memref.dealloc %alloc_1 : memref<2x2x32x32xi32, 1 : i32> | ||
memref.dealloc %alloc_2 : memref<128x128xi32> | ||
return | ||
} | ||
} |