diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index b679586301..5f7ac0bc65 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -2,6 +2,7 @@ import os import sys import logging +import warnings import apps.stable_diffusion.web.utils.app as app if sys.platform == "darwin": @@ -22,6 +23,34 @@ clear_all() +# This function is intended to clean up MEI folders +def cleanup_mei_folders(): + # Determine the operating system + if sys.platform.startswith("win"): + temp_dir = os.path.join(os.environ["LOCALAPPDATA"], "Temp") + + # For potential extension to support Linux or macOS systems: + # NOTE: Before enabling, ensure compatibility and testing. + # elif sys.platform.startswith('linux') or sys.platform == 'darwin': + # temp_dir = '/tmp' + + else: + warnings.warn( + "Temporary files weren't deleted due to an unsupported OS;" + + " program functionality is unaffected." + ) + return + + prefix = "_MEI" + + # Iterate through the items in the temporary directory + for item in os.listdir(temp_dir): + if item.startswith(prefix): + path = os.path.join(temp_dir, item) + if os.path.isdir(path): + shutil.rmtree(path, ignore_errors=True) + + if __name__ == "__main__": if args.debug: logging.basicConfig(level=logging.DEBUG) @@ -478,3 +507,4 @@ def register_outputgallery_sendto_editor_button( server_port=actual_port, favicon_path=nodicon_loc, ) + cleanup_mei_folders() diff --git a/shark/tests/test_index.py b/shark/tests/test_index.py new file mode 100644 index 0000000000..50e87fa359 --- /dev/null +++ b/shark/tests/test_index.py @@ -0,0 +1,34 @@ +import os +import pytest +import tempfile +from unittest import mock +from apps.stable_diffusion.web.index import cleanup_mei_folders + + +# Test for removing temporary _MEI folders on windows +def test_cleanup_mei_folders_windows(): + # Setting up the test environment for Windows + with mock.patch("sys.platform", "win32"): + with tempfile.TemporaryDirectory() as temp_dir: + temp_temp_dir = os.path.join(temp_dir, "Temp") + os.makedirs(temp_temp_dir) + + # Creating a fictitious _MEI directory + with mock.patch.dict("os.environ", {"LOCALAPPDATA": temp_dir}): + mei_folder = os.path.join(temp_temp_dir, "_MEI12345") + os.makedirs(mei_folder) + + cleanup_mei_folders() + assert not os.path.exists(mei_folder) + + +# Test for removing temporary folders at unsupported OS +def test_cleanup_mei_folders_unsupported_os(): + with mock.patch("sys.platform", "unsupported_os"): + with pytest.warns(UserWarning) as record: + cleanup_mei_folders() + + assert ( + "Temporary files weren't deleted due to an unsupported OS" + in str(record.list[0].message) + )