Skip to content

Commit

Permalink
added ability to use config file to shard vicuna (#1565)
Browse files Browse the repository at this point in the history
Co-authored-by: Elias Joseph <[email protected]>
  • Loading branch information
Eliasj42 and Elias Joseph authored Jun 22, 2023
1 parent 0ca3b9f commit 8822b9a
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 13 deletions.
20 changes: 18 additions & 2 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
import json

if __name__ == "__main__":
import gc


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -55,6 +59,12 @@
help="Run model in cli mode",
)

parser.add_argument(
"--config",
default=None,
help="configuration file",
)

if __name__ == "__main__":
args, unknown = parser.parse_known_args()

Expand All @@ -80,6 +90,7 @@
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)

vic = vp.Vicuna(
"vicuna",
device=args.device,
Expand All @@ -91,16 +102,21 @@
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
if args.config is not None:
config_file = open(args.config)
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
config_json=config_json,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"

import gc

while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
Expand Down
74 changes: 73 additions & 1 deletion apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(


class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
Expand All @@ -154,6 +154,12 @@ def __init__(self, model, layers0, layers1):
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead

def forward(
self,
Expand All @@ -176,3 +182,69 @@ def forward(
attention_mask=attention_mask,
past_key_values=past_key_values,
)


class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, hidden_states):
output = self.model(hidden_states)
return output


class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output


class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, hidden_states):
output = self.model(hidden_states)
return output


class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output


class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids):
output = self.model(input_ids)
return output


class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
Loading

0 comments on commit 8822b9a

Please sign in to comment.