Skip to content

Commit

Permalink
cleanup NITs
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Nov 16, 2024
1 parent f728545 commit 6a99c5d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 71 deletions.
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,6 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
except ValueError as e:
raise ValidationError("Malformed arguments") from e
custom = get_custom(node)
innermost_dim = custom.type.symbolic_shape[-1]

# Determine whether to extract or combine.
if len(args) > 1:
Expand Down Expand Up @@ -1348,6 +1347,7 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
# actual offset, we need to multiply by the size. The size is obtained by
# computing the number of partitions using the source and target vector shapes
# and dividing the incoming vector shape by the number of partitions.
innermost_dim = custom.type.symbolic_shape[-1]
offset = custom.expanded_dims[innermost_dim]
num_partitions = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
Expand Down
76 changes: 28 additions & 48 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
Expand All @@ -214,7 +214,17 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_32x32x16_F8:
if self.mma_type == MMAType.F32_16x16x32_K4_F8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
), # M
lane % 16, # N
(16 * floor(GPR_NUM / 4))
+ 4 * floor(lane / 16)
+ (GPR_NUM % 4), # K
]
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand All @@ -238,52 +248,22 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_16x16x32_K4_F8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
), # M
lane % 16, # N
(16 * floor(GPR_NUM / 4))
+ 4 * floor(lane / 16)
+ (GPR_NUM % 4), # K
]
size = [
Piecewise((1, ~MMA_ACC), (4, MMA_ACC)), # M
1, # N
8, # K
]
stride = [
Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M
1, # N
16, # K
]
case MMAType.F32_32x32x16_K4_F8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
(
(8 * floor(GPR_NUM / 4) % 32)
+ 4 * floor(lane / 32)
+ (GPR_NUM % 4),
MMA_ACC,
),
), # M
lane % 32, # N
(8 * floor(GPR_NUM / 4))
+ 4 * floor(lane / 32)
+ (GPR_NUM % 4), # K
]
size = [
Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M
1, # N
8, # K
]
stride = [
Piecewise((1, ~MMA_ACC), (32, MMA_ACC)), # M
1, # N
1, # K
]
if self.mma_type == MMAType.F32_32x32x16_K4_F8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
(
(8 * floor(GPR_NUM / 4) % 32)
+ 4 * floor(lane / 32)
+ (GPR_NUM % 4),
MMA_ACC,
),
), # M
lane % 32, # N
(8 * floor(GPR_NUM / 4))
+ 4 * floor(lane / 32)
+ (GPR_NUM % 4), # K
]
case _:
raise ValueError("Unsupported MMA type")
assert isinstance(
Expand Down
20 changes: 10 additions & 10 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def has_strided_access(node: fx.Node) -> bool:
dim: simplify_index(custom.index.get(dim, custom.index[dim]))
for dim in custom.index
}

shape = get_vector_shape(
custom.vector_shapes, custom.register_type.symbolic_shape
)
Expand Down Expand Up @@ -129,7 +130,7 @@ def partition_ops_with_gpr_offsets(trace: CapturedTrace, constraints: list[Const
e.g a vector<16xf16> may be owned by lane 0, and lane 16 in this layout:
[0, 0, 0, 0, 16, 16, 16, 16, 0, 0, 0, 0, 16, 16, 16, 16].
With our current gloassary, this means we have 2 VGPR "chunks".
With our current glossary, this means we have 2 VGPR "chunks".
[0:4) and [8:12) for lane0, and [4:8) and [12:16) for lane16.
To the lane it should just look like vector<8xf16>.
Hence for this example, we'd need two reads of vector<4xf16> and a couple
Expand All @@ -145,15 +146,14 @@ def has_gpr_offsets(node: fx.Node) -> bool:
custom = get_custom(node)
if not isinstance(custom, (Read, Write)):
return False
dims_with_gpr_offset = [
v.start for k, v in custom.index.items() if v.start.has(GPR_NUM)
]
if not dims_with_gpr_offset:
num_dims_with_gpr = sum(
1 for v in custom.index.values() if v.start.has(GPR_NUM)
)
if num_dims_with_gpr == 1:
return True
elif num_dims_with_gpr == 0:
return False
num_dims_with_gpr_offsets = len(dims_with_gpr_offset)
if num_dims_with_gpr_offsets > 1:
raise NotImplementedError("Currently only handle 1 dim with gpr offset.")
return True
raise NotImplementedError("Currently only handles 1 dim with GPR offset.")

strided_operators = trace.walk(has_gpr_offsets)
for operator in strided_operators:
Expand All @@ -164,7 +164,7 @@ def has_gpr_offsets(node: fx.Node) -> bool:
}
elements_per_thread = subs_idxc(custom.elements_per_thread)
gpr_offsets = [
v.start for k, v in simplified_index.items() if v.start.has(GPR_NUM)
v.start for v in simplified_index.values() if v.start.has(GPR_NUM)
]
assert len(gpr_offsets) == 1, "Expected only 1-Dim has gpr offsets"
gpr_offset_expr = gpr_offsets[0]
Expand Down
16 changes: 4 additions & 12 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,13 +833,9 @@ def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int:
return 4
case MMAType.F32_32x32x8_F16:
return 4
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8:
return 8
case MMAType.F32_32x32x16_F8:
return 8
case MMAType.F32_16x16x32_K4_F8:
return 8
case MMAType.F32_32x32x16_K4_F8:
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8:
return 8


Expand All @@ -849,13 +845,9 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int:
return 4
case MMAType.F32_32x32x8_F16:
return 16
case MMAType.F32_16x16x32_F8:
return 4
case MMAType.F32_32x32x16_F8:
return 16
case MMAType.F32_16x16x32_K4_F8:
case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8:
return 4
case MMAType.F32_32x32x16_K4_F8:
case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8:
return 16


Expand Down

0 comments on commit 6a99c5d

Please sign in to comment.