diff --git a/examples/05_stable_diffusion/scripts/compile.py b/examples/05_stable_diffusion/scripts/compile.py index 8c7a5be98..4a38d3bc4 100644 --- a/examples/05_stable_diffusion/scripts/compile.py +++ b/examples/05_stable_diffusion/scripts/compile.py @@ -19,9 +19,13 @@ import torch from aitemplate.testing import detect_target +from aitemplate.utils.import_path import import_parent from diffusers import StableDiffusionPipeline +if __name__ == "__main__": + import_parent(filepath=__file__, level=1) + from src.compile_lib.compile_clip import compile_clip from src.compile_lib.compile_unet import compile_unet from src.compile_lib.compile_vae import compile_vae diff --git a/examples/05_stable_diffusion/scripts/demo.py b/examples/05_stable_diffusion/scripts/demo.py index 77d58cde2..d4f5dbb99 100644 --- a/examples/05_stable_diffusion/scripts/demo.py +++ b/examples/05_stable_diffusion/scripts/demo.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import click import torch from aitemplate.testing.benchmark_pt import benchmark_torch_function +from aitemplate.utils.import_path import import_parent from diffusers import EulerDiscreteScheduler + +if __name__ == "__main__": + import_parent(filepath=__file__, level=1) + from src.pipeline_stable_diffusion_ait import StableDiffusionAITPipeline diff --git a/examples/05_stable_diffusion/scripts/demo_img2img.py b/examples/05_stable_diffusion/scripts/demo_img2img.py index 46c53cfd9..e4d96d865 100644 --- a/examples/05_stable_diffusion/scripts/demo_img2img.py +++ b/examples/05_stable_diffusion/scripts/demo_img2img.py @@ -19,7 +19,12 @@ import torch from aitemplate.testing.benchmark_pt import benchmark_torch_function +from aitemplate.utils.import_path import import_parent from PIL import Image + +if __name__ == "__main__": + import_parent(filepath=__file__, level=1) + from src.pipeline_stable_diffusion_img2img_ait import StableDiffusionImg2ImgAITPipeline diff --git a/python/aitemplate/utils/__init__.py b/python/aitemplate/utils/__init__.py index 44c1a6b98..6f57327ed 100644 --- a/python/aitemplate/utils/__init__.py +++ b/python/aitemplate/utils/__init__.py @@ -18,6 +18,7 @@ from . import ( alignment, graph_utils, + import_path, markdown_table, misc, shape_utils, diff --git a/python/aitemplate/utils/import_path.py b/python/aitemplate/utils/import_path.py new file mode 100644 index 000000000..caaccd9f2 --- /dev/null +++ b/python/aitemplate/utils/import_path.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + + +def import_parent(filepath: str, level: int) -> None: + r_filepath = Path(filepath).resolve() + parent, top = r_filepath.parent, r_filepath.parents[level] + + sys.path.append(str(top)) + try: + sys.path.remove(str(parent)) + except ValueError: # Already removed + pass diff --git a/python/setup.py b/python/setup.py index df01212e3..53eaa8063 100644 --- a/python/setup.py +++ b/python/setup.py @@ -79,7 +79,11 @@ def gen_cutlass_list(): "aitemplate/3rdparty/cutlass/examples", "aitemplate/3rdparty/cutlass/tools/util/include", ] - f_cond = lambda x: True if x.endswith(".h") or x.endswith(".cuh") else False + f_cond = ( + lambda x: True + if x.endswith(".h") or x.endswith(".cuh") or x.endswith(".hpp") + else False + ) return gen_file_list(srcs, f_cond) @@ -128,7 +132,7 @@ def gen_utils_file_list(): def gen_backend_common_file_list(): - srcs = ["aitemplate/backend/common"] + srcs = ["aitemplate/backend"] f_cond = lambda x: True if x.endswith(".py") or x.endswith(".cuh") else False return gen_file_list(srcs, f_cond)