Skip to content

Commit

Permalink
feat: add k8s
Browse files Browse the repository at this point in the history
  • Loading branch information
Anonymous committed Jun 11, 2024
1 parent b6a13cf commit d06c749
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/mirrorsrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
RPC_SECRET = os.environ.get("RPC_SECRET", "")
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")

SCHEME = os.environ.get("SCHEME", None)
SCHEME = os.environ.get("SCHEME", "http").lower()
assert SCHEME in ["http", "https"]

CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache/")
Expand Down
68 changes: 68 additions & 0 deletions src/mirrorsrun/docker_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import base64
import json
import re
import time
from typing import Dict
import httpx


class CachedToken:
token: str
exp: int

def __init__(self, token, exp):
self.token = token
self.exp = exp


cached_tokens: Dict[str, CachedToken] = {}


# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"


def try_extract_image_name(path):
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
match = re.search(pattern, path)

if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference

return None, None, None


def get_docker_token(name):
cached = cached_tokens.get(name, None)
if cached and cached.exp > time.time():
return cached.token

url = "https://auth.docker.io/token"
params = {
"scope": f"repository:{name}:pull",
"service": "registry.docker.io",
}

client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()

token_data = response.json()
token = token_data["token"]
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding

payload = json.loads(base64.b64decode(payload))
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0

cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)

return token
58 changes: 42 additions & 16 deletions src/mirrorsrun/proxy/direct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing
from typing import Callable, Coroutine

Expand All @@ -18,13 +19,46 @@
[Request, Response], Coroutine[Request, Response, Response]
]

PreProcessor = typing.Union[SyncPreProcessor, AsyncPreProcessor, None]
PostProcessor = typing.Union[SyncPostProcessor, AsyncPostProcessor, None]

logger = logging.getLogger(__name__)


async def pre_process_request(
request: Request,
httpx_req: HttpxRequest,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
):
if pre_process:
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
return httpx_req


async def post_process_response(
request: Request,
response: Response,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
):
if post_process:
new_res = post_process(request, response)
if isinstance(new_res, Response):
return new_res
elif isinstance(new_res, Coroutine):
return await new_res
else:
return response


async def direct_proxy(
request: Request,
target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
) -> Response:
# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
Expand All @@ -40,12 +74,7 @@ async def direct_proxy(
headers=req_headers,
)

if pre_process:
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
httpx_req = await pre_process_request(request, httpx_req, pre_process)

upstream_response = await client.send(httpx_req)

Expand All @@ -54,20 +83,17 @@ async def direct_proxy(
res_headers.pop("content-length", None)
res_headers.pop("content-encoding", None)

logger.info(
f"proxy {request.url} to {target_url} {upstream_response.status_code}"
)

content = upstream_response.content
response = Response(
headers=res_headers,
content=content,
status_code=upstream_response.status_code,
)

if post_process:
new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
response = await post_process_response(request, response, post_process)

return final_res
return response
20 changes: 8 additions & 12 deletions src/mirrorsrun/proxy/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
from urllib.parse import urlparse, quote

import httpx
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT

from mirrorsrun.aria2_api import add_download

from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from typing import Optional, Callable

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -76,14 +73,11 @@ async def try_file_based_cache(
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> Response:
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url)
if post_process:
resp = post_process(request, resp)
return resp
logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url)

if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}")
Expand All @@ -95,14 +89,15 @@ async def try_file_based_cache(
assert cache_status == DownloadingStatus.NOT_FOUND

cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir)
logger.info(f"prepare to cache, {target_url=} {cache_file=} {cache_file_dir=}")

processed_url = quote(target_url, safe="/:?=&%")

try:
logger.info(f"Start download {processed_url}")
await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
logger.error(f"Download error, return 500 for {target_url}", exc_info=e)
return Response(
content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
Expand All @@ -113,6 +108,7 @@ async def try_file_based_cache(
await sleep(1)
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(
Expand Down
44 changes: 27 additions & 17 deletions src/mirrorsrun/server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

sys.path.append(os.path.dirname(os.path.dirname(__file__))) # noqa: E402

import base64
import signal
import urllib.parse
from typing import Callable
import logging

import httpx
import uvicorn
Expand All @@ -25,6 +27,19 @@
from mirrorsrun.sites.npm import npm
from mirrorsrun.sites.pypi import pypi
from mirrorsrun.sites.torch import torch
from mirrorsrun.sites.k8s import k8s

subdomain_mapping = {
"pypi": pypi,
"torch": torch,
"docker": docker,
"npm": npm,
"k8s": k8s,
}

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)

app = FastAPI()

Expand Down Expand Up @@ -73,25 +88,21 @@ async def capture_request(request: Request, call_next: Callable):
if hostname.startswith("aria2."):
return await aria2(request, call_next)

if hostname.startswith("pypi."):
return await pypi(request)
if hostname.startswith("torch."):
return await torch(request)
if hostname.startswith("docker."):
return await docker(request)
if hostname.startswith("npm."):
return await npm(request)
subdomain = hostname.split(".")[0]

if subdomain in subdomain_mapping:
return await subdomain_mapping[subdomain](request)

return await call_next(request)


if __name__ == "__main__":
signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
logger.info(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")

for dn in ["pypi", "torch", "docker", "npm"]:
print(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")
for dn in subdomain_mapping.keys():
logger.info(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")

aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()

Expand All @@ -106,14 +117,13 @@ async def capture_request(request: Request, call_next: Callable):
query_string = urllib.parse.urlencode(params)
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string

print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHEME is https
# reload only in dev mode
logger.info(f"Download manager (Aria2) at {aria2_url_with_auth}")

uvicorn.run(
app="server:app",
host="0.0.0.0",
port=port,
reload=True,
proxy_headers=True,
reload=True, # TODO: reload only in dev mode
proxy_headers=True, # trust x-forwarded-for etc.
forwarded_allow_ips="*",
)
Loading

0 comments on commit d06c749

Please sign in to comment.