Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added ability to use config file to shard vicuna #1565

Merged
merged 1 commit into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from apps.language_models.src.pipelines import vicuna_pipeline as vp
from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp
import torch
import json

if __name__ == "__main__":
import gc
dan-garvey marked this conversation as resolved.
Show resolved Hide resolved


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -55,6 +59,12 @@
help="Run model in cli mode",
)

parser.add_argument(
"--config",
default=None,
help="configuration file",
)

dan-garvey marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == "__main__":
args, unknown = parser.parse_known_args()

Expand Down Expand Up @@ -84,6 +94,7 @@
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)

vic = vp.Vicuna(
"vicuna",
device=args.device,
Expand All @@ -95,16 +106,21 @@
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
)
else:
if args.config is not None:
config_file = open(args.config)
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load config if given

vic = vsp.Vicuna(
"vicuna",
device=args.device,
precision=args.precision,
config_json=config_json,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"

import gc

while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
Expand Down
74 changes: 73 additions & 1 deletion apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(


class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1):
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
Expand All @@ -154,6 +154,12 @@ def __init__(self, model, layers0, layers1):
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead

def forward(
self,
Expand All @@ -176,3 +182,69 @@ def forward(
attention_mask=attention_mask,
past_key_values=past_key_values,
)


class LMHead(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, hidden_states):
output = self.model(hidden_states)
return output


class LMHeadCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, hidden_states):
hidden_states = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output


class VicunaNorm(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, hidden_states):
output = self.model(hidden_states)
return output


class VicunaNormCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, hidden_states):
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output


class VicunaEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids):
output = self.model(input_ids)
return output


class VicunaEmbeddingCompiled(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
Loading