Skip to content

Commit

Permalink
Refactor cache initialization and add configurable cache expiry time …
Browse files Browse the repository at this point in the history
…and frontend endpoint
  • Loading branch information
axelpey authored and = committed Feb 6, 2024
1 parent 82ce9f8 commit 587c3ff
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
4 changes: 2 additions & 2 deletions natural_frontend/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import time

class Cache:
def __init__(self, directory="cache"):
def __init__(self, directory, cache_expiry_time):
self.directory = directory
if not os.path.exists(directory):
try:
os.makedirs(directory)
except:
pass
self.expiry_time = 600 # 600 seconds cache expiration time
self.expiry_time = cache_expiry_time # 600 seconds cache expiration time

def get_file_path(self, key):
filename = f"{key}.json"
Expand Down
41 changes: 28 additions & 13 deletions natural_frontend/natural_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi import Depends, FastAPI, Form, Request
from typing import Dict, List, Optional

import importlib.resources as pkg_resources

Expand All @@ -20,8 +21,6 @@

API_DOC_GEN_PROMPT = []

ASK_ENDPOINT = "frontend"

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
Expand All @@ -36,16 +35,17 @@ class NaturalFrontendOptions:
def __init__(
self,
colors: Dict[str, str] = {"primary": "lightblue", "secondary": "purple"},
personas: List[Dict[str, str]] = None,
personas: Optional[List[Dict[str, str]]] = None,
cache_expiry_time: int = 600,
frontend_endpoint: str = "frontend",
):
# Check that colors is a dict with keys "primary" and "secondary"
if colors is not None:
if not isinstance(colors, dict):
raise TypeError("colors must be a dict")
if not "primary" in colors:
raise ValueError("colors must have a 'primary' key")
if not "secondary" in colors:
raise ValueError("colors must have a 'secondary' key")
if not isinstance(colors, dict):
raise TypeError("colors must be a dict")
if not "primary" in colors:
raise ValueError("colors must have a 'primary' key")
if not "secondary" in colors:
raise ValueError("colors must have a 'secondary' key")

self.colors = colors

Expand All @@ -65,15 +65,30 @@ def __init__(

self.personas = personas

# Check that cache_expiry_time is an int
if not isinstance(cache_expiry_time, int):
raise TypeError("cache_expiry_time must be an int")

self.cache_expiry_time = cache_expiry_time # 600 seconds cache expiration time

# Check that frontend_endpoint is a string
if not isinstance(frontend_endpoint, str):
raise TypeError("frontend_endpoint must be a string")

self.frontend_endpoint = frontend_endpoint


def NaturalFrontend(
app: FastAPI, openai_api_key: str, options: NaturalFrontendOptions = NaturalFrontendOptions()
):
app.mount("/static", StaticFiles(directory=str(static_directory)), name="static")

frontend_endpoint = options.frontend_endpoint

frontend_generator = FrontendGenerator(openai_api_key=openai_api_key)

templates = Jinja2Templates(directory=str(template_directory))
cache = Cache(directory=str(cache_directory), cache_expiry_time=options.cache_expiry_time)

@app.on_event("startup")
async def on_startup():
Expand All @@ -96,8 +111,8 @@ async def on_startup():

print("Natural Frontend was initiated successfully")

@app.get("/frontend/", response_class=HTMLResponse)
async def frontend(request: Request, cache: Cache = Depends()):
@app.get(f"/{frontend_endpoint}/", response_class=HTMLResponse)
async def frontend(request: Request):
cache_key = "frontend_personas"

# Try to get cached response
Expand Down Expand Up @@ -183,7 +198,7 @@ def parse_potential_personas(personas: str, retries=5):
},
)

@app.post("/gen_frontend/", response_class=HTMLResponse)
@app.post(f"/gen_{frontend_endpoint}/", response_class=HTMLResponse)
async def handle_form(persona: str = Form(...)):
cache_key = f"html_frontend_{persona.split()[0]}"
response_content = cache.get(cache_key)
Expand Down

0 comments on commit 587c3ff

Please sign in to comment.