From 664b9de906ecc9b104012971310cc1eda34ac9e4 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 18 Sep 2024 19:50:10 +0530
Subject: [PATCH 01/81] Add validations to disallow inviting members to a
personal workspace
---
workspaces/models.py | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/workspaces/models.py b/workspaces/models.py
index c02bc7974..48d08db04 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -350,7 +350,7 @@ class Meta:
fields=["workspace", "user"],
condition=Q(deleted__isnull=True),
name="unique_workspace_user",
- )
+ ),
]
indexes = [
models.Index(fields=["workspace", "role", "deleted"]),
@@ -359,6 +359,11 @@ class Meta:
def __str__(self):
return f"{self.get_role_display()} - {self.user} ({self.workspace})"
+ def clean(self) -> None:
+ if self.workspace.is_personal and self.user_id != self.workspace.created_by_id:
+ raise ValidationError("You cannot add users to a personal workspace")
+ return super().clean()
+
def can_edit_workspace_metadata(self):
return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN)
@@ -503,6 +508,11 @@ class Meta:
def __str__(self):
return f"{self.email} - {self.workspace} ({self.get_status_display()})"
+ def clean(self) -> None:
+ if self.workspace.is_personal:
+ raise ValidationError("You cannot invite users to a personal workspace")
+ return super().clean()
+
@admin.display(description="Expired")
def has_expired(self):
return timezone.now() - self.updated_at > timedelta(
From d467bea09baa06983fe3dd2bd50aabb34bfd92e6 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 18 Sep 2024 19:50:48 +0530
Subject: [PATCH 02/81] Fix view for validation errors
Earlier: was showing a JSON array of strings
After this fix: shows a single string message
---
workspaces/views.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/workspaces/views.py b/workspaces/views.py
index 519ba42e1..e1e4c6f0f 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -77,7 +77,7 @@ def render_workspace_creation_view(user: AppUser):
try:
workspace.create_with_owner()
except ValidationError as e:
- gui.write(str(e), className="text-danger")
+ gui.write(e.message, className="text-danger")
else:
gui.rerun()
@@ -181,7 +181,7 @@ def member_invite_button_with_dialog(membership: WorkspaceMembership):
defaults=dict(role=role),
)
except ValidationError as e:
- gui.write(str(e), className="text-danger")
+ gui.write(e.message, className="text-danger")
else:
ref.set_open(False)
gui.rerun()
@@ -212,7 +212,7 @@ def edit_workspace_button_with_dialog(membership: WorkspaceMembership):
workspace_copy.full_clean()
except ValidationError as e:
# newlines in markdown
- gui.write(str(e), className="text-danger")
+ gui.write(e.message, className="text-danger")
else:
workspace_copy.save()
membership.workspace.refresh_from_db()
From f14d1b559494dce5f21a82c1f3a904da5b16306f Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 18 Sep 2024 19:52:01 +0530
Subject: [PATCH 03/81] Disallow teams tab for personal workspaces
---
routers/account.py | 9 ++++++---
workspaces/views.py | 5 +++++
workspaces/widgets.py | 5 ++++-
3 files changed, 15 insertions(+), 4 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index b05a69e68..e52e568fe 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -212,11 +212,14 @@ def url_path(self) -> str:
return get_route_path(self.route)
@classmethod
- def get_tabs_for_user(cls, user: AppUser | None) -> list["AccountTabs"]:
+ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
from daras_ai_v2.base import BasePage
ret = list(cls)
- if not BasePage.is_user_admin(user):
+ if (
+ not BasePage.is_user_admin(request.user)
+ or get_current_workspace(request.user, request.session).is_personal
+ ):
ret.remove(cls.workspaces)
return ret
@@ -286,7 +289,7 @@ def account_page_wrapper(request: Request, current_tab: TabData):
with page_wrapper(request):
gui.div(className="mt-5")
with gui.nav_tabs():
- for tab in AccountTabs.get_tabs_for_user(request.user):
+ for tab in AccountTabs.get_tabs_for_request(request):
with gui.nav_item(tab.url_path, active=tab == current_tab):
gui.html(tab.title)
diff --git a/workspaces/views.py b/workspaces/views.py
index e1e4c6f0f..19923985a 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -62,7 +62,12 @@ def invitation_page(current_user: AppUser, session: dict, invite: WorkspaceInvit
def workspaces_page(user: AppUser, session: dict):
+ from routers.account import account_route
+
workspace = get_current_workspace(user, session)
+ if workspace.is_personal:
+ raise gui.RedirectException(get_route_path(account_route))
+
membership = workspace.memberships.get(user=user)
render_workspace_by_membership(membership)
diff --git a/workspaces/widgets.py b/workspaces/widgets.py
index b5a3518da..bc32e4c39 100644
--- a/workspaces/widgets.py
+++ b/workspaces/widgets.py
@@ -2,12 +2,15 @@
from app_users.models import AppUser
from daras_ai_v2 import icons
+from daras_ai_v2.fastapi_tricks import get_route_path
from .models import Workspace
SESSION_SELECTED_WORKSPACE = "selected-workspace-id"
def workspace_selector(user: AppUser, session: dict):
+ from routers.account import workspaces_route
+
workspaces = Workspace.objects.filter(
memberships__user=user, memberships__deleted__isnull=True
).order_by("-is_personal", "-created_at")
@@ -31,7 +34,7 @@ def workspace_selector(user: AppUser, session: dict):
workspace.create_with_owner()
gui.session_state[SESSION_SELECTED_WORKSPACE] = workspace.id
session[SESSION_SELECTED_WORKSPACE] = workspace.id
- gui.rerun()
+ raise gui.RedirectException(get_route_path(workspaces_route))
selected_id = gui.selectbox(
label="",
From 8cf22d3a5e167a21229e2eb8446996cd8c3d19d4 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 19 Sep 2024 20:39:07 +0530
Subject: [PATCH 04/81] change billing tab URL to /account/billing and tab
order in account page
---
daras_ai_v2/billing.py | 16 ++++++++--------
daras_ai_v2/send_email.py | 6 +++---
payments/models.py | 4 ++--
payments/tasks.py | 12 ++++++------
routers/account.py | 28 ++++++++++++++++++++--------
routers/paypal.py | 4 ++--
6 files changed, 41 insertions(+), 29 deletions(-)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index daf650134..bd778de3e 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -309,13 +309,13 @@ def fmt_price(plan: PricingPlan) -> str:
def change_subscription(workspace: "Workspace", new_plan: PricingPlan, **kwargs):
- from routers.account import account_route
+ from routers.account import billing_route
from routers.account import payment_processing_route
current_plan = PricingPlan.from_sub(workspace.subscription)
if new_plan == current_plan:
- raise gui.RedirectException(get_app_route_url(account_route), status_code=303)
+ raise gui.RedirectException(get_app_route_url(billing_route), status_code=303)
if new_plan == PricingPlan.STARTER:
workspace.subscription.cancel()
@@ -491,7 +491,7 @@ def render_stripe_addon_button(dollat_amt: int, workspace: "Workspace", save_pm:
def stripe_addon_checkout_redirect(
workspace: "Workspace", dollat_amt: int, save_pm: bool
):
- from routers.account import account_route
+ from routers.account import billing_route
from routers.account import payment_processing_route
line_item = available_subscriptions["addon"]["stripe"].copy()
@@ -505,7 +505,7 @@ def stripe_addon_checkout_redirect(
line_items=[line_item],
mode="payment",
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
+ cancel_url=get_app_route_url(billing_route),
customer=workspace.get_or_create_stripe_customer(),
invoice_creation={"enabled": True},
allow_promotion_codes=True,
@@ -553,7 +553,7 @@ def render_stripe_subscription_button(
def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
- from routers.account import account_route
+ from routers.account import billing_route
from routers.account import payment_processing_route
if workspace.subscription and workspace.subscription.is_paid():
@@ -595,7 +595,7 @@ def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
checkout_session = stripe.checkout.Session.create(
mode="subscription",
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
+ cancel_url=get_app_route_url(billing_route),
allow_promotion_codes=True,
customer=customer,
line_items=line_items,
@@ -715,7 +715,7 @@ def render_payment_information(workspace: "Workspace"):
def change_payment_method(workspace: "Workspace"):
from routers.account import payment_processing_route
- from routers.account import account_route
+ from routers.account import billing_route
match workspace.subscription.payment_provider:
case PaymentProvider.STRIPE:
@@ -727,7 +727,7 @@ def change_payment_method(workspace: "Workspace"):
"metadata": {"subscription_id": workspace.subscription.external_id},
},
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
+ cancel_url=get_app_route_url(billing_route),
)
raise gui.RedirectException(session.url, status_code=303)
case _:
diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py
index a556da362..dd5c01538 100644
--- a/daras_ai_v2/send_email.py
+++ b/daras_ai_v2/send_email.py
@@ -49,16 +49,16 @@ def send_low_balance_email(
workspace: "Workspace",
total_credits_consumed: int,
):
- from routers.account import account_route
+ from routers.account import billing_route
- print("sending...")
+ logger.info("Sending low balance email...")
recipeints = "support@gooey.ai, devs@gooey.ai"
for user in workspace.get_owners():
html_body = templates.get_template("low_balance_email.html").render(
user=user,
workspace=workspace,
- url=get_app_route_url(account_route),
+ url=get_app_route_url(billing_route),
total_credits_consumed=total_credits_consumed,
settings=settings,
)
diff --git a/payments/models.py b/payments/models.py
index e9812541f..3721eb1ca 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -340,7 +340,7 @@ def get_external_management_url(self) -> str:
"""
Get URL to Stripe/PayPal for user to manage the subscription.
"""
- from routers.account import account_route
+ from routers.account import billing_route
match self.payment_provider:
case PaymentProvider.PAYPAL:
@@ -354,7 +354,7 @@ def get_external_management_url(self) -> str:
case PaymentProvider.STRIPE:
portal = stripe.billing_portal.Session.create(
customer=self.stripe_get_customer_id(),
- return_url=get_app_route_url(account_route),
+ return_url=get_app_route_url(billing_route),
)
return portal.url
case _:
diff --git a/payments/tasks.py b/payments/tasks.py
index a43299667..29033f5ac 100644
--- a/payments/tasks.py
+++ b/payments/tasks.py
@@ -11,7 +11,7 @@
@app.task
def send_monthly_spending_notification_email(workspace_id: int):
- from routers.account import account_route
+ from routers.account import billing_route
workspace = Workspace.objects.get(id=workspace_id)
threshold = workspace.subscription.monthly_spending_notification_threshold
@@ -29,7 +29,7 @@ def send_monthly_spending_notification_email(workspace_id: int):
).render(
user=user,
workspace=workspace,
- account_url=get_app_route_url(account_route),
+ account_url=get_app_route_url(billing_route),
),
)
@@ -49,7 +49,7 @@ def send_payment_failed_email_with_invoice(
dollar_amt: float,
subject: str,
):
- from routers.account import account_route
+ from routers.account import billing_route
workspace = Workspace.objects.get(id=workspace_id)
for user in workspace.get_owners():
@@ -65,14 +65,14 @@ def send_payment_failed_email_with_invoice(
user=user,
dollar_amt=f"{dollar_amt:.2f}",
invoice_url=invoice_url,
- account_url=get_app_route_url(account_route),
+ account_url=get_app_route_url(billing_route),
),
message_stream="billing",
)
def send_monthly_budget_reached_email(workspace: Workspace):
- from routers.account import account_route
+ from routers.account import billing_route
for user in workspace.get_owners():
if not user.email:
@@ -81,7 +81,7 @@ def send_monthly_budget_reached_email(workspace: Workspace):
email_body = templates.get_template("monthly_budget_reached_email.html").render(
user=user,
workspace=workspace,
- account_url=get_app_route_url(account_route),
+ account_url=get_app_route_url(billing_route),
)
send_email_via_postmark(
from_address=settings.SUPPORT_EMAIL,
diff --git a/routers/account.py b/routers/account.py
index e52e568fe..bdeaf451a 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -10,7 +10,6 @@
from requests.models import HTTPError
from starlette.responses import Response
-from app_users.models import AppUser
from bots.models import PublishedRun, PublishedRunVisibility, Workflow
from daras_ai_v2 import icons, paypal
from daras_ai_v2.billing import billing_page
@@ -71,7 +70,7 @@ def payment_processing_route(
}, waitingTimeMs);
""",
waitingTimeMs=waiting_time_sec * 1000,
- redirectUrl=get_app_route_url(account_route),
+ redirectUrl=get_app_route_url(billing_route),
)
return dict(
@@ -81,6 +80,19 @@ def payment_processing_route(
@gui.route(app, "/account/")
def account_route(request: Request):
+ from daras_ai_v2.base import BasePage
+
+ if (
+ not BasePage.is_user_admin(request.user)
+ or get_current_workspace(request.user, request.session).is_personal
+ ):
+ raise gui.RedirectException(get_route_path(profile_route))
+ else:
+ raise gui.RedirectException(get_route_path(workspaces_route))
+
+
+@gui.route(app, "/account/billing/")
+def billing_route(request: Request):
with account_page_wrapper(request, AccountTabs.billing):
billing_tab(request)
url = get_og_url_path(request)
@@ -201,11 +213,11 @@ class TabData(typing.NamedTuple):
class AccountTabs(TabData, Enum):
- billing = TabData(title=f"{icons.billing} Billing", route=account_route)
profile = TabData(title=f"{icons.profile} Profile", route=profile_route)
+ workspaces = TabData(title=f"{icons.company} Members", route=workspaces_route)
saved = TabData(title=f"{icons.save} Saved", route=saved_route)
api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route)
- workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route)
+ billing = TabData(title=f"{icons.billing} Billing", route=billing_route)
@property
def url_path(self) -> str:
@@ -216,11 +228,11 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
from daras_ai_v2.base import BasePage
ret = list(cls)
- if (
- not BasePage.is_user_admin(request.user)
- or get_current_workspace(request.user, request.session).is_personal
- ):
+ workspace = get_current_workspace(request.user, request.session)
+ if not BasePage.is_user_admin(request.user) or workspace.is_personal:
ret.remove(cls.workspaces)
+ elif not workspace.is_personal:
+ ret.remove(cls.profile)
return ret
diff --git a/routers/paypal.py b/routers/paypal.py
index 1459d6da6..e61e6d6c6 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -17,7 +17,7 @@
from daras_ai_v2.fastapi_tricks import fastapi_request_json, get_app_route_url
from payments.models import PricingPlan
from payments.webhooks import PaypalWebhookHandler, add_balance_for_payment
-from routers.account import payment_processing_route, account_route
+from routers.account import payment_processing_route, billing_route
from routers.custom_api_router import CustomAPIRouter
from workspaces.models import Workspace
from workspaces.widgets import get_current_workspace
@@ -144,7 +144,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
"brand_name": "Gooey.AI",
"shipping_preference": "NO_SHIPPING",
"return_url": get_app_route_url(payment_processing_route),
- "cancel_url": get_app_route_url(account_route),
+ "cancel_url": get_app_route_url(billing_route),
},
)
From b53d1a6ad82d864c93b450e9ec44bdc5d2f85d23 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 19 Sep 2024 20:41:17 +0530
Subject: [PATCH 05/81] update meta title for Members tab
---
routers/account.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/routers/account.py b/routers/account.py
index bdeaf451a..5908fcba0 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -165,7 +165,7 @@ def workspaces_route(request: Request):
meta=raw_build_meta_tags(
url=url,
canonical_url=url,
- title="Teams • Gooey.AI",
+ title="Members • Gooey.AI",
description="Your teams.",
robots="noindex,nofollow",
)
From c725ce07ac227e2155064b63f2008cc741a19206 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 20 Sep 2024 01:00:05 +0530
Subject: [PATCH 06/81] only owners and admins should see the billing tab
---
routers/account.py | 5 +++++
workspaces/models.py | 9 +++++++++
2 files changed, 14 insertions(+)
diff --git a/routers/account.py b/routers/account.py
index 5908fcba0..31c0ccedb 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -234,6 +234,11 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
elif not workspace.is_personal:
ret.remove(cls.profile)
+ owners_and_admins = workspace.get_owners() | workspace.get_admins()
+ if not owners_and_admins.filter(id=request.user.id).exists():
+ # don't show billing tab if user is not an owner or admin
+ ret.remove(cls.billing)
+
return ret
diff --git a/workspaces/models.py b/workspaces/models.py
index 48d08db04..25eb3132d 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -179,6 +179,15 @@ def get_owners(self) -> models.QuerySet[AppUser]:
workspace_memberships__deleted__isnull=True,
)
+ def get_admins(self) -> models.QuerySet[AppUser]:
+ from app_users.models import AppUser
+
+ return AppUser.objects.filter(
+ workspace_memberships__workspace=self,
+ workspace_memberships__role=WorkspaceRole.ADMIN,
+ workspace_memberships__deleted__isnull=True,
+ )
+
@db_middleware
@transaction.atomic
def add_balance(
From 26e4ac57b994189d62b11b9a4f68b207a749d269 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 20 Sep 2024 01:00:27 +0530
Subject: [PATCH 07/81] Revert "change billing tab URL to /account/billing and
tab order in account page"
This reverts commit 9d4d805de2fbbc792990d5685dbbc57f4bb554eb.
---
daras_ai_v2/billing.py | 16 ++++++++--------
daras_ai_v2/send_email.py | 6 +++---
payments/models.py | 4 ++--
payments/tasks.py | 12 ++++++------
routers/account.py | 28 ++++++++--------------------
routers/paypal.py | 4 ++--
6 files changed, 29 insertions(+), 41 deletions(-)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index bd778de3e..daf650134 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -309,13 +309,13 @@ def fmt_price(plan: PricingPlan) -> str:
def change_subscription(workspace: "Workspace", new_plan: PricingPlan, **kwargs):
- from routers.account import billing_route
+ from routers.account import account_route
from routers.account import payment_processing_route
current_plan = PricingPlan.from_sub(workspace.subscription)
if new_plan == current_plan:
- raise gui.RedirectException(get_app_route_url(billing_route), status_code=303)
+ raise gui.RedirectException(get_app_route_url(account_route), status_code=303)
if new_plan == PricingPlan.STARTER:
workspace.subscription.cancel()
@@ -491,7 +491,7 @@ def render_stripe_addon_button(dollat_amt: int, workspace: "Workspace", save_pm:
def stripe_addon_checkout_redirect(
workspace: "Workspace", dollat_amt: int, save_pm: bool
):
- from routers.account import billing_route
+ from routers.account import account_route
from routers.account import payment_processing_route
line_item = available_subscriptions["addon"]["stripe"].copy()
@@ -505,7 +505,7 @@ def stripe_addon_checkout_redirect(
line_items=[line_item],
mode="payment",
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(billing_route),
+ cancel_url=get_app_route_url(account_route),
customer=workspace.get_or_create_stripe_customer(),
invoice_creation={"enabled": True},
allow_promotion_codes=True,
@@ -553,7 +553,7 @@ def render_stripe_subscription_button(
def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
- from routers.account import billing_route
+ from routers.account import account_route
from routers.account import payment_processing_route
if workspace.subscription and workspace.subscription.is_paid():
@@ -595,7 +595,7 @@ def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
checkout_session = stripe.checkout.Session.create(
mode="subscription",
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(billing_route),
+ cancel_url=get_app_route_url(account_route),
allow_promotion_codes=True,
customer=customer,
line_items=line_items,
@@ -715,7 +715,7 @@ def render_payment_information(workspace: "Workspace"):
def change_payment_method(workspace: "Workspace"):
from routers.account import payment_processing_route
- from routers.account import billing_route
+ from routers.account import account_route
match workspace.subscription.payment_provider:
case PaymentProvider.STRIPE:
@@ -727,7 +727,7 @@ def change_payment_method(workspace: "Workspace"):
"metadata": {"subscription_id": workspace.subscription.external_id},
},
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(billing_route),
+ cancel_url=get_app_route_url(account_route),
)
raise gui.RedirectException(session.url, status_code=303)
case _:
diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py
index dd5c01538..a556da362 100644
--- a/daras_ai_v2/send_email.py
+++ b/daras_ai_v2/send_email.py
@@ -49,16 +49,16 @@ def send_low_balance_email(
workspace: "Workspace",
total_credits_consumed: int,
):
- from routers.account import billing_route
+ from routers.account import account_route
- logger.info("Sending low balance email...")
+ print("sending...")
recipeints = "support@gooey.ai, devs@gooey.ai"
for user in workspace.get_owners():
html_body = templates.get_template("low_balance_email.html").render(
user=user,
workspace=workspace,
- url=get_app_route_url(billing_route),
+ url=get_app_route_url(account_route),
total_credits_consumed=total_credits_consumed,
settings=settings,
)
diff --git a/payments/models.py b/payments/models.py
index 3721eb1ca..e9812541f 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -340,7 +340,7 @@ def get_external_management_url(self) -> str:
"""
Get URL to Stripe/PayPal for user to manage the subscription.
"""
- from routers.account import billing_route
+ from routers.account import account_route
match self.payment_provider:
case PaymentProvider.PAYPAL:
@@ -354,7 +354,7 @@ def get_external_management_url(self) -> str:
case PaymentProvider.STRIPE:
portal = stripe.billing_portal.Session.create(
customer=self.stripe_get_customer_id(),
- return_url=get_app_route_url(billing_route),
+ return_url=get_app_route_url(account_route),
)
return portal.url
case _:
diff --git a/payments/tasks.py b/payments/tasks.py
index 29033f5ac..a43299667 100644
--- a/payments/tasks.py
+++ b/payments/tasks.py
@@ -11,7 +11,7 @@
@app.task
def send_monthly_spending_notification_email(workspace_id: int):
- from routers.account import billing_route
+ from routers.account import account_route
workspace = Workspace.objects.get(id=workspace_id)
threshold = workspace.subscription.monthly_spending_notification_threshold
@@ -29,7 +29,7 @@ def send_monthly_spending_notification_email(workspace_id: int):
).render(
user=user,
workspace=workspace,
- account_url=get_app_route_url(billing_route),
+ account_url=get_app_route_url(account_route),
),
)
@@ -49,7 +49,7 @@ def send_payment_failed_email_with_invoice(
dollar_amt: float,
subject: str,
):
- from routers.account import billing_route
+ from routers.account import account_route
workspace = Workspace.objects.get(id=workspace_id)
for user in workspace.get_owners():
@@ -65,14 +65,14 @@ def send_payment_failed_email_with_invoice(
user=user,
dollar_amt=f"{dollar_amt:.2f}",
invoice_url=invoice_url,
- account_url=get_app_route_url(billing_route),
+ account_url=get_app_route_url(account_route),
),
message_stream="billing",
)
def send_monthly_budget_reached_email(workspace: Workspace):
- from routers.account import billing_route
+ from routers.account import account_route
for user in workspace.get_owners():
if not user.email:
@@ -81,7 +81,7 @@ def send_monthly_budget_reached_email(workspace: Workspace):
email_body = templates.get_template("monthly_budget_reached_email.html").render(
user=user,
workspace=workspace,
- account_url=get_app_route_url(billing_route),
+ account_url=get_app_route_url(account_route),
)
send_email_via_postmark(
from_address=settings.SUPPORT_EMAIL,
diff --git a/routers/account.py b/routers/account.py
index 31c0ccedb..ea8bdd70f 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -10,6 +10,7 @@
from requests.models import HTTPError
from starlette.responses import Response
+from app_users.models import AppUser
from bots.models import PublishedRun, PublishedRunVisibility, Workflow
from daras_ai_v2 import icons, paypal
from daras_ai_v2.billing import billing_page
@@ -70,7 +71,7 @@ def payment_processing_route(
}, waitingTimeMs);
""",
waitingTimeMs=waiting_time_sec * 1000,
- redirectUrl=get_app_route_url(billing_route),
+ redirectUrl=get_app_route_url(account_route),
)
return dict(
@@ -80,19 +81,6 @@ def payment_processing_route(
@gui.route(app, "/account/")
def account_route(request: Request):
- from daras_ai_v2.base import BasePage
-
- if (
- not BasePage.is_user_admin(request.user)
- or get_current_workspace(request.user, request.session).is_personal
- ):
- raise gui.RedirectException(get_route_path(profile_route))
- else:
- raise gui.RedirectException(get_route_path(workspaces_route))
-
-
-@gui.route(app, "/account/billing/")
-def billing_route(request: Request):
with account_page_wrapper(request, AccountTabs.billing):
billing_tab(request)
url = get_og_url_path(request)
@@ -213,11 +201,11 @@ class TabData(typing.NamedTuple):
class AccountTabs(TabData, Enum):
+ billing = TabData(title=f"{icons.billing} Billing", route=account_route)
profile = TabData(title=f"{icons.profile} Profile", route=profile_route)
- workspaces = TabData(title=f"{icons.company} Members", route=workspaces_route)
saved = TabData(title=f"{icons.save} Saved", route=saved_route)
api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route)
- billing = TabData(title=f"{icons.billing} Billing", route=billing_route)
+ workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route)
@property
def url_path(self) -> str:
@@ -228,11 +216,11 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
from daras_ai_v2.base import BasePage
ret = list(cls)
- workspace = get_current_workspace(request.user, request.session)
- if not BasePage.is_user_admin(request.user) or workspace.is_personal:
+ if (
+ not BasePage.is_user_admin(request.user)
+ or get_current_workspace(request.user, request.session).is_personal
+ ):
ret.remove(cls.workspaces)
- elif not workspace.is_personal:
- ret.remove(cls.profile)
owners_and_admins = workspace.get_owners() | workspace.get_admins()
if not owners_and_admins.filter(id=request.user.id).exists():
diff --git a/routers/paypal.py b/routers/paypal.py
index e61e6d6c6..1459d6da6 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -17,7 +17,7 @@
from daras_ai_v2.fastapi_tricks import fastapi_request_json, get_app_route_url
from payments.models import PricingPlan
from payments.webhooks import PaypalWebhookHandler, add_balance_for_payment
-from routers.account import payment_processing_route, billing_route
+from routers.account import payment_processing_route, account_route
from routers.custom_api_router import CustomAPIRouter
from workspaces.models import Workspace
from workspaces.widgets import get_current_workspace
@@ -144,7 +144,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
"brand_name": "Gooey.AI",
"shipping_preference": "NO_SHIPPING",
"return_url": get_app_route_url(payment_processing_route),
- "cancel_url": get_app_route_url(billing_route),
+ "cancel_url": get_app_route_url(account_route),
},
)
From 4203b0383b0c47c5ebeb1e07c8cd624a465c7ec4 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 20 Sep 2024 15:05:05 +0530
Subject: [PATCH 08/81] enforce access control on workspace billing
---
routers/account.py | 23 ++++++++++++-----------
workspaces/models.py | 3 ++-
workspaces/views.py | 2 +-
3 files changed, 15 insertions(+), 13 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index ea8bdd70f..4b490864e 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -10,7 +10,6 @@
from requests.models import HTTPError
from starlette.responses import Response
-from app_users.models import AppUser
from bots.models import PublishedRun, PublishedRunVisibility, Workflow
from daras_ai_v2 import icons, paypal
from daras_ai_v2.billing import billing_page
@@ -201,11 +200,11 @@ class TabData(typing.NamedTuple):
class AccountTabs(TabData, Enum):
- billing = TabData(title=f"{icons.billing} Billing", route=account_route)
profile = TabData(title=f"{icons.profile} Profile", route=profile_route)
+ workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route)
saved = TabData(title=f"{icons.save} Saved", route=saved_route)
api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route)
- workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route)
+ billing = TabData(title=f"{icons.billing} Billing", route=account_route)
@property
def url_path(self) -> str:
@@ -216,22 +215,24 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
from daras_ai_v2.base import BasePage
ret = list(cls)
- if (
- not BasePage.is_user_admin(request.user)
- or get_current_workspace(request.user, request.session).is_personal
- ):
+
+ workspace = get_current_workspace(request.user, request.session)
+ if not BasePage.is_user_admin(request.user) or workspace.is_personal:
ret.remove(cls.workspaces)
- owners_and_admins = workspace.get_owners() | workspace.get_admins()
- if not owners_and_admins.filter(id=request.user.id).exists():
- # don't show billing tab if user is not an owner or admin
- ret.remove(cls.billing)
+ if not workspace.is_personal:
+ ret.remove(cls.profile)
+
+ if not workspace.memberships.get(user=request.user).can_edit_workspace():
+ ret.remove(cls.billing)
return ret
def billing_tab(request: Request):
workspace = get_current_workspace(request.user, request.session)
+ if not workspace.is_personal and not workspace.memberships.get(user=request.user):
+ raise gui.RedirectException(get_route_path(account_route))
return billing_page(workspace)
diff --git a/workspaces/models.py b/workspaces/models.py
index 25eb3132d..14e459b00 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -373,7 +373,8 @@ def clean(self) -> None:
raise ValidationError("You cannot add users to a personal workspace")
return super().clean()
- def can_edit_workspace_metadata(self):
+ def can_edit_workspace(self):
+ # workspace metadata, billing, etc.
return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN)
def can_leave_workspace(self):
diff --git a/workspaces/views.py b/workspaces/views.py
index 19923985a..4f8692301 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -193,7 +193,7 @@ def member_invite_button_with_dialog(membership: WorkspaceMembership):
def edit_workspace_button_with_dialog(membership: WorkspaceMembership):
- if not membership.can_edit_workspace_metadata():
+ if not membership.can_edit_workspace():
return
ref = gui.use_confirm_dialog(key="edit-workspace", close_on_confirm=False)
From fa38031aa2a4884e462b4fe9efd08f0011be6ef9 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 20 Sep 2024 15:07:11 +0530
Subject: [PATCH 09/81] Revert "Revert "change billing tab URL to
/account/billing and tab order in account page""
This reverts commit 5848ee9c7dc2e7a282f7cc0cd5503b178843692f.
---
daras_ai_v2/billing.py | 16 ++++++++--------
daras_ai_v2/send_email.py | 6 +++---
payments/models.py | 4 ++--
payments/tasks.py | 12 ++++++------
routers/account.py | 27 ++++++++++++++++++++++-----
routers/paypal.py | 4 ++--
6 files changed, 43 insertions(+), 26 deletions(-)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index daf650134..bd778de3e 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -309,13 +309,13 @@ def fmt_price(plan: PricingPlan) -> str:
def change_subscription(workspace: "Workspace", new_plan: PricingPlan, **kwargs):
- from routers.account import account_route
+ from routers.account import billing_route
from routers.account import payment_processing_route
current_plan = PricingPlan.from_sub(workspace.subscription)
if new_plan == current_plan:
- raise gui.RedirectException(get_app_route_url(account_route), status_code=303)
+ raise gui.RedirectException(get_app_route_url(billing_route), status_code=303)
if new_plan == PricingPlan.STARTER:
workspace.subscription.cancel()
@@ -491,7 +491,7 @@ def render_stripe_addon_button(dollat_amt: int, workspace: "Workspace", save_pm:
def stripe_addon_checkout_redirect(
workspace: "Workspace", dollat_amt: int, save_pm: bool
):
- from routers.account import account_route
+ from routers.account import billing_route
from routers.account import payment_processing_route
line_item = available_subscriptions["addon"]["stripe"].copy()
@@ -505,7 +505,7 @@ def stripe_addon_checkout_redirect(
line_items=[line_item],
mode="payment",
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
+ cancel_url=get_app_route_url(billing_route),
customer=workspace.get_or_create_stripe_customer(),
invoice_creation={"enabled": True},
allow_promotion_codes=True,
@@ -553,7 +553,7 @@ def render_stripe_subscription_button(
def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
- from routers.account import account_route
+ from routers.account import billing_route
from routers.account import payment_processing_route
if workspace.subscription and workspace.subscription.is_paid():
@@ -595,7 +595,7 @@ def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
checkout_session = stripe.checkout.Session.create(
mode="subscription",
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
+ cancel_url=get_app_route_url(billing_route),
allow_promotion_codes=True,
customer=customer,
line_items=line_items,
@@ -715,7 +715,7 @@ def render_payment_information(workspace: "Workspace"):
def change_payment_method(workspace: "Workspace"):
from routers.account import payment_processing_route
- from routers.account import account_route
+ from routers.account import billing_route
match workspace.subscription.payment_provider:
case PaymentProvider.STRIPE:
@@ -727,7 +727,7 @@ def change_payment_method(workspace: "Workspace"):
"metadata": {"subscription_id": workspace.subscription.external_id},
},
success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
+ cancel_url=get_app_route_url(billing_route),
)
raise gui.RedirectException(session.url, status_code=303)
case _:
diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py
index a556da362..dd5c01538 100644
--- a/daras_ai_v2/send_email.py
+++ b/daras_ai_v2/send_email.py
@@ -49,16 +49,16 @@ def send_low_balance_email(
workspace: "Workspace",
total_credits_consumed: int,
):
- from routers.account import account_route
+ from routers.account import billing_route
- print("sending...")
+ logger.info("Sending low balance email...")
recipeints = "support@gooey.ai, devs@gooey.ai"
for user in workspace.get_owners():
html_body = templates.get_template("low_balance_email.html").render(
user=user,
workspace=workspace,
- url=get_app_route_url(account_route),
+ url=get_app_route_url(billing_route),
total_credits_consumed=total_credits_consumed,
settings=settings,
)
diff --git a/payments/models.py b/payments/models.py
index e9812541f..3721eb1ca 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -340,7 +340,7 @@ def get_external_management_url(self) -> str:
"""
Get URL to Stripe/PayPal for user to manage the subscription.
"""
- from routers.account import account_route
+ from routers.account import billing_route
match self.payment_provider:
case PaymentProvider.PAYPAL:
@@ -354,7 +354,7 @@ def get_external_management_url(self) -> str:
case PaymentProvider.STRIPE:
portal = stripe.billing_portal.Session.create(
customer=self.stripe_get_customer_id(),
- return_url=get_app_route_url(account_route),
+ return_url=get_app_route_url(billing_route),
)
return portal.url
case _:
diff --git a/payments/tasks.py b/payments/tasks.py
index a43299667..29033f5ac 100644
--- a/payments/tasks.py
+++ b/payments/tasks.py
@@ -11,7 +11,7 @@
@app.task
def send_monthly_spending_notification_email(workspace_id: int):
- from routers.account import account_route
+ from routers.account import billing_route
workspace = Workspace.objects.get(id=workspace_id)
threshold = workspace.subscription.monthly_spending_notification_threshold
@@ -29,7 +29,7 @@ def send_monthly_spending_notification_email(workspace_id: int):
).render(
user=user,
workspace=workspace,
- account_url=get_app_route_url(account_route),
+ account_url=get_app_route_url(billing_route),
),
)
@@ -49,7 +49,7 @@ def send_payment_failed_email_with_invoice(
dollar_amt: float,
subject: str,
):
- from routers.account import account_route
+ from routers.account import billing_route
workspace = Workspace.objects.get(id=workspace_id)
for user in workspace.get_owners():
@@ -65,14 +65,14 @@ def send_payment_failed_email_with_invoice(
user=user,
dollar_amt=f"{dollar_amt:.2f}",
invoice_url=invoice_url,
- account_url=get_app_route_url(account_route),
+ account_url=get_app_route_url(billing_route),
),
message_stream="billing",
)
def send_monthly_budget_reached_email(workspace: Workspace):
- from routers.account import account_route
+ from routers.account import billing_route
for user in workspace.get_owners():
if not user.email:
@@ -81,7 +81,7 @@ def send_monthly_budget_reached_email(workspace: Workspace):
email_body = templates.get_template("monthly_budget_reached_email.html").render(
user=user,
workspace=workspace,
- account_url=get_app_route_url(account_route),
+ account_url=get_app_route_url(billing_route),
)
send_email_via_postmark(
from_address=settings.SUPPORT_EMAIL,
diff --git a/routers/account.py b/routers/account.py
index 4b490864e..d29120769 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -70,7 +70,7 @@ def payment_processing_route(
}, waitingTimeMs);
""",
waitingTimeMs=waiting_time_sec * 1000,
- redirectUrl=get_app_route_url(account_route),
+ redirectUrl=get_app_route_url(billing_route),
)
return dict(
@@ -80,6 +80,19 @@ def payment_processing_route(
@gui.route(app, "/account/")
def account_route(request: Request):
+ from daras_ai_v2.base import BasePage
+
+ if (
+ not BasePage.is_user_admin(request.user)
+ or get_current_workspace(request.user, request.session).is_personal
+ ):
+ raise gui.RedirectException(get_route_path(profile_route))
+ else:
+ raise gui.RedirectException(get_route_path(workspaces_route))
+
+
+@gui.route(app, "/account/billing/")
+def billing_route(request: Request):
with account_page_wrapper(request, AccountTabs.billing):
billing_tab(request)
url = get_og_url_path(request)
@@ -201,10 +214,10 @@ class TabData(typing.NamedTuple):
class AccountTabs(TabData, Enum):
profile = TabData(title=f"{icons.profile} Profile", route=profile_route)
- workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route)
+ workspaces = TabData(title=f"{icons.company} Members", route=workspaces_route)
saved = TabData(title=f"{icons.save} Saved", route=saved_route)
api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route)
- billing = TabData(title=f"{icons.billing} Billing", route=account_route)
+ billing = TabData(title=f"{icons.billing} Billing", route=billing_route)
@property
def url_path(self) -> str:
@@ -215,10 +228,11 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
from daras_ai_v2.base import BasePage
ret = list(cls)
-
workspace = get_current_workspace(request.user, request.session)
if not BasePage.is_user_admin(request.user) or workspace.is_personal:
ret.remove(cls.workspaces)
+ elif not workspace.is_personal:
+ ret.remove(cls.profile)
if not workspace.is_personal:
ret.remove(cls.profile)
@@ -231,7 +245,10 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
def billing_tab(request: Request):
workspace = get_current_workspace(request.user, request.session)
- if not workspace.is_personal and not workspace.memberships.get(user=request.user):
+ if (
+ not workspace.is_personal
+ and not workspace.memberships.get(user=request.user).can_edit_workspace()
+ ):
raise gui.RedirectException(get_route_path(account_route))
return billing_page(workspace)
diff --git a/routers/paypal.py b/routers/paypal.py
index 1459d6da6..e61e6d6c6 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -17,7 +17,7 @@
from daras_ai_v2.fastapi_tricks import fastapi_request_json, get_app_route_url
from payments.models import PricingPlan
from payments.webhooks import PaypalWebhookHandler, add_balance_for_payment
-from routers.account import payment_processing_route, account_route
+from routers.account import payment_processing_route, billing_route
from routers.custom_api_router import CustomAPIRouter
from workspaces.models import Workspace
from workspaces.widgets import get_current_workspace
@@ -144,7 +144,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
"brand_name": "Gooey.AI",
"shipping_preference": "NO_SHIPPING",
"return_url": get_app_route_url(payment_processing_route),
- "cancel_url": get_app_route_url(account_route),
+ "cancel_url": get_app_route_url(billing_route),
},
)
From b320cfebbd25cf4901b83e4cc4358d7b86af4430 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 20 Sep 2024 15:08:36 +0530
Subject: [PATCH 10/81] fix accidental bug from conflict resolution
---
routers/account.py | 3 ---
1 file changed, 3 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index d29120769..903f080c9 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -231,12 +231,9 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
workspace = get_current_workspace(request.user, request.session)
if not BasePage.is_user_admin(request.user) or workspace.is_personal:
ret.remove(cls.workspaces)
- elif not workspace.is_personal:
- ret.remove(cls.profile)
if not workspace.is_personal:
ret.remove(cls.profile)
-
if not workspace.memberships.get(user=request.user).can_edit_workspace():
ret.remove(cls.billing)
From ac75f6e16d20698c6a06bafba7d14b4c078df82c Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 00:53:29 +0530
Subject: [PATCH 11/81] Add workspace to published run
---
.../migrations/0083_publishedrun_workspace.py | 20 +++++++++++++++++++
bots/models.py | 6 ++++++
scripts/migrate_workspaces.py | 3 +--
3 files changed, 27 insertions(+), 2 deletions(-)
create mode 100644 bots/migrations/0083_publishedrun_workspace.py
diff --git a/bots/migrations/0083_publishedrun_workspace.py b/bots/migrations/0083_publishedrun_workspace.py
new file mode 100644
index 000000000..9c6c0fba4
--- /dev/null
+++ b/bots/migrations/0083_publishedrun_workspace.py
@@ -0,0 +1,20 @@
+# Generated by Django 5.1.1 on 2024-09-20 14:22
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('bots', '0082_savedrun_workspace'),
+ ('workspaces', '0002_alter_workspace_domain_name'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='publishedrun',
+ name='workspace',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='workspaces.workspace'),
+ ),
+ ]
diff --git a/bots/models.py b/bots/models.py
index 711672ef4..60a53932f 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -1714,6 +1714,12 @@ class PublishedRun(models.Model):
null=True,
)
+ workspace = models.ForeignKey(
+ "workspaces.Workspace",
+ on_delete=models.SET_NULL,
+ null=True,
+ )
+
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
diff --git a/scripts/migrate_workspaces.py b/scripts/migrate_workspaces.py
index b245bc4ef..ba2cee100 100644
--- a/scripts/migrate_workspaces.py
+++ b/scripts/migrate_workspaces.py
@@ -16,8 +16,7 @@ def run():
migrate_personal_workspaces()
migrate_txns()
migrate_saved_runs()
- ## LATER
- # migrate_published_runs()
+ migrate_published_runs()
@transaction.atomic
From 82e17870acaea1222b1cd1eb255d8f610976de40 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 01:10:40 +0530
Subject: [PATCH 12/81] fix can_edit_published_run access control for
workspaces
---
daras_ai_v2/base.py | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 1fa8eecd0..2c4966e60 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -441,11 +441,9 @@ def can_user_save_run(
)
def can_user_edit_published_run(self, published_run: PublishedRun) -> bool:
- return self.is_current_user_admin() or bool(
- self.request
- and self.request.user
- and published_run.created_by_id
- and published_run.created_by_id == self.request.user.id
+ return (
+ self.is_current_user_admin()
+ or published_run.workspace == self.current_workspace
)
def _render_title(self, title: str):
From 80c80a8f1a0678009b2fffeea006ea4f6bae8e2a Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 01:11:30 +0530
Subject: [PATCH 13/81] =?UTF-8?q?feat:=20show=20history=20belonging=20to?=
=?UTF-8?q?=20uid=20=E2=9C=96=EF=B8=8F=20current=20workspace?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
daras_ai_v2/base.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 2c4966e60..388aa9a2e 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1841,9 +1841,11 @@ def _render(pr: PublishedRun):
def _history_tab(self):
self.ensure_authentication(anon_ok=True)
+ workspace_id = self.current_workspace.id
uid = self.request.user.uid
if self.is_current_user_admin():
uid = self.request.query_params.get("uid", uid)
+ workspace_id = self.request.query_params.get("workspace_id", workspace_id)
before = self.request.query_params.get("updated_at__lt", None)
if before:
@@ -1853,8 +1855,9 @@ def _history_tab(self):
run_history = list(
SavedRun.objects.filter(
workflow=self.workflow,
- uid=uid,
updated_at__lt=before,
+ uid=uid,
+ workspace_id=workspace_id,
)[:25]
)
if not run_history:
From 86722393d736bbde1fe5d001d8eef61749ea836b Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 01:13:06 +0530
Subject: [PATCH 14/81] feat: show (and create) API keys belonging to workspace
---
daras_ai_v2/base.py | 2 +-
daras_ai_v2/manage_api_keys_widget.py | 30 ++++++++++++++++++++-------
routers/account.py | 3 ++-
routers/root.py | 2 +-
4 files changed, 26 insertions(+), 11 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 388aa9a2e..267833ba6 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -2078,7 +2078,7 @@ def run_as_api_tab(self):
with gui.tag("a", id="api-keys"):
gui.write("### 🔐 API keys")
- manage_api_keys(self.request.user)
+ manage_api_keys(self.current_workspace, self.request.user)
def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict):
if not settings.CREDITS_TO_DEDUCT_PER_RUN:
diff --git a/daras_ai_v2/manage_api_keys_widget.py b/daras_ai_v2/manage_api_keys_widget.py
index 1db291be7..4b94a78e1 100644
--- a/daras_ai_v2/manage_api_keys_widget.py
+++ b/daras_ai_v2/manage_api_keys_widget.py
@@ -1,4 +1,7 @@
import datetime
+import typing
+
+from google.cloud import firestore
import gooey_gui as gui
from app_users.models import AppUser
@@ -12,11 +15,14 @@
get_random_api_key,
)
+if typing.TYPE_CHECKING:
+ from workspaces.models import Workspace
+
-def manage_api_keys(user: AppUser):
+def manage_api_keys(workspace: "Workspace", user: AppUser):
gui.write(
- """
-Your secret API keys are listed below.
+ f"""
+{workspace.display_name()} API keys are listed below.
Please note that we do not display your secret API keys again after you generate them.
Do not share your API key with others, or expose it in the browser or other client-side code.
@@ -27,13 +33,14 @@ def manage_api_keys(user: AppUser):
)
db_collection = db.get_client().collection(db.API_KEYS_COLLECTION)
- api_keys = _load_api_keys(db_collection, user)
+ api_keys = _load_api_keys(db_collection, workspace)
table_area = gui.div()
if gui.button("+ Create new secret key"):
doc = _generate_new_key_doc()
doc["uid"] = user.uid
+ doc["workspace_id"] = workspace.id
api_keys.append(doc)
db_collection.add(doc)
@@ -54,12 +61,19 @@ def manage_api_keys(user: AppUser):
)
-def _load_api_keys(db_collection, user):
+def _load_api_keys(
+ db_collection: firestore.CollectionReference, workspace: "Workspace"
+):
+ filter = firestore.FieldFilter("workspace_id", "==", workspace.id)
+ if workspace.is_personal:
+ # for backwards compatibility with existing keys
+ filter = firestore.Or(
+ [filter, firestore.FieldFilter("uid", "==", workspace.created_by.uid)]
+ )
+
return [
snap.to_dict()
- for snap in db_collection.where("uid", "==", user.uid)
- .order_by("created_at")
- .get()
+ for snap in db_collection.where(filter=filter).order_by("created_at").get()
]
diff --git a/routers/account.py b/routers/account.py
index 903f080c9..9781cf6fb 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -296,7 +296,8 @@ def _render_run(pr: PublishedRun):
def api_keys_tab(request: Request):
gui.write("# 🔐 API Keys")
- manage_api_keys(request.user)
+ workspace = get_current_workspace(request.user, request.session)
+ manage_api_keys(workspace, request.user)
@contextmanager
diff --git a/routers/root.py b/routers/root.py
index 9cb94daad..3ef490654 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -347,7 +347,7 @@ def _api_docs_page(request: Request):
)
return
- manage_api_keys(page.request.user)
+ manage_api_keys(page.current_workspace, page.request.user)
@gui.route(
From 4bf557023eccebd985e905f3bed237f468027e59 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 01:13:34 +0530
Subject: [PATCH 15/81] feat: fix view for saved tab to filter by workspace
---
routers/account.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index 9781cf6fb..a712f8a1d 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -255,9 +255,8 @@ def profile_tab(request: Request):
def all_saved_runs_tab(request: Request):
- prs = PublishedRun.objects.filter(
- created_by=request.user,
- ).order_by("-updated_at")
+ workspace = get_current_workspace(request.user, request.session)
+ prs = PublishedRun.objects.filter(workspace=workspace).order_by("-updated_at")
def _render_run(pr: PublishedRun):
workflow = Workflow(pr.workflow)
From 02740e869816340a4a7fcba9f6a51ef5a78de4a0 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 01:15:02 +0530
Subject: [PATCH 16/81] feat: add current balance & plan to members tab
+ other UI changes in that tab
---
workspaces/views.py | 81 ++++++++++++++++++++++++++++-----------------
1 file changed, 51 insertions(+), 30 deletions(-)
diff --git a/workspaces/views.py b/workspaces/views.py
index 4f8692301..1b7a8d739 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -10,6 +10,7 @@
from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_button
from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.user_date_widgets import render_local_date_attrs
+from payments.plans import PricingPlan
from .models import Workspace, WorkspaceInvite, WorkspaceMembership, WorkspaceRole
from .widgets import get_current_workspace, set_current_workspace
@@ -94,32 +95,53 @@ def render_workspace_by_membership(membership: WorkspaceMembership):
- current user
- current user's role in the workspace (and other metadata)
"""
+ from routers.account import billing_route
+
workspace = membership.workspace
- with gui.div(
- className="d-xs-block d-sm-flex flex-row-reverse justify-content-between"
- ):
- with gui.div(className="d-flex justify-content-center align-items-center"):
- edit_workspace_button_with_dialog(membership)
+ with gui.div(className="d-block d-sm-flex justify-content-between"):
+ col1 = gui.div(
+ className="d-block d-md-flex text-center text-sm-start align-items-center"
+ )
+ col2 = gui.div()
- with gui.div(className="d-flex align-items-center"):
- gui.image(
- workspace.logo or DEFAULT_WORKSPACE_LOGO,
- className="my-0 me-4 rounded",
- style={"width": "128px", "height": "128px", "object-fit": "contain"},
+ with col1:
+ gui.image(
+ workspace.logo or DEFAULT_WORKSPACE_LOGO,
+ className="my-0 me-4 rounded",
+ style={"width": "128px", "height": "128px", "object-fit": "contain"},
+ )
+ with gui.div(className="d-flex flex-column justify-content-end"):
+ plan = (
+ workspace.subscription
+ and PricingPlan.from_sub(workspace.subscription)
+ or PricingPlan.STARTER
)
- with gui.div(className="d-flex flex-column justify-content-center"):
- gui.write(f"# {workspace.display_name(membership.user)}")
- if workspace.domain_name:
- gui.write(
- f"Workspace Domain: `@{workspace.domain_name}`",
- className="text-muted",
- )
+
+ with gui.tag("h1", className="mb-0" if workspace.domain_name else ""):
+ gui.html(html_lib.escape(workspace.display_name(membership.user)))
+ if workspace.domain_name:
+ gui.caption(f"Domain: `@{workspace.domain_name}`")
+
+ billing_info = f"""
+ Credits: {workspace.balance}
+ Plan: {plan.title}
+ """.strip()
+ if membership.can_edit_workspace() and plan in (
+ PricingPlan.STARTER,
+ PricingPlan.CREATOR,
+ ):
+ billing_info += f" [Upgrade]({get_route_path(billing_route)})"
+ gui.write(billing_info, unsafe_allow_html=True)
+
+ if membership.can_edit_workspace():
+ with col2, gui.div(className="mt-2"):
+ edit_workspace_button_with_dialog(membership)
gui.newline()
with gui.div(className="d-flex justify-content-between align-items-center"):
- gui.write("## Members")
+ gui.write("#### Members")
member_invite_button_with_dialog(membership)
render_members_list(workspace=workspace, current_member=membership)
@@ -134,12 +156,14 @@ def render_workspace_by_membership(membership: WorkspaceMembership):
gui.newline()
dialog_ref = gui.use_confirm_dialog(key="leave-workspace", close_on_confirm=False)
- with gui.div(className="text-end"):
- if gui.button(
- label=f"{icons.sign_out} Leave",
- className="py-2 bg-danger border-danger text-light",
- ):
- dialog_ref.set_open(True)
+ with gui.div():
+ gui.write("#### Danger Zone")
+ with gui.div():
+ if gui.button(
+ label=f"{icons.sign_out} Leave",
+ className="py-2 text-danger",
+ ):
+ dialog_ref.set_open(True)
if dialog_ref.is_open:
new_owner_id = render_workspace_leave_dialog(dialog_ref, membership)
@@ -193,9 +217,6 @@ def member_invite_button_with_dialog(membership: WorkspaceMembership):
def edit_workspace_button_with_dialog(membership: WorkspaceMembership):
- if not membership.can_edit_workspace():
- return
-
ref = gui.use_confirm_dialog(key="edit-workspace", close_on_confirm=False)
if gui.button(label=f"{icons.edit} Edit"):
@@ -407,7 +428,7 @@ def render_membership_actions(
gui.button_with_confirm_dialog(
ref=ref,
trigger_label=f"{icons.remove_user} Remove",
- trigger_className="btn-sm my-0 py-0 bg-danger border-danger text-light",
+ trigger_className="btn-sm my-0 py-0 text-danger",
modal_title="#### Remove a Member",
modal_content=f"Are you sure you want to remove **{m.user.full_name()}** from **{m.workspace.display_name(m.user)}**?",
confirm_label="Remove",
@@ -425,7 +446,7 @@ def render_pending_invites_list(
if not pending_invites:
return
- gui.write("## Pending")
+ gui.write("#### Pending")
with gui.tag("table", className="table table-responsive"):
with gui.tag("thead"), gui.tag("tr"):
with gui.tag("th", scope="col"):
@@ -488,7 +509,7 @@ def render_invitation_actions(
gui.button_with_confirm_dialog(
ref=ref,
trigger_label=f"{icons.cancel} Cancel",
- trigger_className="btn-sm my-0 py-0 bg-danger border-danger text-light",
+ trigger_className="btn-sm my-0 py-0 text-danger",
modal_title="#### Cancel Invitation",
modal_content=f"Are you sure you want to cancel the invitation to **{invite.email}**?",
cancel_label="No, keep it",
From 18c0b8431c0bb77799955eb7d7e3237c7418f8af Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 04:30:13 +0530
Subject: [PATCH 17/81] feat: fix URLs for team & personal workspaces in
account settings
---
daras_ai_v2/base.py | 7 ++-
daras_ai_v2/profiles.py | 2 +-
routers/account.py | 117 ++++++++++++++++++++++++++++++++--------
routers/root.py | 9 +++-
workspaces/models.py | 2 +
workspaces/widgets.py | 67 +++++++++++++++++++++--
6 files changed, 173 insertions(+), 31 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 267833ba6..4bdec65cf 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -550,14 +550,17 @@ def _render_publish_form(
options = {
str(enum.value): enum.help_text() for enum in PublishedRunVisibility
}
- if self.request.user and self.request.user.handle:
+ if not self.current_workspace.is_personal:
+ # TODO: implement with workspace handles
+ pass
+ elif self.request.user and self.request.user.handle:
profile_url = self.request.user.handle.get_app_url()
pretty_profile_url = urls.remove_scheme(profile_url).rstrip("/")
options[
str(PublishedRunVisibility.PUBLIC.value)
] += f' on [{pretty_profile_url}]({profile_url})'
elif self.request.user and not self.request.user.is_anonymous:
- edit_profile_url = AccountTabs.profile.url_path
+ edit_profile_url = AccountTabs.profile.get_url_path(self.request)
options[
str(PublishedRunVisibility.PUBLIC.value)
] += f' on my [profile page]({edit_profile_url})'
diff --git a/daras_ai_v2/profiles.py b/daras_ai_v2/profiles.py
index b622fbf25..b52c7ebbb 100644
--- a/daras_ai_v2/profiles.py
+++ b/daras_ai_v2/profiles.py
@@ -94,7 +94,7 @@ def user_profile_header(request, user: AppUser):
from routers.account import AccountTabs
with gui.link(
- to=AccountTabs.profile.url_path,
+ to=AccountTabs.profile.get_url_path(request),
className="text-decoration-none btn btn-theme btn-secondary mb-0",
):
gui.html(f"{icons.edit} Edit Profile")
diff --git a/routers/account.py b/routers/account.py
index a712f8a1d..269459ada 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -21,9 +21,13 @@
from payments.webhooks import PaypalWebhookHandler
from routers.custom_api_router import CustomAPIRouter
from routers.root import page_wrapper, get_og_url_path
-from workspaces.models import WorkspaceInvite
+from workspaces.models import Workspace, WorkspaceInvite, WorkspaceMembership
from workspaces.views import invitation_page, workspaces_page
-from workspaces.widgets import get_current_workspace
+from workspaces.widgets import (
+ get_current_workspace,
+ get_workspaces_route_path,
+ set_current_workspace,
+)
app = CustomAPIRouter()
@@ -82,17 +86,19 @@ def payment_processing_route(
def account_route(request: Request):
from daras_ai_v2.base import BasePage
- if (
- not BasePage.is_user_admin(request.user)
- or get_current_workspace(request.user, request.session).is_personal
- ):
+ workspace = get_current_workspace(request.user, request.session)
+ if not BasePage.is_user_admin(request.user) or workspace.is_personal:
raise gui.RedirectException(get_route_path(profile_route))
else:
- raise gui.RedirectException(get_route_path(workspaces_route))
+ raise gui.RedirectException(
+ get_workspaces_route_path(workspaces_route, workspace)
+ )
@gui.route(app, "/account/billing/")
-def billing_route(request: Request):
+@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/billing/")
+def billing_route(request: Request, workspace_slug: str, workspace_hashid: str | None):
+ validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.billing):
billing_tab(request)
url = get_og_url_path(request)
@@ -124,7 +130,16 @@ def profile_route(request: Request):
@gui.route(app, "/saved/")
-def saved_route(request: Request):
+def saved_shortcut_route():
+ raise RedirectException(get_route_path(saved_route))
+
+
+@gui.route(app, "/account/saved/")
+@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/saved/")
+def saved_route(
+ request: Request, workspace_slug: str, workspace_hashid: str | None = None
+):
+ validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.saved):
all_saved_runs_tab(request)
url = get_og_url_path(request)
@@ -140,7 +155,9 @@ def saved_route(request: Request):
@gui.route(app, "/account/api-keys/")
-def api_keys_route(request: Request):
+@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/api-keys/")
+def api_keys_route(request: Request, workspace_slug: str, workspace_hashid: str | None):
+ validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.api_keys):
api_keys_tab(request)
url = get_og_url_path(request)
@@ -155,9 +172,31 @@ def api_keys_route(request: Request):
)
-@gui.route(app, "/workspaces/")
-def workspaces_route(request: Request):
- with account_page_wrapper(request, AccountTabs.workspaces):
+@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/")
+def workspaces_route(
+ request: Request,
+ workspace_hashid: str,
+ workspace_slug: str | None,
+):
+ raise RedirectException(
+ get_route_path(
+ workspaces_members_route,
+ path_params={
+ "workspace_slug": workspace_slug,
+ "workspace_hashid": workspace_hashid,
+ },
+ )
+ )
+
+
+@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/members/")
+def workspaces_members_route(
+ request: Request,
+ workspace_hashid: str,
+ workspace_slug: str | None,
+):
+ validate_and_set_current_workspace(request, workspace_hashid)
+ with account_page_wrapper(request, AccountTabs.members):
workspaces_page(request.user, request.session)
url = get_og_url_path(request)
@@ -214,15 +253,11 @@ class TabData(typing.NamedTuple):
class AccountTabs(TabData, Enum):
profile = TabData(title=f"{icons.profile} Profile", route=profile_route)
- workspaces = TabData(title=f"{icons.company} Members", route=workspaces_route)
+ members = TabData(title=f"{icons.company} Members", route=workspaces_members_route)
saved = TabData(title=f"{icons.save} Saved", route=saved_route)
api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route)
billing = TabData(title=f"{icons.billing} Billing", route=billing_route)
- @property
- def url_path(self) -> str:
- return get_route_path(self.route)
-
@classmethod
def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
from daras_ai_v2.base import BasePage
@@ -230,7 +265,7 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
ret = list(cls)
workspace = get_current_workspace(request.user, request.session)
if not BasePage.is_user_admin(request.user) or workspace.is_personal:
- ret.remove(cls.workspaces)
+ ret.remove(cls.members)
if not workspace.is_personal:
ret.remove(cls.profile)
@@ -239,6 +274,13 @@ def get_tabs_for_request(cls, request: Request) -> list["AccountTabs"]:
return ret
+ def get_url_path(self, request: Request) -> str:
+ workspace = get_current_workspace(request.user, request.session)
+ if workspace.is_personal or self == AccountTabs.profile:
+ return get_route_path(self.route)
+ else:
+ return get_workspaces_route_path(self.route, workspace)
+
def billing_tab(request: Request):
workspace = get_current_workspace(request.user, request.session)
@@ -251,6 +293,9 @@ def billing_tab(request: Request):
def profile_tab(request: Request):
+ workspace = get_current_workspace(request.user, request.session)
+ if not workspace.is_personal:
+ raise gui.RedirectException(get_route_path(account_route))
return edit_user_profile_page(user=request.user)
@@ -281,7 +326,7 @@ def _render_run(pr: PublishedRun):
f"profile page at {request.user.handle.get_app_url()}."
)
else:
- edit_profile_url = AccountTabs.profile.url_path
+ edit_profile_url = AccountTabs.profile.get_url_path(request)
gui.caption(
"All your Saved workflows are here. Public ones will be listed on your "
f"profile page if you [create a username]({edit_profile_url})."
@@ -300,17 +345,20 @@ def api_keys_tab(request: Request):
@contextmanager
-def account_page_wrapper(request: Request, current_tab: TabData):
+def account_page_wrapper(request: Request, current_tab: AccountTabs):
if not request.user or request.user.is_anonymous:
next_url = request.query_params.get("next", "/account/")
redirect_url = furl("/login", query_params={"next": next_url})
raise gui.RedirectException(str(redirect_url))
- with page_wrapper(request):
+ with page_wrapper(request, route_fn=current_tab.route):
+ if request.url.path != current_tab.get_url_path(request):
+ raise gui.RedirectException(current_tab.get_url_path(request))
+
gui.div(className="mt-5")
with gui.nav_tabs():
for tab in AccountTabs.get_tabs_for_request(request):
- with gui.nav_item(tab.url_path, active=tab == current_tab):
+ with gui.nav_item(tab.get_url_path(request), active=tab == current_tab):
gui.html(tab.title)
with gui.nav_tab_content():
@@ -328,3 +376,26 @@ def threaded_paypal_handle_subscription_updated(subscription_id: str) -> bool:
logger.exception(f"Unexpected PayPal error for sub: {subscription_id}")
return False
return True
+
+
+def validate_and_set_current_workspace(request: Request, workspace_hashid: str | None):
+ from routers.root import login
+
+ if not request.user or request.user.is_anonymous:
+ next_url = request.url.path
+ redirect_url = str(furl(get_route_path(login), query_params={"next": next_url}))
+ raise gui.RedirectException(redirect_url)
+
+ if not workspace_hashid:
+ # not a workspace URL, we set the current workspace to user's personal workspace
+ workspace, _ = request.user.get_or_create_personal_workspace()
+ set_current_workspace(request.session, workspace.id)
+ return
+
+ try:
+ workspace_id = Workspace.api_hashids.decode(workspace_hashid)[0]
+ WorkspaceMembership.objects.get(workspace_id=workspace_id, user=request.user)
+ except (IndexError, WorkspaceMembership.DoesNotExist):
+ return Response(status_code=404)
+ else:
+ set_current_workspace(request.session, workspace_id)
diff --git a/routers/root.py b/routers/root.py
index 3ef490654..25274b421 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -704,7 +704,12 @@ def get_og_url_path(request) -> str:
@contextmanager
-def page_wrapper(request: Request, className=""):
+def page_wrapper(
+ request: Request,
+ className="",
+ *,
+ route_fn: typing.Callable | None = None,
+):
from daras_ai_v2.base import BasePage
context = {
@@ -730,7 +735,7 @@ def page_wrapper(request: Request, className=""):
),
gui.div(style=dict(minWidth="200pt")),
):
- workspace_selector(request.user, request.session)
+ workspace_selector(request.user, request.session, route_fn=route_fn)
with gui.div(id="main-content", className="container-xxl " + className):
yield
diff --git a/workspaces/models.py b/workspaces/models.py
index 14e459b00..0a0dfb38e 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -127,6 +127,8 @@ class Workspace(SafeDeleteModel):
objects = WorkspaceQuerySet.as_manager()
+ api_hashids = hashids.Hashids(salt=settings.HASHIDS_API_SALT + "/workspaces")
+
class Meta:
constraints = [
models.UniqueConstraint(
diff --git a/workspaces/widgets.py b/workspaces/widgets.py
index bc32e4c39..5b2ffccdc 100644
--- a/workspaces/widgets.py
+++ b/workspaces/widgets.py
@@ -1,3 +1,5 @@
+import typing
+
import gooey_gui as gui
from app_users.models import AppUser
@@ -8,7 +10,12 @@
SESSION_SELECTED_WORKSPACE = "selected-workspace-id"
-def workspace_selector(user: AppUser, session: dict):
+def workspace_selector(
+ user: AppUser,
+ session: dict,
+ *,
+ route_fn: typing.Callable | None = None,
+):
from routers.account import workspaces_route
workspaces = Workspace.objects.filter(
@@ -34,7 +41,9 @@ def workspace_selector(user: AppUser, session: dict):
workspace.create_with_owner()
gui.session_state[SESSION_SELECTED_WORKSPACE] = workspace.id
session[SESSION_SELECTED_WORKSPACE] = workspace.id
- raise gui.RedirectException(get_route_path(workspaces_route))
+ raise gui.RedirectException(
+ get_workspaces_route_path(workspaces_route, workspace)
+ )
selected_id = gui.selectbox(
label="",
@@ -45,9 +54,17 @@ def workspace_selector(user: AppUser, session: dict):
value=session.get(SESSION_SELECTED_WORKSPACE),
)
set_current_workspace(session, int(selected_id))
+ if route_fn:
+ workspace = next(w for w in workspaces if w.id == int(selected_id))
+ next_route_fn = get_next_route_fn(route_fn, workspace)
+ if route_fn != next_route_fn:
+ # redirect is needed
+ raise gui.RedirectException(
+ get_workspaces_route_path(next_route_fn, workspace)
+ )
-def get_current_workspace(user: AppUser, session: dict) -> "Workspace":
+def get_current_workspace(user: AppUser, session: dict) -> Workspace:
try:
workspace_id = session[SESSION_SELECTED_WORKSPACE]
return Workspace.objects.get(
@@ -63,3 +80,47 @@ def get_current_workspace(user: AppUser, session: dict) -> "Workspace":
def set_current_workspace(session: dict, workspace_id: int):
session[SESSION_SELECTED_WORKSPACE] = workspace_id
+
+
+def get_next_route_fn(
+ route_fn: typing.Callable, workspace: Workspace
+) -> typing.Callable:
+ """
+ When we need to redirect after user changes the workspace.
+ """
+ from routers.account import (
+ account_route,
+ profile_route,
+ workspaces_route,
+ workspaces_members_route,
+ )
+
+ if workspace.is_personal and route_fn in (
+ workspaces_members_route,
+ workspaces_route,
+ ):
+ # personal workspaces don't have a members page
+ # account_route does the right redirect instead
+ return account_route
+ elif not workspace.is_personal and route_fn == profile_route:
+ # team workspaces don't have a profile page
+ # workspaces_route does the right redirect instead
+ return workspaces_route
+
+ return route_fn
+
+
+def get_workspaces_route_path(route_fn: typing.Callable, workspace: Workspace):
+ """
+ For routes like /workspaces/{workspace_slug}-{workspace_hashid}/...
+ """
+ if workspace.is_personal:
+ return get_route_path(route_fn)
+ else:
+ return get_route_path(
+ route_fn,
+ path_params={
+ "workspace_hashid": Workspace.api_hashids.encode(workspace.id),
+ "workspace_slug": workspace.get_slug(),
+ },
+ )
From 396e541b30939a1aaed14e1f09992e1934aa3b8b Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 04:38:06 +0530
Subject: [PATCH 18/81] feat: allow published runs to be created with workspace
---
bots/models.py | 7 ++++++-
daras_ai_v2/base.py | 5 +++++
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/bots/models.py b/bots/models.py
index 60a53932f..92b2485b1 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -22,8 +22,9 @@
from gooeysite.custom_create import get_or_create_lazy
if typing.TYPE_CHECKING:
- from daras_ai_v2.base import BasePage
import celery.result
+ from daras_ai_v2.base import BasePage
+ from workspaces.models import Workspace
CHATML_ROLE_USER = "user"
CHATML_ROLE_ASSISSTANT = "assistant"
@@ -1627,6 +1628,7 @@ def get_or_create_with_version(
published_run_id: str,
saved_run: SavedRun,
user: AppUser | None,
+ workspace: "Workspace | None",
title: str,
notes: str,
visibility: PublishedRunVisibility,
@@ -1639,6 +1641,7 @@ def get_or_create_with_version(
**kwargs,
saved_run=saved_run,
user=user,
+ workspace=workspace,
title=title,
notes=notes,
visibility=visibility,
@@ -1652,6 +1655,7 @@ def create_with_version(
published_run_id: str,
saved_run: SavedRun,
user: AppUser | None,
+ workspace: "Workspace | None",
title: str,
notes: str,
visibility: PublishedRunVisibility,
@@ -1662,6 +1666,7 @@ def create_with_version(
published_run_id=published_run_id,
created_by=user,
last_edited_by=user,
+ workspace=workspace,
title=title,
)
pr.add_version(
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 4bdec65cf..1538b51f2 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -637,6 +637,7 @@ def _render_publish_form(
published_run_id=get_random_doc_id(),
saved_run=sr,
user=self.request.user,
+ workspace=self.current_workspace,
title=published_run_title.strip(),
notes=published_run_notes.strip(),
visibility=published_run_visibility,
@@ -741,6 +742,7 @@ def _saved_options_modal(self, *, sr: SavedRun, pr: PublishedRun):
published_run_id=get_random_doc_id(),
saved_run=sr,
user=self.request.user,
+ workspace=self.current_workspace,
title=f"{pr.title} (Copy)",
notes=pr.notes,
visibility=PublishedRunVisibility(PublishedRunVisibility.UNLISTED),
@@ -1194,6 +1196,7 @@ def get_root_pr(cls) -> PublishedRun:
defaults=dict(state=cls.load_state_defaults({})),
)[0],
user=None,
+ workspace=None,
title=cls.title,
notes=cls().preview_description(state=cls.sane_defaults),
visibility=PublishedRunVisibility.PUBLIC,
@@ -1206,6 +1209,7 @@ def create_published_run(
published_run_id: str,
saved_run: SavedRun,
user: AppUser | None,
+ workspace: "Workspace",
title: str,
notes: str,
visibility: PublishedRunVisibility,
@@ -1215,6 +1219,7 @@ def create_published_run(
published_run_id=published_run_id,
saved_run=saved_run,
user=user,
+ workspace=workspace,
title=title,
notes=notes,
visibility=visibility,
From b3e4563b7d47e9fd7dc85c4233ab4fcc8a88286e Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 06:12:38 +0530
Subject: [PATCH 19/81] fix routes in account
---
routers/account.py | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index 269459ada..940a628be 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -97,7 +97,11 @@ def account_route(request: Request):
@gui.route(app, "/account/billing/")
@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/billing/")
-def billing_route(request: Request, workspace_slug: str, workspace_hashid: str | None):
+def billing_route(
+ request: Request,
+ workspace_slug: str | None = None,
+ workspace_hashid: str | None = None,
+):
validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.billing):
billing_tab(request)
@@ -137,7 +141,9 @@ def saved_shortcut_route():
@gui.route(app, "/account/saved/")
@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/saved/")
def saved_route(
- request: Request, workspace_slug: str, workspace_hashid: str | None = None
+ request: Request,
+ workspace_slug: str | None = None,
+ workspace_hashid: str | None = None,
):
validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.saved):
@@ -156,7 +162,11 @@ def saved_route(
@gui.route(app, "/account/api-keys/")
@gui.route(app, "/workspaces/{workspace_slug}-{workspace_hashid}/api-keys/")
-def api_keys_route(request: Request, workspace_slug: str, workspace_hashid: str | None):
+def api_keys_route(
+ request: Request,
+ workspace_slug: str | None = None,
+ workspace_hashid: str | None = None,
+):
validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.api_keys):
api_keys_tab(request)
@@ -193,7 +203,7 @@ def workspaces_route(
def workspaces_members_route(
request: Request,
workspace_hashid: str,
- workspace_slug: str | None,
+ workspace_slug: str | None = None,
):
validate_and_set_current_workspace(request, workspace_hashid)
with account_page_wrapper(request, AccountTabs.members):
From 2f997d1780d24b9e643a7b5fa35c7558b1d80a82 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 06:13:49 +0530
Subject: [PATCH 20/81] fix: register workspace when duplicating published run
---
bots/models.py | 2 ++
daras_ai_v2/base.py | 2 ++
2 files changed, 4 insertions(+)
diff --git a/bots/models.py b/bots/models.py
index 92b2485b1..a01ee4baf 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -1780,6 +1780,7 @@ def duplicate(
self,
*,
user: AppUser,
+ workspace: "Workspace",
title: str,
notes: str,
visibility: PublishedRunVisibility,
@@ -1789,6 +1790,7 @@ def duplicate(
published_run_id=get_random_doc_id(),
saved_run=self.saved_run,
user=user,
+ workspace=workspace,
title=title,
notes=notes,
visibility=visibility,
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 1538b51f2..8be28a874 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -729,6 +729,7 @@ def _saved_options_modal(self, *, sr: SavedRun, pr: PublishedRun):
if duplicate_button:
duplicate_pr = pr.duplicate(
user=self.request.user,
+ workspace=self.current_workspace,
title=f"{pr.title} (Copy)",
notes=pr.notes,
visibility=PublishedRunVisibility(PublishedRunVisibility.UNLISTED),
@@ -782,6 +783,7 @@ def _unsaved_options_button_with_dialog(self):
pr = self.current_pr
duplicate_pr = pr.duplicate(
user=self.request.user,
+ workspace=self.current_workspace,
title=f"{self.request.user.first_name_possesive()} {pr.title}",
notes=pr.notes,
visibility=PublishedRunVisibility(PublishedRunVisibility.UNLISTED),
From 05d09defd5f90c14a07981449b4b1990ec86abd8 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 06:16:34 +0530
Subject: [PATCH 21/81] render workspace as the owner on published runs
---
daras_ai_v2/base.py | 74 ++++++++++++++++++++++++++++++++------------
recipes/VideoBots.py | 1 +
2 files changed, 55 insertions(+), 20 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 8be28a874..9d9362253 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -83,9 +83,7 @@
from routers.account import AccountTabs
from routers.root import RecipeTabs
from workspaces.widgets import get_current_workspace
-
-if typing.TYPE_CHECKING:
- from workspaces.models import Workspace
+from workspaces.models import Workspace
DEFAULT_META_IMG = (
@@ -382,7 +380,7 @@ def _render_header(self):
)
if is_example:
- author = pr.created_by
+ author = pr.workspace
else:
author = self.current_sr_user or sr.get_creator()
if not is_root_example:
@@ -1118,7 +1116,7 @@ def update_flag_for_run(self, is_flagged: bool):
gui.session_state["is_flagged"] = is_flagged
@cached_property
- def current_workspace(self) -> "Workspace":
+ def current_workspace(self) -> Workspace:
assert self.request.user
return get_current_workspace(self.request.user, self.request.session)
@@ -1211,7 +1209,7 @@ def create_published_run(
published_run_id: str,
saved_run: SavedRun,
user: AppUser | None,
- workspace: "Workspace",
+ workspace: Workspace,
title: str,
notes: str,
visibility: PublishedRunVisibility,
@@ -1247,16 +1245,56 @@ def render_form_v2(self):
def validate_form_v2(self):
pass
- @staticmethod
+ @classmethod
def render_author(
- user: AppUser,
+ cls,
+ workspace_or_user: "Workspace | AppUser | None",
*,
image_size: str = "30px",
responsive: bool = True,
show_as_link: bool = True,
text_size: str | None = None,
):
- if not user or (not user.photo_url and not user.display_name):
+ if not workspace_or_user:
+ return
+
+ link = None
+ if isinstance(workspace_or_user, Workspace):
+ workspace = workspace_or_user
+ photo = workspace.logo
+ if not photo and workspace.is_personal:
+ photo = workspace.created_by.photo_url
+ name = workspace.display_name()
+ if show_as_link and workspace.is_personal and workspace.created_by.handle:
+ link = workspace.created_by.handle.get_app_url()
+ else:
+ user = workspace_or_user
+ photo = user.photo_url
+ name = user.display_name
+ if show_as_link and user.handle:
+ link = user.handle.get_app_url()
+
+ return cls._render_author(
+ photo=photo,
+ name=name,
+ link=link,
+ image_size=image_size,
+ responsive=responsive,
+ text_size=text_size,
+ )
+
+ @classmethod
+ def _render_author(
+ cls,
+ photo: str | None,
+ name: str | None,
+ link: str | None,
+ *,
+ image_size: str,
+ responsive: bool,
+ text_size: str | None,
+ ):
+ if not photo and not name:
return
responsive_image_size = (
@@ -1268,13 +1306,9 @@ def render_author(
if responsive:
class_name += "-responsive"
- if show_as_link and user and user.handle:
- linkto = gui.link(to=user.handle.get_app_url())
- else:
- linkto = gui.dummy()
-
+ linkto = link and gui.link(to=link) or gui.dummy()
with linkto, gui.div(className="d-flex align-items-center"):
- if user.photo_url:
+ if photo:
gui.html(
f"""
"""
)
- gui.image(user.photo_url, className=class_name)
+ gui.image(photo, className=class_name)
- if user.display_name:
+ if name:
name_style = {"fontSize": text_size} if text_size else {}
with gui.tag("span", style=name_style):
- gui.html(html.escape(user.display_name))
+ gui.html(html.escape(name))
def get_credits_click_url(self):
if self.request.user and self.request.user.is_anonymous:
@@ -1962,10 +1996,10 @@ def _render_example_preview(
):
tb = get_title_breadcrumbs(self, published_run.saved_run, published_run)
- if published_run.created_by:
+ if published_run.workspace:
with gui.div(className="mb-1 text-truncate", style={"height": "1.5rem"}):
self.render_author(
- published_run.created_by,
+ published_run.workspace,
image_size="20px",
text_size="0.9rem",
)
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index 9cd906dd5..62291ddb7 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -1147,6 +1147,7 @@ def render_integrations_add(self, label: str, run_title: str, pr: PublishedRun):
run_title = f"{self.request.user and self.request.user.first_name_possesive()} {run_title}"
pr = pr.duplicate(
user=self.request.user,
+ workspace=self.current_workspace,
title=run_title,
notes=pr.notes,
visibility=PublishedRunVisibility.UNLISTED,
From 12d858a81073214c652fee217a867699849db952 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sat, 21 Sep 2024 14:33:11 +0530
Subject: [PATCH 22/81] fix rendering of validationerror messages
---
workspaces/views.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/workspaces/views.py b/workspaces/views.py
index 1b7a8d739..b8de45d5c 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -83,7 +83,7 @@ def render_workspace_creation_view(user: AppUser):
try:
workspace.create_with_owner()
except ValidationError as e:
- gui.write(e.message, className="text-danger")
+ gui.write("\n".join(e.messages), className="text-danger")
else:
gui.rerun()
@@ -210,7 +210,7 @@ def member_invite_button_with_dialog(membership: WorkspaceMembership):
defaults=dict(role=role),
)
except ValidationError as e:
- gui.write(e.message, className="text-danger")
+ gui.write("\n".join(e.messages), className="text-danger")
else:
ref.set_open(False)
gui.rerun()
@@ -238,7 +238,7 @@ def edit_workspace_button_with_dialog(membership: WorkspaceMembership):
workspace_copy.full_clean()
except ValidationError as e:
# newlines in markdown
- gui.write(e.message, className="text-danger")
+ gui.write("\n".join(e.messages), className="text-danger")
else:
workspace_copy.save()
membership.workspace.refresh_from_db()
From 6c3d8451ca13b47aba351cd3ffbf24213b80ca7a Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 23 Sep 2024 20:47:25 +0530
Subject: [PATCH 23/81] Refactor workspace_selector with current_tab for
clarity
---
routers/account.py | 2 +-
routers/root.py | 10 +++++++--
workspaces/widgets.py | 51 +++++++++++++++----------------------------
3 files changed, 26 insertions(+), 37 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index 940a628be..15d7202bb 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -361,7 +361,7 @@ def account_page_wrapper(request: Request, current_tab: AccountTabs):
redirect_url = furl("/login", query_params={"next": next_url})
raise gui.RedirectException(str(redirect_url))
- with page_wrapper(request, route_fn=current_tab.route):
+ with page_wrapper(request, current_tab=current_tab):
if request.url.path != current_tab.get_url_path(request):
raise gui.RedirectException(current_tab.get_url_path(request))
diff --git a/routers/root.py b/routers/root.py
index 25274b421..b3e9faa98 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -45,6 +45,10 @@
from routers.static_pages import serve_static_file
from workspaces.widgets import workspace_selector
+if typing.TYPE_CHECKING:
+ from routers.account import AccountTabs
+
+
app = CustomAPIRouter()
DEFAULT_LOGIN_REDIRECT = "/explore/"
@@ -708,7 +712,7 @@ def page_wrapper(
request: Request,
className="",
*,
- route_fn: typing.Callable | None = None,
+ current_tab: "AccountTabs | None" = None,
):
from daras_ai_v2.base import BasePage
@@ -735,7 +739,9 @@ def page_wrapper(
),
gui.div(style=dict(minWidth="200pt")),
):
- workspace_selector(request.user, request.session, route_fn=route_fn)
+ workspace_selector(
+ request.user, request.session, current_tab=current_tab
+ )
with gui.div(id="main-content", className="container-xxl " + className):
yield
diff --git a/workspaces/widgets.py b/workspaces/widgets.py
index 5b2ffccdc..a4df379b3 100644
--- a/workspaces/widgets.py
+++ b/workspaces/widgets.py
@@ -7,6 +7,10 @@
from daras_ai_v2.fastapi_tricks import get_route_path
from .models import Workspace
+if typing.TYPE_CHECKING:
+ from routers.account import AccountTabs
+
+
SESSION_SELECTED_WORKSPACE = "selected-workspace-id"
@@ -14,9 +18,9 @@ def workspace_selector(
user: AppUser,
session: dict,
*,
- route_fn: typing.Callable | None = None,
+ current_tab: "AccountTabs | None" = None,
):
- from routers.account import workspaces_route
+ from routers.account import workspaces_route, account_route
workspaces = Workspace.objects.filter(
memberships__user=user, memberships__deleted__isnull=True
@@ -54,14 +58,11 @@ def workspace_selector(
value=session.get(SESSION_SELECTED_WORKSPACE),
)
set_current_workspace(session, int(selected_id))
- if route_fn:
+ if current_tab:
workspace = next(w for w in workspaces if w.id == int(selected_id))
- next_route_fn = get_next_route_fn(route_fn, workspace)
- if route_fn != next_route_fn:
- # redirect is needed
- raise gui.RedirectException(
- get_workspaces_route_path(next_route_fn, workspace)
- )
+ if not validate_tab_for_workspace(current_tab, workspace):
+ # account_route will redirect to the correct tab
+ raise gui.RedirectException(get_route_path(account_route))
def get_current_workspace(user: AppUser, session: dict) -> Workspace:
@@ -82,32 +83,14 @@ def set_current_workspace(session: dict, workspace_id: int):
session[SESSION_SELECTED_WORKSPACE] = workspace_id
-def get_next_route_fn(
- route_fn: typing.Callable, workspace: Workspace
-) -> typing.Callable:
- """
- When we need to redirect after user changes the workspace.
- """
- from routers.account import (
- account_route,
- profile_route,
- workspaces_route,
- workspaces_members_route,
- )
+def validate_tab_for_workspace(tab: "AccountTabs", workspace: Workspace) -> bool:
+ from routers.account import AccountTabs
- if workspace.is_personal and route_fn in (
- workspaces_members_route,
- workspaces_route,
- ):
- # personal workspaces don't have a members page
- # account_route does the right redirect instead
- return account_route
- elif not workspace.is_personal and route_fn == profile_route:
- # team workspaces don't have a profile page
- # workspaces_route does the right redirect instead
- return workspaces_route
-
- return route_fn
+ if tab == AccountTabs.members:
+ return not workspace.is_personal
+ if tab == AccountTabs.profile:
+ return workspace.is_personal
+ return True
def get_workspaces_route_path(route_fn: typing.Callable, workspace: Workspace):
From e2834e49de343799e734173a8ae9d8436767b4aa Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 25 Sep 2024 17:21:38 +0530
Subject: [PATCH 24/81] Add workspace selector to 'Save'/'Update' menu
---
app_users/models.py | 7 +++++++
daras_ai_v2/base.py | 35 +++++++++++++++++++++++++++++++++--
workspaces/models.py | 4 +++-
workspaces/widgets.py | 31 ++++++++++++++++++++-----------
4 files changed, 63 insertions(+), 14 deletions(-)
diff --git a/app_users/models.py b/app_users/models.py
index 4dba297e4..7f740bf16 100644
--- a/app_users/models.py
+++ b/app_users/models.py
@@ -233,6 +233,13 @@ def get_or_create_personal_workspace(self) -> tuple["Workspace", bool]:
return Workspace.objects.get_or_create_from_user(self)
+ def get_workspaces(self) -> models.QuerySet["Workspace"]:
+ from workspaces.models import Workspace
+
+ return Workspace.objects.filter(
+ memberships__user=self, memberships__deleted__isnull=True
+ )
+
class TransactionReason(models.IntegerChoices):
DEDUCT = 1, "Deduct"
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 9d9362253..4625dead5 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -82,7 +82,12 @@
)
from routers.account import AccountTabs
from routers.root import RecipeTabs
-from workspaces.widgets import get_current_workspace
+from workspaces.widgets import (
+ create_workspace_with_defaults,
+ get_current_workspace,
+ set_current_workspace,
+ workspace_selector,
+)
from workspaces.models import Workspace
@@ -600,6 +605,31 @@ def _render_publish_form(
value=(pr.notes or self.preview_description(gui.session_state) or ""),
)
+ col1, col2 = gui.columns([1, 3])
+ with col1, gui.div(className="mt-2"):
+ gui.write("###### Workspace")
+ with col2:
+ if self.request.user.get_workspaces().count() > 1:
+ workspace_selector(self.request.user, self.request.session)
+ else:
+ with gui.div(className="p-2 mb-2"):
+ self.render_author(
+ self.current_workspace,
+ show_as_link=False,
+ current_user=self.request.user,
+ )
+ with gui.div(className="align-middle alert alert-warning"):
+ gui.html(icons.company + " ")
+ if gui.button(
+ "Create a team workspace",
+ type="link",
+ className="d-inline m-0",
+ ):
+ workspace = create_workspace_with_defaults(self.request.user)
+ set_current_workspace(self.request.session, workspace.id)
+ gui.rerun()
+ gui.html(" " + "to edit with others")
+
self._render_admin_options(sr, pr)
if not dialog.pressed_confirm:
@@ -1254,6 +1284,7 @@ def render_author(
responsive: bool = True,
show_as_link: bool = True,
text_size: str | None = None,
+ current_user: AppUser | None = None,
):
if not workspace_or_user:
return
@@ -1264,7 +1295,7 @@ def render_author(
photo = workspace.logo
if not photo and workspace.is_personal:
photo = workspace.created_by.photo_url
- name = workspace.display_name()
+ name = workspace.display_name(current_user=current_user)
if show_as_link and workspace.is_personal and workspace.created_by.handle:
link = workspace.created_by.handle.get_app_url()
else:
diff --git a/workspaces/models.py b/workspaces/models.py
index 0a0dfb38e..c19ac07df 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -317,7 +317,9 @@ def display_name(self, current_user: AppUser | None = None) -> str:
elif (
self.is_personal and current_user and self.created_by_id == current_user.id
):
- return "Personal Account"
+ return f"{current_user.full_name()} (personal)"
+ elif self.is_personal:
+ return self.created_by.full_name()
else:
return f"{self.created_by.first_name_possesive()} Workspace"
diff --git a/workspaces/widgets.py b/workspaces/widgets.py
index a4df379b3..1721a0015 100644
--- a/workspaces/widgets.py
+++ b/workspaces/widgets.py
@@ -20,11 +20,9 @@ def workspace_selector(
*,
current_tab: "AccountTabs | None" = None,
):
- from routers.account import workspaces_route, account_route
+ from routers.account import account_route, workspaces_route
- workspaces = Workspace.objects.filter(
- memberships__user=user, memberships__deleted__isnull=True
- ).order_by("-is_personal", "-created_at")
+ workspaces = user.get_workspaces().order_by("-is_personal", "-created_at")
if not workspaces:
workspaces = [user.get_or_create_personal_workspace()[0]]
@@ -38,13 +36,10 @@ def workspace_selector(
}
if gui.session_state.get(SESSION_SELECTED_WORKSPACE) == "":
- name = f"{user.first_name_possesive()} Team Workspace"
- if len(workspaces) > 1:
- name += f" {len(workspaces) - 1}"
- workspace = Workspace(name=name, created_by=user)
- workspace.create_with_owner()
- gui.session_state[SESSION_SELECTED_WORKSPACE] = workspace.id
- session[SESSION_SELECTED_WORKSPACE] = workspace.id
+ suffix = f" {len(workspaces) - 1}" if len(workspaces) > 1 else ""
+ name = get_default_name_for_new_workspace(user, suffix=suffix)
+ workspace = create_workspace_with_defaults(user, name=name)
+ set_current_workspace(session, workspace.id)
raise gui.RedirectException(
get_workspaces_route_path(workspaces_route, workspace)
)
@@ -107,3 +102,17 @@ def get_workspaces_route_path(route_fn: typing.Callable, workspace: Workspace):
"workspace_slug": workspace.get_slug(),
},
)
+
+
+def create_workspace_with_defaults(user: AppUser, name: str | None = None):
+ if not name:
+ workspace_count = user.get_workspaces().count()
+ suffix = f" {workspace_count - 1}" if workspace_count > 1 else ""
+ name = get_default_name_for_new_workspace(user, suffix=suffix)
+ workspace = Workspace(name=name, created_by=user)
+ workspace.create_with_owner()
+ return workspace
+
+
+def get_default_name_for_new_workspace(user: AppUser, suffix: str = "") -> str:
+ return f"{user.first_name_possesive()} Team Workspace" + suffix
From d8caf286378b5453106eddafe32293d7c50c39b6 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 25 Sep 2024 18:36:04 +0530
Subject: [PATCH 25/81] Add workspace API keys, bot integrations (#476)
* fix API key authentication to return Workspace
* feat: scope API keys to workspace - creation and display widget
* Add workspace to published run & sr.submit_api_call
* Bill to workspace for recipe_functions & safety checker
* call_recipe_functions should accept workspace param
* Pass workspace= in runs created by bulk runner
* Add migration to add workspace to published runs
* Use workspace in bot integration's submit_api_call
* Add script to migrate bot integrations & published runs to workspaces
* fix logic to create fixtures to have workspaces
* modify logic to display workspace name to show author's name if is_personal
* fix logic for creation & fetching of API keys (for workspace)
* Add API Keys model and fallback to firebase
* display api_keys from DB and Firebase both
* disable migration from firebase in manage_api_keys_widget
* Add script to migrate API keys from firebase
* Add migration for adding workspace to bot integrations
---
api_keys/__init__.py | 0
api_keys/admin.py | 8 ++
api_keys/apps.py | 7 ++
api_keys/migrations/0001_initial.py | 33 ++++++
api_keys/migrations/__init__.py | 0
api_keys/models.py | 67 ++++++++++++
api_keys/tests.py | 3 +
api_keys/views.py | 3 +
auth/token_authentication.py | 59 ++++++++---
.../0084_botintegration_workspace.py | 20 ++++
bots/models.py | 14 ++-
bots/tasks.py | 5 +-
daras_ai_v2/base.py | 4 +-
daras_ai_v2/bots.py | 19 ++--
daras_ai_v2/manage_api_keys_widget.py | 100 ++++++++++--------
daras_ai_v2/safety_checker.py | 4 +-
daras_ai_v2/settings.py | 1 +
functions/recipe_functions.py | 3 +
recipes/BulkRunner.py | 2 +
recipes/VideoBots.py | 3 +-
recipes/VideoBotsStats.py | 6 +-
routers/account.py | 2 +-
routers/api.py | 40 ++++---
routers/broadcast_api.py | 9 +-
routers/root.py | 2 +-
routers/twilio_api.py | 2 +-
scripts/create_fixture.py | 5 +
scripts/migrate_workspaces.py | 85 ++++++++++++++-
workspaces/models.py | 7 +-
29 files changed, 401 insertions(+), 112 deletions(-)
create mode 100644 api_keys/__init__.py
create mode 100644 api_keys/admin.py
create mode 100644 api_keys/apps.py
create mode 100644 api_keys/migrations/0001_initial.py
create mode 100644 api_keys/migrations/__init__.py
create mode 100644 api_keys/models.py
create mode 100644 api_keys/tests.py
create mode 100644 api_keys/views.py
create mode 100644 bots/migrations/0084_botintegration_workspace.py
diff --git a/api_keys/__init__.py b/api_keys/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/api_keys/admin.py b/api_keys/admin.py
new file mode 100644
index 000000000..208923086
--- /dev/null
+++ b/api_keys/admin.py
@@ -0,0 +1,8 @@
+from django.contrib import admin
+
+from .models import ApiKey
+
+
+@admin.register(ApiKey)
+class ApiKeyAdmin(admin.ModelAdmin):
+ list_display = ("preview", "workspace", "created_by", "created_at")
diff --git a/api_keys/apps.py b/api_keys/apps.py
new file mode 100644
index 000000000..2b6f1300e
--- /dev/null
+++ b/api_keys/apps.py
@@ -0,0 +1,7 @@
+from django.apps import AppConfig
+
+
+class ApiKeysConfig(AppConfig):
+ default_auto_field = "django.db.models.BigAutoField"
+ name = "api_keys"
+ verbose_name = "API Keys"
diff --git a/api_keys/migrations/0001_initial.py b/api_keys/migrations/0001_initial.py
new file mode 100644
index 000000000..126156f74
--- /dev/null
+++ b/api_keys/migrations/0001_initial.py
@@ -0,0 +1,33 @@
+# Generated by Django 5.1.1 on 2024-09-25 00:37
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ initial = True
+
+ dependencies = [
+ ('app_users', '0023_alter_appusertransaction_workspace'),
+ ('workspaces', '0002_alter_workspace_domain_name'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='ApiKey',
+ fields=[
+ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('hash', models.CharField(max_length=128, unique=True)),
+ ('preview', models.CharField(max_length=32)),
+ ('created_at', models.DateTimeField(auto_now_add=True)),
+ ('updated_at', models.DateTimeField(auto_now=True)),
+ ('created_by', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='app_users.appuser')),
+ ('workspace', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='workspaces.workspace')),
+ ],
+ options={
+ 'verbose_name': 'API Key',
+ 'verbose_name_plural': 'API Keys',
+ },
+ ),
+ ]
diff --git a/api_keys/migrations/__init__.py b/api_keys/migrations/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/api_keys/models.py b/api_keys/models.py
new file mode 100644
index 000000000..c89ba96e3
--- /dev/null
+++ b/api_keys/models.py
@@ -0,0 +1,67 @@
+import typing
+
+from django.db import models
+
+from daras_ai_v2.crypto import PBKDF2PasswordHasher, get_random_api_key, safe_preview
+
+if typing.TYPE_CHECKING:
+ from workspaces.models import Workspace
+
+
+class ApiKeyQueySet(models.QuerySet):
+ def create_api_key(self, workspace: "Workspace", **kwargs) -> tuple["ApiKey", str]:
+ """
+ Returns a tuple of the created ApiKey instance and the secret key.
+ """
+ secret_key = get_random_api_key()
+ hasher = PBKDF2PasswordHasher()
+ api_key = self.create(
+ hash=hasher.encode(secret_key),
+ preview=safe_preview(secret_key),
+ workspace=workspace,
+ **kwargs,
+ )
+ return api_key, secret_key
+
+ def create_from_secret_key(
+ self, secret_key: str, workspace: "Workspace", **kwargs
+ ) -> "ApiKey":
+ """
+ `key` must be a valid plain-text key.
+ """
+ hasher = PBKDF2PasswordHasher()
+ return self.get_or_create(
+ hash=hasher.encode(secret_key),
+ defaults=dict(
+ preview=safe_preview(secret_key),
+ workspace=workspace,
+ **kwargs,
+ ),
+ )[0]
+
+ def get_from_secret_key(self, secret_key: str) -> "ApiKey | None":
+ hasher = PBKDF2PasswordHasher()
+ hash = hasher.encode(secret_key)
+ return self.filter(hash=hash).first()
+
+
+class ApiKey(models.Model):
+ hash = models.CharField(max_length=128, unique=True)
+ preview = models.CharField(max_length=32)
+ workspace = models.ForeignKey(
+ "workspaces.Workspace", on_delete=models.CASCADE, related_name="api_keys"
+ )
+ created_by = models.ForeignKey(
+ "app_users.AppUser", on_delete=models.SET_NULL, null=True
+ )
+ created_at = models.DateTimeField(auto_now_add=True)
+ updated_at = models.DateTimeField(auto_now=True)
+
+ objects = ApiKeyQueySet.as_manager()
+
+ def __str__(self):
+ return self.preview
+
+ class Meta:
+ verbose_name = "API Key"
+ verbose_name_plural = "API Keys"
diff --git a/api_keys/tests.py b/api_keys/tests.py
new file mode 100644
index 000000000..7ce503c2d
--- /dev/null
+++ b/api_keys/tests.py
@@ -0,0 +1,3 @@
+from django.test import TestCase
+
+# Create your tests here.
diff --git a/api_keys/views.py b/api_keys/views.py
new file mode 100644
index 000000000..91ea44a21
--- /dev/null
+++ b/api_keys/views.py
@@ -0,0 +1,3 @@
+from django.shortcuts import render
+
+# Create your views here.
diff --git a/auth/token_authentication.py b/auth/token_authentication.py
index a1281faa7..19286cdc3 100644
--- a/auth/token_authentication.py
+++ b/auth/token_authentication.py
@@ -4,12 +4,14 @@
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType
from fastapi.security.base import SecurityBase
+from loguru import logger
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
-from app_users.models import AppUser
+from api_keys.models import ApiKey
from auth.auth_backend import authlocal
from daras_ai_v2 import db
from daras_ai_v2.crypto import PBKDF2PasswordHasher
+from workspaces.models import Workspace
class AuthenticationError(HTTPException):
@@ -26,30 +28,53 @@ def __init__(self, msg: str):
super().__init__(status_code=self.status_code, detail={"error": msg})
-def authenticate_credentials(token: str) -> AppUser:
- db_collection = db.get_client().collection(db.API_KEYS_COLLECTION)
+def _authenticate_credentials_from_firebase(token: str) -> Workspace | None:
hasher = PBKDF2PasswordHasher()
- secret_key_hash = hasher.encode(token)
+ hash = hasher.encode(token)
+ db_collection = db.get_client().collection(db.API_KEYS_COLLECTION)
try:
- doc = (
- db_collection.where("secret_key_hash", "==", secret_key_hash)
- .limit(1)
- .get()[0]
- )
+ doc = db_collection.where("secret_key_hash", "==", hash).limit(1).get()[0]
except IndexError:
- raise AuthorizationError("Invalid API Key.")
+ return None
+
+ try:
+ workspace_id = doc.get("workspace_id")
+ except KeyError:
+ uid = doc.get("uid")
+ return Workspace.objects.get_or_create_from_uid(uid)[0]
+
+ try:
+ return Workspace.objects.get(id=workspace_id)
+ except Workspace.DoesNotExist:
+ logger.warning(f"Workspace {workspace_id} not found (for API key {doc.id=}).")
+ return None
+
+
+def authenticate_credentials(token: str) -> Workspace:
+ api_key = ApiKey.objects.select_related("workspace").get_from_secret_key(token)
+ if not api_key:
+ workspace = _authenticate_credentials_from_firebase(token)
+ if not workspace:
+ raise AuthorizationError("Invalid API key.")
+
+ # firebase was used for API Keys before team workspaces, so we
+ # can assume that api_key.created_by_id = workspace.created_by
+ api_key = ApiKey.objects.create_from_secret_key(
+ token,
+ workspace=workspace,
+ created_by_id=workspace.created_by_id,
+ )
- uid = doc.get("uid")
- user = AppUser.objects.get_or_create_from_uid(uid)[0]
- if user.is_disabled:
+ workspace = api_key.workspace
+ if workspace.is_personal and workspace.created_by.is_disabled:
msg = (
"Your Gooey.AI account has been disabled for violating our Terms of Service. "
"Contact us at support@gooey.ai if you think this is a mistake."
)
raise AuthenticationError(msg)
- return user
+ return workspace
class APIAuth(SecurityBase):
@@ -59,8 +84,8 @@ class APIAuth(SecurityBase):
```python
api_auth = APIAuth(scheme_name="bearer", description="Bearer $GOOEY_API_KEY")
- @app.get("/api/users")
- def get_users(authenticated_user: AppUser = Depends(api_auth)):
+ @app.get("/api/runs")
+ def get_runs(current_workspace: Workspace = Depends(api_auth)):
...
```
"""
@@ -77,7 +102,7 @@ def __init__(
self.scheme_name = scheme_name
self.description = description
- def __call__(self, request: Request) -> AppUser:
+ def __call__(self, request: Request) -> Workspace:
if authlocal: # testing only!
return authlocal[0]
diff --git a/bots/migrations/0084_botintegration_workspace.py b/bots/migrations/0084_botintegration_workspace.py
new file mode 100644
index 000000000..98626d5b2
--- /dev/null
+++ b/bots/migrations/0084_botintegration_workspace.py
@@ -0,0 +1,20 @@
+# Generated by Django 5.1.1 on 2024-09-23 08:11
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('bots', '0083_publishedrun_workspace'),
+ ('workspaces', '0002_alter_workspace_domain_name'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='botintegration',
+ name='workspace',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='botintegrations', to='workspaces.workspace'),
+ ),
+ ]
diff --git a/bots/models.py b/bots/models.py
index a01ee4baf..da7cf8a8b 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -363,11 +363,12 @@ def copy_from_firebase_state(self, state: dict) -> "SavedRun":
def submit_api_call(
self,
*,
- current_user: AppUser,
+ workspace: "Workspace",
request_body: dict,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
parent_pr: "PublishedRun" = None,
+ current_user: AppUser | None = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call
@@ -386,6 +387,7 @@ def submit_api_call(
kwds=dict(
page_cls=page_cls,
query_params=query_params,
+ workspace=workspace,
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
@@ -504,6 +506,12 @@ class BotIntegration(models.Model):
help_text="The gooey account uid where the credits will be deducted from",
db_index=True,
)
+ workspace = models.ForeignKey(
+ "workspaces.Workspace",
+ on_delete=models.CASCADE,
+ related_name="botintegrations",
+ null=True,
+ )
user_language = models.TextField(
default="",
help_text="The response language (same as user language in video bots)",
@@ -1852,12 +1860,14 @@ def get_run_count(self):
def submit_api_call(
self,
*,
- current_user: AppUser,
+ workspace: "Workspace",
request_body: dict,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
+ current_user: AppUser | None = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
return self.saved_run.submit_api_call(
+ workspace=workspace,
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
diff --git a/bots/tasks.py b/bots/tasks.py
index 0a76e42bc..90229c6aa 100644
--- a/bots/tasks.py
+++ b/bots/tasks.py
@@ -61,9 +61,6 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None):
msg.role == CHATML_ROLE_ASSISSTANT
), f"the message being analyzed must must be an {CHATML_ROLE_ASSISSTANT} msg"
- billing_account = AppUser.objects.get(
- uid=msg.conversation.bot_integration.billing_account_uid
- )
analysis_sr = anal.get_active_saved_run()
variables = analysis_sr.state.get("variables", {})
@@ -89,7 +86,7 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None):
# make the api call
result, sr = analysis_sr.submit_api_call(
- current_user=billing_account,
+ workspace=msg.conversation.bot_integration.workspace,
request_body=dict(variables=variables),
parent_pr=anal.published_run,
)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 4625dead5..063c2385d 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1484,6 +1484,7 @@ def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
yield from call_recipe_functions(
saved_run=sr,
+ workspace=self.current_workspace,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
@@ -1495,6 +1496,7 @@ def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
yield from call_recipe_functions(
saved_run=sr,
+ workspace=self.current_workspace,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
@@ -2153,7 +2155,7 @@ def run_as_api_tab(self):
with gui.tag("a", id="api-keys"):
gui.write("### 🔐 API keys")
- manage_api_keys(self.current_workspace, self.request.user)
+ manage_api_keys(workspace=self.current_workspace, user=self.request.user)
def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict):
if not settings.CREDITS_TO_DEDUCT_PER_RUN:
diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py
index e8fb61c6b..0534f9611 100644
--- a/daras_ai_v2/bots.py
+++ b/daras_ai_v2/bots.py
@@ -1,7 +1,6 @@
import mimetypes
import typing
from datetime import datetime
-from types import SimpleNamespace
import gooey_gui as gui
from django.db import transaction
@@ -10,7 +9,6 @@
from furl import furl
from pydantic import BaseModel, Field
-from app_users.models import AppUser
from bots.models import (
Platform,
Message,
@@ -26,9 +24,10 @@
from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys
from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT
from daras_ai_v2.vector_search import doc_url_to_file_metadata
-from gooeysite.bg_db_conn import db_middleware, get_celery_result_db_safe
+from gooeysite.bg_db_conn import db_middleware
from recipes.VideoBots import VideoBotsPage, ReplyButton
from routers.api import submit_api_call
+from workspaces.models import Workspace
PAGE_NOT_CONNECTED_ERROR = (
"💔 Looks like you haven't connected this page to a gooey.ai workflow. "
@@ -87,7 +86,7 @@ class BotInterface:
page_cls: typing.Type[BasePage] = None
query_params: dict
user_language: str = None
- billing_account_uid: str
+ workspace: Workspace
show_feedback_buttons: bool = False
streaming_enabled: bool = False
input_glossary: str | None = None
@@ -131,7 +130,7 @@ def __init__(self):
elif should_translate_lang(user_language):
self.user_language = user_language
- self.billing_account_uid = self.bi.billing_account_uid
+ self.workspace = self.bi.workspace
self.show_feedback_buttons = self.bi.show_feedback_buttons
self.streaming_enabled = self.bi.streaming_enabled
@@ -252,9 +251,7 @@ def _msg_handler(bot: BotInterface):
# mark message as read
bot.mark_read()
# get the attached billing account
- billing_account_user = AppUser.objects.get_or_create_from_uid(
- bot.billing_account_uid
- )[0]
+ workspace = bot.workspace
# get the user's input
# print("input type:", bot.input_type)
input_text = (bot.get_input_text() or "").strip()
@@ -311,7 +308,7 @@ def _msg_handler(bot: BotInterface):
_handle_feedback_msg(bot, input_text)
else:
_process_and_send_msg(
- billing_account_user=billing_account_user,
+ workspace=workspace,
bot=bot,
input_images=input_images,
input_documents=input_documents,
@@ -345,7 +342,7 @@ def _handle_feedback_msg(bot: BotInterface, input_text):
def _process_and_send_msg(
*,
- billing_account_user: AppUser,
+ workspace: Workspace,
bot: BotInterface,
input_images: list[str] | None,
input_audio: str | None,
@@ -378,7 +375,7 @@ def _process_and_send_msg(
result, sr = submit_api_call(
page_cls=bot.page_cls,
query_params=bot.query_params,
- current_user=billing_account_user,
+ workspace=workspace,
request_body=body,
)
bot.on_run_created(sr)
diff --git a/daras_ai_v2/manage_api_keys_widget.py b/daras_ai_v2/manage_api_keys_widget.py
index 4b94a78e1..869d249e1 100644
--- a/daras_ai_v2/manage_api_keys_widget.py
+++ b/daras_ai_v2/manage_api_keys_widget.py
@@ -1,19 +1,10 @@
-import datetime
import typing
-from google.cloud import firestore
-
import gooey_gui as gui
+from api_keys.models import ApiKey
from app_users.models import AppUser
from daras_ai_v2 import db
-from daras_ai_v2.copy_to_clipboard_button_widget import (
- copy_to_clipboard_button,
-)
-from daras_ai_v2.crypto import (
- PBKDF2PasswordHasher,
- safe_preview,
- get_random_api_key,
-)
+from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_button
if typing.TYPE_CHECKING:
from workspaces.models import Workspace
@@ -32,17 +23,12 @@ def manage_api_keys(workspace: "Workspace", user: AppUser):
"""
)
- db_collection = db.get_client().collection(db.API_KEYS_COLLECTION)
- api_keys = _load_api_keys(db_collection, workspace)
-
+ api_keys = load_api_keys(workspace)
table_area = gui.div()
if gui.button("+ Create new secret key"):
- doc = _generate_new_key_doc()
- doc["uid"] = user.uid
- doc["workspace_id"] = workspace.id
- api_keys.append(doc)
- db_collection.add(doc)
+ api_key = generate_new_api_key(workspace=workspace, user=user)
+ api_keys.append(api_key)
with table_area:
import pandas as pd
@@ -52,8 +38,8 @@ def manage_api_keys(workspace: "Workspace", user: AppUser):
columns=["Secret Key (Preview)", "Created At"],
data=[
(
- api_key["secret_key_preview"],
- api_key["created_at"].strftime("%B %d, %Y at %I:%M:%S %p %Z"),
+ api_key.preview,
+ api_key.created_at.strftime("%B %d, %Y at %I:%M:%S %p %Z"),
)
for api_key in api_keys
],
@@ -61,30 +47,54 @@ def manage_api_keys(workspace: "Workspace", user: AppUser):
)
-def _load_api_keys(
- db_collection: firestore.CollectionReference, workspace: "Workspace"
-):
- filter = firestore.FieldFilter("workspace_id", "==", workspace.id)
- if workspace.is_personal:
- # for backwards compatibility with existing keys
- filter = firestore.Or(
- [filter, firestore.FieldFilter("uid", "==", workspace.created_by.uid)]
- )
-
- return [
- snap.to_dict()
- for snap in db_collection.where(filter=filter).order_by("created_at").get()
+def load_api_keys(workspace: "Workspace") -> list[ApiKey]:
+ db_api_keys = {api_key.hash: api_key for api_key in workspace.api_keys.all()}
+ firebase_api_keys = [
+ d
+ for d in _load_api_keys_from_firebase(workspace)
+ if d["secret_key_hash"] not in db_api_keys
]
+ if firebase_api_keys:
+ new_api_keys = [
+ ApiKey(
+ hash=d["secret_key_hash"],
+ preview=d["secret_key_preview"],
+ workspace=workspace,
+ created_by_id=workspace.created_by_id,
+ )
+ for d in firebase_api_keys
+ ]
+ # TODO: also update created_at for migrated keys
+ # migrated_api_keys = ApiKey.objects.bulk_create(
+ # new_api_keys,
+ # ignore_conflicts=True,
+ # batch_size=100,
+ # )
+ db_api_keys.update({api_key.hash: api_key for api_key in new_api_keys})
+
+ return sorted(
+ db_api_keys.values(), key=lambda api_key: api_key.created_at, reverse=True
+ )
-def _generate_new_key_doc() -> dict:
- new_api_key = get_random_api_key()
- hasher = PBKDF2PasswordHasher()
- secret_key_hash = hasher.encode(new_api_key)
- created_at = datetime.datetime.utcnow()
+def _load_api_keys_from_firebase(workspace: "Workspace"):
+ db_collection = db.get_client().collection(db.API_KEYS_COLLECTION)
+ if workspace.is_personal:
+ return [
+ snap.to_dict()
+ for snap in db_collection.where("uid", "==", workspace.created_by.uid)
+ .order_by("created_at")
+ .get()
+ ]
+ else:
+ return []
+
+
+def generate_new_api_key(workspace: "Workspace", user: AppUser) -> ApiKey:
+ api_key, secret_key = ApiKey.objects.create_api_key(workspace, created_by=user)
gui.success(
- f"""
+ """
##### API key generated
Please save this secret key somewhere safe and accessible.
@@ -98,17 +108,13 @@ def _generate_new_key_doc() -> dict:
"recipe url",
label_visibility="collapsed",
disabled=True,
- value=new_api_key,
+ value=secret_key,
)
with col2:
copy_to_clipboard_button(
"📎 Copy Secret Key",
- value=new_api_key,
+ value=secret_key,
style="height: 3.2rem",
)
- return {
- "secret_key_hash": secret_key_hash,
- "secret_key_preview": safe_preview(new_api_key),
- "created_at": created_at,
- }
+ return api_key
diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py
index 7faf6b66b..e7d4de660 100644
--- a/daras_ai_v2/safety_checker.py
+++ b/daras_ai_v2/safety_checker.py
@@ -21,16 +21,18 @@ def safety_checker(*, text: str | None = None, image: str | None = None):
def safety_checker_text(text_input: str):
- # ge the billing account for the checker
+ # get the billing account for the checker
billing_account = AppUser.objects.get_or_create_from_email(
settings.SAFETY_CHECKER_BILLING_EMAIL
)[0]
+ workspace, _ = billing_account.get_or_create_personal_workspace()
# run in a thread to avoid messing up threadlocals
result, sr = (
CompareLLMPage()
.get_pr_from_example_id(example_id=settings.SAFETY_CHECKER_EXAMPLE_ID)
.submit_api_call(
+ workspace=workspace,
current_user=billing_account,
request_body=dict(variables=dict(input=text_input)),
deduct_credits=False,
diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py
index 44768289a..381f40636 100644
--- a/daras_ai_v2/settings.py
+++ b/daras_ai_v2/settings.py
@@ -64,6 +64,7 @@
"payments",
"functions",
"workspaces",
+ "api_keys",
]
MIDDLEWARE = [
diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py
index 21d3fa185..223946448 100644
--- a/functions/recipe_functions.py
+++ b/functions/recipe_functions.py
@@ -15,11 +15,13 @@
if typing.TYPE_CHECKING:
from bots.models import SavedRun
+ from workspaces.models import Workspace
def call_recipe_functions(
*,
saved_run: "SavedRun",
+ workspace: "Workspace",
current_user: AppUser,
request_model: typing.Type[BaseModel],
response_model: typing.Type[BaseModel],
@@ -48,6 +50,7 @@ def call_recipe_functions(
# run the function
page_cls, sr, pr = url_to_runs(fun.url)
result, sr = sr.submit_api_call(
+ workspace=workspace,
current_user=current_user,
parent_pr=pr,
request_body=dict(
diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py
index 7ad9e67e9..25c99a2dc 100644
--- a/recipes/BulkRunner.py
+++ b/recipes/BulkRunner.py
@@ -318,6 +318,7 @@ def run_v2(
yield f"{progress}%"
result, sr = sr.submit_api_call(
+ workspace=self.current_workspace,
current_user=self.request.user,
request_body=request_body,
parent_pr=pr,
@@ -389,6 +390,7 @@ def run_v2(
documents=response.output_documents
).dict(exclude_unset=True)
result, sr = sr.submit_api_call(
+ workspace=self.current_workspace,
current_user=self.request.user,
request_body=request_body,
parent_pr=pr,
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index 62291ddb7..d55c61427 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -43,7 +43,6 @@
broadcast_input,
get_bot_test_link,
web_widget_config,
- get_web_widget_embed_code,
integrations_welcome_screen,
)
from daras_ai_v2.doc_search_settings_widgets import (
@@ -1157,7 +1156,7 @@ def render_integrations_add(self, label: str, run_title: str, pr: PublishedRun):
case Platform.WEB:
bi = BotIntegration.objects.create(
name=run_title,
- billing_account_uid=self.request.user.uid,
+ workspace=self.current_workspace,
platform=Platform.WEB,
)
redirect_url = connect_bot_to_published_run(bi, pr)
diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py
index dffb79dc6..5bec1fb6d 100644
--- a/recipes/VideoBotsStats.py
+++ b/recipes/VideoBotsStats.py
@@ -16,7 +16,6 @@
from furl import furl
from pydantic import BaseModel
-from app_users.models import AppUser
from bots.models import (
Workflow,
Platform,
@@ -96,10 +95,7 @@ def show_title_breadcrumb_share(
),
)
- author = (
- AppUser.objects.filter(uid=bi.billing_account_uid).first()
- or self.request.user
- )
+ author = bi.workspace or self.request.user
VideoBotsPage.render_author(
author,
show_as_link=self.is_current_user_admin(),
diff --git a/routers/account.py b/routers/account.py
index 15d7202bb..312fef365 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -351,7 +351,7 @@ def _render_run(pr: PublishedRun):
def api_keys_tab(request: Request):
gui.write("# 🔐 API Keys")
workspace = get_current_workspace(request.user, request.session)
- manage_api_keys(workspace, request.user)
+ manage_api_keys(workspace=workspace, user=request.user)
@contextmanager
diff --git a/routers/api.py b/routers/api.py
index 102b6ea1a..a37a2899b 100644
--- a/routers/api.py
+++ b/routers/api.py
@@ -41,10 +41,12 @@
from daras_ai_v2.fastapi_tricks import fastapi_request_form
from functions.models import CalledFunctionResponse
from routers.custom_api_router import CustomAPIRouter
+from workspaces.models import Workspace
+from workspaces.widgets import set_current_workspace
if typing.TYPE_CHECKING:
- from bots.models import SavedRun
import celery.result
+ from bots.models import SavedRun
app = CustomAPIRouter()
@@ -156,13 +158,13 @@ def script_to_api(page_cls: typing.Type[BasePage]):
def run_api_json(
request: Request,
page_request: request_model,
- user: AppUser = Depends(api_auth_header),
+ workspace: Workspace = Depends(api_auth_header),
):
result, sr = submit_api_call(
page_cls=page_cls,
query_params=dict(request.query_params),
retention_policy=RetentionPolicy[page_request.settings.retention_policy],
- current_user=user,
+ workspace=workspace,
request_body=page_request.dict(exclude_unset=True),
enable_rate_limits=True,
)
@@ -180,14 +182,14 @@ def run_api_json(
)
def run_api_form(
request: Request,
- user: AppUser = Depends(api_auth_header),
+ workspace: Workspace = Depends(api_auth_header),
form_data=fastapi_request_form,
page_request_json: str = Form(alias="json"),
):
# parse form data
page_request = _parse_form_data(request_model, form_data, page_request_json)
# call regular json api
- return run_api_json(request, page_request=page_request, user=user)
+ return run_api_json(request, page_request=page_request, workspace=workspace)
endpoint = endpoint.replace("v2", "v3")
response_model = AsyncApiResponseModelV3
@@ -205,13 +207,13 @@ def run_api_json_async(
request: Request,
response: Response,
page_request: request_model,
- user: AppUser = Depends(api_auth_header),
+ workspace: Workspace = Depends(api_auth_header),
):
result, sr = submit_api_call(
page_cls=page_cls,
query_params=dict(request.query_params),
retention_policy=RetentionPolicy[page_request.settings.retention_policy],
- current_user=user,
+ workspace=workspace,
request_body=page_request.dict(exclude_unset=True),
enable_rate_limits=True,
)
@@ -232,7 +234,7 @@ def run_api_json_async(
def run_api_form_async(
request: Request,
response: Response,
- user: AppUser = Depends(api_auth_header),
+ workspace: Workspace = Depends(api_auth_header),
form_data=fastapi_request_form,
page_request_json: str = Form(alias="json"),
):
@@ -240,7 +242,10 @@ def run_api_form_async(
page_request = _parse_form_data(request_model, form_data, page_request_json)
# call regular json api
return run_api_json_async(
- request, response=response, page_request=page_request, user=user
+ request,
+ response=response,
+ page_request=page_request,
+ workspace=workspace,
)
response_model = create_model(
@@ -258,10 +263,14 @@ def run_api_form_async(
)
def get_run_status(
run_id: str,
- user: AppUser = Depends(api_auth_header),
+ workspace: Workspace = Depends(api_auth_header),
):
+ # TODO: current_user doesn't make sense for API calls
+ user = workspace.created_by
+
# init a new page for every request
self = page_cls(user=user, query_params=dict(run_id=run_id, uid=user.uid))
+ set_current_workspace(self.request.session, workspace.id)
sr = self.current_sr
web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid)))
ret = {
@@ -332,15 +341,21 @@ def submit_api_call(
page_cls: typing.Type[BasePage],
query_params: dict,
retention_policy: RetentionPolicy = None,
- current_user: AppUser,
+ workspace: Workspace,
request_body: dict,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
+ current_user: AppUser | None = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
+ # TODO: current_user doesn't make sense for API calls
+ current_user = current_user or workspace.created_by
+
# init a new page for every request
query_params.setdefault("uid", current_user.uid)
page = page_cls(user=current_user, query_params=query_params)
+ set_current_workspace(page.request.session, workspace.id)
+
# get saved state from db
state = page.current_sr_to_session_state()
# load request data
@@ -427,8 +442,7 @@ class BalanceResponse(BaseModel):
@app.get("/v1/balance/", response_model=BalanceResponse, tags=["Misc"])
-def get_balance(user: AppUser = Depends(api_auth_header)):
- workspace = user.get_or_create_personal_workspace()[0]
+def get_balance(workspace: Workspace = Depends(api_auth_header)):
return BalanceResponse(balance=workspace.balance)
diff --git a/routers/broadcast_api.py b/routers/broadcast_api.py
index cbf1ca07c..878a9e463 100644
--- a/routers/broadcast_api.py
+++ b/routers/broadcast_api.py
@@ -5,12 +5,13 @@
from fastapi import HTTPException
from pydantic import BaseModel, Field
-from app_users.models import AppUser
from auth.token_authentication import api_auth_header
from bots.models import BotIntegration
from bots.tasks import send_broadcast_msgs_chunked
from recipes.VideoBots import ReplyButton, VideoBotsPage
from routers.custom_api_router import CustomAPIRouter
+from workspaces.models import Workspace
+
app = CustomAPIRouter()
@@ -58,18 +59,18 @@ class BotBroadcastRequestModel(BaseModel):
)
def broadcast_api_json(
bot_request: BotBroadcastRequestModel,
- user: AppUser = Depends(api_auth_header),
+ workspace: Workspace = Depends(api_auth_header),
example_id: str | None = None,
run_id: str | None = None,
):
- bi_qs = BotIntegration.objects.filter(billing_account_uid=user.uid)
+ bi_qs = BotIntegration.objects.filter(workspace=workspace)
if example_id:
bi_qs = bi_qs.filter(
Q(published_run__published_run_id=example_id)
| Q(saved_run__example_id=example_id)
)
elif run_id:
- bi_qs = bi_qs.filter(saved_run__run_id=run_id, saved_run__uid=user.uid)
+ bi_qs = bi_qs.filter(saved_run__run_id=run_id, saved_run__workspace=workspace)
else:
return HTTPException(
status_code=400,
diff --git a/routers/root.py b/routers/root.py
index b3e9faa98..d9f1e74c8 100644
--- a/routers/root.py
+++ b/routers/root.py
@@ -351,7 +351,7 @@ def _api_docs_page(request: Request):
)
return
- manage_api_keys(page.current_workspace, page.request.user)
+ manage_api_keys(workspace=page.current_workspace, user=page.request.user)
@gui.route(
diff --git a/routers/twilio_api.py b/routers/twilio_api.py
index 108f3e47b..be1405d10 100644
--- a/routers/twilio_api.py
+++ b/routers/twilio_api.py
@@ -263,7 +263,7 @@ def resp_say_or_tts_play(
{**bot.saved_run.state, "text_prompt": text}
).dict()
result, sr = TextToSpeechPage.get_root_pr().submit_api_call(
- current_user=AppUser.objects.get(uid=bot.billing_account_uid),
+ workspace=bot.workspace,
request_body=tts_state,
)
# wait for the TTS to finish
diff --git a/scripts/create_fixture.py b/scripts/create_fixture.py
index f660febf5..0f217e6eb 100644
--- a/scripts/create_fixture.py
+++ b/scripts/create_fixture.py
@@ -55,10 +55,15 @@ def get_objects(*args):
if "bots" not in args:
return
for obj in BotIntegration.objects.all():
+ # TODO: deprecate billing_account_uid
user = AppUser.objects.get(uid=obj.billing_account_uid)
yield user.handle
yield export(user, only_include=USER_FIELDS)
+ if obj.workspace:
+ yield export(obj.workspace.created_by, only_include=USER_FIELDS)
+ yield export(obj.workspace, include_fks=("created_by",))
+
if obj.saved_run_id:
yield export(obj.saved_run)
diff --git a/scripts/migrate_workspaces.py b/scripts/migrate_workspaces.py
index ba2cee100..ea7b490c7 100644
--- a/scripts/migrate_workspaces.py
+++ b/scripts/migrate_workspaces.py
@@ -1,13 +1,18 @@
+from contextlib import contextmanager
+from itertools import islice
from time import sleep
from django.db import transaction
from django.db.models import OuterRef, Subquery
+from api_keys.models import ApiKey
from app_users.models import AppUser, AppUserTransaction
-from bots.models import SavedRun, PublishedRun
+from bots.models import BotIntegration, SavedRun, PublishedRun
+from daras_ai_v2 import db
from workspaces.models import Workspace, WorkspaceMembership, WorkspaceRole
BATCH_SIZE = 10_000
+FIREBASE_BATCH_SIZE = 100
DELAY = 0.1
SEP = " ... "
@@ -16,7 +21,9 @@ def run():
migrate_personal_workspaces()
migrate_txns()
migrate_saved_runs()
+ migrate_api_keys()
migrate_published_runs()
+ migrate_bot_integrations()
@transaction.atomic
@@ -90,6 +97,62 @@ def migrate_published_runs():
)
+def migrate_bot_integrations():
+ qs = BotIntegration.objects.filter(
+ workspace__isnull=True,
+ billing_account_uid__isnull=False,
+ )
+ print(f"migrating {qs.count()} bot integrations", end=SEP)
+ update_in_batches(
+ qs,
+ workspace_id=Workspace.objects.filter(
+ is_personal=True,
+ created_by__uid=OuterRef("billing_account_uid"),
+ ).values("id")[:1],
+ )
+
+
+def migrate_api_keys():
+ firebase_stream = db.get_client().collection(db.API_KEYS_COLLECTION).stream()
+
+ total = 0
+ while True:
+ batch = list(islice(firebase_stream, FIREBASE_BATCH_SIZE))
+ if not batch:
+ print("Done!")
+ break
+ cached_workspaces = Workspace.objects.select_related("created_by").filter(
+ is_personal=True,
+ created_by__uid__in=[snap.get("uid") for snap in batch],
+ )
+ cached_workspaces_by_uid = {w.created_by.uid: w for w in cached_workspaces}
+
+ with (
+ disable_auto_now_add(ApiKey, "created_at"),
+ disable_auto_now(ApiKey, "updated_at"),
+ ):
+ migrated_keys = ApiKey.objects.bulk_create(
+ [
+ ApiKey(
+ hash=snap.get("secret_key_hash"),
+ preview=snap.get("secret_key_preview"),
+ workspace_id=cached_workspaces_by_uid[snap.get("uid")].id,
+ created_by_id=cached_workspaces_by_uid[
+ snap.get("uid")
+ ].created_by.id,
+ created_at=snap.get("created_at"),
+ updated_at=snap.get("created_at"),
+ )
+ for snap in batch
+ if snap.get("uid") in cached_workspaces_by_uid
+ ],
+ ignore_conflicts=True,
+ unique_fields=("hash",),
+ )
+ print(total, f"({len(migrated_keys)}/{len(batch)})", end=SEP)
+ total += len(migrated_keys)
+
+
def update_in_batches(qs, **kwargs):
total = 0
while True:
@@ -100,3 +163,23 @@ def update_in_batches(qs, **kwargs):
break
total += rows
print(total, end=SEP)
+
+
+@contextmanager
+def disable_auto_now(model, field_name):
+ for field in model._meta.local_fields:
+ if field.name == field_name:
+ field.auto_now = False
+ yield
+ field.auto_now = True
+ break
+
+
+@contextmanager
+def disable_auto_now_add(model, field_name):
+ for field in model._meta.local_fields:
+ if field.name == field_name:
+ field.auto_now_add = False
+ yield
+ field.auto_now_add = True
+ break
diff --git a/workspaces/models.py b/workspaces/models.py
index c19ac07df..eac7bb803 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -150,6 +150,11 @@ def __str__(self):
else:
return self.display_name()
+ def clean(self) -> None:
+ if not self.is_personal and not self.name:
+ raise ValidationError("Team name is required for workspaces")
+ return super().clean()
+
def get_slug(self):
return slugify(self.display_name())
@@ -321,7 +326,7 @@ def display_name(self, current_user: AppUser | None = None) -> str:
elif self.is_personal:
return self.created_by.full_name()
else:
- return f"{self.created_by.first_name_possesive()} Workspace"
+ return f"{self.created_by.first_name_possessive()} Workspace"
def html_icon(self, current_user: AppUser | None = None) -> str:
if self.is_personal and self.created_by_id == current_user.id:
From 26692c8106786240c15fef0a8c60f22c7e7535bf Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 25 Sep 2024 18:42:34 +0530
Subject: [PATCH 26/81] fix black formatting
---
recipes/TextToSpeech.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py
index 5f0011c45..6b812ecaa 100644
--- a/recipes/TextToSpeech.py
+++ b/recipes/TextToSpeech.py
@@ -61,6 +61,7 @@ class TextToSpeechSettings(BaseModel):
openai_tts_model: OpenAI_TTS_Models.api_choices | None
ghana_nlp_tts_language: GHANA_NLP_TTS_LANGUAGES.api_choices | None
+
class TextToSpeechPage(BasePage):
title = "Compare AI Voice Generators"
explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/3621e11a-88d9-11ee-b549-02420a000167/Compare%20AI%20voice%20generators.png.png"
From 95efc507795f825c0b35653cd32719f71d6368ac Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 26 Sep 2024 00:41:04 +0530
Subject: [PATCH 27/81] Add migration for ApiKey related_name for workspaces
---
.../migrations/0002_alter_apikey_workspace.py | 20 +++++++++++++++++++
1 file changed, 20 insertions(+)
create mode 100644 api_keys/migrations/0002_alter_apikey_workspace.py
diff --git a/api_keys/migrations/0002_alter_apikey_workspace.py b/api_keys/migrations/0002_alter_apikey_workspace.py
new file mode 100644
index 000000000..647e570e6
--- /dev/null
+++ b/api_keys/migrations/0002_alter_apikey_workspace.py
@@ -0,0 +1,20 @@
+# Generated by Django 5.1.1 on 2024-09-25 19:10
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('api_keys', '0001_initial'),
+ ('workspaces', '0002_alter_workspace_domain_name'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='apikey',
+ name='workspace',
+ field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='api_keys', to='workspaces.workspace'),
+ ),
+ ]
From 06ff80a7eda2d6e9fb3d2e80b7ffacae27f71b5c Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 26 Sep 2024 00:43:26 +0530
Subject: [PATCH 28/81] Add change notes to published run version & UI
---
.../0085_publishedrunversion_change_notes.py | 18 ++++++++++++++++++
bots/models.py | 4 ++++
daras_ai_v2/base.py | 12 ++++++++++++
3 files changed, 34 insertions(+)
create mode 100644 bots/migrations/0085_publishedrunversion_change_notes.py
diff --git a/bots/migrations/0085_publishedrunversion_change_notes.py b/bots/migrations/0085_publishedrunversion_change_notes.py
new file mode 100644
index 000000000..c37186022
--- /dev/null
+++ b/bots/migrations/0085_publishedrunversion_change_notes.py
@@ -0,0 +1,18 @@
+# Generated by Django 5.1.1 on 2024-09-25 19:12
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('bots', '0084_botintegration_workspace'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='publishedrunversion',
+ name='change_notes',
+ field=models.TextField(blank=True, default=''),
+ ),
+ ]
diff --git a/bots/models.py b/bots/models.py
index da7cf8a8b..0598af353 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -1683,6 +1683,7 @@ def create_with_version(
title=title,
visibility=visibility,
notes=notes,
+ change_notes="First Version",
)
return pr
@@ -1817,6 +1818,7 @@ def add_version(
visibility: PublishedRunVisibility,
title: str,
notes: str,
+ change_notes: str,
):
assert saved_run.workflow == self.workflow
@@ -1829,6 +1831,7 @@ def add_version(
title=title,
notes=notes,
visibility=visibility,
+ change_notes=change_notes,
)
version.save()
self.update_fields_to_latest_version()
@@ -1896,6 +1899,7 @@ class PublishedRunVersion(models.Model):
)
title = models.TextField(blank=True, default="")
notes = models.TextField(blank=True, default="")
+ change_notes = models.TextField(blank=True, default="")
visibility = models.IntegerField(
choices=PublishedRunVisibility.choices,
default=PublishedRunVisibility.UNLISTED,
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 063c2385d..fe16d238b 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -604,6 +604,11 @@ def _render_publish_form(
key="published_run_notes",
value=(pr.notes or self.preview_description(gui.session_state) or ""),
)
+ change_notes = gui.text_area(
+ "###### Change Notes",
+ key="published_run_change_notes",
+ value="",
+ )
col1, col2 = gui.columns([1, 3])
with col1, gui.div(className="mt-2"):
@@ -655,6 +660,7 @@ def _render_publish_form(
title=published_run_title.strip(),
notes=published_run_notes.strip(),
visibility=published_run_visibility,
+ change_notes=change_notes.strip(),
)
if not self._has_published_run_changed(published_run=pr, **updates):
gui.error("No changes to publish", icon="⚠️")
@@ -881,6 +887,11 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR
f"This will overwrite the contents of {self.app_url()}",
className="text-danger",
)
+ change_notes = gui.text_area(
+ "Change Notes",
+ key="change_notes",
+ value="",
+ )
if gui.button("👌 Yes, Update the Root Workflow"):
root_run = self.get_root_pr()
root_run.add_version(
@@ -889,6 +900,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR
notes=published_run.notes,
saved_run=published_run.saved_run,
visibility=PublishedRunVisibility.PUBLIC,
+ change_notes=change_notes,
)
raise gui.RedirectException(self.app_url())
From 474ce997be7cb5a1594f60a5af218ce622d6c0e9 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 26 Sep 2024 00:45:54 +0530
Subject: [PATCH 29/81] Remove change_notes from _has_published_run_changed
---
daras_ai_v2/base.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index fe16d238b..6357a43d6 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -660,12 +660,13 @@ def _render_publish_form(
title=published_run_title.strip(),
notes=published_run_notes.strip(),
visibility=published_run_visibility,
- change_notes=change_notes.strip(),
)
if not self._has_published_run_changed(published_run=pr, **updates):
gui.error("No changes to publish", icon="⚠️")
return
- pr.add_version(user=self.request.user, **updates)
+ pr.add_version(
+ user=self.request.user, change_notes=change_notes.strip(), **updates
+ )
else:
pr = self.create_published_run(
published_run_id=get_random_doc_id(),
From 0275214c4685e87d4ffddfbfc311228a3e4ab6fa Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 26 Sep 2024 01:16:41 +0530
Subject: [PATCH 30/81] Redo rendering of version history with change notes
---
daras_ai_v2/base.py | 47 ++++++++++++++++++++++----------------------
daras_ai_v2/icons.py | 1 +
2 files changed, 25 insertions(+), 23 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 6357a43d6..5a830fc6c 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -788,7 +788,11 @@ def _saved_options_modal(self, *, sr: SavedRun, pr: PublishedRun):
)
with gui.div(className="mt-4"):
- gui.write("#### Version History", className="mb-4")
+ gui.write(
+ f"#### {icons.time} Version History",
+ className="mb-4 fw-bold",
+ unsafe_allow_html=True,
+ )
self._render_version_history()
def _unsaved_options_button_with_dialog(self):
@@ -992,20 +996,16 @@ def _render_version_row(
run_id=version.saved_run.run_id,
uid=version.saved_run.uid,
)
- with gui.link(to=url, className="text-decoration-none"):
+ with gui.link(to=url, className="d-block text-decoration-none my-3"):
with gui.div(
- className="d-flex mb-4 disable-p-margin",
- style={"minWidth": "min(100vw, 500px)"},
+ className="d-flex justify-content-between align-items-middle fw-bold"
):
- col1 = gui.div(className="me-4")
- col2 = gui.div()
- with col1:
- with gui.div(className="fs-5 mt-1"):
- gui.html('')
- with col2:
- is_first_version = not older_version
- with gui.div(className="fs-5 d-flex align-items-center"):
- with gui.tag("span"):
+ if version.changed_by:
+ with gui.tag("h6", className="mb-0"):
+ self.render_author(version.changed_by, responsive=False)
+ else:
+ gui.write("###### Deleted User", className="disable-p-margin")
+ with gui.tag("h6", className="mb-0"):
gui.html(
"Loading...",
**render_local_dt_attrs(
@@ -1013,18 +1013,19 @@ def _render_version_row(
date_options={"month": "short", "day": "numeric"},
),
)
+ with gui.div(className="disable-p-margin"):
+ is_first_version = not older_version
if is_first_version:
- with gui.tag("span", className="badge bg-secondary px-3 ms-2"):
+ with gui.tag("span", className="badge bg-secondary px-3"):
gui.write("FIRST VERSION")
- with gui.div(className="text-muted"):
- if older_version and older_version.title != version.title:
- gui.write(f"Renamed: {version.title}")
- elif not older_version:
- gui.write(version.title)
- with gui.div(className="mt-1", style={"fontSize": "0.85rem"}):
- self.render_author(
- version.changed_by, image_size="18px", responsive=False
- )
+ elif older_version and older_version.title != version.title:
+ gui.caption(f"Renamed: {version.title}")
+
+ if version.change_notes:
+ gui.caption(
+ f"{icons.notes} {html.escape(version.change_notes)}",
+ unsafe_allow_html=True,
+ )
def render_related_workflows(self):
page_clses = self.related_workflows()
diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py
index 3b3cd7823..439dd30d6 100644
--- a/daras_ai_v2/icons.py
+++ b/daras_ai_v2/icons.py
@@ -26,6 +26,7 @@
remove_user = ''
add_user = ''
home = ''
+notes = ''
# brands
github = ''
From f7f50d7fa4ca114cb9ca3c69a88168fe0dbc038b Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 26 Sep 2024 01:34:41 +0530
Subject: [PATCH 31/81] Add change notes in publish menu
---
daras_ai_v2/base.py | 22 ++++++++++++++--------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 5a830fc6c..0544f1d2d 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -588,7 +588,7 @@ def _render_publish_form(
checked_by_default=False,
)
- with gui.div(className="mt-4"):
+ with gui.div(className="my-4"):
if is_update_mode:
title = pr.title or self.title
else:
@@ -599,16 +599,22 @@ def _render_publish_form(
key="published_run_title",
value=title,
)
- published_run_notes = gui.text_area(
- "###### Notes",
+ published_run_notes = gui.text_input(
+ "###### Description",
key="published_run_notes",
value=(pr.notes or self.preview_description(gui.session_state) or ""),
)
- change_notes = gui.text_area(
- "###### Change Notes",
- key="published_run_change_notes",
- value="",
- )
+ with gui.div(className="d-flex align-items-start"):
+ with gui.div(className="fs-2 text-muted"):
+ gui.html(icons.notes)
+ with gui.div(className="flex-grow-1"):
+ change_notes = gui.text_input(
+ "",
+ key="published_run_change_notes",
+ value="",
+ className="ms-2",
+ placeholder="Add change notes",
+ )
col1, col2 = gui.columns([1, 3])
with col1, gui.div(className="mt-2"):
From 76658a5fafcf89c1ebdd27bf284897cb954a3916 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 3 Oct 2024 19:13:32 +0530
Subject: [PATCH 32/81] fix api keys widget to not perform key migration from
firebase
---
daras_ai_v2/manage_api_keys_widget.py | 39 ++++++++++-----------------
1 file changed, 14 insertions(+), 25 deletions(-)
diff --git a/daras_ai_v2/manage_api_keys_widget.py b/daras_ai_v2/manage_api_keys_widget.py
index 869d249e1..d190522e6 100644
--- a/daras_ai_v2/manage_api_keys_widget.py
+++ b/daras_ai_v2/manage_api_keys_widget.py
@@ -48,36 +48,25 @@ def manage_api_keys(workspace: "Workspace", user: AppUser):
def load_api_keys(workspace: "Workspace") -> list[ApiKey]:
- db_api_keys = {api_key.hash: api_key for api_key in workspace.api_keys.all()}
- firebase_api_keys = [
- d
- for d in _load_api_keys_from_firebase(workspace)
- if d["secret_key_hash"] not in db_api_keys
- ]
- if firebase_api_keys:
- new_api_keys = [
- ApiKey(
- hash=d["secret_key_hash"],
- preview=d["secret_key_preview"],
- workspace=workspace,
- created_by_id=workspace.created_by_id,
- )
- for d in firebase_api_keys
- ]
- # TODO: also update created_at for migrated keys
- # migrated_api_keys = ApiKey.objects.bulk_create(
- # new_api_keys,
- # ignore_conflicts=True,
- # batch_size=100,
- # )
- db_api_keys.update({api_key.hash: api_key for api_key in new_api_keys})
+ api_keys = {api_key.hash: api_key for api_key in workspace.api_keys.all()}
+ for legacy_key in _load_api_keys_from_firebase(workspace):
+ hash = legacy_key["secret_key_hash"]
+ api_keys[hash] = ApiKey(
+ workspace=workspace,
+ hash=hash,
+ preview=legacy_key["secret_key_preview"],
+ created_at=legacy_key["created_at"],
+ created_by_id=workspace.created_by_id,
+ )
return sorted(
- db_api_keys.values(), key=lambda api_key: api_key.created_at, reverse=True
+ api_keys.values(),
+ key=lambda api_key: api_key.created_at,
+ reverse=True,
)
-def _load_api_keys_from_firebase(workspace: "Workspace"):
+def _load_api_keys_from_firebase(workspace: "Workspace") -> list[dict]:
db_collection = db.get_client().collection(db.API_KEYS_COLLECTION)
if workspace.is_personal:
return [
From 9817033faf85455618500fc215747d6c61b62b56 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 3 Oct 2024 19:14:14 +0530
Subject: [PATCH 33/81] revert API key migration from firebase->DB in API calls
only do it as a django script
---
auth/token_authentication.py | 23 +++--------------------
1 file changed, 3 insertions(+), 20 deletions(-)
diff --git a/auth/token_authentication.py b/auth/token_authentication.py
index 19286cdc3..d2c273cfc 100644
--- a/auth/token_authentication.py
+++ b/auth/token_authentication.py
@@ -4,7 +4,6 @@
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType
from fastapi.security.base import SecurityBase
-from loguru import logger
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from api_keys.models import ApiKey
@@ -29,6 +28,7 @@ def __init__(self, msg: str):
def _authenticate_credentials_from_firebase(token: str) -> Workspace | None:
+ """Legacy method to authenticate API keys stored in Firebase."""
hasher = PBKDF2PasswordHasher()
hash = hasher.encode(token)
@@ -38,17 +38,8 @@ def _authenticate_credentials_from_firebase(token: str) -> Workspace | None:
except IndexError:
return None
- try:
- workspace_id = doc.get("workspace_id")
- except KeyError:
- uid = doc.get("uid")
- return Workspace.objects.get_or_create_from_uid(uid)[0]
-
- try:
- return Workspace.objects.get(id=workspace_id)
- except Workspace.DoesNotExist:
- logger.warning(f"Workspace {workspace_id} not found (for API key {doc.id=}).")
- return None
+ uid = doc.get("uid")
+ return Workspace.objects.get_or_create_from_uid(uid)[0]
def authenticate_credentials(token: str) -> Workspace:
@@ -58,14 +49,6 @@ def authenticate_credentials(token: str) -> Workspace:
if not workspace:
raise AuthorizationError("Invalid API key.")
- # firebase was used for API Keys before team workspaces, so we
- # can assume that api_key.created_by_id = workspace.created_by
- api_key = ApiKey.objects.create_from_secret_key(
- token,
- workspace=workspace,
- created_by_id=workspace.created_by_id,
- )
-
workspace = api_key.workspace
if workspace.is_personal and workspace.created_by.is_disabled:
msg = (
From 4ca72698fa17a401043d2740e4705acbc9a3995e Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 7 Oct 2024 19:33:37 +0530
Subject: [PATCH 34/81] fix badly resolved conflicts
---
daras_ai_v2/base.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 12a8638a7..4292f6392 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -518,7 +518,7 @@ def _render_share_modal(self, dialog: gui.AlertDialogRef):
str(PublishedRunVisibility.PUBLIC.value)
] += f' on [{pretty_profile_url}]({profile_url})'
elif self.request.user and not self.request.user.is_anonymous:
- edit_profile_url = AccountTabs.profile.url_path
+ edit_profile_url = AccountTabs.profile.get_url_path(self.request)
options[
str(PublishedRunVisibility.PUBLIC.value)
] += f' on my [profile page]({edit_profile_url})'
@@ -558,12 +558,14 @@ def _render_share_modal(self, dialog: gui.AlertDialogRef):
)
if pressed_copy or pressed_done:
if self.current_pr.visibility != published_run_visibility:
+ visibility = PublishedRunVisibility(published_run_visibility)
self.current_pr.add_version(
user=self.request.user,
saved_run=self.current_pr.saved_run,
title=self.current_pr.title,
notes=self.current_pr.notes,
- visibility=PublishedRunVisibility(published_run_visibility),
+ visibility=visibility,
+ change_notes=f"Visibility changed to {visibility.name.title()}",
)
dialog.set_open(False)
From 0c66422103a59e1add31cf26bf88e3a85fb0407d Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 7 Oct 2024 19:34:36 +0530
Subject: [PATCH 35/81] make change notes icon fs-3 in publish form
---
daras_ai_v2/base.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 4292f6392..d02b4da8f 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -673,8 +673,8 @@ def _render_publish_form(
key="published_run_notes",
value=(pr.notes or self.preview_description(gui.session_state) or ""),
)
- with gui.div(className="d-flex align-items-start"):
- with gui.div(className="fs-2 text-muted"):
+ with gui.div(className="d-flex align-items-center"):
+ with gui.div(className="fs-3 text-muted mb-3"):
gui.html(icons.notes)
with gui.div(className="flex-grow-1"):
change_notes = gui.text_input(
From 32315c715157599fa34c1033f1863baefd282170 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 7 Oct 2024 20:14:50 +0530
Subject: [PATCH 36/81] saved options menu: show duplicate and delete buttons
on the same line
---
daras_ai_v2/base.py | 59 ++++++++++++++++++++++++---------------------
1 file changed, 31 insertions(+), 28 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index d02b4da8f..0c0c91245 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -825,35 +825,38 @@ def _saved_options_modal(self):
is_latest_version = self.current_pr.saved_run == self.current_sr
- duplicate_button = None
- save_as_new_button = None
- if is_latest_version:
- duplicate_button = gui.button(f"{icons.fork} Duplicate", className="w-100")
- else:
- save_as_new_button = gui.button(
- f"{icons.fork} Save as New", className="w-100"
- )
-
- if not self.current_pr.is_root():
- ref = gui.use_confirm_dialog(key="--delete-run-modal")
- gui.button_with_confirm_dialog(
- ref=ref,
- trigger_label=' Delete',
- trigger_className="w-100 text-danger",
- modal_title="#### Are you sure?",
- modal_content=f"""
-Are you sure you want to delete this published run?
-
-**{self.current_pr.title}**
+ with gui.div(className="mb-3 d-flex justify-content-around align-items-center"):
+ duplicate_button = None
+ save_as_new_button = None
+ if is_latest_version:
+ duplicate_button = gui.button(
+ f"{icons.fork} Duplicate", className="w-100"
+ )
+ else:
+ save_as_new_button = gui.button(
+ f"{icons.fork} Save as New", className="w-100"
+ )
-This will also delete all the associated versions.
- """,
- confirm_label="Delete",
- confirm_className="border-danger bg-danger text-white",
- )
- if ref.pressed_confirm:
- self.current_pr.delete()
- raise gui.RedirectException(self.app_url())
+ if not self.current_pr.is_root():
+ ref = gui.use_confirm_dialog(key="--delete-run-modal")
+ gui.button_with_confirm_dialog(
+ ref=ref,
+ trigger_label=' Delete',
+ trigger_className="w-100 text-danger",
+ modal_title="#### Are you sure?",
+ modal_content=f"""
+ Are you sure you want to delete this published run?
+
+ **{self.current_pr.title}**
+
+ This will also delete all the associated versions.
+ """,
+ confirm_label="Delete",
+ confirm_className="border-danger bg-danger text-white",
+ )
+ if ref.pressed_confirm:
+ self.current_pr.delete()
+ raise gui.RedirectException(self.app_url())
if duplicate_button:
duplicate_pr = self.current_pr.duplicate(
From 5ed9c59c01d42bfb7a534abdce98a1be8f482492 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 7 Oct 2024 20:16:46 +0530
Subject: [PATCH 37/81] only show one of FIRST VERSION tag, change notes, or
Renamed: ... text in version history
---
daras_ai_v2/base.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 0c0c91245..66ed7ebea 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1091,14 +1091,13 @@ def _render_version_row(
if is_first_version:
with gui.tag("span", className="badge bg-secondary px-3"):
gui.write("FIRST VERSION")
- elif older_version and older_version.title != version.title:
- gui.caption(f"Renamed: {version.title}")
-
- if version.change_notes:
+ elif version.change_notes:
gui.caption(
f"{icons.notes} {html.escape(version.change_notes)}",
unsafe_allow_html=True,
)
+ elif older_version and older_version.title != version.title:
+ gui.caption(f"Renamed: {version.title}")
def render_related_workflows(self):
page_clses = self.related_workflows()
From 208ea2d376f75b79a80302fe37380b46e103b7b3 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 7 Oct 2024 20:18:14 +0530
Subject: [PATCH 38/81] fix: use user.full_name() instead of user.display_name
for version history
---
daras_ai_v2/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 66ed7ebea..1534d9da9 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1387,7 +1387,7 @@ def render_author(
else:
user = workspace_or_user
photo = user.photo_url
- name = user.display_name
+ name = user.full_name()
if show_as_link and user.handle:
link = user.handle.get_app_url()
From 95ac836af74f9ca51c5b5e625d76fcd9381c7404 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Tue, 8 Oct 2024 18:58:02 +0530
Subject: [PATCH 39/81] only show pending list to admins
---
workspaces/views.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/workspaces/views.py b/workspaces/views.py
index b8de45d5c..e75182c6e 100644
--- a/workspaces/views.py
+++ b/workspaces/views.py
@@ -147,7 +147,8 @@ def render_workspace_by_membership(membership: WorkspaceMembership):
gui.newline()
- render_pending_invites_list(workspace=workspace, current_member=membership)
+ if membership.can_invite():
+ render_pending_invites_list(workspace=workspace, current_member=membership)
can_leave = membership.can_leave_workspace()
if not can_leave:
From 15169ed36994e2db38da9e12e6df6267e8bd7ff0 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 10 Oct 2024 10:55:41 +0530
Subject: [PATCH 40/81] fix: recipe page breaks for anon users
---
daras_ai_v2/base.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 1534d9da9..fdaafc3b6 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -446,9 +446,9 @@ def can_user_save_run(
)
def can_user_edit_published_run(self, published_run: PublishedRun) -> bool:
- return (
+ return bool(self.request.user) and (
self.is_current_user_admin()
- or published_run.workspace == self.current_workspace
+ or published_run.workspace_id == self.current_workspace.id
)
def _render_title(self, title: str):
From cab2a19049b0fcdca9d42c9217f770eca9de532e Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 10 Oct 2024 18:41:25 +0530
Subject: [PATCH 41/81] rename: published_run_notes ->
published_run_description
---
daras_ai_v2/base.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 14c1587f0..00d2cc952 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -668,9 +668,9 @@ def _render_publish_form(
key="published_run_title",
value=title,
)
- published_run_notes = gui.text_input(
+ published_run_description = gui.text_input(
"###### Description",
- key="published_run_notes",
+ key="published_run_description",
value=(pr.notes or self.preview_description(gui.session_state) or ""),
)
with gui.div(className="d-flex align-items-center"):
@@ -751,14 +751,14 @@ def _render_publish_form(
user=self.request.user,
workspace=self.current_workspace,
title=published_run_title.strip(),
- notes=published_run_notes.strip(),
+ notes=published_run_description.strip(),
visibility=PublishedRunVisibility(pr.visibility),
)
else:
updates = dict(
saved_run=sr,
title=published_run_title.strip(),
- notes=published_run_notes.strip(),
+ notes=published_run_description.strip(),
visibility=published_run_visibility,
)
if not self._has_published_run_changed(published_run=pr, **updates):
From 95c7fc7d3f0146db06b9c11f1c2b5a75e9823f03 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 10 Oct 2024 18:41:51 +0530
Subject: [PATCH 42/81] fix: visibility for duplicate runs should be UNLISTED
---
daras_ai_v2/base.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 00d2cc952..70361db4c 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -759,7 +759,7 @@ def _render_publish_form(
saved_run=sr,
title=published_run_title.strip(),
notes=published_run_description.strip(),
- visibility=published_run_visibility,
+ visibility=PublishedRunVisibility.UNLISTED,
)
if not self._has_published_run_changed(published_run=pr, **updates):
gui.error("No changes to publish", icon="⚠️")
@@ -1782,6 +1782,7 @@ def publish_and_redirect(self) -> typing.NoReturn | None:
published_run_id=get_random_doc_id(),
saved_run=self.current_sr,
user=self.request.user,
+ workspace=self.current_workspace,
title=self._get_default_pr_title(),
notes=self.current_pr.notes,
visibility=PublishedRunVisibility(PublishedRunVisibility.UNLISTED),
From 9b8a0e5cbb003385c960c35cc3b3abe4815f5b2f Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 10 Oct 2024 18:42:45 +0530
Subject: [PATCH 43/81] fix padding in publish dialog
---
daras_ai_v2/base.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 70361db4c..f34492ae2 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -658,7 +658,7 @@ def _render_publish_form(
f'If you want to create a new example, press "{icons.fork} Save as New".'
)
- with gui.div(className="my-4"):
+ with gui.div():
if is_update_mode:
title = pr.title or self.title
else:
@@ -674,19 +674,18 @@ def _render_publish_form(
value=(pr.notes or self.preview_description(gui.session_state) or ""),
)
with gui.div(className="d-flex align-items-center"):
- with gui.div(className="fs-3 text-muted mb-3"):
+ with gui.div(className="fs-3 text-muted mb-3 me-2"):
gui.html(icons.notes)
with gui.div(className="flex-grow-1"):
change_notes = gui.text_input(
"",
key="published_run_change_notes",
value="",
- className="ms-2",
placeholder="Add change notes",
)
col1, col2 = gui.columns([1, 3])
- with col1, gui.div(className="mt-2"):
+ with col1:
gui.write("###### Workspace")
with col2:
if self.request.user and self.request.user.get_workspaces().count() > 1:
From 563e0811ad921f801486a935dc6b58a676266dec Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 10 Oct 2024 19:40:58 +0530
Subject: [PATCH 44/81] fix s/workspace.logo/workspace.photo_url
---
daras_ai_v2/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index f34492ae2..cd78a75ff 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -1377,7 +1377,7 @@ def render_author(
link = None
if isinstance(workspace_or_user, Workspace):
workspace = workspace_or_user
- photo = workspace.logo
+ photo = workspace.photo_url
if not photo and workspace.is_personal:
photo = workspace.created_by.photo_url
name = workspace.display_name(current_user=current_user)
From 68a9bd41edeac932d9fd0b0804663674c99b4425 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 14 Oct 2024 17:18:39 +0530
Subject: [PATCH 45/81] rename personal->Personal
---
workspaces/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/workspaces/models.py b/workspaces/models.py
index 38d945b75..85de756ef 100644
--- a/workspaces/models.py
+++ b/workspaces/models.py
@@ -322,7 +322,7 @@ def display_name(self, current_user: AppUser | None = None) -> str:
elif (
self.is_personal and current_user and self.created_by_id == current_user.id
):
- return f"{current_user.full_name()} (personal)"
+ return f"{current_user.full_name()} (Personal)"
elif self.is_personal:
return self.created_by.full_name()
else:
From c0277e7dcbf1de8139f2d2410616fe700e7f958e Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 16 Oct 2024 20:58:03 +0530
Subject: [PATCH 46/81] commit everything
---
bots/models.py | 35 ++++++-
daras_ai_v2/base.py | 231 +++++++++++++++++++++++++++---------------
daras_ai_v2/icons.py | 4 +
routers/account.py | 6 +-
workspaces/models.py | 21 +++-
workspaces/views.py | 38 ++++---
workspaces/widgets.py | 4 +-
7 files changed, 238 insertions(+), 101 deletions(-)
diff --git a/bots/models.py b/bots/models.py
index 926ca264e..4b63a1de5 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -15,12 +15,14 @@
from app_users.models import AppUser
from bots.admin_links import open_in_new_tab
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
-from daras_ai_v2 import icons
+from daras_ai_v2 import icons, urls
from daras_ai_v2.crypto import get_random_doc_id
+from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunctionResponse
from gooeysite.bg_db_conn import get_celery_result_db_safe
from gooeysite.custom_create import get_or_create_lazy
+from workspaces.widgets import get_route_path_for_workspace
if typing.TYPE_CHECKING:
import celery.result
@@ -36,13 +38,40 @@
class PublishedRunVisibility(models.IntegerChoices):
UNLISTED = 1
PUBLIC = 2
+ INTERNAL = 3
+
+ @classmethod
+ def for_workspace(
+ cls, workspace: "Workspace"
+ ) -> typing.Iterable["PublishedRunVisibility"]:
+ if workspace.is_personal:
+ return [cls.UNLISTED, cls.PUBLIC]
+ else:
+ # TODO: Add cls.PUBLIC when team-handles are added
+ return [cls.UNLISTED, cls.INTERNAL]
+
+ def help_text(self, workspace: "Workspace | None" = None):
+ from routers.account import profile_route, saved_route
- def help_text(self):
match self:
case PublishedRunVisibility.UNLISTED:
return f"{self.get_icon()} Only me + people with a link"
+ case PublishedRunVisibility.PUBLIC if workspace and workspace.is_personal:
+ user = workspace.created_by
+ if user.handle:
+ profile_url = user.handle.get_app_url()
+ pretty_profile_url = urls.remove_scheme(profile_url).rstrip("/")
+ return f'{self.get_icon()} Public on {profile_url}'
+ else:
+ edit_profile_url = get_route_path(profile_route)
+ return f'{self.get_icon()} Public on my profile page'
case PublishedRunVisibility.PUBLIC:
return f"{self.get_icon()} Public"
+ case PublishedRunVisibility.INTERNAL if workspace:
+ saved_route_url = get_route_path_for_workspace(saved_route, workspace)
+ return f'{self.get_icon()} Members can find and edit'
+ case PublishedRunVisibility.INTERNAL:
+ return f"{self.get_icon()} Members can find and edit"
def get_icon(self):
match self:
@@ -50,6 +79,8 @@ def get_icon(self):
return icons.lock
case PublishedRunVisibility.PUBLIC:
return icons.globe
+ case PublishedRunVisibility.INTERNAL:
+ return icons.company_solid
def get_badge_html(self):
match self:
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index cd78a75ff..3265378f2 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -36,7 +36,7 @@
)
from daras_ai.image_input import truncate_text_words
from daras_ai.text_format import format_number_with_suffix
-from daras_ai_v2 import settings, urls, icons
+from daras_ai_v2 import settings, icons
from daras_ai_v2.api_examples_widget import api_example_generator
from daras_ai_v2.breadcrumbs import render_breadcrumbs, get_title_breadcrumbs
from daras_ai_v2.copy_to_clipboard_button_widget import (
@@ -68,13 +68,11 @@
should_attempt_auto_recharge,
run_auto_recharge_gracefully,
)
-from routers.account import AccountTabs
from routers.root import RecipeTabs
from workspaces.widgets import (
create_workspace_with_defaults,
get_current_workspace,
set_current_workspace,
- workspace_selector,
)
from workspaces.models import Workspace
@@ -490,7 +488,8 @@ def _render_share_button(self):
if dialog.is_open:
with gui.alert_dialog(
- ref=dialog, modal_title=f"#### Share: {self.current_pr.title}"
+ ref=dialog,
+ modal_title=f"#### Share: {self.current_pr.title}",
):
self._render_share_modal(dialog=dialog)
else:
@@ -507,43 +506,53 @@ def _render_copy_link_button(
)
def _render_share_modal(self, dialog: gui.AlertDialogRef):
- with gui.div(className="visibility-radio mb-5"):
- options = {
- str(enum.value): enum.help_text() for enum in PublishedRunVisibility
- }
- if self.request.user and self.request.user.handle:
- profile_url = self.request.user.handle.get_app_url()
- pretty_profile_url = urls.remove_scheme(profile_url).rstrip("/")
- options[
- str(PublishedRunVisibility.PUBLIC.value)
- ] += f' on [{pretty_profile_url}]({profile_url})'
- elif self.request.user and not self.request.user.is_anonymous:
- edit_profile_url = AccountTabs.profile.get_url_path(self.request)
- options[
- str(PublishedRunVisibility.PUBLIC.value)
- ] += f' on my [profile page]({edit_profile_url})'
-
- published_run_visibility = PublishedRunVisibility(
- int(
- gui.radio(
- "",
- options=options,
- format_func=options.__getitem__,
- key="published_run_visibility",
- value=str(self.current_pr.visibility),
- )
+ if not self.current_pr.workspace.is_personal:
+ with gui.div(className="mb-4"):
+ self._render_workspace_with_invite_button(self.current_pr.workspace)
+
+ options = {
+ str(enum.value): enum.help_text(self.current_pr.workspace)
+ for enum in PublishedRunVisibility.for_workspace(self.current_pr.workspace)
+ }
+ published_run_visibility = PublishedRunVisibility(
+ int(
+ gui.radio(
+ "",
+ options=options,
+ format_func=options.__getitem__,
+ key="published_run_visibility",
+ value=str(self.current_pr.visibility),
)
)
- gui.radio(
- "",
- options=[
- 'Anyone at my workspace (coming soon)'
- ],
- disabled=True,
- checked_by_default=False,
- )
+ )
- with gui.div(className="d-flex justify-content-between"):
+ if (
+ self.current_workspace.is_personal
+ and self.request.user.get_workspaces().count() > 1
+ ):
+ with gui.div(className="alert alert-warning mb-0 mt-4"):
+ duplicate = gui.button(
+ f"{icons.fork} Duplicate", type="link", className="d-inline m-0 p-0"
+ )
+ gui.html(" this workflow to edit with others")
+ ref = gui.use_alert_dialog(key="publish-modal")
+ if duplicate:
+ self.clear_publish_form()
+ ref.set_open(True)
+ if ref.is_open:
+ gui.session_state["published_run_workspace"] = (
+ self.request.user.get_workspaces()
+ .filter(is_personal=False)
+ .first()
+ .id
+ )
+ return self._render_publish_dialog(ref=ref)
+
+ elif self.current_workspace.is_personal:
+ with gui.div(className="alert alert-warning mb-0 mt-4"):
+ gui.html(f"{icons.company} Create a team workspace to edit with others")
+
+ with gui.div(className="d-flex justify-content-between pt-4"):
pressed_copy = copy_to_clipboard_button_with_return(
label="Copy Link",
key="copy-link-in-share-modal",
@@ -571,17 +580,29 @@ def _render_share_modal(self, dialog: gui.AlertDialogRef):
dialog.set_open(False)
gui.rerun()
- def _render_save_button(self):
- can_edit = self.can_user_edit_published_run(self.current_pr)
+ def _render_workspace_with_invite_button(self, workspace: Workspace):
+ from workspaces.views import member_invite_button_with_dialog
+ col1, col2 = gui.columns([9, 3])
+ with col1:
+ with gui.tag("p", className="mb-1 text-muted"):
+ gui.html("WORKSPACE")
+ self.render_author(workspace, current_user=self.request.user)
+ with col2:
+ membership = workspace.memberships.get(user_id=self.request.user.id)
+ member_invite_button_with_dialog(
+ membership,
+ close_on_confirm=False,
+ type="tertiary",
+ className="mb-0",
+ )
+
+ def _render_save_button(self):
with gui.div(className="d-flex justify-content-end"):
gui.html(
"""