Skip to content

Commit

Permalink
Merge branch 'main' into ean-tm-pin
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 24, 2023
2 parents 64a0b35 + 8cdb384 commit cdf6bff
Show file tree
Hide file tree
Showing 21 changed files with 933 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# vscode related
.vscode
Expand Down
32 changes: 22 additions & 10 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
import json

if __name__ == "__main__":
import gc


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -55,35 +59,38 @@
help="Run model in cli mode",
)

parser.add_argument(
"--config",
default=None,
help="configuration file",
)

if __name__ == "__main__":
args, unknown = parser.parse_known_args()

vic = None
if not args.sharded:
first_vic_mlir_path = (
Path(f"first_vicuna_{args.precision}.mlir")
None
if args.first_vicuna_mlir_path is None
else Path(args.first_vicuna_mlir_path)
)
second_vic_mlir_path = (
Path(f"second_vicuna_{args.precision}.mlir")
None
if args.second_vicuna_mlir_path is None
else Path(args.second_vicuna_mlir_path)
)
first_vic_vmfb_path = (
Path(
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
None
if args.first_vicuna_vmfb_path is None
else Path(args.first_vicuna_vmfb_path)
)
second_vic_vmfb_path = (
Path(
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
None
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)

vic = vp.Vicuna(
"vicuna",
device=args.device,
Expand All @@ -95,16 +102,21 @@
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
if args.config is not None:
config_file = open(args.config)
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
config_json=config_json,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"

import gc

while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
Expand Down
74 changes: 73 additions & 1 deletion apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(


class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
Expand All @@ -154,6 +154,12 @@ def __init__(self, model, layers0, layers1):
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead

def forward(
self,
Expand All @@ -176,3 +182,69 @@ def forward(
attention_mask=attention_mask,
past_key_values=past_key_values,
)


class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, hidden_states):
output = self.model(hidden_states)
return output


class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output


class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, hidden_states):
output = self.model(hidden_states)
return output


class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output


class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids):
output = self.model(input_ids)
return output


class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
39 changes: 28 additions & 11 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,23 @@ def __init__(
first_vicuna_vmfb_path=None,
second_vicuna_vmfb_path=None,
load_mlir_from_shark_tank=True,
low_device_memory=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if precision in ["int4", "int8"]:
print("int4 and int8 are not supported yet, using fp32")
precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
self.low_device_memory = low_device_memory
self.first_vic = None
self.second_vic = None
if self.first_vicuna_mlir_path == None:
self.first_vicuna_mlir_path = self.get_model_path()
if self.second_vicuna_mlir_path == None:
Expand All @@ -61,7 +68,7 @@ def get_model_path(self, model_number="first", suffix="mlir"):
if suffix == "mlir":
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
return Path(
f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}"
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
)

def get_tokenizer(self):
Expand All @@ -87,7 +94,7 @@ def compile_first_vicuna(self):
# Compilation path needs some more work before it is functional

print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
Expand Down Expand Up @@ -436,12 +443,19 @@ def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc

if not self.low_device_memory:
if self.first_vic == None:
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
"fv": self.compile_first_vicuna()
if self.first_vic == None
else self.first_vic,
}

generated_token_op = self.generate_new_token(params=params)
Expand All @@ -457,18 +471,20 @@ def generate(self, prompt, cli=False):
print(f"Assistant: {detok}", end=" ", flush=True)

# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
if self.low_device_memory:
del params
torch.cuda.empty_cache()
gc.collect()

sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
"sv": self.compile_second_vicuna()
if self.second_vic == None
else self.second_vic,
}

generated_token_op = self.generate_new_token(params=params)
Expand All @@ -489,9 +505,10 @@ def generate(self, prompt, cli=False):
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()

for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
Expand Down
Loading

0 comments on commit cdf6bff

Please sign in to comment.