Skip to content

Commit

Permalink
fix some mmap and vicuna bugs (#1576)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Jun 22, 2023
1 parent 045f2bb commit 0ca3b9f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
12 changes: 4 additions & 8 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,22 @@
vic = None
if not args.sharded:
first_vic_mlir_path = (
Path(f"first_vicuna_{args.precision}.mlir")
None
if args.first_vicuna_mlir_path is None
else Path(args.first_vicuna_mlir_path)
)
second_vic_mlir_path = (
Path(f"second_vicuna_{args.precision}.mlir")
None
if args.second_vicuna_mlir_path is None
else Path(args.second_vicuna_mlir_path)
)
first_vic_vmfb_path = (
Path(
f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
None
if args.first_vicuna_vmfb_path is None
else Path(args.first_vicuna_vmfb_path)
)
second_vic_vmfb_path = (
Path(
f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb"
)
None
if args.second_vicuna_vmfb_path is None
else Path(args.second_vicuna_vmfb_path)
)
Expand Down
36 changes: 25 additions & 11 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
first_vicuna_vmfb_path=None,
second_vicuna_vmfb_path=None,
load_mlir_from_shark_tank=True,
low_device_memory=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
Expand All @@ -43,6 +44,9 @@ def __init__(
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
self.low_device_memory = low_device_memory
self.first_vic = None
self.second_vic = None
if self.first_vicuna_mlir_path == None:
self.first_vicuna_mlir_path = self.get_model_path()
if self.second_vicuna_mlir_path == None:
Expand All @@ -61,7 +65,7 @@ def get_model_path(self, model_number="first", suffix="mlir"):
if suffix == "mlir":
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
return Path(
f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}"
f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}"
)

def get_tokenizer(self):
Expand All @@ -87,7 +91,7 @@ def compile_first_vicuna(self):
# Compilation path needs some more work before it is functional

print(
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n"
f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}"
)
if self.first_vicuna_mlir_path.exists():
Expand Down Expand Up @@ -436,12 +440,19 @@ def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc

if not self.low_device_memory:
if self.first_vic == None:
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res = []
res_tokens = []
params = {
"prompt": prompt,
"is_first": True,
"fv": self.compile_first_vicuna(),
"fv": self.compile_first_vicuna()
if self.first_vic == None
else self.first_vic,
}

generated_token_op = self.generate_new_token(params=params)
Expand All @@ -457,18 +468,20 @@ def generate(self, prompt, cli=False):
print(f"Assistant: {detok}", end=" ", flush=True)

# Clear First Vic from Memory (main and cuda)
del params
torch.cuda.empty_cache()
gc.collect()
if self.low_device_memory:
del params
torch.cuda.empty_cache()
gc.collect()

sec_vic = self.compile_second_vicuna()
for _ in range(self.max_num_tokens - 2):
params = {
"prompt": None,
"is_first": False,
"logits": logits,
"pkv": pkv,
"sv": sec_vic,
"sv": self.compile_second_vicuna()
if self.second_vic == None
else self.second_vic,
}

generated_token_op = self.generate_new_token(params=params)
Expand All @@ -489,9 +502,10 @@ def generate(self, prompt, cli=False):
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()

for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
Expand Down
4 changes: 4 additions & 0 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import re
import tempfile
from pathlib import Path


# Get the iree-compile arguments given device.
Expand Down Expand Up @@ -355,6 +356,9 @@ def load_vmfb_using_mmap(
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None

if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
Expand Down

0 comments on commit 0ca3b9f

Please sign in to comment.