Skip to content

Commit

Permalink
fix name check for file existence
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB committed Aug 5, 2023
1 parent fd1c4db commit 872bd72
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions shark/shark_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,20 @@ def download_public_file(
def check_dir_exists(model_name, frontend="torch", dynamic=""):
model_dir = os.path.join(WORKDIR, model_name)

# Remove the _tf keyword from end.
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
# Remove the _tf keyword from end only for non-SD models.
if not any(model in model_name for model in ["clip", "unet", "vae"]):
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]

model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"

if os.path.isdir(model_dir):
if (
os.path.isfile(
os.path.join(
model_dir,
model_name + dynamic + "_" + str(frontend) + ".mlir",
)
)
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
Expand Down

0 comments on commit 872bd72

Please sign in to comment.