From 4015793f84ff43686d53ee5043aa4f571ac000b0 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Mon, 3 Jul 2023 12:12:43 -0700 Subject: [PATCH] changed method of compiling vicuna to remove first and second vicuna (#1611) Co-authored-by: Elias Joseph Co-authored-by: powderluv --- .../model_wrappers/vicuna_sharded_model.py | 180 ++++---- .../src/pipelines/vicuna_sharded_pipeline.py | 404 +++++++++++------- 2 files changed, 318 insertions(+), 266 deletions(-) diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py index cba0b0f952..1a46c3b50f 100644 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py @@ -62,104 +62,23 @@ def forward( ) -class CompiledFirstVicunaLayer(torch.nn.Module): - def __init__(self, shark_module): - super().__init__() - self.model = shark_module - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value=None, - output_attentions=False, - use_cache=True, - ): - hidden_states = hidden_states.detach() - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() - output = self.model( - "forward", - ( - hidden_states, - attention_mask, - position_ids, - ), - ) - - output0 = torch.tensor(output[0]) - output1 = torch.tensor(output[1]) - output2 = torch.tensor(output[2]) - - return ( - output0, - ( - output1, - output2, - ), - ) - - -class CompiledSecondVicunaLayer(torch.nn.Module): - def __init__(self, shark_module): - super().__init__() - self.model = shark_module - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions=False, - use_cache=True, - ): - hidden_states = hidden_states.detach() - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() - pkv0 = past_key_value[0].detach() - pkv1 = past_key_value[1].detach() - output = self.model( - "forward", - ( - hidden_states, - attention_mask, - position_ids, - pkv0, - pkv1, - ), - ) - - output0 = torch.tensor(output[0]) - output1 = torch.tensor(output[1]) - output2 = torch.tensor(output[2]) - - return ( - output0, - ( - output1, - output2, - ), - ) - - class ShardedVicunaModel(torch.nn.Module): - def __init__(self, model, layers0, layers1, lmhead, embedding, norm): + def __init__(self, model, layers, lmhead, embedding, norm): super().__init__() self.model = model - assert len(layers0) == len(model.model.layers) - # self.model.model.layers = torch.nn.modules.container.ModuleList(layers0) + assert len(layers) == len(model.model.layers) self.model.model.config.use_cache = True self.model.model.config.output_attentions = False - self.layers0 = layers0 - self.layers1 = layers1 + self.layers = layers self.norm = norm self.embedding = embedding self.lmhead = lmhead self.model.model.norm = self.norm self.model.model.embed_tokens = self.embedding self.model.lm_head = self.lmhead + self.model.model.layers = torch.nn.modules.container.ModuleList( + self.layers + ) def forward( self, @@ -168,20 +87,11 @@ def forward( past_key_values=None, attention_mask=None, ): - if is_first: - self.model.model.layers = torch.nn.modules.container.ModuleList( - self.layers0 - ) - return self.model.forward(input_ids, attention_mask=attention_mask) - else: - self.model.model.layers = torch.nn.modules.container.ModuleList( - self.layers1 - ) - return self.model.forward( - input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - ) + return self.model.forward( + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) class LMHead(torch.nn.Module): @@ -248,3 +158,71 @@ def forward(self, input_ids): output = self.model("forward", (input_ids,)) output = torch.tensor(output) return output + + +class CompiledVicunaLayer(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value=None, + output_attentions=False, + use_cache=True, + ): + if past_key_value is None: + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + output = self.model( + "first_vicuna_forward", + ( + hidden_states, + attention_mask, + position_ids, + ), + ) + + output0 = torch.tensor(output[0]) + output1 = torch.tensor(output[1]) + output2 = torch.tensor(output[2]) + + return ( + output0, + ( + output1, + output2, + ), + ) + else: + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + pkv0 = past_key_value[0].detach() + pkv1 = past_key_value[1].detach() + output = self.model( + "second_vicuna_forward", + ( + hidden_states, + attention_mask, + position_ids, + pkv0, + pkv1, + ), + ) + + output0 = torch.tensor(output[0]) + output1 = torch.tensor(output[1]) + output2 = torch.tensor(output[2]) + + return ( + output0, + ( + output1, + output2, + ), + ) diff --git a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py index 9cb3a428ea..cc88ce7bb3 100644 --- a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py @@ -1,8 +1,7 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( FirstVicunaLayer, SecondVicunaLayer, - CompiledFirstVicunaLayer, - CompiledSecondVicunaLayer, + CompiledVicunaLayer, ShardedVicunaModel, LMHead, LMHeadCompiled, @@ -95,6 +94,7 @@ def write_in_dynamic_inputs1(self, module, dynamic_input_size): ) continue line = re.sub(f"{dynamic_input_size}x", "?x", line) + line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line) if "?x" in line: line = re.sub( "tensor.empty\(\)", "tensor.empty(%dim_42)", line @@ -112,6 +112,124 @@ def write_in_dynamic_inputs1(self, module, dynamic_input_size): new_module = "\n".join(new_lines) return new_module + def combine_mlir_scripts( + self, first_vicuna_mlir, second_vicuna_mlir, output_name + ): + maps1 = [] + maps2 = [] + constants = set() + f1 = [] + f2 = [] + + for line in first_vicuna_mlir.splitlines(): + if re.search("#map\d*\s*=", line): + maps1.append(line) + elif re.search("arith.constant", line): + constants.add(line) + elif not re.search("module", line): + line = re.sub("forward", "first_vicuna_forward", line) + f1.append(line) + f1 = f1[:-1] + + for i, map_line in enumerate(maps1): + map_var = map_line.split(" ")[0] + map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line) + maps1[i] = map_line + f1 = [ + re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line) + for func_line in f1 + ] + + for line in second_vicuna_mlir.splitlines(): + if re.search("#map\d*\s*=", line): + maps2.append(line) + elif "global_seed" in line: + continue + elif re.search("arith.constant", line): + constants.add(line) + elif not re.search("module", line): + line = re.sub("forward", "second_vicuna_forward", line) + f2.append(line) + f2 = f2[:-1] + + for i, map_line in enumerate(maps2): + map_var = map_line.split(" ")[0] + map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line) + maps2[i] = map_line + f2 = [ + re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line) + for func_line in f2 + ] + + module_start = ( + 'module attributes {torch.debug_module_name = "_lambda"} {' + ) + module_end = "}" + + global_vars = [] + vnames = [] + vdtypes = [] + global_var_loading1 = [] + global_var_loading2 = [] + + for constant in list(constants): + vname, vbody = constant.split("=") + vname = re.sub("%", "", vname) + vname = vname.strip() + vbody = re.sub("arith.constant", "", vbody) + vbody = vbody.strip() + vdtype = vbody.split(":")[1].strip() + fixed_vdtype = vdtype + vdtypes.append(vdtype) + vdtype = re.sub("\d{1,}x", "?x", vdtype) + vnames.append(vname) + global_vars.append( + f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}" + ) + global_var_loading1.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + ) + global_var_loading2.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + ) + + new_f1, new_f2 = [], [] + + for line in f1: + if "func.func" in line: + new_f1.append(line) + for global_var in global_var_loading1: + new_f1.append(global_var) + else: + new_f1.append(line) + + for line in f2: + if "func.func" in line: + new_f2.append(line) + for global_var in global_var_loading1: + new_f2.append(global_var) + else: + new_f2.append(line) + + f1 = new_f1 + f2 = new_f2 + + whole_string = "\n".join( + maps1 + + maps2 + + [module_start] + + global_vars + + f1 + + f2 + + [module_end] + ) + + f_ = open(output_name, "w+") + f_.write(whole_string) + f_.close() + + return whole_string + def compile_vicuna_layer( self, vicuna_layer, @@ -193,6 +311,7 @@ def compile_lmhead( device=device, mlir_dialect="tm_tensor", device_idx=device_idx, + mmap=False, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -235,6 +354,7 @@ def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, + mmap=False, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -276,6 +396,7 @@ def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, + mmap=False, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -286,171 +407,127 @@ def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): return compiled_module - def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): - # compile all layers for vmfb - # this needs to be run seperatley for first and second vicuna + def compile_to_vmfb_one_model( + self, inputs0, layers0, inputs1, layers1, device="cpu" + ): mlirs, modules = [], [] - for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"): - if is_first: - mlir_path = Path(f"{idx}_0.mlir") - vmfb_path = Path(f"{idx}_0.vmfb") - else: - mlir_path = Path(f"{idx}_1.mlir") - vmfb_path = Path(f"{idx}_1.vmfb") - if vmfb_path.exists(): - continue + assert len(layers0) == len(layers1) + for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): + mlir_path = Path(f"{idx}_full.mlir") + vmfb_path = Path(f"{idx}_full.vmfb") + # if vmfb_path.exists(): + # continue if mlir_path.exists(): # print(f"Found layer {idx} mlir") f_ = open(mlir_path, "rb") bytecode = f_.read() f_.close() + mlirs.append(bytecode) else: - hidden_states_placeholder = TensorPlaceholder.like( - inputs[0], dynamic_axes=[1] + hidden_states_placeholder0 = TensorPlaceholder.like( + inputs0[0], dynamic_axes=[1] ) - attention_mask_placeholder = TensorPlaceholder.like( - inputs[1], dynamic_axes=[3] + attention_mask_placeholder0 = TensorPlaceholder.like( + inputs0[1], dynamic_axes=[3] ) - position_ids_placeholder = TensorPlaceholder.like( - inputs[2], dynamic_axes=[1] + position_ids_placeholder0 = TensorPlaceholder.like( + inputs0[2], dynamic_axes=[1] + ) + hidden_states_placeholder1 = TensorPlaceholder.like( + inputs1[0], dynamic_axes=[1] + ) + attention_mask_placeholder1 = TensorPlaceholder.like( + inputs1[1], dynamic_axes=[3] + ) + position_ids_placeholder1 = TensorPlaceholder.like( + inputs1[2], dynamic_axes=[1] + ) + pkv0_placeholder = TensorPlaceholder.like( + inputs1[3], dynamic_axes=[2] + ) + pkv1_placeholder = TensorPlaceholder.like( + inputs1[4], dynamic_axes=[2] ) - if not is_first: - pkv0_placeholder = TensorPlaceholder.like( - inputs[3], dynamic_axes=[2] - ) - pkv1_placeholder = TensorPlaceholder.like( - inputs[4], dynamic_axes=[2] - ) print(f"Compiling layer {idx} mlir") - if is_first: - ts_g = self.compile_vicuna_layer( - layer, inputs[0], inputs[1], inputs[2] - ) - module = torch_mlir.compile( - ts_g, - ( - hidden_states_placeholder, - inputs[1], - inputs[2], - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - else: - ts_g = self.compile_vicuna_layer( - layer, - inputs[0], - inputs[1], - inputs[2], - inputs[3], - inputs[4], - ) - module = torch_mlir.compile( - ts_g, - ( - inputs[0], - attention_mask_placeholder, - inputs[2], - pkv0_placeholder, - pkv1_placeholder, - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - if is_first: - module = self.write_in_dynamic_inputs0(str(module), 137) - bytecode = module.encode("UTF-8") - bytecode_stream = BytesIO(bytecode) - bytecode = bytecode_stream.read() - - else: - module = self.write_in_dynamic_inputs1(str(module), 138) + ts_g = self.compile_vicuna_layer( + layer0, inputs0[0], inputs0[1], inputs0[2] + ) + module0 = torch_mlir.compile( + ts_g, + ( + hidden_states_placeholder0, + inputs0[1], + inputs0[2], + ), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + module0 = self.write_in_dynamic_inputs0(str(module0), 137) + + ts_g = self.compile_vicuna_layer( + layer1, + inputs1[0], + inputs1[1], + inputs1[2], + inputs1[3], + inputs1[4], + ) + module1 = torch_mlir.compile( + ts_g, + ( + inputs1[0], + attention_mask_placeholder1, + inputs1[2], + pkv0_placeholder, + pkv1_placeholder, + ), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + module1 = self.write_in_dynamic_inputs1(str(module1), 138) - bytecode = module.encode("UTF-8") - bytecode_stream = BytesIO(bytecode) - bytecode = bytecode_stream.read() + module_combined = self.combine_mlir_scripts( + module0, module1, f"{idx}_full.mlir" + ) + mlirs.append(module_combined) - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - mlirs.append(bytecode) - - for idx, layer in tqdm(enumerate(layers), desc="compiling modules"): - if is_first: - vmfb_path = Path(f"{idx}_0.vmfb") - if vmfb_path.exists(): - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - module = SharkInference( - None, - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - ) - module.load_module(vmfb_path) - else: - print(f"Compiling layer {idx} vmfb") - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - module = SharkInference( - mlirs[idx], - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - ) - module.save_module( - module_name=f"{idx}_0", - extra_args=[ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ], - ) - module.load_module(vmfb_path) - modules.append(module) + if vmfb_path.exists(): + # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + None, + device=device, + device_idx=idx % 4, + mlir_dialect="tm_tensor", + mmap=False, + ) + module.load_module(vmfb_path) else: - vmfb_path = Path(f"{idx}_1.vmfb") - if vmfb_path.exists(): - # print(f"Found layer {idx} vmfb") - device_idx = self.get_device_index( - f"second_vicuna.model.model.layers.{idx}[\s.$]" - ) - module = SharkInference( - None, - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - ) - module.load_module(vmfb_path) - else: - print(f"Compiling layer {idx} vmfb") - device_idx = self.get_device_index( - f"second_vicuna.model.model.layers.{idx}[\s.$]" - ) - module = SharkInference( - mlirs[idx], - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - ) - module.save_module( - module_name=f"{idx}_1", - extra_args=[ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ], - ) - module.load_module(vmfb_path) - modules.append(module) - + print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + mlirs[idx], + device=device, + device_idx=idx % 4, + mlir_dialect="tm_tensor", + mmap=False, + ) + module.save_module( + module_name=f"{idx}_full", + extra_args=[ + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + module.load_module(vmfb_path) + modules.append(module) return mlirs, modules def get_sharded_model(self, device="cpu"): @@ -511,26 +588,23 @@ def get_sharded_model(self, device="cpu"): layers0 = [ FirstVicunaLayer(layer) for layer in vicuna_model.model.layers ] - _, modules0 = self.compile_to_vmfb( - placeholder_input0, - layers0, - is_first=True, - device=device, - ) - shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0] layers1 = [ SecondVicunaLayer(layer) for layer in vicuna_model.model.layers ] - _, modules1 = self.compile_to_vmfb( - placeholder_input1, layers1, is_first=False, device=device + + _, modules = self.compile_to_vmfb_one_model( + placeholder_input0, + layers0, + placeholder_input1, + layers1, + device=device, ) - shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1] + shark_layers = [CompiledVicunaLayer(m) for m in modules] sharded_model = ShardedVicunaModel( vicuna_model, - shark_layers0, - shark_layers1, + shark_layers, lmhead, embeddings, norm,