diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 73a8054955..633d752ea1 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -111,22 +111,20 @@ def download_public_file( def check_dir_exists(model_name, frontend="torch", dynamic=""): model_dir = os.path.join(WORKDIR, model_name) - # Remove the _tf keyword from end. - if frontend in ["tf", "tensorflow"]: - model_name = model_name[:-3] - elif frontend in ["tflite"]: - model_name = model_name[:-7] - elif frontend in ["torch", "pytorch"]: - model_name = model_name[:-6] + # Remove the _tf keyword from end only for non-SD models. + if not any(model in model_name for model in ["clip", "unet", "vae"]): + if frontend in ["tf", "tensorflow"]: + model_name = model_name[:-3] + elif frontend in ["tflite"]: + model_name = model_name[:-7] + elif frontend in ["torch", "pytorch"]: + model_name = model_name[:-6] + + model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir" if os.path.isdir(model_dir): if ( - os.path.isfile( - os.path.join( - model_dir, - model_name + dynamic + "_" + str(frontend) + ".mlir", - ) - ) + os.path.isfile(os.path.join(model_dir, model_mlir_file_name)) and os.path.isfile(os.path.join(model_dir, "function_name.npy")) and os.path.isfile(os.path.join(model_dir, "inputs.npz")) and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))