Skip to content

Commit

Permalink
changed method of compiling vicuna to remove first and second vicuna (#…
Browse files Browse the repository at this point in the history
…1611)

Co-authored-by: Elias Joseph <[email protected]>
Co-authored-by: powderluv <[email protected]>
  • Loading branch information
3 people committed Jul 3, 2023
1 parent d63ce76 commit 4015793
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 266 deletions.
180 changes: 79 additions & 101 deletions apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,104 +62,23 @@ def forward(
)


class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)

output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])

return (
output0,
(
output1,
output2,
),
)


class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)

output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])

return (
output0,
(
output1,
output2,
),
)


class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
self.layers = layers
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers
)

def forward(
self,
Expand All @@ -168,20 +87,11 @@ def forward(
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)


class LMHead(torch.nn.Module):
Expand Down Expand Up @@ -248,3 +158,71 @@ def forward(self, input_ids):
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output


class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module

def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
)

output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])

return (
output0,
(
output1,
output2,
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)

output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])

return (
output0,
(
output1,
output2,
),
)
Loading

0 comments on commit 4015793

Please sign in to comment.