Skip to content

Commit

Permalink
download all mlirs (#1727)
Browse files Browse the repository at this point in the history
Co-authored-by: Elias Joseph <[email protected]>
  • Loading branch information
Eliasj42 and Elias Joseph authored Aug 4, 2023
1 parent 759664b commit fd1c4db
Showing 1 changed file with 51 additions and 44 deletions.
95 changes: 51 additions & 44 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pathlib import Path
from tqdm import tqdm
from typing import List, Tuple

import subprocess

import torch
import torch_mlir
from torch_mlir import TensorPlaceholder
Expand All @@ -27,7 +27,7 @@
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna4 import(
from apps.language_models.src.model_wrappers.vicuna4 import (
LlamaModel,
EightLayerLayerSV,
EightLayerLayerFV,
Expand Down Expand Up @@ -478,9 +478,8 @@ def __init__(
self.tokenizer = self.get_tokenizer()
self.config = config_json
self.weight_group_size = weight_group_size
self.compressed=compressed
self.compressed = compressed
self.shark_model = self.compile(device=device)


def get_tokenizer(self):
kwargs = {}
Expand Down Expand Up @@ -678,18 +677,23 @@ def compile_lmhead(
hidden_states, dynamic_axes=[1]
)

module = torch_mlir.compile(
lmh,
(hidden_states,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
# module = torch_mlir.compile(
# lmh,
# (hidden_states,),
# torch_mlir.OutputType.LINALG_ON_TENSORS,
# use_tracing=False,
# verbose=False,
# )
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
# f_ = open(mlir_path, "wb")
# f_.write(bytecode)
# f_.close()
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/lmhead.mlir lmhead.mlir"
subprocess.check_call(command.split())
f_ = open(f"lmhead.mlir", "rb")
bytecode = f_.read()
f_.close()

shark_module = SharkInference(
Expand Down Expand Up @@ -721,18 +725,17 @@ def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None):
hidden_states, dynamic_axes=[1]
)

module = torch_mlir.compile(
fvn,
(hidden_states,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
# module = torch_mlir.compile(
# fvn,
# (hidden_states,),
# torch_mlir.OutputType.LINALG_ON_TENSORS,
# use_tracing=False,
# verbose=False,
# )
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/norm.mlir norm.mlir"
subprocess.check_call(command.split())
f_ = open(f"norm.mlir", "rb")
bytecode = f_.read()
f_.close()

shark_module = SharkInference(
Expand Down Expand Up @@ -763,18 +766,23 @@ def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None):
input_ids = torch_mlir.TensorPlaceholder.like(
input_ids, dynamic_axes=[1]
)
module = torch_mlir.compile(
fve,
(input_ids,),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
# module = torch_mlir.compile(
# fve,
# (input_ids,),
# torch_mlir.OutputType.LINALG_ON_TENSORS,
# use_tracing=False,
# verbose=False,
# )
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
# f_ = open(mlir_path, "wb")
# f_.write(bytecode)
# f_.close()
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/embedding.mlir embedding.mlir"
subprocess.check_call(command.split())
f_ = open(f"embedding.mlir", "rb")
bytecode = f_.read()
f_.close()

shark_module = SharkInference(
Expand Down Expand Up @@ -987,8 +995,6 @@ def compile_to_vmfb_one_model4(
f_.close()
mlirs.append(bytecode)



if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
Expand Down Expand Up @@ -1125,7 +1131,6 @@ def get_sharded_model(self, device="cpu", compressed=False):
)

if not compressed:

layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
Expand Down Expand Up @@ -1169,7 +1174,9 @@ def get_sharded_model(self, device="cpu", compressed=False):
return sharded_model

def compile(self, device="cpu"):
return self.get_sharded_model(device=device, compressed=self.compressed)
return self.get_sharded_model(
device=device, compressed=self.compressed
)

def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
Expand Down

0 comments on commit fd1c4db

Please sign in to comment.