Skip to content

Commit

Permalink
Merge branch 'main' into metal_for_shark
Browse files Browse the repository at this point in the history
  • Loading branch information
Ranvirsv committed Jun 21, 2023
2 parents 51f98b9 + 88cc242 commit 9556e69
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 30 deletions.
70 changes: 40 additions & 30 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from io import BytesIO
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, get_f16_inputs
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM

Expand Down Expand Up @@ -78,10 +78,10 @@ def compile_first_vicuna(self):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
if self.precision in ["fp32", "fp16"]:
# download MLIR from shark_tank for fp32/fp16
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/first_vicuna.mlir",
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
self.first_vicuna_mlir_path.absolute(),
single_file=True,
)
Expand All @@ -96,7 +96,7 @@ def compile_first_vicuna(self):
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
)

if not mlir_generated:
Expand Down Expand Up @@ -220,10 +220,10 @@ def compile_second_vicuna(self):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision == "fp32":
# download MLIR from shark_tank for fp32
if self.precision in ["fp32", "fp16"]:
# download MLIR from shark_tank for fp32/fp16
download_public_file(
"gs://shark_tank/vicuna/unsharded/mlir/second_vicuna.mlir",
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
self.second_vicuna_mlir_path.absolute(),
single_file=True,
)
Expand Down Expand Up @@ -253,9 +253,15 @@ def compile_second_vicuna(self):
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",
)
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 64,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
Expand Down Expand Up @@ -307,7 +313,7 @@ def remove_constant_dim(line):
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128xf32>"
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
Expand Down Expand Up @@ -365,41 +371,45 @@ def compile(self):
# download vmfbs for A100
if (
not self.first_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
and self.device in ["cuda", "cpu"]
and self.precision in ["fp32", "fp16"]
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/first_vicuna.vmfb",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
# combinations that are still in the works
if not (self.device == "cuda" and self.precision == "fp16"):
# Will generate vmfb on device
pass
else:
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device == "cuda"
and self.precision == "fp32"
and self.device in ["cuda", "cpu"]
and self.precision in ["fp32", "fp16"]
):
download_public_file(
"gs://shark_tank/vicuna/unsharded/second_vicuna.vmfb",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
# combinations that are still in the works
if not (self.device == "cuda" and self.precision == "fp16"):
# Will generate vmfb on device
pass
else:
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass

# get first vic
# fvic_shark_model = self.compile_first_vicuna()
# get second vic
# svic_shark_model = self.compile_second_vicuna()
# return tuple of shark_modules
# return fvic_shark_model, svic_shark_model
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
Expand Down
5 changes: 5 additions & 0 deletions apps/stable_diffusion/scripts/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@


def load_mlir_module():
if "upscaler" in args.hf_model_id:
is_upscaler = True
else:
is_upscaler = False
sd_model = SharkifyStableDiffusionModel(
args.hf_model_id,
args.ckpt_loc,
Expand All @@ -27,6 +31,7 @@ def load_mlir_module():
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
is_upscaler=is_upscaler,
use_tuned=False,
low_cpu_mem_usage=args.low_cpu_mem_usage,
return_mlir=True,
Expand Down

0 comments on commit 9556e69

Please sign in to comment.