diff --git a/src/mirrorsrun/config.py b/src/mirrorsrun/config.py index d59bc0f..a446145 100644 --- a/src/mirrorsrun/config.py +++ b/src/mirrorsrun/config.py @@ -4,7 +4,7 @@ RPC_SECRET = os.environ.get("RPC_SECRET", "") BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org") -SCHEME = os.environ.get("SCHEME", None) +SCHEME = os.environ.get("SCHEME", "http").lower() assert SCHEME in ["http", "https"] CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache/") diff --git a/src/mirrorsrun/docker_utils.py b/src/mirrorsrun/docker_utils.py new file mode 100644 index 0000000..80bdbd8 --- /dev/null +++ b/src/mirrorsrun/docker_utils.py @@ -0,0 +1,68 @@ +import base64 +import json +import re +import time +from typing import Dict +import httpx + + +class CachedToken: + token: str + exp: int + + def __init__(self, token, exp): + self.token = token + self.exp = exp + + +cached_tokens: Dict[str, CachedToken] = {} + + +# https://github.com/opencontainers/distribution-spec/blob/main/spec.md +name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*" +reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}" + + +def try_extract_image_name(path): + pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$" + match = re.search(pattern, path) + + if match: + assert len(match.groups()) == 3 + name, resource, reference = match.groups() + assert re.match(name_regex, name) + assert re.match(reference_regex, reference) + assert resource in ["manifests", "blobs", "tags"] + return name, resource, reference + + return None, None, None + + +def get_docker_token(name): + cached = cached_tokens.get(name, None) + if cached and cached.exp > time.time(): + return cached.token + + url = "https://auth.docker.io/token" + params = { + "scope": f"repository:{name}:pull", + "service": "registry.docker.io", + } + + client = httpx.Client() + response = client.get(url, params=params) + response.raise_for_status() + + token_data = response.json() + token = token_data["token"] + payload = token.split(".")[1] + padding = len(payload) % 4 + payload += "=" * padding + + payload = json.loads(base64.b64decode(payload)) + assert payload["iss"] == "auth.docker.io" + assert len(payload["access"]) > 0 + + cached_tokens[name] = CachedToken(exp=payload["exp"], token=token) + + return token diff --git a/src/mirrorsrun/proxy/direct.py b/src/mirrorsrun/proxy/direct.py index 60dba95..d6981cf 100644 --- a/src/mirrorsrun/proxy/direct.py +++ b/src/mirrorsrun/proxy/direct.py @@ -1,3 +1,4 @@ +import logging import typing from typing import Callable, Coroutine @@ -18,13 +19,46 @@ [Request, Response], Coroutine[Request, Response, Response] ] +PreProcessor = typing.Union[SyncPreProcessor, AsyncPreProcessor, None] +PostProcessor = typing.Union[SyncPostProcessor, AsyncPostProcessor, None] + +logger = logging.getLogger(__name__) + + +async def pre_process_request( + request: Request, + httpx_req: HttpxRequest, + pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, +): + if pre_process: + new_httpx_req = pre_process(request, httpx_req) + if isinstance(new_httpx_req, HttpxRequest): + httpx_req = new_httpx_req + else: + httpx_req = await new_httpx_req + return httpx_req + + +async def post_process_response( + request: Request, + response: Response, + post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, +): + if post_process: + new_res = post_process(request, response) + if isinstance(new_res, Response): + return new_res + elif isinstance(new_res, Coroutine): + return await new_res + else: + return response + async def direct_proxy( request: Request, target_url: str, pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, - cache_ttl: int = 3600, ) -> Response: # httpx will use the following environment variables to determine the proxy # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy @@ -40,12 +74,7 @@ async def direct_proxy( headers=req_headers, ) - if pre_process: - new_httpx_req = pre_process(request, httpx_req) - if isinstance(new_httpx_req, HttpxRequest): - httpx_req = new_httpx_req - else: - httpx_req = await new_httpx_req + httpx_req = await pre_process_request(request, httpx_req, pre_process) upstream_response = await client.send(httpx_req) @@ -54,6 +83,10 @@ async def direct_proxy( res_headers.pop("content-length", None) res_headers.pop("content-encoding", None) + logger.info( + f"proxy {request.url} to {target_url} {upstream_response.status_code}" + ) + content = upstream_response.content response = Response( headers=res_headers, @@ -61,13 +94,6 @@ async def direct_proxy( status_code=upstream_response.status_code, ) - if post_process: - new_res = post_process(request, response) - if isinstance(new_res, Response): - final_res = new_res - elif isinstance(new_res, Coroutine): - final_res = await new_res - else: - final_res = response + response = await post_process_response(request, response, post_process) - return final_res + return response diff --git a/src/mirrorsrun/proxy/file_cache.py b/src/mirrorsrun/proxy/file_cache.py index 187cac9..c2f0ba0 100644 --- a/src/mirrorsrun/proxy/file_cache.py +++ b/src/mirrorsrun/proxy/file_cache.py @@ -7,15 +7,12 @@ from urllib.parse import urlparse, quote import httpx +from mirrorsrun.aria2_api import add_download +from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2 from starlette.requests import Request from starlette.responses import Response from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT -from mirrorsrun.aria2_api import add_download - -from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2 -from typing import Optional, Callable - logger = logging.getLogger(__name__) @@ -76,14 +73,11 @@ async def try_file_based_cache( request: Request, target_url: str, download_wait_time: int = 60, - post_process: Optional[Callable[[Request, Response], Response]] = None, ) -> Response: cache_status = lookup_cache(target_url) if cache_status == DownloadingStatus.DOWNLOADED: - resp = make_cached_response(target_url) - if post_process: - resp = post_process(request, resp) - return resp + logger.info(f"Cache hit for {target_url}") + return make_cached_response(target_url) if cache_status == DownloadingStatus.DOWNLOADING: logger.info(f"Download is not finished, return 503 for {target_url}") @@ -95,14 +89,15 @@ async def try_file_based_cache( assert cache_status == DownloadingStatus.NOT_FOUND cache_file, cache_file_dir = get_cache_file_and_folder(target_url) - print("prepare to download", target_url, cache_file, cache_file_dir) + logger.info(f"prepare to cache, {target_url=} {cache_file=} {cache_file_dir=}") processed_url = quote(target_url, safe="/:?=&%") try: + logger.info(f"Start download {processed_url}") await add_download(processed_url, save_dir=cache_file_dir) except Exception as e: - logger.error(f"Download error, return 503500 for {target_url}", exc_info=e) + logger.error(f"Download error, return 500 for {target_url}", exc_info=e) return Response( content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR, @@ -113,6 +108,7 @@ async def try_file_based_cache( await sleep(1) cache_status = lookup_cache(target_url) if cache_status == DownloadingStatus.DOWNLOADED: + logger.info(f"Cache hit for {target_url}") return make_cached_response(target_url) logger.info(f"Download is not finished, return 503 for {target_url}") return Response( diff --git a/src/mirrorsrun/server.py b/src/mirrorsrun/server.py index 6f11357..c04633e 100644 --- a/src/mirrorsrun/server.py +++ b/src/mirrorsrun/server.py @@ -1,11 +1,13 @@ import os import sys -sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) # noqa: E402 import base64 import signal import urllib.parse from typing import Callable +import logging import httpx import uvicorn @@ -25,6 +27,19 @@ from mirrorsrun.sites.npm import npm from mirrorsrun.sites.pypi import pypi from mirrorsrun.sites.torch import torch +from mirrorsrun.sites.k8s import k8s + +subdomain_mapping = { + "pypi": pypi, + "torch": torch, + "docker": docker, + "npm": npm, + "k8s": k8s, +} + +logging.basicConfig(level=logging.INFO) + +logger = logging.getLogger(__name__) app = FastAPI() @@ -73,14 +88,10 @@ async def capture_request(request: Request, call_next: Callable): if hostname.startswith("aria2."): return await aria2(request, call_next) - if hostname.startswith("pypi."): - return await pypi(request) - if hostname.startswith("torch."): - return await torch(request) - if hostname.startswith("docker."): - return await docker(request) - if hostname.startswith("npm."): - return await npm(request) + subdomain = hostname.split(".")[0] + + if subdomain in subdomain_mapping: + return await subdomain_mapping[subdomain](request) return await call_next(request) @@ -88,10 +99,10 @@ async def capture_request(request: Request, call_next: Callable): if __name__ == "__main__": signal.signal(signal.SIGINT, signal.SIG_DFL) port = 80 - print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})") + logger.info(f"Server started at {SCHEME}://*.{BASE_DOMAIN})") - for dn in ["pypi", "torch", "docker", "npm"]: - print(f" - {SCHEME}://{dn}.{BASE_DOMAIN}") + for dn in subdomain_mapping.keys(): + logger.info(f" - {SCHEME}://{dn}.{BASE_DOMAIN}") aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode() @@ -106,14 +117,13 @@ async def capture_request(request: Request, call_next: Callable): query_string = urllib.parse.urlencode(params) aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string - print(f"Download manager (Aria2) at {aria2_url_with_auth}") - # FIXME: only proxy headers if SCHEME is https - # reload only in dev mode + logger.info(f"Download manager (Aria2) at {aria2_url_with_auth}") + uvicorn.run( app="server:app", host="0.0.0.0", port=port, - reload=True, - proxy_headers=True, + reload=True, # TODO: reload only in dev mode + proxy_headers=True, # trust x-forwarded-for etc. forwarded_allow_ips="*", ) diff --git a/src/mirrorsrun/sites/docker.py b/src/mirrorsrun/sites/docker.py index 69930ac..787bf6d 100644 --- a/src/mirrorsrun/sites/docker.py +++ b/src/mirrorsrun/sites/docker.py @@ -1,83 +1,19 @@ -import base64 -import json import logging -import re -import time -from typing import Dict import httpx from starlette.requests import Request from starlette.responses import Response +from mirrorsrun.docker_utils import get_docker_token from mirrorsrun.proxy.direct import direct_proxy from mirrorsrun.proxy.file_cache import try_file_based_cache +from mirrorsrun.sites.k8s import try_extract_image_name logger = logging.getLogger(__name__) BASE_URL = "https://registry-1.docker.io" -class CachedToken: - token: str - exp: int - - def __init__(self, token, exp): - self.token = token - self.exp = exp - - -cached_tokens: Dict[str, CachedToken] = {} - -# https://github.com/opencontainers/distribution-spec/blob/main/spec.md -name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*" -reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}" - - -def try_extract_image_name(path): - pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$" - match = re.search(pattern, path) - - if match: - assert len(match.groups()) == 3 - name, resource, reference = match.groups() - assert re.match(name_regex, name) - assert re.match(reference_regex, reference) - assert resource in ["manifests", "blobs", "tags"] - return name, resource, reference - - return None, None, None - - -def get_docker_token(name): - cached = cached_tokens.get(name, None) - if cached and cached.exp > time.time(): - return cached.token - - url = "https://auth.docker.io/token" - params = { - "scope": f"repository:{name}:pull", - "service": "registry.docker.io", - } - - client = httpx.Client() - response = client.get(url, params=params) - response.raise_for_status() - - token_data = response.json() - token = token_data["token"] - payload = token.split(".")[1] - padding = len(payload) % 4 - payload += "=" * padding - - payload = json.loads(base64.b64decode(payload)) - assert payload["iss"] == "auth.docker.io" - assert len(payload["access"]) > 0 - - cached_tokens[name] = CachedToken(exp=payload["exp"], token=token) - - return token - - def inject_token(name: str, req: Request, httpx_req: httpx.Request): docker_token = get_docker_token(f"{name}") httpx_req.headers["Authorization"] = f"Bearer {docker_token}" @@ -112,11 +48,13 @@ async def docker(request: Request): target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}" - logger.info(f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}") + logger.info( + f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}" + ) return await direct_proxy( request, target_url, pre_process=lambda req, http_req: inject_token(name, req, http_req), - post_process=post_process, + post_process=post_process, # cache in post_process ) diff --git a/src/mirrorsrun/sites/k8s.py b/src/mirrorsrun/sites/k8s.py new file mode 100644 index 0000000..965a4cb --- /dev/null +++ b/src/mirrorsrun/sites/k8s.py @@ -0,0 +1,50 @@ +import logging + +from starlette.requests import Request +from starlette.responses import Response + +from mirrorsrun.docker_utils import try_extract_image_name +from mirrorsrun.proxy.direct import direct_proxy +from mirrorsrun.proxy.file_cache import try_file_based_cache + +logger = logging.getLogger(__name__) + +BASE_URL = "https://registry.k8s.io" + + +async def post_process(request: Request, response: Response): + if response.status_code == 307: + location = response.headers["location"] + + if "/blobs/" in request.url.path: + return await try_file_based_cache(request, location) + + return await direct_proxy(request, location) + + return response + + +async def k8s(request: Request): + path = request.url.path + if not path.startswith("/v2/"): + return Response(content="Not Found", status_code=404) + + if path == "/v2/": + return Response(content="OK") + + name, resource, reference = try_extract_image_name(path) + + if not name: + return Response(content="404 Not Found", status_code=404) + + target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}" + + logger.info( + f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}" + ) + + return await direct_proxy( + request, + target_url, + post_process=post_process, + ) diff --git a/src/setup.cfg b/src/setup.cfg index bf750d2..27af692 100644 --- a/src/setup.cfg +++ b/src/setup.cfg @@ -1,2 +1,3 @@ [flake8] -max-line-length = 99 \ No newline at end of file +max-line-length = 99 +ignore = E402 \ No newline at end of file diff --git a/test/mirrors_test.py b/test/mirrors_test.py index 60ca621..36c37f6 100644 --- a/test/mirrors_test.py +++ b/test/mirrors_test.py @@ -16,5 +16,8 @@ def test_pypi_http(self): def test_torch_http(self): call(f"pip download -i {TORCH_INDEX} tqdm --trusted-host {TORCH_HOST} --dest /tmp/torch/") - def test_docker_pull(self): + def test_dockerhub_pull(self): call(f"docker pull docker.local.homeinfra.org/alpine:3.12") + + def test_k8s_pull(self): + call(f"docker pull k8s.local.homeinfra.org/pause:3.5")