Skip to content

Commit

Permalink
added ability to use config file to shard vicuna
Browse files Browse the repository at this point in the history
  • Loading branch information
Elias Joseph committed Jun 21, 2023
1 parent d61b664 commit 43a6a4a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
21 changes: 19 additions & 2 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
import json

if __name__ == "__main__":
import gc


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -55,6 +59,12 @@
help="Run model in cli mode",
)

parser.add_argument(
"--config",
default=None,
help="configuration file",
)

if __name__ == "__main__":
args, unknown = parser.parse_known_args()

Expand Down Expand Up @@ -84,6 +94,7 @@
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)

vic = vp.Vicuna(
"vicuna",
device=args.device,
Expand All @@ -95,16 +106,22 @@
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
if args.config is not None:
config_file = open(args.config)
config_json = json.load(config_file)
config_file.close()
else:
print("No Json Found")
config_json = None
vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
config_json=config_json,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"

import gc

while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
Expand Down
36 changes: 32 additions & 4 deletions apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import torch_mlir
import os
import json


class Vicuna(SharkLLMBase):
Expand All @@ -29,12 +30,14 @@ def __init__(
max_num_tokens=512,
device="cuda",
precision="fp32",
config_json=None,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.config = config_json
self.shark_model = self.compile(device=device)

def get_tokenizer(self):
Expand Down Expand Up @@ -220,24 +223,43 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
f_.close()
mlirs.append(bytecode)

def get_device_index(layer_string):
if self.config is None:
return None
idx_votes = {}
for key in self.config.keys():
if re.search(layer_string, key):
if int(self.config[key]["gpu"]) in idx_votes.keys():
idx_votes[int(self.config[key]["gpu"])] += 1
else:
idx_votes[int(self.config[key]["gpu"])] = 1
device_idx = max(idx_votes, key=idx_votes.get)
return device_idx

for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.]"
)
module = SharkInference(
None,
device=device,
device_idx=idx % 1,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=idx % 1,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.save_module(
Expand All @@ -255,19 +277,25 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = get_device_index(
f"second_vicuna.model.model.layers.{idx}[\s.]"
)
module = SharkInference(
None,
device=device,
device_idx=idx % 1,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = get_device_index(
f"second_vicuna.model.model.layers.{idx}[\s.]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=idx % 1,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.save_module(
Expand Down

0 comments on commit 43a6a4a

Please sign in to comment.