From 587c3ff1c67447b18709beace9ed7f6755a3bd79 Mon Sep 17 00:00:00 2001 From: Axel Peytavin Date: Tue, 6 Feb 2024 15:38:57 -0800 Subject: [PATCH] Refactor cache initialization and add configurable cache expiry time and frontend endpoint --- natural_frontend/cache.py | 4 +-- natural_frontend/natural_frontend.py | 41 +++++++++++++++++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/natural_frontend/cache.py b/natural_frontend/cache.py index 6b9cbfe..39740b0 100644 --- a/natural_frontend/cache.py +++ b/natural_frontend/cache.py @@ -3,14 +3,14 @@ import time class Cache: - def __init__(self, directory="cache"): + def __init__(self, directory, cache_expiry_time): self.directory = directory if not os.path.exists(directory): try: os.makedirs(directory) except: pass - self.expiry_time = 600 # 600 seconds cache expiration time + self.expiry_time = cache_expiry_time # 600 seconds cache expiration time def get_file_path(self, key): filename = f"{key}.json" diff --git a/natural_frontend/natural_frontend.py b/natural_frontend/natural_frontend.py index 18a2be0..2c9af18 100644 --- a/natural_frontend/natural_frontend.py +++ b/natural_frontend/natural_frontend.py @@ -6,6 +6,7 @@ from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi import Depends, FastAPI, Form, Request +from typing import Dict, List, Optional import importlib.resources as pkg_resources @@ -20,8 +21,6 @@ API_DOC_GEN_PROMPT = [] -ASK_ENDPOINT = "frontend" - logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) @@ -36,16 +35,17 @@ class NaturalFrontendOptions: def __init__( self, colors: Dict[str, str] = {"primary": "lightblue", "secondary": "purple"}, - personas: List[Dict[str, str]] = None, + personas: Optional[List[Dict[str, str]]] = None, + cache_expiry_time: int = 600, + frontend_endpoint: str = "frontend", ): # Check that colors is a dict with keys "primary" and "secondary" - if colors is not None: - if not isinstance(colors, dict): - raise TypeError("colors must be a dict") - if not "primary" in colors: - raise ValueError("colors must have a 'primary' key") - if not "secondary" in colors: - raise ValueError("colors must have a 'secondary' key") + if not isinstance(colors, dict): + raise TypeError("colors must be a dict") + if not "primary" in colors: + raise ValueError("colors must have a 'primary' key") + if not "secondary" in colors: + raise ValueError("colors must have a 'secondary' key") self.colors = colors @@ -65,15 +65,30 @@ def __init__( self.personas = personas + # Check that cache_expiry_time is an int + if not isinstance(cache_expiry_time, int): + raise TypeError("cache_expiry_time must be an int") + + self.cache_expiry_time = cache_expiry_time # 600 seconds cache expiration time + + # Check that frontend_endpoint is a string + if not isinstance(frontend_endpoint, str): + raise TypeError("frontend_endpoint must be a string") + + self.frontend_endpoint = frontend_endpoint + def NaturalFrontend( app: FastAPI, openai_api_key: str, options: NaturalFrontendOptions = NaturalFrontendOptions() ): app.mount("/static", StaticFiles(directory=str(static_directory)), name="static") + frontend_endpoint = options.frontend_endpoint + frontend_generator = FrontendGenerator(openai_api_key=openai_api_key) templates = Jinja2Templates(directory=str(template_directory)) + cache = Cache(directory=str(cache_directory), cache_expiry_time=options.cache_expiry_time) @app.on_event("startup") async def on_startup(): @@ -96,8 +111,8 @@ async def on_startup(): print("Natural Frontend was initiated successfully") - @app.get("/frontend/", response_class=HTMLResponse) - async def frontend(request: Request, cache: Cache = Depends()): + @app.get(f"/{frontend_endpoint}/", response_class=HTMLResponse) + async def frontend(request: Request): cache_key = "frontend_personas" # Try to get cached response @@ -183,7 +198,7 @@ def parse_potential_personas(personas: str, retries=5): }, ) - @app.post("/gen_frontend/", response_class=HTMLResponse) + @app.post(f"/gen_{frontend_endpoint}/", response_class=HTMLResponse) async def handle_form(persona: str = Form(...)): cache_key = f"html_frontend_{persona.split()[0]}" response_content = cache.get(cache_key)