Skip to content

Commit

Permalink
Fixing args and testing packages (Stability-AI#150)
Browse files Browse the repository at this point in the history
* Fixing args and testing packages

* fix/add test

* Fix show function

* Update test (may be a better way)

* formatting

* Bump python-version to 3.8
  • Loading branch information
palp authored Dec 14, 2022
1 parent c5d145b commit d8f140f
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 48 deletions.
51 changes: 29 additions & 22 deletions .github/workflows/publish-to-pypi.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This workflow will upload a Python Package using Twine when a release is created.
# For more information see:
# For more information see:
# * https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
# * https://github.com/pypa/gh-action-pypi-publish

Expand All @@ -13,27 +13,34 @@ on:

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
#cache: 'pip'
- name: Install build dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine build
- name: Build package
run: |
python -m build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
- uses: actions/checkout@v3
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
#cache: 'pip'
- name: Install build dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine build
- name: Build package
run: |
python -m build
- name: Install package
run: |
python setup.py install
- name: Test package
env:
STABILITY_KEY: ${{ secrets.STABILITY_KEY }}
run: |
python -m stability_sdk A beautiful painting of a happy robot
python -m stability_sdk.client A nice drawing of a horse
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='stability-sdk',
version='0.3.0',
version='0.3.1',
author='Wes Brown',
author_email='[email protected]',
maintainer='David Marx',
Expand Down
30 changes: 16 additions & 14 deletions src/stability_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from stability_sdk.client import (
StabilityInference,
process_artifacts_from_answers,
)
from stability_sdk.utils import (
SAMPLERS,
Expand Down Expand Up @@ -98,11 +99,11 @@
"--sampler",
"-A",
type=str,
default="k_lms",
help="[k_lms] (" + ", ".join(SAMPLERS.keys()) + ")",
default=None,
help="[auto] (" + ", ".join(SAMPLERS.keys()) + ")",
)
parser.add_argument(
"--steps", "-s", type=int, default=50, help="[50] number of steps"
"--steps", "-s", type=int, default=None, help="[auto] number of steps"
)
parser.add_argument("--seed", "-S", type=int, default=0, help="random seed to use")
parser.add_argument(
Expand Down Expand Up @@ -155,19 +156,20 @@
args.mask_image = Image.open(args.mask_image)

request = {
"height": cli_args.height,
"width": cli_args.width,
"start_schedule": cli_args.start_schedule,
"end_schedule": cli_args.end_schedule,
"cfg_scale": cli_args.cfg_scale,
"sampler": get_sampler_from_str(cli_args.sampler),
"steps": cli_args.steps,
"seed": cli_args.seed,
"samples": cli_args.num_samples,
"init_image": cli_args.init_image,
"mask_image": cli_args.mask_image,
"height": args.height,
"width": args.width,
"start_schedule": args.start_schedule,
"end_schedule": args.end_schedule,
"cfg_scale": args.cfg_scale,
"samples": args.num_samples,
"init_image": args.init_image,
"mask_image": args.mask_image,
}

if args.sampler:
request["sampler"] = get_sampler_from_str(args.sampler)
if args.seed and args.seed > 0:
request["seed"] = args.seed

stability_api = StabilityInference(
STABILITY_HOST, STABILITY_KEY, engine=args.engine, verbose=True
Expand Down
25 changes: 15 additions & 10 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def generate(
start_schedule: float = 1.0,
end_schedule: float = 0.01,
cfg_scale: float = 7.0,
sampler: generation.DiffusionSampler = generation.SAMPLER_K_LMS,
steps: int = 50,
sampler: generation.DiffusionSampler = None,
steps: Optional[int] = None,
seed: Union[Sequence[int], int] = 0,
samples: int = 1,
safety: bool = True,
Expand Down Expand Up @@ -233,7 +233,7 @@ def generate(
scaled_step=0,
sampler=generation.SamplerParameters(cfg_scale=cfg_scale),
)

# NB: Specifying schedule when there's no init image causes washed out results
if init_image is not None:
step_parameters['schedule'] = generation.ScheduleParameters(
Expand All @@ -245,7 +245,7 @@ def generate(
if mask_image is not None:
prompts += [image_to_prompt(mask_image, mask=True)]


if guidance_prompt:
if isinstance(guidance_prompt, str):
guidance_prompt = generation.Prompt(text=guidance_prompt)
Expand All @@ -254,7 +254,7 @@ def generate(
if guidance_strength == 0.0:
guidance_strength = None


# Build our CLIP parameters
if guidance_preset is not generation.GUIDANCE_PRESET_NONE:
# to do: make it so user can override this
Expand Down Expand Up @@ -282,8 +282,12 @@ def generate(
],
)

transform=None
if sampler:
transform=generation.TransformType(diffusion=sampler)

image_parameters=generation.ImageParameters(
transform=generation.TransformType(diffusion=sampler),
transform=transform,
height=height,
width=width,
seed=seed,
Expand All @@ -292,9 +296,10 @@ def generate(
parameters=[generation.StepParameter(**step_parameters)],
)


return self.emit_request(prompt=prompts, image_parameters=image_parameters)


# The motivation here is to facilitate constructing requests by passing protobuf objects directly.
def emit_request(
self,
Expand All @@ -307,14 +312,14 @@ def emit_request(
request_id = str(uuid.uuid4())
if not engine_id:
engine_id = self.engine

rq = generation.Request(
engine_id=engine_id,
request_id=request_id,
prompt=prompt,
image=image_parameters
)

if self.verbose:
logger.info("Sending request.")

Expand Down Expand Up @@ -357,7 +362,7 @@ def emit_request(
"[Deprecation Warning] instead do this:"
"[Deprecation Warning] $ python -m stability_sdk ... "
)

STABILITY_HOST = os.getenv("STABILITY_HOST", "grpc.stability.ai:443")
STABILITY_KEY = os.getenv("STABILITY_KEY", "")

Expand Down
4 changes: 3 additions & 1 deletion src/stability_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import stability_sdk.interfaces.gooseai.generation.generation_pb2_grpc as generation_grpc

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

SAMPLERS: Dict[str, int] = {
"ddim": generation.SAMPLER_DDIM,
"plms": generation.SAMPLER_DDPM,
Expand All @@ -27,7 +30,6 @@
"k_dpmpp_2m": generation.SAMPLER_K_DPMPP_2M,
"k_dpmpp_2s_ancestral": generation.SAMPLER_K_DPMPP_2S_ANCESTRAL
}


MAX_FILENAME_SZ = int(os.getenv("MAX_FILENAME_SZ", 200))

Expand Down

0 comments on commit d8f140f

Please sign in to comment.