diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index f688ea27..9ad67e93 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -1,5 +1,5 @@ # This workflow will upload a Python Package using Twine when a release is created. -# For more information see: +# For more information see: # * https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries # * https://github.com/pypa/gh-action-pypi-publish @@ -13,27 +13,34 @@ on: jobs: deploy: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - with: - submodules: recursive - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.7' - #cache: 'pip' - - name: Install build dependencies - run: | - python -m pip install --upgrade pip - pip install setuptools wheel twine build - - name: Build package - run: | - python -m build - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.PYPI_API_TOKEN }} + - uses: actions/checkout@v3 + with: + submodules: recursive + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.8" + #cache: 'pip' + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine build + - name: Build package + run: | + python -m build + - name: Install package + run: | + python setup.py install + - name: Test package + env: + STABILITY_KEY: ${{ secrets.STABILITY_KEY }} + run: | + python -m stability_sdk A beautiful painting of a happy robot + python -m stability_sdk.client A nice drawing of a horse + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/setup.py b/setup.py index 79087524..30d63aa8 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='stability-sdk', - version='0.3.0', + version='0.3.1', author='Wes Brown', author_email='wesbrown18@gmail.com', maintainer='David Marx', diff --git a/src/stability_sdk/__main__.py b/src/stability_sdk/__main__.py index 56ee4488..03ce53ae 100644 --- a/src/stability_sdk/__main__.py +++ b/src/stability_sdk/__main__.py @@ -34,6 +34,7 @@ from stability_sdk.client import ( StabilityInference, + process_artifacts_from_answers, ) from stability_sdk.utils import ( SAMPLERS, @@ -98,11 +99,11 @@ "--sampler", "-A", type=str, - default="k_lms", - help="[k_lms] (" + ", ".join(SAMPLERS.keys()) + ")", + default=None, + help="[auto] (" + ", ".join(SAMPLERS.keys()) + ")", ) parser.add_argument( - "--steps", "-s", type=int, default=50, help="[50] number of steps" + "--steps", "-s", type=int, default=None, help="[auto] number of steps" ) parser.add_argument("--seed", "-S", type=int, default=0, help="random seed to use") parser.add_argument( @@ -155,19 +156,20 @@ args.mask_image = Image.open(args.mask_image) request = { - "height": cli_args.height, - "width": cli_args.width, - "start_schedule": cli_args.start_schedule, - "end_schedule": cli_args.end_schedule, - "cfg_scale": cli_args.cfg_scale, - "sampler": get_sampler_from_str(cli_args.sampler), - "steps": cli_args.steps, - "seed": cli_args.seed, - "samples": cli_args.num_samples, - "init_image": cli_args.init_image, - "mask_image": cli_args.mask_image, + "height": args.height, + "width": args.width, + "start_schedule": args.start_schedule, + "end_schedule": args.end_schedule, + "cfg_scale": args.cfg_scale, + "samples": args.num_samples, + "init_image": args.init_image, + "mask_image": args.mask_image, } +if args.sampler: + request["sampler"] = get_sampler_from_str(args.sampler) +if args.seed and args.seed > 0: + request["seed"] = args.seed stability_api = StabilityInference( STABILITY_HOST, STABILITY_KEY, engine=args.engine, verbose=True diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index bd49ce39..28483853 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -168,8 +168,8 @@ def generate( start_schedule: float = 1.0, end_schedule: float = 0.01, cfg_scale: float = 7.0, - sampler: generation.DiffusionSampler = generation.SAMPLER_K_LMS, - steps: int = 50, + sampler: generation.DiffusionSampler = None, + steps: Optional[int] = None, seed: Union[Sequence[int], int] = 0, samples: int = 1, safety: bool = True, @@ -233,7 +233,7 @@ def generate( scaled_step=0, sampler=generation.SamplerParameters(cfg_scale=cfg_scale), ) - + # NB: Specifying schedule when there's no init image causes washed out results if init_image is not None: step_parameters['schedule'] = generation.ScheduleParameters( @@ -245,7 +245,7 @@ def generate( if mask_image is not None: prompts += [image_to_prompt(mask_image, mask=True)] - + if guidance_prompt: if isinstance(guidance_prompt, str): guidance_prompt = generation.Prompt(text=guidance_prompt) @@ -254,7 +254,7 @@ def generate( if guidance_strength == 0.0: guidance_strength = None - + # Build our CLIP parameters if guidance_preset is not generation.GUIDANCE_PRESET_NONE: # to do: make it so user can override this @@ -282,8 +282,12 @@ def generate( ], ) + transform=None + if sampler: + transform=generation.TransformType(diffusion=sampler) + image_parameters=generation.ImageParameters( - transform=generation.TransformType(diffusion=sampler), + transform=transform, height=height, width=width, seed=seed, @@ -292,9 +296,10 @@ def generate( parameters=[generation.StepParameter(**step_parameters)], ) + return self.emit_request(prompt=prompts, image_parameters=image_parameters) - + # The motivation here is to facilitate constructing requests by passing protobuf objects directly. def emit_request( self, @@ -307,14 +312,14 @@ def emit_request( request_id = str(uuid.uuid4()) if not engine_id: engine_id = self.engine - + rq = generation.Request( engine_id=engine_id, request_id=request_id, prompt=prompt, image=image_parameters ) - + if self.verbose: logger.info("Sending request.") @@ -357,7 +362,7 @@ def emit_request( "[Deprecation Warning] instead do this:" "[Deprecation Warning] $ python -m stability_sdk ... " ) - + STABILITY_HOST = os.getenv("STABILITY_HOST", "grpc.stability.ai:443") STABILITY_KEY = os.getenv("STABILITY_KEY", "") diff --git a/src/stability_sdk/utils.py b/src/stability_sdk/utils.py index a0dbc6bc..96b75a97 100644 --- a/src/stability_sdk/utils.py +++ b/src/stability_sdk/utils.py @@ -15,6 +15,9 @@ import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc +logger = logging.getLogger(__name__) +logger.setLevel(level=logging.INFO) + SAMPLERS: Dict[str, int] = { "ddim": generation.SAMPLER_DDIM, "plms": generation.SAMPLER_DDPM, @@ -27,7 +30,6 @@ "k_dpmpp_2m": generation.SAMPLER_K_DPMPP_2M, "k_dpmpp_2s_ancestral": generation.SAMPLER_K_DPMPP_2S_ANCESTRAL } - MAX_FILENAME_SZ = int(os.getenv("MAX_FILENAME_SZ", 200))