Skip to content

Commit

Permalink
Fix save flow for anonymous users
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Sep 16, 2024
1 parent 00cd4f5 commit 231e9ef
Showing 1 changed file with 88 additions and 6 deletions.
94 changes: 88 additions & 6 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import binascii
import datetime
import hashlib
import html
Expand All @@ -21,6 +23,7 @@
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
from loguru import logger
from pydantic import BaseModel, Field, ValidationError
from sentry_sdk.tracing import (
TRANSACTION_SOURCE_ROUTE,
Expand Down Expand Up @@ -99,6 +102,8 @@
gooey_rng = Random()

SUBMIT_AFTER_LOGIN_Q = "submitafterlogin"
SAVE_AFTER_LOGIN_Q = "saveafterlogin"
REQUEST_STATE_Q = "requeststate"


class RecipeRunState(Enum):
Expand Down Expand Up @@ -325,6 +330,17 @@ def sentry_event_set_user(self, event, hint):
}
return event

def _load_state_from_query_params(self):
request_state = self.request.query_params.get(REQUEST_STATE_Q)
if not request_state:
return

try:
request_state = base64.urlsafe_b64decode(request_state).decode()
gui.session_state.update(json.loads(request_state))
except (json.JSONDecodeError, binascii.Error) as e:
logger.warning(f"Failed to load request state from query params: {e}")

def refresh_state(self):
sr = self.current_sr
channel = self.realtime_channel_name(sr.run_id, sr.uid)
Expand All @@ -334,6 +350,7 @@ def refresh_state(self):

def render(self):
self.setup_sentry()
self._load_state_from_query_params()

if self.get_run_state(gui.session_state) == RecipeRunState.running:
self.refresh_state()
Expand All @@ -343,6 +360,12 @@ def render(self):
self._user_disabled_check()
self._check_if_flagged()

if self.should_submit_after_login():
self.submit_and_redirect(save=self.should_save_after_login())

if self.should_save_after_login():
self.save_and_redirect()

if gui.session_state.get("show_report_workflow"):
self.render_report_form()
return
Expand Down Expand Up @@ -508,6 +531,9 @@ def _render_published_run_save_buttons(self, *, sr: SavedRun, pr: PublishedRun):
className="mb-0 ms-lg-2 px-lg-4",
type="primary",
):
if not self.request.user or self.request.user.is_anonymous:
self._save_for_anonymous_user()

self.clear_publish_form()
ref.set_open(True)

Expand All @@ -523,6 +549,16 @@ def _render_published_run_save_buttons(self, *, sr: SavedRun, pr: PublishedRun):
sr=sr, pr=pr, dialog=ref, is_update_mode=can_edit
)

def _save_for_anonymous_user(self):
query_params = {SUBMIT_AFTER_LOGIN_Q: "1", SAVE_AFTER_LOGIN_Q: "1"}
if diff := self._get_request_diff_to_save():
query_params[REQUEST_STATE_Q] = base64.urlsafe_b64encode(
json.dumps(diff).encode()
).decode()
raise gui.RedirectException(
self.get_auth_url(next_url=self.current_app_url(query_params=query_params))
)

@staticmethod
def clear_publish_form():
keys = {k for k in gui.session_state.keys() if k.startswith("published_run_")}
Expand Down Expand Up @@ -588,8 +624,7 @@ def _render_publish_form(
if is_update_mode:
title = pr.title or self.title
else:
recipe_title = self.get_root_pr().title or self.title
title = f"{self.request.user and self.request.user.first_name_possesive()} {recipe_title}"
title = self._get_default_pr_title()
published_run_title = gui.text_input(
"###### Title",
key="published_run_title",
Expand Down Expand Up @@ -642,6 +677,14 @@ def _render_publish_form(
)
raise gui.RedirectException(pr.get_app_url())

def _get_default_pr_title(self):
recipe_title = self.get_root_pr().title or self.title
if self.request.user and not self.request.user.is_anonymous:
title = f"{self.request.user.first_name_possesive()} {recipe_title}"
else:
title = f"My {recipe_title}"
return title

def _validate_published_run_title(self, title: str):
if slugify(title) in settings.DISALLOWED_TITLE_SLUGS:
raise TitleValidationError(
Expand Down Expand Up @@ -689,6 +732,14 @@ def _has_request_changed(self) -> bool:
else:
return False

def _get_request_diff_to_save(self) -> dict[str, typing.Any]:
sr_state = self.current_sr_to_session_state()
return {
k: gui.session_state[k]
for k in RequestModel.__fields__
if k in gui.session_state and sr_state.get(k) != gui.session_state[k]
}

def _saved_options_modal(self, *, sr: SavedRun, pr: PublishedRun):
is_latest_version = pr.saved_run == sr

Expand Down Expand Up @@ -1316,7 +1367,7 @@ def render_submit_button(self, key="--submit-1"):
self.render_run_cost()
with col2:
submitted = gui.button(
"🏃 Submit",
"🏃 Run",
key=key,
type="primary",
# disabled=bool(gui.session_state.get(StateKeys.run_status)),
Expand Down Expand Up @@ -1531,7 +1582,7 @@ def _render_output_col(self, *, submitted: bool = False, is_deleted: bool = Fals
gui.session_state.pop(StateKeys.pressed_randomize, None)
submitted = True

if submitted or self.should_submit_after_login():
if submitted:
self.submit_and_redirect()

run_state = self.get_run_state(gui.session_state)
Expand Down Expand Up @@ -1594,11 +1645,35 @@ def render_extra_waiting_output(self):
def estimate_run_duration(self) -> int | None:
pass

def submit_and_redirect(self):
def submit_and_redirect(self, save: bool = False) -> typing.NoReturn | None:
sr = self.on_submit()
if not sr:
return
raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))
if save:
self.save_and_redirect(sr)
else:
raise gui.RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))

def save_and_redirect(self, sr: SavedRun | None = None) -> typing.NoReturn:
pr = self.create_published_run(
published_run_id=get_random_doc_id(),
saved_run=sr or self.current_sr,
user=self.request.user,
title=gui.session_state.get(
"published_run_title", self._get_default_pr_title()
),
notes=gui.session_state.get(
"published_run_notes", self.current_pr and self.current_pr.notes or ""
),
visibility=PublishedRunVisibility(
int(
gui.session_state.get(
"published_run_visibility", PublishedRunVisibility.UNLISTED
)
)
),
)
raise gui.RedirectException(pr.get_app_url())

def on_submit(self):
try:
Expand All @@ -1621,6 +1696,13 @@ def should_submit_after_login(self) -> bool:
and not self.request.user.is_anonymous
)

def should_save_after_login(self) -> bool:
return bool(
self.request.query_params.get(SAVE_AFTER_LOGIN_Q)
and self.request.user
and not self.request.user.is_anonymous
)

def create_new_run(
self, *, enable_rate_limits: bool = False, **defaults
) -> SavedRun:
Expand Down

0 comments on commit 231e9ef

Please sign in to comment.