Skip to content

Commit

Permalink
Rename to coglet plus reorg
Browse files Browse the repository at this point in the history
so that:

- `python -m coglet` is the main entrypoint
- `cog` compatibility imports are a thin wrapper around `coglet.api` and
  are only available for import once `coglet/_compat` is put into
`sys.path`, thus greatly reducing the likelihood of import collisions.
- `cog/internal/` directory collapsed up a level to `coglet/` since the
  above import collision problem is fixed.
  • Loading branch information
meatballhat committed Nov 17, 2024
1 parent 24c0f02 commit 610121a
Show file tree
Hide file tree
Showing 20 changed files with 152 additions and 98 deletions.
1 change: 0 additions & 1 deletion python/README.md

This file was deleted.

10 changes: 10 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Coglet

Coglet provides a minimum viable [Cog] runtime primarily for use within the Replicate
platform, e.g.:

```
python -m coglet --working-dir path/to/code/ --module-name predict --class-name Predictor
```

[Cog]: <https://github.com/replicate/cog>
Empty file removed python/cog/internal/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions python/coglet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pathlib
import sys
import warnings

# NOTE: The compatibility import provided in `./_compat/cog.py` **SHOULD NOT** be in
# PYTHONPATH until `coglet` is imported. This prevents `coglet` from interfering with
# normal usage of `cog` within a given python environment.
warnings.warn(
(
'coglet/_compat/ is being added to the front of sys.path '
"for 'cog' import compatibility"
),
category=ImportWarning,
stacklevel=2,
)
sys.path.insert(0, str(pathlib.Path(__file__).absolute().parent / '_compat'))
80 changes: 80 additions & 0 deletions python/coglet/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
import asyncio
import contextvars
import logging
import sys
from typing import Optional

from coglet import file_runner


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--working-dir', metavar='DIR', required=True, help='working directory'
)
parser.add_argument(
'--module-name', metavar='NAME', required=True, help='Python module name'
)
parser.add_argument(
'--class-name', metavar='NAME', required=True, help='Python class name'
)

_ctx_pid: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
'pid', default=None
)
_ctx_newline: contextvars.ContextVar[bool] = contextvars.ContextVar(
'newline', default=False
)

logger = logging.getLogger('coglet')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
'%(asctime)s\t%(levelname)s\t[%(name)s]\t%(filename)s:%(lineno)d\t%(message)s'
)
)
logger.addHandler(handler)

_stdout_write = sys.stdout.write
_stderr_write = sys.stderr.write

def _ctx_write(write_fn):
def _write(s: str) -> int:
pid = _ctx_pid.get()
if pid is None:
return write_fn(s)
else:
n = 0
if _ctx_newline.get():
n += write_fn(f'[pid={pid}] ')
if s[-1] == '\n':
_ctx_newline.set(True)
s = s[:-1].replace('\n', f'\n[pid={pid}] ') + '\n'
else:
_ctx_newline.set(False)
s = s.replace('\n', f'\n[pid={pid}] ')
n += write_fn(s)
return n

return _write

sys.stdout.write = _ctx_write(_stdout_write) # type: ignore
sys.stderr.write = _ctx_write(_stderr_write) # type: ignore

args = parser.parse_args()

return asyncio.run(
file_runner.FileRunner(
logger=logger,
working_dir=args.working_dir,
module_name=args.module_name,
class_name=args.class_name,
ctx_pid=_ctx_pid,
).start()
)


if __name__ == '__main__':
sys.exit(main())
9 changes: 8 additions & 1 deletion python/cog/__init__.py → python/coglet/_compat/cog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from cog.api import BaseModel, BasePredictor, ConcatenateIterator, Input, Path, Secret
from coglet.api import (
BaseModel,
BasePredictor,
ConcatenateIterator,
Input,
Path,
Secret,
)

__all__ = [
'BaseModel',
Expand Down
8 changes: 4 additions & 4 deletions python/cog/internal/adt.py → python/coglet/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Union

import cog
import coglet.api


class Type(Enum):
Expand Down Expand Up @@ -41,8 +41,8 @@ class Kind(Enum):
float: Type.FLOAT,
int: Type.INTEGER,
str: Type.STRING,
cog.Path: Type.PATH,
cog.Secret: Type.SECRET,
coglet.api.Path: Type.PATH,
coglet.api.Secret: Type.SECRET,
}

# Cog types to JSON types
Expand All @@ -68,7 +68,7 @@ class Kind(Enum):
CONTAINER_TO_COG = {
list: Kind.LIST,
typing.get_origin(Iterator): Kind.ITERATOR,
cog.ConcatenateIterator: Kind.CONCAT_ITERATOR,
coglet.api.ConcatenateIterator: Kind.CONCAT_ITERATOR,
}


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import asyncio
import contextvars
import json
Expand All @@ -9,7 +8,7 @@
import sys
from typing import Any, Dict, Optional

from cog.internal import inspector, runner, schemas, util
from coglet import inspector, runner, schemas, util


class FileRunner:
Expand All @@ -26,16 +25,19 @@ class FileRunner:

def __init__(
self,
*,
logger: logging.Logger,
working_dir: str,
module_name: str,
class_name: str,
ctx_pid: contextvars.ContextVar[Optional[str]],
):
self.logger = logger
self.working_dir = working_dir
self.module_name = module_name
self.class_name = class_name
self.runner: Optional[runner.Runner] = None
self.ctx_pid = ctx_pid
self.isatty = sys.stdout.isatty()

async def start(self) -> int:
Expand Down Expand Up @@ -147,7 +149,7 @@ async def start(self) -> int:

async def _predict(self, pid: str, req: Dict[str, Any]) -> None:
assert self.runner is not None
_ctx_pid.set(pid)
self.ctx_pid.set(pid)
resp: Dict[str, Any] = {
'started_at': util.now_iso(),
'status': 'starting',
Expand Down Expand Up @@ -193,63 +195,3 @@ def _respond(
def _signal(self, signum: int) -> None:
if not self.isatty:
os.kill(os.getppid(), signum)


parser = argparse.ArgumentParser()
parser.add_argument(
'--working-dir', metavar='DIR', required=True, help='working directory'
)
parser.add_argument(
'--module-name', metavar='NAME', required=True, help='Python module name'
)
parser.add_argument(
'--class-name', metavar='NAME', required=True, help='Python class name'
)

_ctx_pid: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
'pid', default=None
)
_ctx_newline: contextvars.ContextVar[bool] = contextvars.ContextVar(
'newline', default=False
)

if __name__ == '__main__':
logger = logging.getLogger('cog-file-runner')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
'%(asctime)s\t%(levelname)s\t[%(name)s]\t%(filename)s:%(lineno)d\t%(message)s'
)
)
logger.addHandler(handler)

_stdout_write = sys.stdout.write
_stderr_write = sys.stderr.write

def _ctx_write(write_fn):
def _write(s: str) -> int:
pid = _ctx_pid.get()
if pid is None:
return write_fn(s)
else:
n = 0
if _ctx_newline.get():
n += write_fn(f'[pid={pid}] ')
if s[-1] == '\n':
_ctx_newline.set(True)
s = s[:-1].replace('\n', f'\n[pid={pid}] ') + '\n'
else:
_ctx_newline.set(False)
s = s.replace('\n', f'\n[pid={pid}] ')
n += write_fn(s)
return n

return _write

sys.stdout.write = _ctx_write(_stdout_write) # type: ignore
sys.stderr.write = _ctx_write(_stderr_write) # type: ignore

args = parser.parse_args()
fr = FileRunner(logger, args.working_dir, args.module_name, args.class_name)
sys.exit(asyncio.run(fr.start()))
11 changes: 5 additions & 6 deletions python/cog/internal/inspector.py → python/coglet/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import typing
from typing import Callable, Optional

import cog
from cog.internal import adt, util
from coglet import adt, api, util


def _check_parent(child: type, parent: type) -> bool:
Expand Down Expand Up @@ -38,7 +37,7 @@ def _validate_predict(f: Callable) -> None:


def _validate_input(
name: str, cog_t: adt.Type, is_list: bool, cog_in: cog.Input
name: str, cog_t: adt.Type, is_list: bool, cog_in: api.Input
) -> None:
defaults = []
if cog_in.default is not None:
Expand Down Expand Up @@ -102,7 +101,7 @@ def _validate_input(


def _input_adt(
order: int, name: str, tpe: type, cog_in: Optional[cog.Input]
order: int, name: str, tpe: type, cog_in: Optional[api.Input]
) -> adt.Input:
cog_t, is_list = util.check_cog_type(tpe)
assert cog_t is not None, f'unsupported input type for {name}'
Expand Down Expand Up @@ -139,7 +138,7 @@ def _input_adt(


def _output_adt(tpe: type) -> adt.Output:
if inspect.isclass(tpe) and _check_parent(tpe, cog.BaseModel):
if inspect.isclass(tpe) and _check_parent(tpe, api.BaseModel):
assert tpe.__name__ == 'Output', 'output type must be named Output'
fields = {}
for name, t in tpe.__annotations__.items():
Expand Down Expand Up @@ -183,7 +182,7 @@ def create_predictor(module_name: str, class_name: str) -> adt.Predictor:
cls = getattr(module, class_name)
assert inspect.isclass(cls), f'not a class: {fullname}'
assert _check_parent(
cls, cog.BasePredictor
cls, api.BasePredictor
), f'predictor {fullname} does not inherit cog.BasePredictor'

assert hasattr(cls, 'setup'), f'setup method not found: {fullname}'
Expand Down
File renamed without changes.
7 changes: 3 additions & 4 deletions python/cog/internal/runner.py → python/coglet/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import re
from typing import Any, AsyncGenerator, Dict

import cog
from cog.internal import adt, util
from coglet import adt, api, util


def _kwargs(adt_ins: Dict[str, adt.Input], inputs: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -96,8 +95,8 @@ async def setup(self) -> None:
kwargs['weights'] = url
self.predictor.setup(weights=url)
elif os.path.exists(path):
kwargs['weights'] = cog.Path(path)
self.predictor.setup(weights=cog.Path(path))
kwargs['weights'] = api.Path(path)
self.predictor.setup(weights=api.Path(path))
else:
kwargs['weights'] = None
if inspect.iscoroutinefunction(self.predictor.setup):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path
from typing import Any, Dict, Optional, Union

from cog.internal import adt, util
from coglet import adt, util


def _from_json_type(prop: Dict[str, Any]) -> adt.Type:
Expand Down
7 changes: 3 additions & 4 deletions python/cog/internal/util.py → python/coglet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from datetime import datetime, timezone
from typing import Any, Tuple

import cog
from cog.internal import adt
from coglet import adt, api


def check_cog_type(tpe: type) -> Tuple[adt.Type, bool]:
Expand Down Expand Up @@ -47,9 +46,9 @@ def normalize_value(expected: adt.Type, value: Any) -> Any:
if expected is adt.Type.FLOAT:
return float(value)
elif expected is adt.Type.PATH:
return cog.Path(value) if type(value) is str else value
return api.Path(value) if type(value) is str else value
elif expected is adt.Type.SECRET:
return cog.Secret(value) if type(value) is str else value
return api.Secret(value) if type(value) is str else value
else:
return value

Expand Down
7 changes: 5 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "coggo"
name = "coglet"
version = "0.1.0"
description = "Add your description here"
description = "Minimum viable Cog runtime"
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
Expand Down Expand Up @@ -31,3 +31,6 @@ build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
filterwarnings = [
"ignore::ImportWarning",
]
4 changes: 2 additions & 2 deletions python/tests/test_file_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from typing import Dict, List, Optional

from cog.internal.file_runner import FileRunner
from coglet.file_runner import FileRunner


def setup_signals() -> List[int]:
Expand All @@ -28,7 +28,7 @@ def file_runner(
cmd = [
sys.executable,
'-m',
'cog.internal.file_runner',
'coglet',
'--working-dir',
tmp_path,
'--module-name',
Expand Down
Loading

0 comments on commit 610121a

Please sign in to comment.