diff --git a/packages/backend/server/migrations/20241202070358_universal_subscription/migration.sql b/packages/backend/server/migrations/20241202070358_universal_subscription/migration.sql new file mode 100644 index 0000000000000..5c82cbd51102a --- /dev/null +++ b/packages/backend/server/migrations/20241202070358_universal_subscription/migration.sql @@ -0,0 +1,53 @@ +-- DropForeignKey +ALTER TABLE "user_invoices" DROP CONSTRAINT "user_invoices_user_id_fkey"; + +-- DropForeignKey +ALTER TABLE "user_subscriptions" DROP CONSTRAINT "user_subscriptions_user_id_fkey"; + +-- CreateTable +CREATE TABLE "subscriptions" ( + "id" SERIAL NOT NULL, + "target_id" VARCHAR NOT NULL, + "plan" VARCHAR(20) NOT NULL, + "recurring" VARCHAR(20) NOT NULL, + "variant" VARCHAR(20), + "quantity" INTEGER NOT NULL DEFAULT 1, + "stripe_subscription_id" TEXT, + "stripe_schedule_id" VARCHAR, + "status" VARCHAR(20) NOT NULL, + "start" TIMESTAMPTZ(3) NOT NULL, + "end" TIMESTAMPTZ(3), + "next_bill_at" TIMESTAMPTZ(3), + "canceled_at" TIMESTAMPTZ(3), + "trial_start" TIMESTAMPTZ(3), + "trial_end" TIMESTAMPTZ(3), + "created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(3) NOT NULL, + + CONSTRAINT "subscriptions_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "invoices" ( + "stripe_invoice_id" TEXT NOT NULL, + "target_id" VARCHAR NOT NULL, + "currency" VARCHAR(3) NOT NULL, + "amount" INTEGER NOT NULL, + "status" VARCHAR(20) NOT NULL, + "created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(3) NOT NULL, + "reason" VARCHAR, + "last_payment_error" TEXT, + "link" TEXT, + + CONSTRAINT "invoices_pkey" PRIMARY KEY ("stripe_invoice_id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "subscriptions_stripe_subscription_id_key" ON "subscriptions"("stripe_subscription_id"); + +-- CreateIndex +CREATE UNIQUE INDEX "subscriptions_target_id_plan_key" ON "subscriptions"("target_id", "plan"); + +-- CreateIndex +CREATE INDEX "invoices_target_id_idx" ON "invoices"("target_id"); diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 865a22a3b133f..269cd8b37259f 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -23,9 +23,7 @@ model User { registered Boolean @default(true) features UserFeature[] - customer UserStripeCustomer? - subscriptions UserSubscription[] - invoices UserInvoice[] + userStripeCustomer UserStripeCustomer? workspacePermissions WorkspaceUserPermission[] pagePermissions WorkspacePageUserPermission[] connectedAccounts ConnectedAccount[] @@ -318,77 +316,6 @@ model SnapshotHistory { @@map("snapshot_histories") } -model UserStripeCustomer { - userId String @id @map("user_id") @db.VarChar - stripeCustomerId String @unique @map("stripe_customer_id") @db.VarChar - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) - - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - - @@map("user_stripe_customers") -} - -model UserSubscription { - id Int @id @default(autoincrement()) @db.Integer - userId String @map("user_id") @db.VarChar - plan String @db.VarChar(20) - // yearly/monthly/lifetime - recurring String @db.VarChar(20) - // onetime subscription or anything else - variant String? @db.VarChar(20) - // subscription.id, null for linefetime payment or one time payment subscription - stripeSubscriptionId String? @unique @map("stripe_subscription_id") - // subscription.status, active/past_due/canceled/unpaid... - status String @db.VarChar(20) - // subscription.current_period_start - start DateTime @map("start") @db.Timestamptz(3) - // subscription.current_period_end, null for lifetime payment - end DateTime? @map("end") @db.Timestamptz(3) - // subscription.billing_cycle_anchor - nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(3) - // subscription.canceled_at - canceledAt DateTime? @map("canceled_at") @db.Timestamptz(3) - // subscription.trial_start - trialStart DateTime? @map("trial_start") @db.Timestamptz(3) - // subscription.trial_end - trialEnd DateTime? @map("trial_end") @db.Timestamptz(3) - stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar - - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) - updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - - @@unique([userId, plan]) - @@map("user_subscriptions") -} - -model UserInvoice { - id Int @id @default(autoincrement()) @db.Integer - userId String @map("user_id") @db.VarChar - stripeInvoiceId String @unique @map("stripe_invoice_id") - currency String @db.VarChar(3) - // CNY 12.50 stored as 1250 - amount Int @db.Integer - status String @db.VarChar(20) - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) - updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) - // billing reason - reason String? @db.VarChar - lastPaymentError String? @map("last_payment_error") @db.Text - // stripe hosted invoice link - link String? @db.Text - - // @deprecated - plan String? @db.VarChar(20) - // @deprecated - recurring String? @db.VarChar(20) - - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - - @@index([userId]) - @@map("user_invoices") -} - enum AiPromptRole { system assistant @@ -503,3 +430,124 @@ model RuntimeConfig { @@unique([module, key]) @@map("app_runtime_settings") } + +model DeprecatedUserSubscription { + id Int @id @default(autoincrement()) @db.Integer + userId String @map("user_id") @db.VarChar + plan String @db.VarChar(20) + // yearly/monthly/lifetime + recurring String @db.VarChar(20) + // onetime subscription or anything else + variant String? @db.VarChar(20) + // subscription.id, null for lifetime payment or one time payment subscription + stripeSubscriptionId String? @unique @map("stripe_subscription_id") + // subscription.status, active/past_due/canceled/unpaid... + status String @db.VarChar(20) + // subscription.current_period_start + start DateTime @map("start") @db.Timestamptz(3) + // subscription.current_period_end, null for lifetime payment + end DateTime? @map("end") @db.Timestamptz(3) + // subscription.billing_cycle_anchor + nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(3) + // subscription.canceled_at + canceledAt DateTime? @map("canceled_at") @db.Timestamptz(3) + // subscription.trial_start + trialStart DateTime? @map("trial_start") @db.Timestamptz(3) + // subscription.trial_end + trialEnd DateTime? @map("trial_end") @db.Timestamptz(3) + stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar + + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) + + @@unique([userId, plan]) + @@map("user_subscriptions") +} + +model DeprecatedUserInvoice { + id Int @id @default(autoincrement()) @db.Integer + userId String @map("user_id") @db.VarChar + stripeInvoiceId String @unique @map("stripe_invoice_id") + currency String @db.VarChar(3) + // CNY 12.50 stored as 1250 + amount Int @db.Integer + status String @db.VarChar(20) + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) + // billing reason + reason String? @db.VarChar + lastPaymentError String? @map("last_payment_error") @db.Text + // stripe hosted invoice link + link String? @db.Text + + // @deprecated + plan String? @db.VarChar(20) + // @deprecated + recurring String? @db.VarChar(20) + + @@index([userId]) + @@map("user_invoices") +} + +model UserStripeCustomer { + userId String @id @map("user_id") @db.VarChar + stripeCustomerId String @unique @map("stripe_customer_id") @db.VarChar + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@map("user_stripe_customers") +} + +model Subscription { + id Int @id @default(autoincrement()) @db.Integer + targetId String @map("target_id") @db.VarChar + plan String @db.VarChar(20) + // yearly/monthly/lifetime + recurring String @db.VarChar(20) + // onetime subscription or anything else + variant String? @db.VarChar(20) + quantity Int @default(1) @db.Integer + // subscription.id, null for lifetime payment or one time payment subscription + stripeSubscriptionId String? @unique @map("stripe_subscription_id") + // stripe schedule id + stripeScheduleId String? @map("stripe_schedule_id") @db.VarChar + // subscription.status, active/past_due/canceled/unpaid... + status String @db.VarChar(20) + // subscription.current_period_start + start DateTime @map("start") @db.Timestamptz(3) + // subscription.current_period_end, null for lifetime payment + end DateTime? @map("end") @db.Timestamptz(3) + // subscription.billing_cycle_anchor + nextBillAt DateTime? @map("next_bill_at") @db.Timestamptz(3) + // subscription.canceled_at + canceledAt DateTime? @map("canceled_at") @db.Timestamptz(3) + // subscription.trial_start + trialStart DateTime? @map("trial_start") @db.Timestamptz(3) + // subscription.trial_end + trialEnd DateTime? @map("trial_end") @db.Timestamptz(3) + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) + + @@unique([targetId, plan]) + @@map("subscriptions") +} + +model Invoice { + stripeInvoiceId String @id @map("stripe_invoice_id") + targetId String @map("target_id") @db.VarChar + currency String @db.VarChar(3) + // CNY 12.50 stored as 1250 + amount Int @db.Integer + status String @db.VarChar(20) + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) + // billing reason + reason String? @db.VarChar + lastPaymentError String? @map("last_payment_error") @db.Text + // stripe hosted invoice link + link String? @db.Text + + @@index([targetId]) + @@map("invoices") +} diff --git a/packages/backend/server/src/core/user/resolver.ts b/packages/backend/server/src/core/user/resolver.ts index 00723651f2358..4eba70f6da835 100644 --- a/packages/backend/server/src/core/user/resolver.ts +++ b/packages/backend/server/src/core/user/resolver.ts @@ -5,7 +5,6 @@ import { Int, Mutation, Query, - ResolveField, Resolver, } from '@nestjs/graphql'; import { PrismaClient } from '@prisma/client'; @@ -37,7 +36,6 @@ import { @Resolver(() => UserType) export class UserResolver { constructor( - private readonly prisma: PrismaClient, private readonly storage: AvatarStorage, private readonly users: UserService ) {} @@ -72,16 +70,6 @@ export class UserResolver { }; } - @ResolveField(() => Int, { - name: 'invoiceCount', - description: 'Get user invoice count', - }) - async invoiceCount(@CurrentUser() user: CurrentUser) { - return this.prisma.userInvoice.count({ - where: { userId: user.id }, - }); - } - @Mutation(() => UserType, { name: 'uploadAvatar', description: 'Upload user avatar', diff --git a/packages/backend/server/src/core/workspaces/index.ts b/packages/backend/server/src/core/workspaces/index.ts index f100a953ca472..78d6e707387cb 100644 --- a/packages/backend/server/src/core/workspaces/index.ts +++ b/packages/backend/server/src/core/workspaces/index.ts @@ -37,4 +37,4 @@ import { }) export class WorkspaceModule {} -export type { InvitationType, WorkspaceType } from './types'; +export { InvitationType, WorkspaceType } from './types'; diff --git a/packages/backend/server/src/data/migrations/1733125339942-universal-subscription.ts b/packages/backend/server/src/data/migrations/1733125339942-universal-subscription.ts new file mode 100644 index 0000000000000..8233c57877eb8 --- /dev/null +++ b/packages/backend/server/src/data/migrations/1733125339942-universal-subscription.ts @@ -0,0 +1,29 @@ +import { PrismaClient } from '@prisma/client'; + +import { loop } from './utils/loop'; + +export class UniversalSubscription1733125339942 { + // do the migration + static async up(db: PrismaClient) { + await loop(async (offset, take) => { + const oldSubscriptions = await db.deprecatedUserSubscription.findMany({ + skip: offset, + take, + }); + + await db.subscription.createMany({ + data: oldSubscriptions.map(s => ({ + targetId: s.userId, + ...s, + })), + }); + + return oldSubscriptions.length; + }, 50); + } + + // revert the migration + static async down(_db: PrismaClient) { + // noop + } +} diff --git a/packages/backend/server/src/fundamentals/config/index.ts b/packages/backend/server/src/fundamentals/config/index.ts index 826d20b5be6dc..91afcdbf1a834 100644 --- a/packages/backend/server/src/fundamentals/config/index.ts +++ b/packages/backend/server/src/fundamentals/config/index.ts @@ -36,3 +36,5 @@ export class ConfigModule { }; }; } + +export { Runtime }; diff --git a/packages/backend/server/src/fundamentals/error/def.ts b/packages/backend/server/src/fundamentals/error/def.ts index 10403ee477d44..385eebd3fc727 100644 --- a/packages/backend/server/src/fundamentals/error/def.ts +++ b/packages/backend/server/src/fundamentals/error/def.ts @@ -412,15 +412,28 @@ export const USER_FRIENDLY_ERRORS = { }, // Subscription Errors + unsupported_subscription_plan: { + type: 'invalid_input', + args: { plan: 'string' }, + message: ({ plan }) => `Unsupported subscription plan: ${plan}.`, + }, failed_to_checkout: { type: 'internal_server_error', message: 'Failed to create checkout session.', }, + invalid_checkout_parameters: { + type: 'invalid_input', + message: 'Invalid checkout parameters provided.', + }, subscription_already_exists: { type: 'resource_already_exists', args: { plan: 'string' }, message: ({ plan }) => `You have already subscribed to the ${plan} plan.`, }, + invalid_subscription_parameters: { + type: 'invalid_input', + message: 'Invalid subscription parameters provided.', + }, subscription_not_exists: { type: 'resource_not_found', args: { plan: 'string' }, @@ -430,6 +443,10 @@ export const USER_FRIENDLY_ERRORS = { type: 'action_forbidden', message: 'Your subscription has already been canceled.', }, + subscription_has_not_been_canceled: { + type: 'action_forbidden', + message: 'Your subscription has not been canceled.', + }, subscription_expired: { type: 'action_forbidden', message: 'Your subscription has expired.', @@ -453,6 +470,14 @@ export const USER_FRIENDLY_ERRORS = { type: 'action_forbidden', message: 'You cannot update an onetime payment subscription.', }, + workspace_id_required_for_team_subscription: { + type: 'invalid_input', + message: 'A workspace is required to checkout for team subscription.', + }, + workspace_id_required_to_update_team_subscription: { + type: 'invalid_input', + message: 'Workspace id is required to update team subscription.', + }, // Copilot errors copilot_session_not_found: { diff --git a/packages/backend/server/src/fundamentals/error/errors.gen.ts b/packages/backend/server/src/fundamentals/error/errors.gen.ts index ed0506970479d..c3212d71cc815 100644 --- a/packages/backend/server/src/fundamentals/error/errors.gen.ts +++ b/packages/backend/server/src/fundamentals/error/errors.gen.ts @@ -328,12 +328,28 @@ export class FailedToUpsertSnapshot extends UserFriendlyError { super('internal_server_error', 'failed_to_upsert_snapshot', message); } } +@ObjectType() +class UnsupportedSubscriptionPlanDataType { + @Field() plan!: string +} + +export class UnsupportedSubscriptionPlan extends UserFriendlyError { + constructor(args: UnsupportedSubscriptionPlanDataType, message?: string | ((args: UnsupportedSubscriptionPlanDataType) => string)) { + super('invalid_input', 'unsupported_subscription_plan', message, args); + } +} export class FailedToCheckout extends UserFriendlyError { constructor(message?: string) { super('internal_server_error', 'failed_to_checkout', message); } } + +export class InvalidCheckoutParameters extends UserFriendlyError { + constructor(message?: string) { + super('invalid_input', 'invalid_checkout_parameters', message); + } +} @ObjectType() class SubscriptionAlreadyExistsDataType { @Field() plan!: string @@ -344,6 +360,12 @@ export class SubscriptionAlreadyExists extends UserFriendlyError { super('resource_already_exists', 'subscription_already_exists', message, args); } } + +export class InvalidSubscriptionParameters extends UserFriendlyError { + constructor(message?: string) { + super('invalid_input', 'invalid_subscription_parameters', message); + } +} @ObjectType() class SubscriptionNotExistsDataType { @Field() plan!: string @@ -361,6 +383,12 @@ export class SubscriptionHasBeenCanceled extends UserFriendlyError { } } +export class SubscriptionHasNotBeenCanceled extends UserFriendlyError { + constructor(message?: string) { + super('action_forbidden', 'subscription_has_not_been_canceled', message); + } +} + export class SubscriptionExpired extends UserFriendlyError { constructor(message?: string) { super('action_forbidden', 'subscription_expired', message); @@ -400,6 +428,18 @@ export class CantUpdateOnetimePaymentSubscription extends UserFriendlyError { } } +export class WorkspaceIdRequiredForTeamSubscription extends UserFriendlyError { + constructor(message?: string) { + super('invalid_input', 'workspace_id_required_for_team_subscription', message); + } +} + +export class WorkspaceIdRequiredToUpdateTeamSubscription extends UserFriendlyError { + constructor(message?: string) { + super('invalid_input', 'workspace_id_required_to_update_team_subscription', message); + } +} + export class CopilotSessionNotFound extends UserFriendlyError { constructor(message?: string) { super('resource_not_found', 'copilot_session_not_found', message); @@ -587,15 +627,21 @@ export enum ErrorNames { PAGE_IS_NOT_PUBLIC, FAILED_TO_SAVE_UPDATES, FAILED_TO_UPSERT_SNAPSHOT, + UNSUPPORTED_SUBSCRIPTION_PLAN, FAILED_TO_CHECKOUT, + INVALID_CHECKOUT_PARAMETERS, SUBSCRIPTION_ALREADY_EXISTS, + INVALID_SUBSCRIPTION_PARAMETERS, SUBSCRIPTION_NOT_EXISTS, SUBSCRIPTION_HAS_BEEN_CANCELED, + SUBSCRIPTION_HAS_NOT_BEEN_CANCELED, SUBSCRIPTION_EXPIRED, SAME_SUBSCRIPTION_RECURRING, CUSTOMER_PORTAL_CREATE_FAILED, SUBSCRIPTION_PLAN_NOT_FOUND, CANT_UPDATE_ONETIME_PAYMENT_SUBSCRIPTION, + WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION, + WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION, COPILOT_SESSION_NOT_FOUND, COPILOT_SESSION_DELETED, NO_COPILOT_PROVIDER_AVAILABLE, @@ -624,5 +670,5 @@ registerEnumType(ErrorNames, { export const ErrorDataUnionType = createUnionType({ name: 'ErrorDataUnion', types: () => - [UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, SpaceNotFoundDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const, + [UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, SpaceNotFoundDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const, }); diff --git a/packages/backend/server/src/plugins/payment/config.ts b/packages/backend/server/src/plugins/payment/config.ts index 0c8c2101fa2b4..707311b12f9f2 100644 --- a/packages/backend/server/src/plugins/payment/config.ts +++ b/packages/backend/server/src/plugins/payment/config.ts @@ -29,6 +29,6 @@ defineStartupConfig('plugins.payment', {}); defineRuntimeConfig('plugins.payment', { showLifetimePrice: { desc: 'Whether enable lifetime price and allow user to pay for it.', - default: false, + default: true, }, }); diff --git a/packages/backend/server/src/plugins/payment/cron.ts b/packages/backend/server/src/plugins/payment/cron.ts index dacacaf824959..9b96cd634ae08 100644 --- a/packages/backend/server/src/plugins/payment/cron.ts +++ b/packages/backend/server/src/plugins/payment/cron.ts @@ -19,7 +19,7 @@ export class SubscriptionCronJobs { @Cron(CronExpression.EVERY_HOUR) async cleanExpiredOnetimeSubscriptions() { - const subscriptions = await this.db.userSubscription.findMany({ + const subscriptions = await this.db.subscription.findMany({ where: { variant: SubscriptionVariant.Onetime, end: { @@ -30,7 +30,7 @@ export class SubscriptionCronJobs { for (const subscription of subscriptions) { this.event.emit('user.subscription.canceled', { - userId: subscription.userId, + userId: subscription.targetId, plan: subscription.plan as SubscriptionPlan, recurring: subscription.variant as SubscriptionRecurring, }); @@ -42,10 +42,10 @@ export class SubscriptionCronJobs { userId, plan, }: EventPayload<'user.subscription.canceled'>) { - await this.db.userSubscription.delete({ + await this.db.subscription.delete({ where: { - userId_plan: { - userId, + targetId_plan: { + targetId: userId, plan, }, }, diff --git a/packages/backend/server/src/plugins/payment/index.ts b/packages/backend/server/src/plugins/payment/index.ts index c7f85354e6da4..c59072bf43d95 100644 --- a/packages/backend/server/src/plugins/payment/index.ts +++ b/packages/backend/server/src/plugins/payment/index.ts @@ -2,19 +2,27 @@ import './config'; import { ServerFeature } from '../../core/config'; import { FeatureModule } from '../../core/features'; +import { PermissionModule } from '../../core/permission'; import { UserModule } from '../../core/user'; import { Plugin } from '../registry'; import { StripeWebhookController } from './controller'; import { SubscriptionCronJobs } from './cron'; -import { UserSubscriptionManager } from './manager'; -import { SubscriptionResolver, UserSubscriptionResolver } from './resolver'; +import { + UserSubscriptionManager, + WorkspaceSubscriptionManager, +} from './manager'; +import { + SubscriptionResolver, + UserSubscriptionResolver, + WorkspaceSubscriptionResolver, +} from './resolver'; import { SubscriptionService } from './service'; import { StripeProvider } from './stripe'; import { StripeWebhook } from './webhook'; @Plugin({ name: 'payment', - imports: [FeatureModule, UserModule], + imports: [FeatureModule, UserModule, PermissionModule], providers: [ StripeProvider, SubscriptionService, @@ -22,7 +30,9 @@ import { StripeWebhook } from './webhook'; UserSubscriptionResolver, StripeWebhook, UserSubscriptionManager, + WorkspaceSubscriptionManager, SubscriptionCronJobs, + WorkspaceSubscriptionResolver, ], controllers: [StripeWebhookController], requires: [ diff --git a/packages/backend/server/src/plugins/payment/manager/common.ts b/packages/backend/server/src/plugins/payment/manager/common.ts index 29336df980082..c8d16e4191e01 100644 --- a/packages/backend/server/src/plugins/payment/manager/common.ts +++ b/packages/backend/server/src/plugins/payment/manager/common.ts @@ -1,13 +1,23 @@ -import { UserStripeCustomer } from '@prisma/client'; +import { PrismaClient, UserStripeCustomer } from '@prisma/client'; +import Stripe from 'stripe'; +import { z } from 'zod'; +import { UserNotFound } from '../../../fundamentals'; +import { ScheduleManager } from '../schedule'; import { + encodeLookupKey, + KnownStripeInvoice, KnownStripePrice, KnownStripeSubscription, + LookupKey, SubscriptionPlan, SubscriptionRecurring, + SubscriptionVariant, } from '../types'; export interface Subscription { + stripeSubscriptionId: string | null; + stripeScheduleId: string | null; status: string; plan: string; recurring: string; @@ -21,36 +31,225 @@ export interface Subscription { } export interface Invoice { + stripeInvoiceId: string; currency: string; amount: number; status: string; - createdAt: Date; + reason: string | null; lastPaymentError: string | null; link: string | null; } -export interface SubscriptionManager { - filterPrices( +export const SubscriptionIdentity = z.object({ + plan: z.nativeEnum(SubscriptionPlan), +}); + +export const CheckoutParams = z.object({ + plan: z.nativeEnum(SubscriptionPlan), + recurring: z.nativeEnum(SubscriptionRecurring), + variant: z.nativeEnum(SubscriptionVariant).nullable().optional(), + coupon: z.string().nullable().optional(), + quantity: z.number().min(1).nullable().optional(), + successCallbackLink: z.string(), +}); + +export abstract class SubscriptionManager { + protected readonly scheduleManager = new ScheduleManager(this.stripe); + constructor( + protected readonly stripe: Stripe, + protected readonly db: PrismaClient + ) {} + + abstract filterPrices( prices: KnownStripePrice[], customer?: UserStripeCustomer - ): Promise; + ): KnownStripePrice[] | Promise; + + abstract checkout( + price: KnownStripePrice, + params: z.infer, + args: any + ): Promise; - saveSubscription( + abstract saveStripeSubscription( subscription: KnownStripeSubscription ): Promise; - deleteSubscription(subscription: KnownStripeSubscription): Promise; + abstract deleteStripeSubscription( + subscription: KnownStripeSubscription + ): Promise; - getSubscription( - id: string, - plan: SubscriptionPlan + abstract getSubscription( + identity: z.infer ): Promise; + abstract cancelSubscription( + subscription: Subscription + ): Promise; - cancelSubscription(subscription: Subscription): Promise; - - resumeSubscription(subscription: Subscription): Promise; + abstract resumeSubscription( + subscription: Subscription + ): Promise; - updateSubscriptionRecurring( + abstract updateSubscriptionRecurring( subscription: Subscription, recurring: SubscriptionRecurring ): Promise; + + abstract saveInvoice(knownInvoice: KnownStripeInvoice): Promise; + + transformSubscription({ + lookupKey, + stripeSubscription: subscription, + }: KnownStripeSubscription): Subscription { + return { + ...lookupKey, + stripeScheduleId: subscription.schedule as string | null, + stripeSubscriptionId: subscription.id, + status: subscription.status, + start: new Date(subscription.current_period_start * 1000), + end: new Date(subscription.current_period_end * 1000), + trialStart: subscription.trial_start + ? new Date(subscription.trial_start * 1000) + : null, + trialEnd: subscription.trial_end + ? new Date(subscription.trial_end * 1000) + : null, + nextBillAt: !subscription.canceled_at + ? new Date(subscription.current_period_end * 1000) + : null, + canceledAt: subscription.canceled_at + ? new Date(subscription.canceled_at * 1000) + : null, + }; + } + + async transformInvoice({ + stripeInvoice, + }: KnownStripeInvoice): Promise { + const status = stripeInvoice.status ?? 'void'; + let error: string | boolean | null = null; + + if (status !== 'paid') { + if (stripeInvoice.last_finalization_error) { + error = stripeInvoice.last_finalization_error.message ?? true; + } else if ( + stripeInvoice.attempt_count > 1 && + stripeInvoice.payment_intent + ) { + const paymentIntent = + typeof stripeInvoice.payment_intent === 'string' + ? await this.stripe.paymentIntents.retrieve( + stripeInvoice.payment_intent + ) + : stripeInvoice.payment_intent; + + if (paymentIntent.last_payment_error) { + error = paymentIntent.last_payment_error.message ?? true; + } + } + } + + // fallback to generic error message + if (error === true) { + error = 'Payment Error. Please contact support.'; + } + + return { + stripeInvoiceId: stripeInvoice.id, + status, + link: stripeInvoice.hosted_invoice_url || null, + reason: stripeInvoice.billing_reason, + amount: stripeInvoice.total, + currency: stripeInvoice.currency, + lastPaymentError: error, + }; + } + + async getOrCreateCustomer(userId: string): Promise { + const user = await this.db.user.findUnique({ + where: { + id: userId, + }, + select: { + email: true, + userStripeCustomer: true, + }, + }); + + if (!user) { + throw new UserNotFound(); + } + + let customer = user.userStripeCustomer; + if (!customer) { + const stripeCustomersList = await this.stripe.customers.list({ + email: user.email, + limit: 1, + }); + + let stripeCustomer: Stripe.Customer | undefined; + if (stripeCustomersList.data.length) { + stripeCustomer = stripeCustomersList.data[0]; + } else { + stripeCustomer = await this.stripe.customers.create({ + email: user.email, + }); + } + + customer = await this.db.userStripeCustomer.create({ + data: { + userId, + stripeCustomerId: stripeCustomer.id, + }, + }); + } + + return customer; + } + + protected async getPrice( + lookupKey: LookupKey + ): Promise { + const prices = await this.stripe.prices.list({ + lookup_keys: [encodeLookupKey(lookupKey)], + limit: 1, + }); + + const price = prices.data[0]; + + return price + ? { + lookupKey, + price, + } + : null; + } + + protected async getCouponFromPromotionCode( + userFacingPromotionCode: string, + customer: UserStripeCustomer + ) { + const list = await this.stripe.promotionCodes.list({ + code: userFacingPromotionCode, + active: true, + limit: 1, + }); + + const code = list.data[0]; + if (!code) { + return null; + } + + // the coupons are always bound to products, we need to check it first + // but the logic would be too complicated, and stripe will complain if the code is not applicable when checking out + // It's safe to skip the check here + // code.coupon.applies_to.products.forEach() + + // check if the code is bound to a specific customer + return !code.customer || + (typeof code.customer === 'string' + ? code.customer === customer.stripeCustomerId + : code.customer.id === customer.stripeCustomerId) + ? code.coupon.id + : null; + } } diff --git a/packages/backend/server/src/plugins/payment/manager/index.ts b/packages/backend/server/src/plugins/payment/manager/index.ts index 13c53c6d662ba..21d7a26a20784 100644 --- a/packages/backend/server/src/plugins/payment/manager/index.ts +++ b/packages/backend/server/src/plugins/payment/manager/index.ts @@ -1,2 +1,3 @@ export * from './common'; export * from './user'; +export * from './workspace'; diff --git a/packages/backend/server/src/plugins/payment/manager/user.ts b/packages/backend/server/src/plugins/payment/manager/user.ts index be2103dea667f..32b20c570d930 100644 --- a/packages/backend/server/src/plugins/payment/manager/user.ts +++ b/packages/backend/server/src/plugins/payment/manager/user.ts @@ -1,10 +1,8 @@ import { Injectable } from '@nestjs/common'; -import { - PrismaClient, - UserStripeCustomer, - UserSubscription, -} from '@prisma/client'; +import { PrismaClient, UserStripeCustomer } from '@prisma/client'; +import { omit, pick } from 'lodash-es'; import Stripe from 'stripe'; +import { z } from 'zod'; import { EarlyAccessType, @@ -14,6 +12,9 @@ import { Config, EventEmitter, InternalServerError, + SubscriptionAlreadyExists, + SubscriptionPlanNotFound, + URLHelper, } from '../../../fundamentals'; import { CouponType, @@ -26,7 +27,7 @@ import { SubscriptionStatus, SubscriptionVariant, } from '../types'; -import { SubscriptionManager } from './common'; +import { CheckoutParams, Subscription, SubscriptionManager } from './common'; interface PriceStrategyStatus { proEarlyAccess: boolean; @@ -36,15 +37,30 @@ interface PriceStrategyStatus { onetime: boolean; } +export const UserSubscriptionIdentity = z.object({ + plan: z.enum([SubscriptionPlan.Pro, SubscriptionPlan.AI]), + userId: z.string(), +}); + +export const UserSubscriptionCheckoutArgs = z.object({ + user: z.object({ + id: z.string(), + email: z.string(), + }), +}); + @Injectable() -export class UserSubscriptionManager implements SubscriptionManager { +export class UserSubscriptionManager extends SubscriptionManager { constructor( - private readonly db: PrismaClient, + stripe: Stripe, + db: PrismaClient, private readonly config: Config, - private readonly stripe: Stripe, private readonly feature: FeatureManagementService, - private readonly event: EventEmitter - ) {} + private readonly event: EventEmitter, + private readonly url: URLHelper + ) { + super(stripe, db); + } async filterPrices( prices: KnownStripePrice[], @@ -71,11 +87,105 @@ export class UserSubscriptionManager implements SubscriptionManager { return availablePrices; } - async getSubscription(userId: string, plan: SubscriptionPlan) { - return this.db.userSubscription.findFirst({ + async checkout( + price: KnownStripePrice, + params: z.infer, + { user }: z.infer + ) { + const lookupKey = price.lookupKey; + const subscription = await this.getSubscription({ + // @ts-expect-error filtered already + plan: price.lookupKey.plan, + user, + }); + + if ( + subscription && + // do not allow to re-subscribe unless + !( + /* current subscription is a onetime subscription and so as the one that's checking out */ + ( + (subscription.variant === SubscriptionVariant.Onetime && + lookupKey.variant === SubscriptionVariant.Onetime) || + /* current subscription is normal subscription and is checking-out a lifetime subscription */ + (subscription.recurring !== SubscriptionRecurring.Lifetime && + subscription.variant !== SubscriptionVariant.Onetime && + lookupKey.recurring === SubscriptionRecurring.Lifetime) + ) + ) + ) { + throw new SubscriptionAlreadyExists({ plan: lookupKey.plan }); + } + + const customer = await this.getOrCreateCustomer(user.id); + const strategy = await this.strategyStatus(customer); + const available = await this.isPriceAvailable(price, { + ...strategy, + onetime: true, + }); + + if (!available) { + throw new SubscriptionPlanNotFound({ + plan: lookupKey.plan, + recurring: lookupKey.recurring, + }); + } + + const discounts = await (async () => { + const coupon = await this.getBuildInCoupon(customer, price); + if (coupon) { + return { discounts: [{ coupon }] }; + } else if (params.coupon) { + const couponId = await this.getCouponFromPromotionCode( + params.coupon, + customer + ); + if (couponId) { + return { discounts: [{ coupon: couponId }] }; + } + } + + return { allow_promotion_codes: true }; + })(); + + // mode: 'subscription' or 'payment' for lifetime and onetime payment + const mode = + lookupKey.recurring === SubscriptionRecurring.Lifetime || + lookupKey.variant === SubscriptionVariant.Onetime + ? { + mode: 'payment' as const, + invoice_creation: { + enabled: true, + }, + } + : { + mode: 'subscription' as const, + }; + + return this.stripe.checkout.sessions.create({ + line_items: [ + { + price: price.price.id, + quantity: 1, + }, + ], + tax_id_collection: { + enabled: true, + }, + ...discounts, + ...mode, + success_url: this.url.link(params.successCallbackLink, { + session_id: '{CHECKOUT_SESSION_ID}', + }), + customer: customer.stripeCustomerId, + }); + } + + async getSubscription(args: z.infer) { + return this.db.subscription.findFirst({ where: { - userId, - plan, + targetId: args.userId, + plan: args.plan, status: { in: [SubscriptionStatus.Active, SubscriptionStatus.Trialing], }, @@ -83,11 +193,8 @@ export class UserSubscriptionManager implements SubscriptionManager { }); } - async saveSubscription({ - userId, - lookupKey, - stripeSubscription: subscription, - }: KnownStripeSubscription) { + async saveStripeSubscription(subscription: KnownStripeSubscription) { + const { userId, lookupKey, stripeSubscription } = subscription; // update features first, features modify are idempotent // so there is no need to skip if a subscription already exists. // TODO(@forehalo): @@ -99,43 +206,85 @@ export class UserSubscriptionManager implements SubscriptionManager { recurring: lookupKey.recurring, }); - const commonData = { - status: subscription.status, - stripeScheduleId: subscription.schedule as string | null, - nextBillAt: !subscription.canceled_at - ? new Date(subscription.current_period_end * 1000) - : null, - canceledAt: subscription.canceled_at - ? new Date(subscription.canceled_at * 1000) - : null, - }; + const subscriptionData = this.transformSubscription(subscription); - return await this.db.userSubscription.upsert({ + // @deprecated backward compatibility + await this.db.deprecatedUserSubscription.upsert({ where: { - stripeSubscriptionId: subscription.id, + stripeSubscriptionId: stripeSubscription.id, }, - update: commonData, + update: pick(subscriptionData, [ + 'status', + 'stripeScheduleId', + 'nextBillAt', + 'canceledAt', + ]), create: { userId, - ...lookupKey, - stripeSubscriptionId: subscription.id, - start: new Date(subscription.current_period_start * 1000), - end: new Date(subscription.current_period_end * 1000), - trialStart: subscription.trial_start - ? new Date(subscription.trial_start * 1000) - : null, - trialEnd: subscription.trial_end - ? new Date(subscription.trial_end * 1000) - : null, - ...commonData, + ...subscriptionData, + }, + }); + + return this.db.subscription.upsert({ + where: { + stripeSubscriptionId: stripeSubscription.id, + }, + update: pick(subscriptionData, [ + 'status', + 'stripeScheduleId', + 'nextBillAt', + 'canceledAt', + ]), + create: { + targetId: userId, + ...subscriptionData, }, }); } - async cancelSubscription(subscription: UserSubscription) { - return this.db.userSubscription.update({ + async deleteStripeSubscription({ + userId, + lookupKey, + stripeSubscription, + }: KnownStripeSubscription) { + const deleted = await this.db.subscription.deleteMany({ where: { - id: subscription.id, + stripeSubscriptionId: stripeSubscription.id, + }, + }); + + // @deprecated backward compatibility + await this.db.deprecatedUserSubscription.deleteMany({ + where: { + stripeSubscriptionId: stripeSubscription.id, + }, + }); + + if (deleted.count > 0) { + this.event.emit('user.subscription.canceled', { + userId, + plan: lookupKey.plan, + recurring: lookupKey.recurring, + }); + } + } + + async cancelSubscription(subscription: Subscription) { + // @deprecated backward compatibility + await this.db.deprecatedUserSubscription.updateMany({ + where: { + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { + canceledAt: new Date(), + nextBillAt: null, + }, + }); + + return this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, }, data: { canceledAt: new Date(), @@ -144,9 +293,23 @@ export class UserSubscriptionManager implements SubscriptionManager { }); } - async resumeSubscription(subscription: UserSubscription) { - return this.db.userSubscription.update({ - where: { id: subscription.id }, + async resumeSubscription(subscription: Subscription) { + // @deprecated backward compatibility + await this.db.deprecatedUserSubscription.updateMany({ + where: { + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { + canceledAt: null, + nextBillAt: subscription.end, + }, + }); + + return this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, data: { canceledAt: null, nextBillAt: subscription.end, @@ -155,34 +318,30 @@ export class UserSubscriptionManager implements SubscriptionManager { } async updateSubscriptionRecurring( - subscription: UserSubscription, + subscription: Subscription, recurring: SubscriptionRecurring ) { - return this.db.userSubscription.update({ - where: { id: subscription.id }, + // @deprecated backward compatibility + await this.db.deprecatedUserSubscription.updateMany({ + where: { + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, data: { recurring }, }); - } - async deleteSubscription({ - userId, - lookupKey, - stripeSubscription, - }: KnownStripeSubscription) { - await this.db.userSubscription.delete({ + return this.db.subscription.update({ where: { - stripeSubscriptionId: stripeSubscription.id, + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, }, - }); - - this.event.emit('user.subscription.canceled', { - userId, - plan: lookupKey.plan, - recurring: lookupKey.recurring, + data: { recurring }, }); } - async validatePrice(price: KnownStripePrice, customer: UserStripeCustomer) { + private async getBuildInCoupon( + customer: UserStripeCustomer, + price: KnownStripePrice + ) { const strategyStatus = await this.strategyStatus(customer); // onetime price is allowed for checkout @@ -192,7 +351,7 @@ export class UserSubscriptionManager implements SubscriptionManager { return null; } - let coupon: CouponType | null = null; + let coupon: CouponType | undefined; if (price.lookupKey.variant === SubscriptionVariant.EA) { if (price.lookupKey.plan === SubscriptionPlan.Pro) { @@ -207,69 +366,40 @@ export class UserSubscriptionManager implements SubscriptionManager { } } - return { - price, - coupon, - }; + return coupon; } async saveInvoice(knownInvoice: KnownStripeInvoice) { const { userId, lookupKey, stripeInvoice } = knownInvoice; - const status = stripeInvoice.status ?? 'void'; - let error: string | boolean | null = null; - - if (status !== 'paid') { - if (stripeInvoice.last_finalization_error) { - error = stripeInvoice.last_finalization_error.message ?? true; - } else if ( - stripeInvoice.attempt_count > 1 && - stripeInvoice.payment_intent - ) { - const paymentIntent = - typeof stripeInvoice.payment_intent === 'string' - ? await this.stripe.paymentIntents.retrieve( - stripeInvoice.payment_intent - ) - : stripeInvoice.payment_intent; - - if (paymentIntent.last_payment_error) { - error = paymentIntent.last_payment_error.message ?? true; - } - } - } - - // fallback to generic error message - if (error === true) { - error = 'Payment Error. Please contact support.'; - } + const invoiceData = await this.transformInvoice(knownInvoice); - const invoice = this.db.userInvoice.upsert({ + // @deprecated backward compatibility + await this.db.deprecatedUserInvoice.upsert({ where: { stripeInvoiceId: stripeInvoice.id, }, - update: { - status, - link: stripeInvoice.hosted_invoice_url, - amount: stripeInvoice.total, - currency: stripeInvoice.currency, - lastPaymentError: error, - }, + update: omit(invoiceData, 'stripeInvoiceId'), create: { userId, + ...invoiceData, + }, + }); + + const invoice = this.db.invoice.upsert({ + where: { stripeInvoiceId: stripeInvoice.id, - status, - link: stripeInvoice.hosted_invoice_url, - reason: stripeInvoice.billing_reason, - amount: stripeInvoice.total, - currency: stripeInvoice.currency, - lastPaymentError: error, + }, + update: omit(invoiceData, 'stripeInvoiceId'), + create: { + targetId: userId, + ...invoiceData, }, }); // onetime and lifetime subscription is a special "subscription" that doesn't get involved with stripe subscription system // we track the deals by invoice only. - if (status === 'paid') { + if (stripeInvoice.status === 'paid') { if (lookupKey.recurring === SubscriptionRecurring.Lifetime) { await this.saveLifetimeSubscription(knownInvoice); } else if (lookupKey.variant === SubscriptionVariant.Onetime) { @@ -282,45 +412,49 @@ export class UserSubscriptionManager implements SubscriptionManager { async saveLifetimeSubscription( knownInvoice: KnownStripeInvoice - ): Promise { + ): Promise { // cancel previous non-lifetime subscription - const prevSubscription = await this.db.userSubscription.findUnique({ + const prevSubscription = await this.db.subscription.findUnique({ where: { - userId_plan: { - userId: knownInvoice.userId, + targetId_plan: { + targetId: knownInvoice.userId, plan: SubscriptionPlan.Pro, }, }, }); - let subscription: UserSubscription; - if (prevSubscription && prevSubscription.stripeSubscriptionId) { - subscription = await this.db.userSubscription.update({ - where: { - id: prevSubscription.id, - }, - data: { - stripeScheduleId: null, - stripeSubscriptionId: null, - plan: knownInvoice.lookupKey.plan, - recurring: SubscriptionRecurring.Lifetime, - start: new Date(), - end: null, - status: SubscriptionStatus.Active, - nextBillAt: null, - }, - }); - - await this.stripe.subscriptions.cancel( - prevSubscription.stripeSubscriptionId, - { - prorate: true, - } - ); + let subscription: Subscription; + if (prevSubscription) { + if (prevSubscription.stripeSubscriptionId) { + subscription = await this.db.subscription.update({ + where: { + id: prevSubscription.id, + }, + data: { + stripeScheduleId: null, + stripeSubscriptionId: null, + plan: knownInvoice.lookupKey.plan, + recurring: SubscriptionRecurring.Lifetime, + start: new Date(), + end: null, + status: SubscriptionStatus.Active, + nextBillAt: null, + }, + }); + + await this.stripe.subscriptions.cancel( + prevSubscription.stripeSubscriptionId, + { + prorate: true, + } + ); + } else { + subscription = prevSubscription; + } } else { - subscription = await this.db.userSubscription.create({ + subscription = await this.db.subscription.create({ data: { - userId: knownInvoice.userId, + targetId: knownInvoice.userId, stripeSubscriptionId: null, plan: knownInvoice.lookupKey.plan, recurring: SubscriptionRecurring.Lifetime, @@ -343,12 +477,13 @@ export class UserSubscriptionManager implements SubscriptionManager { async saveOnetimePaymentSubscription( knownInvoice: KnownStripeInvoice - ): Promise { + ): Promise { + // TODO(@forehalo): identify whether the invoice has already been redeemed. const { userId, lookupKey } = knownInvoice; - const existingSubscription = await this.db.userSubscription.findUnique({ + const existingSubscription = await this.db.subscription.findUnique({ where: { - userId_plan: { - userId, + targetId_plan: { + targetId: userId, plan: lookupKey.plan, }, }, @@ -362,7 +497,7 @@ export class UserSubscriptionManager implements SubscriptionManager { 60 * 1000; - let subscription: UserSubscription; + let subscription: Subscription; // extends the subscription time if exists if (existingSubscription) { @@ -385,16 +520,16 @@ export class UserSubscriptionManager implements SubscriptionManager { ), }; - subscription = await this.db.userSubscription.update({ + subscription = await this.db.subscription.update({ where: { id: existingSubscription.id, }, data: period, }); } else { - subscription = await this.db.userSubscription.create({ + subscription = await this.db.subscription.create({ data: { - userId, + targetId: userId, stripeSubscriptionId: null, ...lookupKey, start: new Date(), diff --git a/packages/backend/server/src/plugins/payment/manager/workspace.ts b/packages/backend/server/src/plugins/payment/manager/workspace.ts new file mode 100644 index 0000000000000..5eba7ae036e82 --- /dev/null +++ b/packages/backend/server/src/plugins/payment/manager/workspace.ts @@ -0,0 +1,305 @@ +import { Injectable } from '@nestjs/common'; +import { PrismaClient, UserStripeCustomer } from '@prisma/client'; +import { omit, pick } from 'lodash-es'; +import Stripe from 'stripe'; +import { z } from 'zod'; + +import { + EventEmitter, + type EventPayload, + OnEvent, + SubscriptionAlreadyExists, + URLHelper, +} from '../../../fundamentals'; +import { + KnownStripeInvoice, + KnownStripePrice, + KnownStripeSubscription, + retriveLookupKeyFromStripeSubscription, + SubscriptionPlan, + SubscriptionRecurring, + SubscriptionStatus, +} from '../types'; +import { + CheckoutParams, + Invoice, + Subscription, + SubscriptionManager, +} from './common'; + +export const WorkspaceSubscriptionIdentity = z.object({ + plan: z.literal(SubscriptionPlan.Team), + workspaceId: z.string(), +}); + +export const WorkspaceSubscriptionCheckoutArgs = z.object({ + plan: z.literal(SubscriptionPlan.Team), + workspaceId: z.string(), + user: z.object({ + id: z.string(), + email: z.string(), + }), +}); + +@Injectable() +export class WorkspaceSubscriptionManager extends SubscriptionManager { + constructor( + stripe: Stripe, + db: PrismaClient, + private readonly url: URLHelper, + private readonly event: EventEmitter + ) { + super(stripe, db); + } + + filterPrices( + prices: KnownStripePrice[], + _customer?: UserStripeCustomer + ): KnownStripePrice[] { + return prices.filter( + price => price.lookupKey.plan === SubscriptionPlan.Team + ); + } + + async checkout( + { price }: KnownStripePrice, + params: z.infer, + args: z.infer + ) { + const subscription = await this.getSubscription({ + plan: SubscriptionPlan.Team, + workspaceId: args.workspaceId, + }); + + if (subscription) { + throw new SubscriptionAlreadyExists({ plan: SubscriptionPlan.Team }); + } + + const customer = await this.getOrCreateCustomer(args.user.id); + + const discounts = await (async () => { + if (params.coupon) { + const couponId = await this.getCouponFromPromotionCode( + params.coupon, + customer + ); + if (couponId) { + return { discounts: [{ coupon: couponId }] }; + } + } + + return { allow_promotion_codes: true }; + })(); + + const count = await this.db.workspaceUserPermission.count({ + where: { + workspaceId: args.workspaceId, + // @TODO(darksky): replace with [status: WorkspaceUserPermissionStatus.Accepted] + accepted: true, + }, + }); + + return this.stripe.checkout.sessions.create({ + line_items: [ + { + price: price.id, + quantity: count, + }, + ], + tax_id_collection: { + enabled: true, + }, + + ...discounts, + mode: 'subscription', + success_url: this.url.link(params.successCallbackLink), + customer: customer.stripeCustomerId, + subscription_data: { + metadata: { + workspaceId: args.workspaceId, + }, + }, + }); + } + + async saveStripeSubscription(subscription: KnownStripeSubscription) { + const { lookupKey, quantity, stripeSubscription } = subscription; + + const workspaceId = stripeSubscription.metadata.workspaceId; + + if (!workspaceId) { + throw new Error( + 'Workspace ID is required in workspace subscription metadata' + ); + } + + this.event.emit('workspace.subscription.activated', { + workspaceId, + plan: lookupKey.plan, + recurring: lookupKey.recurring, + quantity, + }); + + const subscriptionData = this.transformSubscription(subscription); + + return this.db.subscription.upsert({ + where: { + stripeSubscriptionId: stripeSubscription.id, + }, + update: { + quantity, + ...pick(subscriptionData, [ + 'status', + 'stripeScheduleId', + 'nextBillAt', + 'canceledAt', + ]), + }, + create: { + targetId: workspaceId, + quantity, + ...subscriptionData, + }, + }); + } + + async deleteStripeSubscription({ + lookupKey, + stripeSubscription, + }: KnownStripeSubscription) { + const workspaceId = stripeSubscription.metadata.workspaceId; + + if (!workspaceId) { + throw new Error( + 'Workspace ID is required in workspace subscription metadata' + ); + } + + const deleted = await this.db.subscription.deleteMany({ + where: { stripeSubscriptionId: stripeSubscription.id }, + }); + + if (deleted.count > 0) { + this.event.emit('workspace.subscription.canceled', { + workspaceId, + plan: lookupKey.plan, + recurring: lookupKey.recurring, + }); + } + } + + getSubscription(identity: z.infer) { + return this.db.subscription.findFirst({ + where: { + targetId: identity.workspaceId, + status: { + in: [SubscriptionStatus.Active, SubscriptionStatus.Trialing], + }, + }, + }); + } + + async cancelSubscription(subscription: Subscription) { + return await this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { + canceledAt: new Date(), + nextBillAt: null, + }, + }); + } + + resumeSubscription(subscription: Subscription): Promise { + return this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { + canceledAt: null, + nextBillAt: subscription.end, + }, + }); + } + + updateSubscriptionRecurring( + subscription: Subscription, + recurring: SubscriptionRecurring + ): Promise { + return this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { recurring }, + }); + } + + async saveInvoice(knownInvoice: KnownStripeInvoice): Promise { + const { metadata, stripeInvoice } = knownInvoice; + + const workspaceId = metadata.workspaceId; + + if (!workspaceId) { + throw new Error('Workspace ID is required in workspace invoice metadata'); + } + + const invoiceData = await this.transformInvoice(knownInvoice); + + return this.db.invoice.upsert({ + where: { + stripeInvoiceId: stripeInvoice.id, + }, + update: omit(invoiceData, 'stripeInvoiceId'), + create: { + targetId: workspaceId, + ...invoiceData, + }, + }); + } + + @OnEvent('workspace.members.updated') + async onMembersUpdated({ + workspaceId, + count, + }: EventPayload<'workspace.members.updated'>) { + const subscription = await this.getSubscription({ + plan: SubscriptionPlan.Team, + workspaceId, + }); + + if (!subscription || !subscription.stripeSubscriptionId) { + return; + } + const stripeSubscription = await this.stripe.subscriptions.retrieve( + subscription.stripeSubscriptionId + ); + + const lookupKey = + retriveLookupKeyFromStripeSubscription(stripeSubscription); + + await this.stripe.subscriptions.update(stripeSubscription.id, { + items: [ + { + id: stripeSubscription.items.data[0].id, + quantity: count, + }, + ], + payment_behavior: 'pending_if_incomplete', + proration_behavior: + lookupKey?.recurring === SubscriptionRecurring.Yearly + ? 'always_invoice' + : 'none', + }); + + if (subscription.stripeScheduleId) { + const schedule = await this.scheduleManager.fromSchedule( + subscription.stripeScheduleId + ); + await schedule.updateQuantity(count); + } + } +} diff --git a/packages/backend/server/src/plugins/payment/resolver.ts b/packages/backend/server/src/plugins/payment/resolver.ts index 6a2cabecef7d0..d71399fba3de6 100644 --- a/packages/backend/server/src/plugins/payment/resolver.ts +++ b/packages/backend/server/src/plugins/payment/resolver.ts @@ -12,15 +12,23 @@ import { ResolveField, Resolver, } from '@nestjs/graphql'; -import type { User, UserSubscription } from '@prisma/client'; +import type { User } from '@prisma/client'; import { PrismaClient } from '@prisma/client'; +import { GraphQLJSONObject } from 'graphql-scalars'; import { groupBy } from 'lodash-es'; +import { z } from 'zod'; import { CurrentUser, Public } from '../../core/auth'; +import { Permission, PermissionService } from '../../core/permission'; import { UserType } from '../../core/user'; -import { AccessDenied, FailedToCheckout, URLHelper } from '../../fundamentals'; -import { Invoice, Subscription } from './manager'; -import { SubscriptionService } from './service'; +import { WorkspaceType } from '../../core/workspaces'; +import { + AccessDenied, + FailedToCheckout, + WorkspaceIdRequiredToUpdateTeamSubscription, +} from '../../fundamentals'; +import { Invoice, Subscription, WorkspaceSubscriptionManager } from './manager'; +import { CheckoutParams, SubscriptionService } from './service'; import { InvoiceStatus, SubscriptionPlan, @@ -57,7 +65,7 @@ class SubscriptionPrice { } @ObjectType() -export class SubscriptionType implements Subscription { +export class SubscriptionType implements Partial { @Field(() => SubscriptionPlan, { description: "The 'Free' plan just exists to be a placeholder and for the type convenience of frontend.\nThere won't actually be a subscription with plan 'Free'", @@ -107,7 +115,7 @@ export class SubscriptionType implements Subscription { } @ObjectType() -export class InvoiceType implements Invoice { +export class InvoiceType implements Partial { @Field() currency!: string; @@ -138,7 +146,7 @@ export class InvoiceType implements Invoice { nullable: true, deprecationReason: 'removed', }) - stripeInvoiceId!: string | null; + stripeInvoiceId?: string; @Field(() => SubscriptionPlan, { nullable: true, @@ -154,7 +162,7 @@ export class InvoiceType implements Invoice { } @InputType() -class CreateCheckoutSessionInput { +class CreateCheckoutSessionInput implements z.infer { @Field(() => SubscriptionRecurring, { nullable: true, defaultValue: SubscriptionRecurring.Yearly, @@ -170,7 +178,7 @@ class CreateCheckoutSessionInput { @Field(() => SubscriptionVariant, { nullable: true, }) - variant?: SubscriptionVariant; + variant!: SubscriptionVariant | null; @Field(() => String, { nullable: true }) coupon!: string | null; @@ -180,17 +188,17 @@ class CreateCheckoutSessionInput { @Field(() => String, { nullable: true, - deprecationReason: 'use header `Idempotency-Key`', + deprecationReason: 'not required anymore', }) idempotencyKey?: string; + + @Field(() => GraphQLJSONObject, { nullable: true }) + args!: { workspaceId?: string }; } @Resolver(() => SubscriptionType) export class SubscriptionResolver { - constructor( - private readonly service: SubscriptionService, - private readonly url: URLHelper - ) {} + constructor(private readonly service: SubscriptionService) {} @Public() @Query(() => [SubscriptionPrice]) @@ -232,7 +240,11 @@ export class SubscriptionResolver { } // extend it when new plans are added - const fixedPlans = [SubscriptionPlan.Pro, SubscriptionPlan.AI]; + const fixedPlans = [ + SubscriptionPlan.Pro, + SubscriptionPlan.AI, + SubscriptionPlan.Team, + ]; return fixedPlans.reduce((prices, plan) => { const price = findPrice(plan); @@ -255,26 +267,19 @@ export class SubscriptionResolver { async createCheckoutSession( @CurrentUser() user: CurrentUser, @Args({ name: 'input', type: () => CreateCheckoutSessionInput }) - input: CreateCheckoutSessionInput, - @Headers('idempotency-key') idempotencyKey?: string + input: CreateCheckoutSessionInput ) { - const session = await this.service.checkout({ + const session = await this.service.checkout(input, { + plan: input.plan as any, user, - lookupKey: { - plan: input.plan, - recurring: input.recurring, - variant: input.variant, - }, - promotionCode: input.coupon, - redirectUrl: this.url.link(input.successCallbackLink), - idempotencyKey, + workspaceId: input.args?.workspaceId, }); if (!session.url) { throw new FailedToCheckout(); } - return session.url; + return session; } @Mutation(() => String, { @@ -294,6 +299,8 @@ export class SubscriptionResolver { defaultValue: SubscriptionPlan.Pro, }) plan: SubscriptionPlan, + @Args({ name: 'workspaceId', type: () => String, nullable: true }) + workspaceId: string | null, @Headers('idempotency-key') idempotencyKey?: string, @Args('idempotencyKey', { type: () => String, @@ -302,7 +309,25 @@ export class SubscriptionResolver { }) _?: string ) { - return this.service.cancelSubscription(user.id, plan, idempotencyKey); + if (plan === SubscriptionPlan.Team) { + if (!workspaceId) { + throw new WorkspaceIdRequiredToUpdateTeamSubscription(); + } + + return this.service.cancelSubscription( + { workspaceId, plan }, + idempotencyKey + ); + } + + return this.service.cancelSubscription( + { + targetId: user.id, + // @ts-expect-error exam inside + plan, + }, + idempotencyKey + ); } @Mutation(() => SubscriptionType) @@ -315,6 +340,8 @@ export class SubscriptionResolver { defaultValue: SubscriptionPlan.Pro, }) plan: SubscriptionPlan, + @Args({ name: 'workspaceId', type: () => String, nullable: true }) + workspaceId: string | null, @Headers('idempotency-key') idempotencyKey?: string, @Args('idempotencyKey', { type: () => String, @@ -323,14 +350,30 @@ export class SubscriptionResolver { }) _?: string ) { - return this.service.resumeSubscription(user.id, plan, idempotencyKey); + if (plan === SubscriptionPlan.Team) { + if (!workspaceId) { + throw new WorkspaceIdRequiredToUpdateTeamSubscription(); + } + + return this.service.resumeSubscription( + { workspaceId, plan }, + idempotencyKey + ); + } + + return this.service.resumeSubscription( + { + targetId: user.id, + // @ts-expect-error exam inside + plan, + }, + idempotencyKey + ); } @Mutation(() => SubscriptionType) async updateSubscriptionRecurring( @CurrentUser() user: CurrentUser, - @Args({ name: 'recurring', type: () => SubscriptionRecurring }) - recurring: SubscriptionRecurring, @Args({ name: 'plan', type: () => SubscriptionPlan, @@ -338,6 +381,10 @@ export class SubscriptionResolver { defaultValue: SubscriptionPlan.Pro, }) plan: SubscriptionPlan, + @Args({ name: 'workspaceId', type: () => String, nullable: true }) + workspaceId: string | null, + @Args({ name: 'recurring', type: () => SubscriptionRecurring }) + recurring: SubscriptionRecurring, @Headers('idempotency-key') idempotencyKey?: string, @Args('idempotencyKey', { type: () => String, @@ -346,9 +393,24 @@ export class SubscriptionResolver { }) _?: string ) { + if (plan === SubscriptionPlan.Team) { + if (!workspaceId) { + throw new WorkspaceIdRequiredToUpdateTeamSubscription(); + } + + return this.service.updateSubscriptionRecurring( + { workspaceId, plan }, + recurring, + idempotencyKey + ); + } + return this.service.updateSubscriptionRecurring( - user.id, - plan, + { + userId: user.id, + // @ts-expect-error exam inside + plan, + }, recurring, idempotencyKey ); @@ -363,14 +425,14 @@ export class UserSubscriptionResolver { async subscriptions( @CurrentUser() me: User, @Parent() user: User - ): Promise { + ): Promise { if (me.id !== user.id) { throw new AccessDenied(); } - const subscriptions = await this.db.userSubscription.findMany({ + const subscriptions = await this.db.subscription.findMany({ where: { - userId: user.id, + targetId: user.id, status: SubscriptionStatus.Active, }, }); @@ -389,6 +451,16 @@ export class UserSubscriptionResolver { return subscriptions; } + @ResolveField(() => Int, { + name: 'invoiceCount', + description: 'Get user invoice count', + }) + async invoiceCount(@CurrentUser() user: CurrentUser) { + return this.db.invoice.count({ + where: { targetId: user.id }, + }); + } + @ResolveField(() => [InvoiceType]) async invoices( @CurrentUser() me: User, @@ -401,14 +473,72 @@ export class UserSubscriptionResolver { throw new AccessDenied(); } - return this.db.userInvoice.findMany({ + return this.db.invoice.findMany({ where: { - userId: user.id, + targetId: user.id, + }, + take, + skip, + orderBy: { + createdAt: 'desc', + }, + }); + } +} + +@Resolver(() => WorkspaceType) +export class WorkspaceSubscriptionResolver { + constructor( + private readonly service: WorkspaceSubscriptionManager, + private readonly db: PrismaClient, + private readonly permission: PermissionService + ) {} + + @ResolveField(() => SubscriptionType, { + nullable: true, + description: 'The team subscription of the workspace, if exists.', + }) + async subscription(@Parent() workspace: WorkspaceType) { + return this.service.getSubscription({ + plan: SubscriptionPlan.Team, + workspaceId: workspace.id, + }); + } + + @ResolveField(() => Int, { + name: 'invoiceCount', + description: 'Get user invoice count', + }) + async invoiceCount( + @CurrentUser() me: CurrentUser, + @Parent() workspace: WorkspaceType + ) { + await this.permission.checkWorkspace(workspace.id, me.id, Permission.Owner); + return this.db.invoice.count({ + where: { + targetId: workspace.id, + }, + }); + } + + @ResolveField(() => [InvoiceType]) + async invoices( + @CurrentUser() me: CurrentUser, + @Parent() workspace: WorkspaceType, + @Args('take', { type: () => Int, nullable: true, defaultValue: 8 }) + take: number, + @Args('skip', { type: () => Int, nullable: true }) skip?: number + ) { + await this.permission.checkWorkspace(workspace.id, me.id, Permission.Owner); + + return this.db.invoice.findMany({ + where: { + targetId: workspace.id, }, take, skip, orderBy: { - id: 'desc', + createdAt: 'desc', }, }); } diff --git a/packages/backend/server/src/plugins/payment/schedule.ts b/packages/backend/server/src/plugins/payment/schedule.ts index f0714cf1c5e10..b4c2abb35fb3d 100644 --- a/packages/backend/server/src/plugins/payment/schedule.ts +++ b/packages/backend/server/src/plugins/payment/schedule.ts @@ -101,7 +101,7 @@ export class ScheduleManager { items: [ { price: this.currentPhase.items[0].price as string, - quantity: 1, + quantity: this.currentPhase.items[0].quantity, }, ], coupon: (this.currentPhase.coupon as string | null) ?? undefined, @@ -143,10 +143,9 @@ export class ScheduleManager { items: [ { price: this.currentPhase.items[0].price as string, - quantity: 1, + quantity: this.currentPhase.items[0].quantity, }, ], - coupon: (this.currentPhase.coupon as string | null) ?? undefined, start_date: this.currentPhase.start_date, end_date: this.currentPhase.end_date, metadata: { @@ -161,7 +160,7 @@ export class ScheduleManager { items: [ { price: this.currentPhase.metadata.next_price, - quantity: 1, + quantity: this.currentPhase.items[0].quantity, }, ], coupon: this.currentPhase.metadata.next_coupon || undefined, @@ -212,6 +211,7 @@ export class ScheduleManager { items: [ { price: this.currentPhase.items[0].price as string, + quantity: this.currentPhase.items[0].quantity, }, ], start_date: this.currentPhase.start_date, @@ -221,6 +221,7 @@ export class ScheduleManager { items: [ { price: price, + quantity: this.currentPhase.items[0].quantity, }, ], }, @@ -230,4 +231,31 @@ export class ScheduleManager { ); } } + + async updateQuantity(quantity: number, idempotencyKey?: string) { + if (!this._schedule) { + throw new Error('No schedule'); + } + + if (!this.isActive || !this.currentPhase) { + throw new Error('Unexpected subscription schedule status'); + } + + await this.stripe.subscriptionSchedules.update( + this._schedule.id, + { + phases: this._schedule.phases.map(phase => ({ + items: [ + { + price: phase.items[0].price as string, + quantity, + }, + ], + start_date: phase.start_date, + end_date: phase.end_date, + })), + }, + { idempotencyKey } + ); + } } diff --git a/packages/backend/server/src/plugins/payment/service.ts b/packages/backend/server/src/plugins/payment/service.ts index 2de5f031d8e07..facdb19df1eb0 100644 --- a/packages/backend/server/src/plugins/payment/service.ts +++ b/packages/backend/server/src/plugins/payment/service.ts @@ -1,12 +1,8 @@ import { Injectable, Logger } from '@nestjs/common'; -import type { - User, - UserInvoice, - UserStripeCustomer, - UserSubscription, -} from '@prisma/client'; +import type { User, UserStripeCustomer } from '@prisma/client'; import { PrismaClient } from '@prisma/client'; import Stripe from 'stripe'; +import { z } from 'zod'; import { CurrentUser } from '../../core/auth'; import { FeatureManagementService } from '../../core/features'; @@ -17,30 +13,56 @@ import { Config, CustomerPortalCreateFailed, InternalServerError, + InvalidCheckoutParameters, + InvalidSubscriptionParameters, OnEvent, SameSubscriptionRecurring, - SubscriptionAlreadyExists, SubscriptionExpired, SubscriptionHasBeenCanceled, + SubscriptionHasNotBeenCanceled, SubscriptionNotExists, SubscriptionPlanNotFound, + UnsupportedSubscriptionPlan, UserNotFound, } from '../../fundamentals'; -import { UserSubscriptionManager } from './manager'; +import { + CheckoutParams, + Invoice, + Subscription, + SubscriptionManager, + UserSubscriptionCheckoutArgs, + UserSubscriptionIdentity, + UserSubscriptionManager, + WorkspaceSubscriptionCheckoutArgs, + WorkspaceSubscriptionIdentity, + WorkspaceSubscriptionManager, +} from './manager'; import { ScheduleManager } from './schedule'; import { encodeLookupKey, KnownStripeInvoice, KnownStripePrice, + KnownStripeSubscription, LookupKey, retriveLookupKeyFromStripePrice, retriveLookupKeyFromStripeSubscription, SubscriptionPlan, SubscriptionRecurring, SubscriptionStatus, - SubscriptionVariant, } from './types'; +export const CheckoutExtraArgs = z.union([ + UserSubscriptionCheckoutArgs, + WorkspaceSubscriptionCheckoutArgs, +]); + +export const SubscriptionIdentity = z.union([ + UserSubscriptionIdentity, + WorkspaceSubscriptionIdentity, +]); + +export { CheckoutParams }; + @Injectable() export class SubscriptionService { private readonly logger = new Logger(SubscriptionService.name); @@ -52,143 +74,86 @@ export class SubscriptionService { private readonly db: PrismaClient, private readonly feature: FeatureManagementService, private readonly user: UserService, - private readonly userManager: UserSubscriptionManager + private readonly userManager: UserSubscriptionManager, + private readonly workspaceManager: WorkspaceSubscriptionManager ) {} - async listPrices(user?: CurrentUser): Promise { - const customer = user ? await this.getOrCreateCustomer(user) : undefined; + private select(plan: SubscriptionPlan): SubscriptionManager { + switch (plan) { + case SubscriptionPlan.Team: + return this.workspaceManager; + case SubscriptionPlan.Pro: + case SubscriptionPlan.AI: + return this.userManager; + default: + throw new UnsupportedSubscriptionPlan({ plan }); + } + } - // TODO(@forehalo): cache - const prices = await this.stripe.prices.list({ - active: true, - limit: 100, - }); + async listPrices(user?: CurrentUser): Promise { + const prices = await this.listStripePrices(); - return this.userManager.filterPrices( - prices.data - .map(price => this.parseStripePrice(price)) - .filter(Boolean) as KnownStripePrice[], - customer - ); + const customer = user + ? await this.getOrCreateCustomer({ + userId: user.id, + userEmail: user.email, + }) + : undefined; + + return [ + ...(await this.userManager.filterPrices(prices, customer)), + ...this.workspaceManager.filterPrices(prices, customer), + ]; } - async checkout({ - user, - lookupKey, - promotionCode, - redirectUrl, - idempotencyKey, - }: { - user: CurrentUser; - lookupKey: LookupKey; - promotionCode?: string | null; - redirectUrl: string; - idempotencyKey?: string; - }) { + async checkout( + params: z.infer, + args: z.infer + ) { + const { plan, recurring, variant } = params; + if ( this.config.deploy && this.config.affine.canary && - !this.feature.isStaff(user.email) + !this.feature.isStaff(args.user.email) ) { throw new ActionForbidden(); } - const currentSubscription = await this.userManager.getSubscription( - user.id, - lookupKey.plan - ); - - if ( - currentSubscription && - // do not allow to re-subscribe unless - !( - /* current subscription is a onetime subscription and so as the one that's checking out */ - ( - (currentSubscription.variant === SubscriptionVariant.Onetime && - lookupKey.variant === SubscriptionVariant.Onetime) || - /* current subscription is normal subscription and is checking-out a lifetime subscription */ - (currentSubscription.recurring !== SubscriptionRecurring.Lifetime && - currentSubscription.variant !== SubscriptionVariant.Onetime && - lookupKey.recurring === SubscriptionRecurring.Lifetime) - ) - ) - ) { - throw new SubscriptionAlreadyExists({ plan: lookupKey.plan }); - } - - const price = await this.getPrice(lookupKey); - const customer = await this.getOrCreateCustomer(user); - - const priceAndAutoCoupon = price - ? await this.userManager.validatePrice(price, customer) - : null; + const price = await this.getPrice({ + plan, + recurring, + variant: variant ?? null, + }); - if (!priceAndAutoCoupon) { + if (!price) { throw new SubscriptionPlanNotFound({ - plan: lookupKey.plan, - recurring: lookupKey.recurring, + plan, + recurring, }); } - let discounts: Stripe.Checkout.SessionCreateParams['discounts'] = []; + const manager = this.select(plan); + const result = CheckoutExtraArgs.safeParse(args); - if (priceAndAutoCoupon.coupon) { - discounts = [{ coupon: priceAndAutoCoupon.coupon }]; - } else if (promotionCode) { - const coupon = await this.getCouponFromPromotionCode( - promotionCode, - customer - ); - if (coupon) { - discounts = [{ coupon }]; - } + if (!result.success) { + throw new InvalidCheckoutParameters(); } - return await this.stripe.checkout.sessions.create( - { - line_items: [ - { - price: priceAndAutoCoupon.price.price.id, - quantity: 1, - }, - ], - tax_id_collection: { - enabled: true, - }, - // discount - ...(discounts.length ? { discounts } : { allow_promotion_codes: true }), - // mode: 'subscription' or 'payment' for lifetime and onetime payment - ...(lookupKey.recurring === SubscriptionRecurring.Lifetime || - lookupKey.variant === SubscriptionVariant.Onetime - ? { - mode: 'payment', - invoice_creation: { - enabled: true, - }, - } - : { - mode: 'subscription', - }), - success_url: redirectUrl, - customer: customer.stripeCustomerId, - customer_update: { - address: 'auto', - name: 'auto', - }, - }, - { idempotencyKey } - ); + return manager.checkout(price, params, args); } async cancelSubscription( - userId: string, - plan: SubscriptionPlan, + identity: z.infer, idempotencyKey?: string - ): Promise { - const subscription = await this.userManager.getSubscription(userId, plan); + ): Promise { + this.assertSubscriptionIdentity(identity); + + const manager = this.select(identity.plan); + const subscription = await manager.getSubscription(identity); if (!subscription) { - throw new SubscriptionNotExists({ plan }); + throw new SubscriptionNotExists({ plan: identity.plan }); } if (!subscription.stripeSubscriptionId) { @@ -202,7 +167,7 @@ export class SubscriptionService { } // update the subscription in db optimistically - const newSubscription = this.userManager.cancelSubscription(subscription); + const newSubscription = manager.cancelSubscription(subscription); // should release the schedule first if (subscription.stripeScheduleId) { @@ -224,18 +189,21 @@ export class SubscriptionService { } async resumeSubscription( - userId: string, - plan: SubscriptionPlan, + identity: z.infer, idempotencyKey?: string - ): Promise { - const subscription = await this.userManager.getSubscription(userId, plan); + ): Promise { + this.assertSubscriptionIdentity(identity); + + const manager = this.select(identity.plan); + + const subscription = await manager.getSubscription(identity); if (!subscription) { - throw new SubscriptionNotExists({ plan }); + throw new SubscriptionNotExists({ plan: identity.plan }); } if (!subscription.canceledAt) { - throw new SubscriptionHasBeenCanceled(); + throw new SubscriptionHasNotBeenCanceled(); } if (!subscription.stripeSubscriptionId || !subscription.end) { @@ -249,8 +217,7 @@ export class SubscriptionService { } // update the subscription in db optimistically - const newSubscription = - await this.userManager.resumeSubscription(subscription); + const newSubscription = await manager.resumeSubscription(subscription); if (subscription.stripeScheduleId) { const manager = await this.scheduleManager.fromSchedule( @@ -269,15 +236,17 @@ export class SubscriptionService { } async updateSubscriptionRecurring( - userId: string, - plan: SubscriptionPlan, + identity: z.infer, recurring: SubscriptionRecurring, idempotencyKey?: string - ): Promise { - const subscription = await this.userManager.getSubscription(userId, plan); + ): Promise { + this.assertSubscriptionIdentity(identity); + + const manager = this.select(identity.plan); + const subscription = await manager.getSubscription(identity); if (!subscription) { - throw new SubscriptionNotExists({ plan }); + throw new SubscriptionNotExists({ plan: identity.plan }); } if (!subscription.stripeSubscriptionId) { @@ -293,25 +262,29 @@ export class SubscriptionService { } const price = await this.getPrice({ - plan, + plan: identity.plan, recurring, + variant: null, }); if (!price) { - throw new SubscriptionPlanNotFound({ plan, recurring }); + throw new SubscriptionPlanNotFound({ + plan: identity.plan, + recurring, + }); } // update the subscription in db optimistically - const newSubscription = this.userManager.updateSubscriptionRecurring( + const newSubscription = manager.updateSubscriptionRecurring( subscription, recurring ); - const manager = await this.scheduleManager.fromSubscription( + const scheduleManager = await this.scheduleManager.fromSubscription( subscription.stripeSubscriptionId ); - await manager.update(price.price.id, idempotencyKey); + await scheduleManager.update(price.price.id, idempotencyKey); return newSubscription; } @@ -339,14 +312,14 @@ export class SubscriptionService { } } - async saveStripeInvoice(stripeInvoice: Stripe.Invoice): Promise { + async saveStripeInvoice(stripeInvoice: Stripe.Invoice): Promise { const knownInvoice = await this.parseStripeInvoice(stripeInvoice); if (!knownInvoice) { throw new InternalServerError('Failed to parse stripe invoice.'); } - return this.userManager.saveInvoice(knownInvoice); + return this.select(knownInvoice.lookupKey.plan).saveInvoice(knownInvoice); } async saveStripeSubscription(subscription: Stripe.Subscription) { @@ -360,10 +333,12 @@ export class SubscriptionService { subscription.status === SubscriptionStatus.Active || subscription.status === SubscriptionStatus.Trialing; + const manager = this.select(knownSubscription.lookupKey.plan); + if (!isPlanActive) { - await this.userManager.deleteSubscription(knownSubscription); + await manager.deleteStripeSubscription(knownSubscription); } else { - await this.userManager.saveSubscription(knownSubscription); + await manager.saveStripeSubscription(knownSubscription); } } @@ -374,19 +349,26 @@ export class SubscriptionService { throw new InternalServerError('Failed to parse stripe subscription.'); } - await this.userManager.deleteSubscription(knownSubscription); + const manager = this.select(knownSubscription.lookupKey.plan); + await manager.deleteStripeSubscription(knownSubscription); } - async getOrCreateCustomer(user: CurrentUser): Promise { + async getOrCreateCustomer({ + userId, + userEmail, + }: { + userId: string; + userEmail: string; + }): Promise { let customer = await this.db.userStripeCustomer.findUnique({ where: { - userId: user.id, + userId, }, }); if (!customer) { const stripeCustomersList = await this.stripe.customers.list({ - email: user.email, + email: userEmail, limit: 1, }); @@ -395,13 +377,13 @@ export class SubscriptionService { stripeCustomer = stripeCustomersList.data[0]; } else { stripeCustomer = await this.stripe.customers.create({ - email: user.email, + email: userEmail, }); } customer = await this.db.userStripeCustomer.create({ data: { - userId: user.id, + userId, stripeCustomerId: stripeCustomer.id, }, }); @@ -467,6 +449,17 @@ export class SubscriptionService { return user.id; } + private async listStripePrices(): Promise { + const prices = await this.stripe.prices.list({ + active: true, + limit: 100, + }); + + return prices.data + .map(price => this.parseStripePrice(price)) + .filter(Boolean) as KnownStripePrice[]; + } + private async getPrice( lookupKey: LookupKey ): Promise { @@ -485,35 +478,6 @@ export class SubscriptionService { : null; } - private async getCouponFromPromotionCode( - userFacingPromotionCode: string, - customer: UserStripeCustomer - ) { - const list = await this.stripe.promotionCodes.list({ - code: userFacingPromotionCode, - active: true, - limit: 1, - }); - - const code = list.data[0]; - if (!code) { - return null; - } - - // the coupons are always bound to products, we need to check it first - // but the logic would be too complicated, and stripe will complain if the code is not applicable when checking out - // It's safe to skip the check here - // code.coupon.applies_to.products.forEach() - - // check if the code is bound to a specific customer - return !code.customer || - (typeof code.customer === 'string' - ? code.customer === customer.stripeCustomerId - : code.customer.id === customer.stripeCustomerId) - ? code.coupon.id - : null; - } - private async parseStripeInvoice( invoice: Stripe.Invoice ): Promise { @@ -549,10 +513,13 @@ export class SubscriptionService { userId: user.id, stripeInvoice: invoice, lookupKey, + metadata: invoice.subscription_details?.metadata ?? {}, }; } - private async parseStripeSubscription(subscription: Stripe.Subscription) { + private async parseStripeSubscription( + subscription: Stripe.Subscription + ): Promise { const lookupKey = retriveLookupKeyFromStripeSubscription(subscription); if (!lookupKey) { @@ -569,6 +536,8 @@ export class SubscriptionService { userId, lookupKey, stripeSubscription: subscription, + quantity: subscription.items.data[0]?.quantity ?? 1, + metadata: subscription.metadata, }; } @@ -582,4 +551,14 @@ export class SubscriptionService { } : null; } + + private assertSubscriptionIdentity( + args: z.infer + ) { + const result = SubscriptionIdentity.safeParse(args); + + if (!result.success) { + throw new InvalidSubscriptionParameters(); + } + } } diff --git a/packages/backend/server/src/plugins/payment/types.ts b/packages/backend/server/src/plugins/payment/types.ts index 196c57d98e5c0..efb09089b9c66 100644 --- a/packages/backend/server/src/plugins/payment/types.ts +++ b/packages/backend/server/src/plugins/payment/types.ts @@ -1,4 +1,4 @@ -import type { User } from '@prisma/client'; +import type { User, Workspace } from '@prisma/client'; import Stripe from 'stripe'; import type { Payload } from '../../fundamentals/event/def'; @@ -64,12 +64,31 @@ declare module '../../fundamentals/event/def' { }>; }; } + + interface WorkspaceEvents { + subscription: { + activated: Payload<{ + workspaceId: Workspace['id']; + plan: SubscriptionPlan; + recurring: SubscriptionRecurring; + quantity: number; + }>; + canceled: Payload<{ + workspaceId: Workspace['id']; + plan: SubscriptionPlan; + recurring: SubscriptionRecurring; + }>; + }; + members: { + updated: Payload<{ workspaceId: Workspace['id']; count: number }>; + }; + } } export interface LookupKey { plan: SubscriptionPlan; recurring: SubscriptionRecurring; - variant?: SubscriptionVariant; + variant: SubscriptionVariant | null; } export interface KnownStripeInvoice { @@ -87,6 +106,11 @@ export interface KnownStripeInvoice { * The invoice object from Stripe. */ stripeInvoice: Stripe.Invoice; + + /** + * The metadata of the subscription related to the invoice. + */ + metadata: Record; } export interface KnownStripeSubscription { @@ -104,6 +128,16 @@ export interface KnownStripeSubscription { * The subscription object from Stripe. */ stripeSubscription: Stripe.Subscription; + + /** + * The quantity of the subscription items. + */ + quantity: number; + + /** + * The metadata of the subscription. + */ + metadata: Record; } export interface KnownStripePrice { @@ -167,7 +201,7 @@ export function decodeLookupKey(key: string): LookupKey | null { return { plan: plan as SubscriptionPlan, recurring: recurring as SubscriptionRecurring, - variant: variant as SubscriptionVariant | undefined, + variant: variant as SubscriptionVariant, }; } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 505ac394a1aeb..a0b0d9a980415 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -140,6 +140,7 @@ input CreateChatSessionInput { } input CreateCheckoutSessionInput { + args: JSONObject coupon: String idempotencyKey: String plan: SubscriptionPlan = Pro @@ -208,7 +209,7 @@ type EditorType { name: String! } -union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInSpaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | VersionRejectedDataType +union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInSpaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedSubscriptionPlanDataType | VersionRejectedDataType enum ErrorNames { ACCESS_DENIED @@ -246,12 +247,14 @@ enum ErrorNames { FAILED_TO_SAVE_UPDATES FAILED_TO_UPSERT_SNAPSHOT INTERNAL_SERVER_ERROR + INVALID_CHECKOUT_PARAMETERS INVALID_EMAIL INVALID_EMAIL_TOKEN INVALID_HISTORY_TIMESTAMP INVALID_OAUTH_CALLBACK_STATE INVALID_PASSWORD_LENGTH INVALID_RUNTIME_CONFIG_TYPE + INVALID_SUBSCRIPTION_PARAMETERS LINK_EXPIRED MAILER_SERVICE_IS_NOT_CONFIGURED MEMBER_QUOTA_EXCEEDED @@ -273,14 +276,18 @@ enum ErrorNames { SUBSCRIPTION_ALREADY_EXISTS SUBSCRIPTION_EXPIRED SUBSCRIPTION_HAS_BEEN_CANCELED + SUBSCRIPTION_HAS_NOT_BEEN_CANCELED SUBSCRIPTION_NOT_EXISTS SUBSCRIPTION_PLAN_NOT_FOUND TOO_MANY_REQUEST UNKNOWN_OAUTH_PROVIDER UNSPLASH_IS_NOT_CONFIGURED + UNSUPPORTED_SUBSCRIPTION_PLAN USER_AVATAR_NOT_FOUND USER_NOT_FOUND VERSION_REJECTED + WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION + WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION WRONG_SIGN_IN_CREDENTIALS WRONG_SIGN_IN_METHOD } @@ -315,10 +322,6 @@ type HumanReadableQuotaType { storageQuota: String! } -type InvalidEmailDataType { - email: String! -} - type InvalidHistoryTimestampDataType { timestamp: String! } @@ -444,7 +447,7 @@ type MissingOauthQueryParameterDataType { type Mutation { acceptInviteById(inviteId: String!, sendAcceptMail: Boolean, workspaceId: String!): Boolean! addWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Int! - cancelSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro): SubscriptionType! + cancelSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType! changeEmail(email: String!, token: String!): UserType! changePassword(newPassword: String!, token: String!, userId: String): Boolean! @@ -491,7 +494,7 @@ type Mutation { """Remove user avatar""" removeAvatar: RemoveAvatar! removeWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Int! - resumeSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro): SubscriptionType! + resumeSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType! revoke(userId: String!, workspaceId: String!): Boolean! revokePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "use revokePublicPage") revokePublicPage(pageId: String!, workspaceId: String!): WorkspacePage! @@ -513,7 +516,7 @@ type Mutation { """update multiple server runtime configurable settings""" updateRuntimeConfigs(updates: JSONObject!): [ServerRuntimeConfigType!]! - updateSubscriptionRecurring(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, recurring: SubscriptionRecurring!): SubscriptionType! + updateSubscriptionRecurring(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, recurring: SubscriptionRecurring!, workspaceId: String): SubscriptionType! """Update a user""" updateUser(id: String!, input: ManageUserInput!): UserType! @@ -814,6 +817,10 @@ type UnknownOauthProviderDataType { name: String! } +type UnsupportedSubscriptionPlanDataType { + plan: String! +} + input UpdateUserInput { """User name""" name: String @@ -929,6 +936,10 @@ type WorkspaceType { """is current workspace initialized""" initialized: Boolean! + """Get user invoice count""" + invoiceCount: Int! + invoices(skip: Int, take: Int = 8): [InvoiceType!]! + """member count of workspace""" memberCount: Int! @@ -958,6 +969,9 @@ type WorkspaceType { """Shared pages of workspace""" sharedPages: [String!]! @deprecated(reason: "use WorkspaceType.publicPages") + + """The team subscription of the workspace, if exists.""" + subscription: SubscriptionType } type tokenType { diff --git a/packages/backend/server/tests/payment/service.spec.ts b/packages/backend/server/tests/payment/service.spec.ts index 2692ccc8a563a..2d8eee2055545 100644 --- a/packages/backend/server/tests/payment/service.spec.ts +++ b/packages/backend/server/tests/payment/service.spec.ts @@ -14,7 +14,7 @@ import { FeatureManagementService, } from '../../src/core/features'; import { EventEmitter } from '../../src/fundamentals'; -import { Config, ConfigModule } from '../../src/fundamentals/config'; +import { Config, ConfigModule, Runtime } from '../../src/fundamentals/config'; import { SubscriptionService } from '../../src/plugins/payment/service'; import { CouponType, @@ -32,6 +32,8 @@ const PRO_LIFETIME = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Lifetime}` const PRO_EA_YEARLY = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Yearly}_${SubscriptionVariant.EA}`; const AI_YEARLY = `${SubscriptionPlan.AI}_${SubscriptionRecurring.Yearly}`; const AI_YEARLY_EA = `${SubscriptionPlan.AI}_${SubscriptionRecurring.Yearly}_${SubscriptionVariant.EA}`; +const TEAM_MONTHLY = `${SubscriptionPlan.Team}_${SubscriptionRecurring.Monthly}`; +const TEAM_YEARLY = `${SubscriptionPlan.Team}_${SubscriptionRecurring.Yearly}`; // prices for code redeeming const PRO_MONTHLY_CODE = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Monthly}_${SubscriptionVariant.Onetime}`; const PRO_YEARLY_CODE = `${SubscriptionPlan.Pro}_${SubscriptionRecurring.Yearly}_${SubscriptionVariant.Onetime}`; @@ -107,6 +109,18 @@ const PRICES = { id: AI_YEARLY_CODE, lookup_key: AI_YEARLY_CODE, }, + [TEAM_MONTHLY]: { + unit_amount: 1500, + currency: 'usd', + id: TEAM_MONTHLY, + lookup_key: TEAM_MONTHLY, + }, + [TEAM_YEARLY]: { + unit_amount: 14400, + currency: 'usd', + id: TEAM_YEARLY, + lookup_key: TEAM_YEARLY, + }, } as any as Record; const sub: Stripe.Subscription = { @@ -139,6 +153,7 @@ const sub: Stripe.Subscription = { trial_end: null, trial_start: null, schedule: null, + metadata: {}, }; const test = ava as TestFn<{ @@ -148,6 +163,7 @@ const test = ava as TestFn<{ service: SubscriptionService; event: Sinon.SinonStubbedInstance; feature: Sinon.SinonStubbedInstance; + runtime: Sinon.SinonStubbedInstance; stripe: { customers: Sinon.SinonStubbedInstance; prices: Sinon.SinonStubbedInstance; @@ -156,6 +172,7 @@ const test = ava as TestFn<{ checkout: { sessions: Sinon.SinonStubbedInstance; }; + promotionCodes: Sinon.SinonStubbedInstance; }; }>; @@ -192,12 +209,14 @@ test.before(async t => { m.overrideProvider(EventEmitter).useValue( Sinon.createStubInstance(EventEmitter) ); + m.overrideProvider(Runtime).useValue(Sinon.createStubInstance(Runtime)); }, }); t.context.event = app.get(EventEmitter); t.context.service = app.get(SubscriptionService); t.context.feature = app.get(FeatureManagementService); + t.context.runtime = app.get(Runtime); t.context.db = app.get(PrismaClient); t.context.app = app; @@ -210,6 +229,7 @@ test.before(async t => { checkout: { sessions: Sinon.stub(stripe.checkout.sessions), }, + promotionCodes: Sinon.stub(stripe.promotionCodes), }; t.context.stripe = stripeStubs; @@ -219,10 +239,17 @@ test.beforeEach(async t => { const { db, app, stripe } = t.context; Sinon.reset(); await initTestingDB(db); - // TODO(@forehalo): workaround for runtime module, need to init all runtime configs in [initTestingDB] - await app.get(Config).runtime.onModuleInit(); + t.context.runtime.fetch + .withArgs('plugins.payment/showLifetimePrice') + .resolves(true); t.context.u1 = await app.get(AuthService).signUp('u1@affine.pro', '1'); + await db.workspace.create({ + data: { + id: 'ws_1', + public: false, + }, + }); await db.userStripeCustomer.create({ data: { userId: t.context.u1.id, @@ -270,6 +297,15 @@ test('should list normal prices for authenticated user', async t => { t.snapshot(prices.map(p => encodeLookupKey(p.lookupKey))); }); +test('should not show lifetime price if not enabled', async t => { + const { service, runtime } = t.context; + runtime.fetch.withArgs('plugins.payment/showLifetimePrice').resolves(false); + + const prices = await service.listPrices(t.context.u1); + + t.snapshot(prices.map(p => encodeLookupKey(p.lookupKey))); +}); + test('should list early access prices for pro ea user', async t => { const { feature, service, u1 } = t.context; @@ -368,9 +404,9 @@ test('should list normal prices for ai ea user with old subscriptions', async t test('should throw if user has subscription already', async t => { const { service, u1, db } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, @@ -382,15 +418,14 @@ test('should throw if user has subscription already', async t => { await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { user: u1 } + ), { message: 'You have already subscribed to the pro plan.' } ); }); @@ -401,15 +436,14 @@ test('should get correct pro plan price for checking out', async t => { { feature.isEarlyAccessUser.resolves(false); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_MONTHLY, @@ -420,15 +454,14 @@ test('should get correct pro plan price for checking out', async t => { // ea user, but monthly { feature.isEarlyAccessUser.resolves(true); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_MONTHLY, @@ -439,16 +472,15 @@ test('should get correct pro plan price for checking out', async t => { // ea user, yearly { feature.isEarlyAccessUser.resolves(true); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Yearly, variant: SubscriptionVariant.EA, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_EA_YEARLY, @@ -478,15 +510,14 @@ test('should get correct pro plan price for checking out', async t => { ], }); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Yearly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_YEARLY, @@ -495,16 +526,15 @@ test('should get correct pro plan price for checking out', async t => { await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Yearly, variant: SubscriptionVariant.EA, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { user: u1 } + ), { message: 'You are trying to access a unknown subscription plan.', } @@ -517,15 +547,14 @@ test('should get correct pro plan price for checking out', async t => { const config = app.get(Config); await config.runtime.set('plugins.payment/showLifetimePrice', true); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_LIFETIME, @@ -541,15 +570,14 @@ test('should get correct ai plan price for checking out', async t => { { feature.isEarlyAccessUser.resolves(false); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.AI, recurring: SubscriptionRecurring.Yearly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: AI_YEARLY, @@ -561,16 +589,15 @@ test('should get correct ai plan price for checking out', async t => { { feature.isEarlyAccessUser.resolves(true); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.AI, recurring: SubscriptionRecurring.Yearly, variant: SubscriptionVariant.EA, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: AI_YEARLY_EA, @@ -603,15 +630,14 @@ test('should get correct ai plan price for checking out', async t => { ], }); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.AI, recurring: SubscriptionRecurring.Yearly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: AI_YEARLY, @@ -620,16 +646,15 @@ test('should get correct ai plan price for checking out', async t => { await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.AI, recurring: SubscriptionRecurring.Yearly, variant: SubscriptionVariant.EA, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { user: u1 } + ), { message: 'You are trying to access a unknown subscription plan.', } @@ -645,15 +670,14 @@ test('should get correct ai plan price for checking out', async t => { // @ts-expect-error stub stripe.subscriptions.list.resolves({ data: [] }); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.AI, recurring: SubscriptionRecurring.Yearly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: AI_YEARLY, @@ -686,15 +710,14 @@ test('should get correct ai plan price for checking out', async t => { ], }); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.AI, recurring: SubscriptionRecurring.Yearly, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: AI_YEARLY, @@ -706,19 +729,26 @@ test('should get correct ai plan price for checking out', async t => { test('should apply user coupon for checking out', async t => { const { service, u1, stripe } = t.context; - // @ts-expect-error private member - Sinon.stub(service, 'getCouponFromPromotionCode').resolves('coupon_1'); + stripe.promotionCodes.list.resolves({ + data: [ + { + // @ts-expect-error mock + coupon: { + id: 'coupon_1', + }, + }, + ], + }); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, + successCallbackLink: '', + coupon: 'test', }, - redirectUrl: '', - idempotencyKey: '', - promotionCode: 'test', - }); + { user: u1 } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_MONTHLY, @@ -732,8 +762,8 @@ test('should be able to create subscription', async t => { await service.saveStripeSubscription(sub); - const subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + const subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.true( @@ -764,8 +794,8 @@ test('should be able to update subscription', async t => { }) ); - const subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + const subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.is(subInDB?.status, SubscriptionStatus.Active); @@ -789,8 +819,8 @@ test('should be able to delete subscription', async t => { }) ); - const subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + const subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.is(subInDB, null); @@ -799,9 +829,9 @@ test('should be able to delete subscription', async t => { test('should be able to cancel subscription', async t => { const { service, db, u1, stripe } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Yearly, @@ -817,7 +847,10 @@ test('should be able to cancel subscription', async t => { canceled_at: 1714118236, } as any); - const subInDB = await service.cancelSubscription(u1.id, SubscriptionPlan.Pro); + const subInDB = await service.cancelSubscription({ + userId: u1.id, + plan: SubscriptionPlan.Pro, + }); t.true( stripe.subscriptions.update.calledOnceWith('sub_1', { @@ -831,9 +864,9 @@ test('should be able to cancel subscription', async t => { test('should be able to resume subscription', async t => { const { service, db, u1, stripe } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Yearly, @@ -846,7 +879,10 @@ test('should be able to resume subscription', async t => { stripe.subscriptions.update.resolves(sub as any); - const subInDB = await service.resumeSubscription(u1.id, SubscriptionPlan.Pro); + const subInDB = await service.resumeSubscription({ + userId: u1.id, + plan: SubscriptionPlan.Pro, + }); t.true( stripe.subscriptions.update.calledOnceWith('sub_1', { @@ -868,10 +904,21 @@ const subscriptionSchedule: Stripe.SubscriptionSchedule = { // @ts-expect-error mock { price: PRO_MONTHLY, + quantity: 1, + }, + ], + start_date: Math.floor(Date.now() / 1000), + end_date: Math.floor((Date.now() + 30 * 24 * 60 * 60 * 1000) / 1000), + }, + { + items: [ + // @ts-expect-error mock + { + price: PRO_YEARLY, + quantity: 1, }, ], - start_date: 1714118236, - end_date: 1745654236, + start_date: Math.floor((Date.now() + 30 * 24 * 60 * 60 * 1000) / 1000), }, ], }; @@ -879,9 +926,9 @@ const subscriptionSchedule: Stripe.SubscriptionSchedule = { test('should be able to update recurring', async t => { const { service, db, u1, stripe } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, @@ -892,51 +939,54 @@ test('should be able to update recurring', async t => { }); // 1. turn a subscription into a subscription schedule - // 2. update the schedule - // 2.1 update the current phase with an end date - // 2.2 add a new phase with a start date - + // 2. update the current phase with an end date stripe.subscriptions.retrieve.resolves(sub as any); stripe.subscriptionSchedules.create.resolves(subscriptionSchedule as any); stripe.subscriptionSchedules.update.resolves(subscriptionSchedule as any); await service.updateSubscriptionRecurring( - u1.id, - SubscriptionPlan.Pro, + { + userId: u1.id, + plan: SubscriptionPlan.Pro, + }, SubscriptionRecurring.Yearly ); - t.true(stripe.subscriptionSchedules.update.calledOnce); - const arg = stripe.subscriptionSchedules.update.firstCall.args; - t.is(arg[0], subscriptionSchedule.id); - t.deepEqual(arg[1], { - phases: [ + t.true( + stripe.subscriptionSchedules.update.calledOnceWith( + subscriptionSchedule.id, { - items: [ + phases: [ { - price: PRO_MONTHLY, + items: [ + { + price: PRO_MONTHLY, + quantity: 1, + }, + ], + start_date: subscriptionSchedule.phases[0].start_date, + end_date: subscriptionSchedule.phases[0].end_date, }, - ], - start_date: 1714118236, - end_date: 1745654236, - }, - { - items: [ { - price: PRO_YEARLY, + items: [ + { + price: PRO_YEARLY, + quantity: 1, + }, + ], }, ], - }, - ], - }); + } + ) + ); }); test('should release the schedule if the new recurring is the same as the current phase', async t => { const { service, db, u1, stripe } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', stripeScheduleId: 'sub_sched_1', plan: SubscriptionPlan.Pro, @@ -953,15 +1003,129 @@ test('should release the schedule if the new recurring is the same as the curren } as any); stripe.subscriptionSchedules.retrieve.resolves(subscriptionSchedule as any); - await service.updateSubscriptionRecurring( - u1.id, - SubscriptionPlan.Pro, + const subInDB = await service.updateSubscriptionRecurring( + { + userId: u1.id, + plan: SubscriptionPlan.Pro, + }, SubscriptionRecurring.Monthly ); t.true( stripe.subscriptionSchedules.release.calledOnceWith(subscriptionSchedule.id) ); + + t.is(subInDB.recurring, SubscriptionRecurring.Monthly); +}); + +test('should be able to cancel subscription with schedule', async t => { + const { service, u1, stripe } = t.context; + + await service.saveStripeSubscription({ + ...sub, + schedule: 'sub_sched_1', + }); + + stripe.subscriptionSchedules.retrieve.resolves(subscriptionSchedule as any); + + const subInDB = await service.cancelSubscription({ + userId: u1.id, + plan: SubscriptionPlan.Pro, + }); + + t.true( + stripe.subscriptionSchedules.update.calledOnceWith( + subscriptionSchedule.id, + { + phases: [ + { + items: [ + { + price: PRO_MONTHLY, + quantity: 1, + }, + ], + coupon: undefined, + start_date: subscriptionSchedule.phases[0].start_date, + end_date: subscriptionSchedule.phases[0].end_date, + metadata: { + next_coupon: null, + next_price: PRO_YEARLY, + }, + }, + ], + end_behavior: 'cancel', + } + ) + ); + + t.is(subInDB.status, SubscriptionStatus.Active); + t.truthy(subInDB.canceledAt); + t.falsy(subInDB.nextBillAt); +}); + +test('should be able to resume subscription with schedule', async t => { + const { service, u1, stripe } = t.context; + + await service.saveStripeSubscription({ + ...sub, + canceled_at: 1714118236, + schedule: 'sub_sched_1', + }); + + stripe.subscriptionSchedules.retrieve.resolves({ + ...subscriptionSchedule, + phases: [ + { + items: [ + // @ts-expect-error mock + { + price: PRO_MONTHLY, + quantity: 1, + }, + ], + start_date: subscriptionSchedule.phases[0].start_date, + end_date: subscriptionSchedule.phases[0].end_date, + metadata: { + next_price: PRO_YEARLY, + }, + }, + ], + end_behavior: 'cancel', + }); + + const subInDB = await service.resumeSubscription({ + userId: u1.id, + plan: SubscriptionPlan.Pro, + }); + + t.true( + stripe.subscriptionSchedules.update.calledOnceWith( + subscriptionSchedule.id, + { + phases: [ + { + items: [{ price: PRO_MONTHLY, quantity: 1 }], + start_date: subscriptionSchedule.phases[0].start_date, + end_date: subscriptionSchedule.phases[0].end_date, + metadata: { + next_price: null, + next_coupon: null, + }, + }, + { + items: [{ price: PRO_YEARLY, quantity: 1 }], + coupon: undefined, + }, + ], + end_behavior: 'release', + } + ) + ); + + t.is(subInDB.status, SubscriptionStatus.Active); + t.falsy(subInDB.canceledAt); + t.truthy(subInDB.nextBillAt); }); // ============== Lifetime Subscription =============== @@ -1023,38 +1187,40 @@ const onetimeYearlyInvoice: Stripe.Invoice = { }; test('should not be able to checkout for lifetime recurring if not enabled', async t => { - const { service, u1, app } = t.context; - await app.get(Config).runtime.set('plugins.payment/showLifetimePrice', false); + const { service, u1, runtime } = t.context; + runtime.fetch.withArgs('plugins.payment/showLifetimePrice').resolves(false); await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, + variant: null, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { + user: u1, + } + ), { message: 'You are trying to access a unknown subscription plan.' } ); }); test('should be able to checkout for lifetime recurring', async t => { - const { service, u1, app, stripe } = t.context; - const config = app.get(Config); - await config.runtime.set('plugins.payment/showLifetimePrice', true); + const { service, u1, stripe } = t.context; - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, + variant: null, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { + user: u1, + } + ); t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { price: PRO_LIFETIME, @@ -1065,9 +1231,9 @@ test('should be able to checkout for lifetime recurring', async t => { test('should not be able to checkout for lifetime recurring if already subscribed', async t => { const { service, u1, db } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: null, plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, @@ -1078,20 +1244,22 @@ test('should not be able to checkout for lifetime recurring if already subscribe await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, + variant: null, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { + user: u1, + } + ), { message: 'You have already subscribed to the pro plan.' } ); - await db.userSubscription.updateMany({ - where: { userId: u1.id }, + await db.subscription.updateMany({ + where: { targetId: u1.id }, data: { stripeSubscriptionId: null, recurring: SubscriptionRecurring.Monthly, @@ -1102,15 +1270,17 @@ test('should not be able to checkout for lifetime recurring if already subscribe await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, + variant: null, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { + user: u1, + } + ), { message: 'You have already subscribed to the pro plan.' } ); }); @@ -1121,8 +1291,8 @@ test('should be able to subscribe to lifetime recurring', async t => { await service.saveStripeInvoice(lifetimeInvoice); - const subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + const subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.true( @@ -1141,9 +1311,9 @@ test('should be able to subscribe to lifetime recurring', async t => { test('should be able to subscribe to lifetime recurring with old subscription', async t => { const { service, stripe, db, u1, event } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, @@ -1156,8 +1326,8 @@ test('should be able to subscribe to lifetime recurring with old subscription', stripe.subscriptions.cancel.resolves(sub as any); await service.saveStripeInvoice(lifetimeInvoice); - const subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + const subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.true( @@ -1176,9 +1346,9 @@ test('should be able to subscribe to lifetime recurring with old subscription', test('should not be able to cancel lifetime subscription', async t => { const { service, db, u1 } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, status: SubscriptionStatus.Active, @@ -1188,7 +1358,11 @@ test('should not be able to cancel lifetime subscription', async t => { }); await t.throwsAsync( - () => service.cancelSubscription(u1.id, SubscriptionPlan.Pro), + () => + service.cancelSubscription({ + plan: SubscriptionPlan.Pro, + userId: u1.id, + }), { message: 'Onetime payment subscription cannot be canceled.' } ); }); @@ -1196,9 +1370,9 @@ test('should not be able to cancel lifetime subscription', async t => { test('should not be able to update lifetime recurring', async t => { const { service, db, u1 } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Lifetime, status: SubscriptionStatus.Active, @@ -1210,8 +1384,10 @@ test('should not be able to update lifetime recurring', async t => { await t.throwsAsync( () => service.updateSubscriptionRecurring( - u1.id, - SubscriptionPlan.Pro, + { + plan: SubscriptionPlan.Pro, + userId: u1.id, + }, SubscriptionRecurring.Monthly ), { message: 'You cannot update an onetime payment subscription.' } @@ -1222,16 +1398,17 @@ test('should not be able to update lifetime recurring', async t => { test('should be able to checkout for onetime payment', async t => { const { service, u1, stripe } = t.context; - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, variant: SubscriptionVariant.Onetime, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { + user: u1, + } + ); t.true(stripe.checkout.sessions.create.calledOnce); const arg = stripe.checkout.sessions.create.firstCall @@ -1246,9 +1423,9 @@ test('should be able to checkout for onetime payment', async t => { test('should be able to checkout onetime payment if previous subscription is onetime', async t => { const { service, u1, stripe, db } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, @@ -1259,16 +1436,17 @@ test('should be able to checkout onetime payment if previous subscription is one }, }); - await service.checkout({ - user: u1, - lookupKey: { + await service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, variant: SubscriptionVariant.Onetime, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }); + { + user: u1, + } + ); t.true(stripe.checkout.sessions.create.calledOnce); const arg = stripe.checkout.sessions.create.firstCall @@ -1283,9 +1461,9 @@ test('should be able to checkout onetime payment if previous subscription is one test('should not be able to checkout out onetime payment if previous subscription is not onetime', async t => { const { service, u1, db } = t.context; - await db.userSubscription.create({ + await db.subscription.create({ data: { - userId: u1.id, + targetId: u1.id, stripeSubscriptionId: 'sub_1', plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, @@ -1297,21 +1475,22 @@ test('should not be able to checkout out onetime payment if previous subscriptio await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, variant: SubscriptionVariant.Onetime, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { + user: u1, + } + ), { message: 'You have already subscribed to the pro plan.' } ); - await db.userSubscription.updateMany({ - where: { userId: u1.id }, + await db.subscription.updateMany({ + where: { targetId: u1.id }, data: { stripeSubscriptionId: null, recurring: SubscriptionRecurring.Lifetime, @@ -1320,16 +1499,17 @@ test('should not be able to checkout out onetime payment if previous subscriptio await t.throwsAsync( () => - service.checkout({ - user: u1, - lookupKey: { + service.checkout( + { plan: SubscriptionPlan.Pro, recurring: SubscriptionRecurring.Monthly, variant: SubscriptionVariant.Onetime, + successCallbackLink: '', }, - redirectUrl: '', - idempotencyKey: '', - }), + { + user: u1, + } + ), { message: 'You have already subscribed to the pro plan.' } ); }); @@ -1339,8 +1519,8 @@ test('should be able to subscribe onetime payment subscription', async t => { await service.saveStripeInvoice(onetimeMonthlyInvoice); - const subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + const subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.true( @@ -1365,16 +1545,16 @@ test('should be able to recalculate onetime payment subscription period', async await service.saveStripeInvoice(onetimeMonthlyInvoice); - let subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + let subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); t.truthy(subInDB); let end = subInDB!.end!; await service.saveStripeInvoice(onetimeMonthlyInvoice); - subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); // add 30 days @@ -1382,23 +1562,23 @@ test('should be able to recalculate onetime payment subscription period', async end = subInDB!.end!; await service.saveStripeInvoice(onetimeYearlyInvoice); - subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); // add 365 days t.is(subInDB!.end!.getTime(), end.getTime() + 365 * 24 * 60 * 60 * 1000); // make subscription expired - await db.userSubscription.update({ + await db.subscription.update({ where: { id: subInDB!.id }, data: { end: new Date(Date.now() - 1000), }, }); await service.saveStripeInvoice(onetimeYearlyInvoice); - subInDB = await db.userSubscription.findFirst({ - where: { userId: u1.id }, + subInDB = await db.subscription.findFirst({ + where: { targetId: u1.id }, }); // add 365 days from now @@ -1407,3 +1587,148 @@ test('should be able to recalculate onetime payment subscription period', async new Date(Date.now() + 365 * 24 * 60 * 60 * 1000).toDateString() ); }); + +// TEAM +test('should be able to list prices for team', async t => { + const { service } = t.context; + + const prices = await service.listPrices(undefined); + + t.snapshot(prices.map(p => encodeLookupKey(p.lookupKey))); +}); + +test('should be able to checkout for team', async t => { + const { service, u1, stripe } = t.context; + + await service.checkout( + { + plan: SubscriptionPlan.Team, + recurring: SubscriptionRecurring.Monthly, + variant: null, + successCallbackLink: '', + }, + { + user: u1, + workspaceId: 'ws_1', + } + ); + + t.deepEqual(getLastCheckoutPrice(stripe.checkout.sessions.create), { + price: TEAM_MONTHLY, + coupon: undefined, + }); +}); + +test('should not be able to checkout for workspace if subscribed', async t => { + const { service, u1, db } = t.context; + + await db.subscription.create({ + data: { + targetId: 'ws_1', + stripeSubscriptionId: 'sub_1', + plan: SubscriptionPlan.Team, + recurring: SubscriptionRecurring.Monthly, + status: SubscriptionStatus.Active, + start: new Date(), + end: new Date(), + quantity: 1, + }, + }); + + await t.throwsAsync( + () => + service.checkout( + { + plan: SubscriptionPlan.Team, + recurring: SubscriptionRecurring.Monthly, + variant: null, + successCallbackLink: '', + }, + { + user: u1, + workspaceId: 'ws_1', + } + ), + { message: 'You have already subscribed to the team plan.' } + ); +}); + +const teamSub: Stripe.Subscription = { + ...sub, + items: { + object: 'list', + data: [ + { + id: 'si_1', + // @ts-expect-error stub + price: { + id: TEAM_MONTHLY, + lookup_key: 'team_monthly', + }, + subscription: 'sub_1', + quantity: 1, + }, + ], + }, + metadata: { + workspaceId: 'ws_1', + }, +}; + +test('should be able to create team subscription', async t => { + const { event, service, db } = t.context; + + await service.saveStripeSubscription(teamSub); + + const subInDB = await db.subscription.findFirst({ + where: { targetId: 'ws_1' }, + }); + + t.true( + event.emit.calledOnceWith('workspace.subscription.activated', { + workspaceId: 'ws_1', + plan: SubscriptionPlan.Team, + recurring: SubscriptionRecurring.Monthly, + quantity: 1, + }) + ); + t.is(subInDB?.stripeSubscriptionId, sub.id); +}); + +test('should be able to update team subscription', async t => { + const { service, db, event } = t.context; + + await service.saveStripeSubscription(teamSub); + + await service.saveStripeSubscription({ + ...teamSub, + items: { + ...teamSub.items, + data: [ + { + ...teamSub.items.data[0], + quantity: 2, + }, + ], + }, + }); + + const subInDB = await db.subscription.findFirst({ + where: { targetId: 'ws_1' }, + }); + + t.is(subInDB?.quantity, 2); + + t.true( + event.emit.calledWith('workspace.subscription.activated', { + workspaceId: 'ws_1', + plan: SubscriptionPlan.Team, + recurring: SubscriptionRecurring.Monthly, + quantity: 2, + }) + ); +}); + +// NOTE(@forehalo): cancel and resume a team subscription share the same logic with user subscription +test.skip('should be able to cancel team subscription', async () => {}); +test.skip('should be able to resume team subscription', async () => {}); diff --git a/packages/backend/server/tests/payment/snapshots/service.spec.ts.md b/packages/backend/server/tests/payment/snapshots/service.spec.ts.md index b10c48de7847f..cbce3754c87b4 100644 --- a/packages/backend/server/tests/payment/snapshots/service.spec.ts.md +++ b/packages/backend/server/tests/payment/snapshots/service.spec.ts.md @@ -11,7 +11,10 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', 'pro_yearly', + 'pro_lifetime', 'ai_yearly', + 'team_monthly', + 'team_yearly', ] ## should list normal prices for authenticated user @@ -21,7 +24,22 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', 'pro_yearly', + 'pro_lifetime', 'ai_yearly', + 'team_monthly', + 'team_yearly', + ] + +## should not show lifetime price if not enabled + +> Snapshot 1 + + [ + 'pro_monthly', + 'pro_yearly', + 'ai_yearly', + 'team_monthly', + 'team_yearly', ] ## should list early access prices for pro ea user @@ -30,8 +48,11 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', + 'pro_lifetime', 'pro_yearly_earlyaccess', 'ai_yearly', + 'team_monthly', + 'team_yearly', ] ## should list normal prices for pro ea user with old subscriptions @@ -41,7 +62,10 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', 'pro_yearly', + 'pro_lifetime', 'ai_yearly', + 'team_monthly', + 'team_yearly', ] ## should list early access prices for ai ea user @@ -51,7 +75,10 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', 'pro_yearly', + 'pro_lifetime', 'ai_yearly_earlyaccess', + 'team_monthly', + 'team_yearly', ] ## should list early access prices for pro and ai ea user @@ -60,8 +87,11 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', + 'pro_lifetime', 'pro_yearly_earlyaccess', 'ai_yearly_earlyaccess', + 'team_monthly', + 'team_yearly', ] ## should list normal prices for ai ea user with old subscriptions @@ -71,5 +101,21 @@ Generated by [AVA](https://avajs.dev). [ 'pro_monthly', 'pro_yearly', + 'pro_lifetime', + 'ai_yearly', + 'team_monthly', + 'team_yearly', + ] + +## should be able to list prices for team + +> Snapshot 1 + + [ + 'pro_monthly', + 'pro_yearly', + 'pro_lifetime', 'ai_yearly', + 'team_monthly', + 'team_yearly', ] diff --git a/packages/backend/server/tests/payment/snapshots/service.spec.ts.snap b/packages/backend/server/tests/payment/snapshots/service.spec.ts.snap index e442c5ca70755..165ed37186952 100644 Binary files a/packages/backend/server/tests/payment/snapshots/service.spec.ts.snap and b/packages/backend/server/tests/payment/snapshots/service.spec.ts.snap differ diff --git a/packages/frontend/core/src/modules/cloud/stores/subscription.ts b/packages/frontend/core/src/modules/cloud/stores/subscription.ts index 144695f28d3c4..5e32a1eac629b 100644 --- a/packages/frontend/core/src/modules/cloud/stores/subscription.ts +++ b/packages/frontend/core/src/modules/cloud/stores/subscription.ts @@ -23,7 +23,7 @@ const SUBSCRIPTION_CACHE_KEY = 'subscription:'; const getDefaultSubscriptionSuccessCallbackLink = ( baseUrl: string, - plan: SubscriptionPlan | null, + plan?: SubscriptionPlan | null, scheme?: string ) => { const path = diff --git a/packages/frontend/graphql/codegen.yml b/packages/frontend/graphql/codegen.yml index 220382380c36b..9ae652f4e6be8 100644 --- a/packages/frontend/graphql/codegen.yml +++ b/packages/frontend/graphql/codegen.yml @@ -5,7 +5,11 @@ config: strict: true maybeValue: T | null declarationKind: interface - avoidOptionals: true + avoidOptionals: + field: true + inputValue: false + object: false + defaultValue: false preResolveTypes: true namingConvention: enumValues: keep diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index 15014413e986c..c2cdfd4f4367e 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -77,8 +77,8 @@ export interface Copilot { } export interface CopilotHistoriesArgs { - docId: InputMaybe; - options: InputMaybe; + docId?: InputMaybe; + options?: InputMaybe; } export interface CopilotHistories { @@ -111,11 +111,11 @@ export enum CopilotModels { } export interface CopilotPromptConfigInput { - frequencyPenalty: InputMaybe; - jsonMode: InputMaybe; - presencePenalty: InputMaybe; - temperature: InputMaybe; - topP: InputMaybe; + frequencyPenalty?: InputMaybe; + jsonMode?: InputMaybe; + presencePenalty?: InputMaybe; + temperature?: InputMaybe; + topP?: InputMaybe; } export interface CopilotPromptConfigType { @@ -129,7 +129,7 @@ export interface CopilotPromptConfigType { export interface CopilotPromptMessageInput { content: Scalars['String']['input']; - params: InputMaybe; + params?: InputMaybe; role: CopilotPromptMessageRole; } @@ -174,10 +174,10 @@ export interface CopilotQuota { } export interface CreateChatMessageInput { - attachments: InputMaybe>; - blobs: InputMaybe>; - content: InputMaybe; - params: InputMaybe; + attachments?: InputMaybe>; + blobs?: InputMaybe>; + content?: InputMaybe; + params?: InputMaybe; sessionId: Scalars['String']['input']; } @@ -189,17 +189,19 @@ export interface CreateChatSessionInput { } export interface CreateCheckoutSessionInput { - coupon: InputMaybe; - idempotencyKey: InputMaybe; - plan: InputMaybe; - recurring: InputMaybe; + args?: InputMaybe; + coupon?: InputMaybe; + idempotencyKey?: InputMaybe; + plan?: InputMaybe; + quantity?: InputMaybe; + recurring?: InputMaybe; successCallbackLink: Scalars['String']['input']; - variant: InputMaybe; + variant?: InputMaybe; } export interface CreateCopilotPromptInput { - action: InputMaybe; - config: InputMaybe; + action?: InputMaybe; + config?: InputMaybe; messages: Array; model: CopilotModels; name: Scalars['String']['input']; @@ -207,7 +209,7 @@ export interface CreateCopilotPromptInput { export interface CreateUserInput { email: Scalars['String']['input']; - name: InputMaybe; + name?: InputMaybe; } export interface CredentialsRequirementType { @@ -282,6 +284,7 @@ export type ErrorDataUnion = | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType + | UnsupportedSubscriptionPlanDataType | VersionRejectedDataType; export enum ErrorNames { @@ -320,12 +323,14 @@ export enum ErrorNames { FAILED_TO_SAVE_UPDATES = 'FAILED_TO_SAVE_UPDATES', FAILED_TO_UPSERT_SNAPSHOT = 'FAILED_TO_UPSERT_SNAPSHOT', INTERNAL_SERVER_ERROR = 'INTERNAL_SERVER_ERROR', + INVALID_CHECKOUT_PARAMETERS = 'INVALID_CHECKOUT_PARAMETERS', INVALID_EMAIL = 'INVALID_EMAIL', INVALID_EMAIL_TOKEN = 'INVALID_EMAIL_TOKEN', INVALID_HISTORY_TIMESTAMP = 'INVALID_HISTORY_TIMESTAMP', INVALID_OAUTH_CALLBACK_STATE = 'INVALID_OAUTH_CALLBACK_STATE', INVALID_PASSWORD_LENGTH = 'INVALID_PASSWORD_LENGTH', INVALID_RUNTIME_CONFIG_TYPE = 'INVALID_RUNTIME_CONFIG_TYPE', + INVALID_SUBSCRIPTION_PARAMETERS = 'INVALID_SUBSCRIPTION_PARAMETERS', LINK_EXPIRED = 'LINK_EXPIRED', MAILER_SERVICE_IS_NOT_CONFIGURED = 'MAILER_SERVICE_IS_NOT_CONFIGURED', MEMBER_QUOTA_EXCEEDED = 'MEMBER_QUOTA_EXCEEDED', @@ -347,14 +352,18 @@ export enum ErrorNames { SUBSCRIPTION_ALREADY_EXISTS = 'SUBSCRIPTION_ALREADY_EXISTS', SUBSCRIPTION_EXPIRED = 'SUBSCRIPTION_EXPIRED', SUBSCRIPTION_HAS_BEEN_CANCELED = 'SUBSCRIPTION_HAS_BEEN_CANCELED', + SUBSCRIPTION_HAS_NOT_BEEN_CANCELED = 'SUBSCRIPTION_HAS_NOT_BEEN_CANCELED', SUBSCRIPTION_NOT_EXISTS = 'SUBSCRIPTION_NOT_EXISTS', SUBSCRIPTION_PLAN_NOT_FOUND = 'SUBSCRIPTION_PLAN_NOT_FOUND', TOO_MANY_REQUEST = 'TOO_MANY_REQUEST', UNKNOWN_OAUTH_PROVIDER = 'UNKNOWN_OAUTH_PROVIDER', UNSPLASH_IS_NOT_CONFIGURED = 'UNSPLASH_IS_NOT_CONFIGURED', + UNSUPPORTED_SUBSCRIPTION_PLAN = 'UNSUPPORTED_SUBSCRIPTION_PLAN', USER_AVATAR_NOT_FOUND = 'USER_AVATAR_NOT_FOUND', USER_NOT_FOUND = 'USER_NOT_FOUND', VERSION_REJECTED = 'VERSION_REJECTED', + WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION = 'WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION', + WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION = 'WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION', WRONG_SIGN_IN_CREDENTIALS = 'WRONG_SIGN_IN_CREDENTIALS', WRONG_SIGN_IN_METHOD = 'WRONG_SIGN_IN_METHOD', } @@ -485,15 +494,15 @@ export interface LimitedUserType { } export interface ListUserInput { - first: InputMaybe; - skip: InputMaybe; + first?: InputMaybe; + skip?: InputMaybe; } export interface ManageUserInput { /** User email */ - email: InputMaybe; + email?: InputMaybe; /** User name */ - name: InputMaybe; + name?: InputMaybe; } export interface MissingOauthQueryParameterDataType { @@ -575,7 +584,7 @@ export interface Mutation { export interface MutationAcceptInviteByIdArgs { inviteId: Scalars['String']['input']; - sendAcceptMail: InputMaybe; + sendAcceptMail?: InputMaybe; workspaceId: Scalars['String']['input']; } @@ -585,8 +594,9 @@ export interface MutationAddWorkspaceFeatureArgs { } export interface MutationCancelSubscriptionArgs { - idempotencyKey: InputMaybe; + idempotencyKey?: InputMaybe; plan?: InputMaybe; + workspaceId?: InputMaybe; } export interface MutationChangeEmailArgs { @@ -597,7 +607,7 @@ export interface MutationChangeEmailArgs { export interface MutationChangePasswordArgs { newPassword: Scalars['String']['input']; token: Scalars['String']['input']; - userId: InputMaybe; + userId?: InputMaybe; } export interface MutationCleanupCopilotSessionArgs { @@ -630,7 +640,7 @@ export interface MutationCreateUserArgs { } export interface MutationCreateWorkspaceArgs { - init: InputMaybe; + init?: InputMaybe; } export interface MutationDeleteBlobArgs { @@ -653,12 +663,12 @@ export interface MutationForkCopilotSessionArgs { export interface MutationInviteArgs { email: Scalars['String']['input']; permission: Permission; - sendInviteMail: InputMaybe; + sendInviteMail?: InputMaybe; workspaceId: Scalars['String']['input']; } export interface MutationLeaveWorkspaceArgs { - sendLeaveMail: InputMaybe; + sendLeaveMail?: InputMaybe; workspaceId: Scalars['String']['input']; workspaceName: Scalars['String']['input']; } @@ -681,8 +691,9 @@ export interface MutationRemoveWorkspaceFeatureArgs { } export interface MutationResumeSubscriptionArgs { - idempotencyKey: InputMaybe; + idempotencyKey?: InputMaybe; plan?: InputMaybe; + workspaceId?: InputMaybe; } export interface MutationRevokeArgs { @@ -702,17 +713,17 @@ export interface MutationRevokePublicPageArgs { export interface MutationSendChangeEmailArgs { callbackUrl: Scalars['String']['input']; - email: InputMaybe; + email?: InputMaybe; } export interface MutationSendChangePasswordEmailArgs { callbackUrl: Scalars['String']['input']; - email: InputMaybe; + email?: InputMaybe; } export interface MutationSendSetPasswordEmailArgs { callbackUrl: Scalars['String']['input']; - email: InputMaybe; + email?: InputMaybe; } export interface MutationSendVerifyChangeEmailArgs { @@ -760,9 +771,10 @@ export interface MutationUpdateRuntimeConfigsArgs { } export interface MutationUpdateSubscriptionRecurringArgs { - idempotencyKey: InputMaybe; + idempotencyKey?: InputMaybe; plan?: InputMaybe; recurring: SubscriptionRecurring; + workspaceId?: InputMaybe; } export interface MutationUpdateUserArgs { @@ -839,7 +851,6 @@ export interface Query { /** List all copilot prompts */ listCopilotPrompts: Array; listWorkspaceFeatures: Array; - /** @deprecated use `userPrices` instead */ prices: Array; /** server config */ serverConfig: ServerConfigType; @@ -908,13 +919,13 @@ export interface QueryWorkspaceArgs { } export interface QueryChatHistoriesInput { - action: InputMaybe; - fork: InputMaybe; - limit: InputMaybe; - messageOrder: InputMaybe; - sessionId: InputMaybe; - sessionOrder: InputMaybe; - skip: InputMaybe; + action?: InputMaybe; + fork?: InputMaybe; + limit?: InputMaybe; + messageOrder?: InputMaybe; + sessionId?: InputMaybe; + sessionOrder?: InputMaybe; + skip?: InputMaybe; } export interface QuotaQueryType { @@ -1117,17 +1128,22 @@ export interface UnknownOauthProviderDataType { name: Scalars['String']['output']; } +export interface UnsupportedSubscriptionPlanDataType { + __typename?: 'UnsupportedSubscriptionPlanDataType'; + plan: Scalars['String']['output']; +} + export interface UpdateUserInput { /** User name */ - name: InputMaybe; + name?: InputMaybe; } export interface UpdateWorkspaceInput { /** Enable url previous when sharing */ - enableUrlPreview: InputMaybe; + enableUrlPreview?: InputMaybe; id: Scalars['ID']['input']; /** is Public workspace */ - public: InputMaybe; + public?: InputMaybe; } export type UserOrLimitedUser = LimitedUserType | UserType; @@ -1182,11 +1198,11 @@ export interface UserType { } export interface UserTypeCopilotArgs { - workspaceId: InputMaybe; + workspaceId?: InputMaybe; } export interface UserTypeInvoicesArgs { - skip: InputMaybe; + skip?: InputMaybe; take?: InputMaybe; } @@ -1258,17 +1274,19 @@ export interface WorkspaceType { * @deprecated use WorkspaceType.publicPages */ sharedPages: Array; + /** The team subscription of the workspace, if exists. */ + subscription: Maybe; } export interface WorkspaceTypeHistoriesArgs { - before: InputMaybe; + before?: InputMaybe; guid: Scalars['String']['input']; - take: InputMaybe; + take?: InputMaybe; } export interface WorkspaceTypeMembersArgs { - skip: InputMaybe; - take: InputMaybe; + skip?: InputMaybe; + take?: InputMaybe; } export interface WorkspaceTypePageMetaArgs { @@ -1513,8 +1531,8 @@ export type PasswordLimitsFragment = { export type GetCopilotHistoriesQueryVariables = Exact<{ workspaceId: Scalars['String']['input']; - docId: InputMaybe; - options: InputMaybe; + docId?: InputMaybe; + options?: InputMaybe; }>; export type GetCopilotHistoriesQuery = { @@ -1544,8 +1562,8 @@ export type GetCopilotHistoriesQuery = { export type GetCopilotHistoryIdsQueryVariables = Exact<{ workspaceId: Scalars['String']['input']; - docId: InputMaybe; - options: InputMaybe; + docId?: InputMaybe; + options?: InputMaybe; }>; export type GetCopilotHistoryIdsQuery = { @@ -1915,8 +1933,8 @@ export type GetWorkspacesQuery = { export type ListHistoryQueryVariables = Exact<{ workspaceId: Scalars['String']['input']; pageDocId: Scalars['String']['input']; - take: InputMaybe; - before: InputMaybe; + take?: InputMaybe; + before?: InputMaybe; }>; export type ListHistoryQuery = { @@ -1970,7 +1988,7 @@ export type InvoicesQuery = { export type LeaveWorkspaceMutationVariables = Exact<{ workspaceId: Scalars['String']['input']; workspaceName: Scalars['String']['input']; - sendLeaveMail: InputMaybe; + sendLeaveMail?: InputMaybe; }>; export type LeaveWorkspaceMutation = { @@ -2422,7 +2440,7 @@ export type InviteByEmailMutationVariables = Exact<{ workspaceId: Scalars['String']['input']; email: Scalars['String']['input']; permission: Permission; - sendInviteMail: InputMaybe; + sendInviteMail?: InputMaybe; }>; export type InviteByEmailMutation = { __typename?: 'Mutation'; invite: string }; @@ -2430,7 +2448,7 @@ export type InviteByEmailMutation = { __typename?: 'Mutation'; invite: string }; export type AcceptInviteByInviteIdMutationVariables = Exact<{ workspaceId: Scalars['String']['input']; inviteId: Scalars['String']['input']; - sendAcceptMail: InputMaybe; + sendAcceptMail?: InputMaybe; }>; export type AcceptInviteByInviteIdMutation = {