forked from exo-explore/exo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
partitioning_strategy.py
40 lines (30 loc) · 1.17 KB
/
partitioning_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from abc import ABC, abstractmethod
from typing import List
from dataclasses import dataclass
from .topology import Topology
from exo.inference.shard import Shard
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
@dataclass
class Partition:
node_id: str
start: float
end: float
class PartitioningStrategy(ABC):
@abstractmethod
def partition(self, topology: Topology) -> List[Partition]:
pass
def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
shards = []
for i, partition in enumerate(partitions):
start_layer = int(partition.start*num_layers)
end_layer = int(partition.end*num_layers) - 1
# Ensure the last partition covers up to num_layers - 1
if i == len(partitions) - 1:
end_layer = num_layers - 1
# Ensure no empty shards
if start_layer <= end_layer:
shards.append(Shard(model_id, start_layer, end_layer, num_layers))
# Ensure full coverage
if shards and shards[-1].end_layer < num_layers - 1:
shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
return shards