diff --git a/packages/backend/server/migrations/20241202082025_licenses/migration.sql b/packages/backend/server/migrations/20241202082025_licenses/migration.sql new file mode 100644 index 0000000000000..3c541bd5600d7 --- /dev/null +++ b/packages/backend/server/migrations/20241202082025_licenses/migration.sql @@ -0,0 +1,20 @@ +-- CreateTable +CREATE TABLE "licenses" ( + "key" VARCHAR NOT NULL, + "installed_at" TIMESTAMPTZ(3) DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "licenses_pkey" PRIMARY KEY ("key") +); + +-- CreateTable +CREATE TABLE "installed_licenses" ( + "key" VARCHAR NOT NULL, + "workspace_id" VARCHAR NOT NULL, + "installed_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "revalidated_at" TIMESTAMPTZ(3), + + CONSTRAINT "installed_licenses_pkey" PRIMARY KEY ("key") +); + +-- CreateIndex +CREATE UNIQUE INDEX "installed_licenses_workspace_id_key" ON "installed_licenses"("workspace_id"); diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 269cd8b37259f..b43d40d3a5ef2 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -551,3 +551,19 @@ model Invoice { @@index([targetId]) @@map("invoices") } + +model License { + key String @id @map("key") @db.VarChar + installedAt DateTime? @default(now()) @map("installed_at") @db.Timestamptz(3) + + @@map("licenses") +} + +model InstalledLicense { + key String @id @map("key") @db.VarChar + workspaceId String @unique @map("workspace_id") @db.VarChar + installedAt DateTime @default(now()) @map("installed_at") @db.Timestamptz(3) + revalidatedAt DateTime? @map("revalidated_at") @db.Timestamptz(3) + + @@map("installed_licenses") +} diff --git a/packages/backend/server/src/core/quota/service.ts b/packages/backend/server/src/core/quota/service.ts index c5193e768e192..6683071dabff6 100644 --- a/packages/backend/server/src/core/quota/service.ts +++ b/packages/backend/server/src/core/quota/service.ts @@ -133,6 +133,19 @@ export class QuotaService { }); } + async switchWorkspaceQuota( + workspaceId: string, + quota: QuotaType, + reason?: string, + expiredAt?: Date + ) { + // TODO(@darksky): implement + } + + async switchWorkspaceQuotaToDefault(workspaceId: string, reason?: string) { + // TODO(@darksky): implement + } + async hasQuota(userId: string, quota: QuotaType, tx?: PrismaTransaction) { const executor = tx ?? this.prisma; diff --git a/packages/backend/server/src/fundamentals/error/def.ts b/packages/backend/server/src/fundamentals/error/def.ts index 385eebd3fc727..a6df823650d1f 100644 --- a/packages/backend/server/src/fundamentals/error/def.ts +++ b/packages/backend/server/src/fundamentals/error/def.ts @@ -478,6 +478,25 @@ export const USER_FRIENDLY_ERRORS = { type: 'invalid_input', message: 'Workspace id is required to update team subscription.', }, + license_not_found: { + type: 'resource_not_found', + message: 'License not found.', + }, + invalid_license_to_activate: { + type: 'bad_request', + message: 'Invalid license to activate.', + }, + invalid_license_update_params: { + type: 'invalid_input', + args: { reason: 'string' }, + message: ({ reason }) => `Invalid license update params. ${reason}`, + }, + workspace_members_exceed_limit_to_downgrade: { + type: 'bad_request', + args: { limit: 'number' }, + message: ({ limit }) => + `You cannot downgrade the workspace from team workspace because there are more than ${limit} members that are currently active.`, + }, // Copilot errors copilot_session_not_found: { diff --git a/packages/backend/server/src/fundamentals/error/errors.gen.ts b/packages/backend/server/src/fundamentals/error/errors.gen.ts index c3212d71cc815..e3440a0c5dcbf 100644 --- a/packages/backend/server/src/fundamentals/error/errors.gen.ts +++ b/packages/backend/server/src/fundamentals/error/errors.gen.ts @@ -440,6 +440,38 @@ export class WorkspaceIdRequiredToUpdateTeamSubscription extends UserFriendlyErr } } +export class LicenseNotFound extends UserFriendlyError { + constructor(message?: string) { + super('resource_not_found', 'license_not_found', message); + } +} + +export class InvalidLicenseToActivate extends UserFriendlyError { + constructor(message?: string) { + super('bad_request', 'invalid_license_to_activate', message); + } +} +@ObjectType() +class InvalidLicenseUpdateParamsDataType { + @Field() reason!: string +} + +export class InvalidLicenseUpdateParams extends UserFriendlyError { + constructor(args: InvalidLicenseUpdateParamsDataType, message?: string | ((args: InvalidLicenseUpdateParamsDataType) => string)) { + super('invalid_input', 'invalid_license_update_params', message, args); + } +} +@ObjectType() +class WorkspaceMembersExceedLimitToDowngradeDataType { + @Field() limit!: number +} + +export class WorkspaceMembersExceedLimitToDowngrade extends UserFriendlyError { + constructor(args: WorkspaceMembersExceedLimitToDowngradeDataType, message?: string | ((args: WorkspaceMembersExceedLimitToDowngradeDataType) => string)) { + super('bad_request', 'workspace_members_exceed_limit_to_downgrade', message, args); + } +} + export class CopilotSessionNotFound extends UserFriendlyError { constructor(message?: string) { super('resource_not_found', 'copilot_session_not_found', message); @@ -642,6 +674,10 @@ export enum ErrorNames { CANT_UPDATE_ONETIME_PAYMENT_SUBSCRIPTION, WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION, WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION, + LICENSE_NOT_FOUND, + INVALID_LICENSE_TO_ACTIVATE, + INVALID_LICENSE_UPDATE_PARAMS, + WORKSPACE_MEMBERS_EXCEED_LIMIT_TO_DOWNGRADE, COPILOT_SESSION_NOT_FOUND, COPILOT_SESSION_DELETED, NO_COPILOT_PROVIDER_AVAILABLE, @@ -670,5 +706,5 @@ registerEnumType(ErrorNames, { export const ErrorDataUnionType = createUnionType({ name: 'ErrorDataUnion', types: () => - [UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, SpaceNotFoundDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const, + [UnknownOauthProviderDataType, MissingOauthQueryParameterDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, SpaceNotFoundDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, DocNotFoundDataType, DocAccessDeniedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, InvalidLicenseUpdateParamsDataType, WorkspaceMembersExceedLimitToDowngradeDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderSideErrorDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType] as const, }); diff --git a/packages/backend/server/src/plugins/payment/index.ts b/packages/backend/server/src/plugins/payment/index.ts index c59072bf43d95..5b373aabf4e51 100644 --- a/packages/backend/server/src/plugins/payment/index.ts +++ b/packages/backend/server/src/plugins/payment/index.ts @@ -8,6 +8,7 @@ import { Plugin } from '../registry'; import { StripeWebhookController } from './controller'; import { SubscriptionCronJobs } from './cron'; import { + SelfhostTeamSubscriptionManager, UserSubscriptionManager, WorkspaceSubscriptionManager, } from './manager'; @@ -31,6 +32,7 @@ import { StripeWebhook } from './webhook'; StripeWebhook, UserSubscriptionManager, WorkspaceSubscriptionManager, + SelfhostTeamSubscriptionManager, SubscriptionCronJobs, WorkspaceSubscriptionResolver, ], diff --git a/packages/backend/server/src/plugins/payment/license/controller.ts b/packages/backend/server/src/plugins/payment/license/controller.ts new file mode 100644 index 0000000000000..502332432ecae --- /dev/null +++ b/packages/backend/server/src/plugins/payment/license/controller.ts @@ -0,0 +1,164 @@ +import { Body, Controller, Param, Post } from '@nestjs/common'; +import { PrismaClient } from '@prisma/client'; +import { z } from 'zod'; + +import { + InvalidLicenseToActivate, + InvalidLicenseUpdateParams, + LicenseNotFound, + Mutex, +} from '../../../fundamentals'; +import { SelfhostTeamSubscriptionManager } from '../manager/selfhost'; +import { SubscriptionService } from '../service'; +import { + SubscriptionPlan, + SubscriptionRecurring, + SubscriptionStatus, +} from '../types'; + +const UpdateSeatsParams = z.object({ + seats: z.number().min(1), +}); + +const UpdateRecurringParams = z.object({ + recurring: z.enum([ + SubscriptionRecurring.Monthly, + SubscriptionRecurring.Yearly, + ]), +}); + +@Controller('/api/team/licenses') +export class LicenseController { + constructor( + private readonly db: PrismaClient, + private readonly mutex: Mutex, + private readonly subscription: SubscriptionService, + private readonly manager: SelfhostTeamSubscriptionManager + ) {} + + @Post('/:license/activate') + async activate(@Param('license') key: string) { + await using lock = await this.mutex.lock(`license-activation:${key}`); + + if (!lock) { + throw new InvalidLicenseToActivate(); + } + + const license = await this.db.license.findUnique({ + where: { + key, + }, + }); + + if (!license) { + throw new InvalidLicenseToActivate(); + } + + const subscription = await this.manager.getSubscription({ + key: license.key, + plan: SubscriptionPlan.SelfHostedTeam, + }); + + if ( + !subscription || + license.installedAt || + subscription.status !== SubscriptionStatus.Active + ) { + throw new InvalidLicenseToActivate(); + } + + await this.db.license.update({ + where: { + key, + }, + data: { + installedAt: new Date(), + }, + }); + + return { + quota: {}, + endAt: subscription.end?.getTime(), + }; + } + + @Post('/:license/deactivate') + async deactivate(@Param('license') key: string) { + await this.db.license.update({ + where: { + key, + }, + data: { + installedAt: null, + }, + }); + + return { + success: true, + }; + } + + @Post('/:license/seats') + async updateSeats( + @Param('license') key: string, + @Body() body: z.infer + ) { + const parseResult = UpdateSeatsParams.safeParse(body); + + if (parseResult.error) { + throw new InvalidLicenseUpdateParams({ + reason: parseResult.error.message, + }); + } + + const license = await this.db.license.findUnique({ + where: { + key, + }, + }); + + if (!license) { + throw new LicenseNotFound(); + } + + await this.subscription.updateSubscriptionQuantity( + { + key: license.key, + plan: SubscriptionPlan.SelfHostedTeam, + }, + parseResult.data.seats + ); + } + + @Post('/:license/recurring') + async updateRecurring( + @Param('license') key: string, + @Body() body: z.infer + ) { + const parseResult = UpdateRecurringParams.safeParse(body); + + if (parseResult.error) { + throw new InvalidLicenseUpdateParams({ + reason: parseResult.error.message, + }); + } + + const license = await this.db.license.findUnique({ + where: { + key, + }, + }); + + if (!license) { + throw new LicenseNotFound(); + } + + await this.subscription.updateSubscriptionRecurring( + { + key: license.key, + plan: SubscriptionPlan.SelfHostedTeam, + }, + parseResult.data.recurring + ); + } +} diff --git a/packages/backend/server/src/plugins/payment/license/selfhost.ts b/packages/backend/server/src/plugins/payment/license/selfhost.ts new file mode 100644 index 0000000000000..f1e2f35e90382 --- /dev/null +++ b/packages/backend/server/src/plugins/payment/license/selfhost.ts @@ -0,0 +1,236 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { Cron, CronExpression } from '@nestjs/schedule'; +import { InstalledLicense, PrismaClient } from '@prisma/client'; + +import { QuotaService } from '../../../core/quota'; +import { + type EventPayload, + InternalServerError, + LicenseNotFound, + OnEvent, + UserFriendlyError, + WorkspaceMembersExceedLimitToDowngrade, +} from '../../../fundamentals'; +import { SubscriptionRecurring } from '../types'; + +@Injectable() +export class SelfhostService { + private readonly logger = new Logger(SelfhostService.name); + + constructor( + private readonly db: PrismaClient, + private readonly quota: QuotaService + ) {} + + async activateTeamLicense(workspaceId: string, license: string) { + // TODO(@darksky): check if the workspace is a team workspace already, and update the quota type + const res = await this.fetch<{ quota: any; endAt: number }>( + `/api/team/licenses/${license}/activate`, + { + method: 'POST', + } + ); + + await this.db.installedLicense.upsert({ + where: { + workspaceId, + }, + update: { + key: license, + revalidatedAt: new Date(), + }, + create: { + key: license, + workspaceId, + revalidatedAt: new Date(), + }, + }); + + // TODO(@darksky): upgrade workspace to team + await this.quota.switchWorkspaceQuota( + workspaceId, + res.quota, + 'activate team license', + new Date(res.endAt) + ); + } + + async deactivateTeamLicense(workspaceId: string) { + const license = await this.db.installedLicense.findUnique({ + where: { + workspaceId, + }, + }); + + if (!license) { + throw new LicenseNotFound(); + } + + const count = await this.db.workspaceUserPermission.count({ + where: { + workspaceId, + // TODO(@darksky): to replace with intivation status + accepted: true, + }, + }); + + // TODO(@darksky): get the default quota for workspace + if (count > 10) { + throw new WorkspaceMembersExceedLimitToDowngrade({ limit: 10 }); + } + + await this.fetch(`/api/team/licenses/${license.key}/deactivate`, { + method: 'POST', + }); + + await this.quota.switchWorkspaceQuotaToDefault( + workspaceId, + 'deactivate team license' + ); + + await this.db.installedLicense.deleteMany({ + where: { + workspaceId, + }, + }); + } + + async updateTeamRecurring(key: string, recurring: SubscriptionRecurring) { + await this.fetch(`/api/team/licenses/${key}/recurring`, { + method: 'POST', + body: JSON.stringify({ + recurring, + }), + headers: { + 'Content-Type': 'application/json', + }, + }); + } + + @OnEvent('workspace.members.updated') + async updateTeamSeats(payload: EventPayload<'workspace.members.updated'>) { + const { workspaceId, count } = payload; + + const license = await this.db.installedLicense.findUnique({ + where: { + workspaceId, + }, + }); + + if (!license) { + return; + } + + await this.fetch(`/api/team/licenses/${license.key}/seats`, { + method: 'POST', + body: JSON.stringify({ + quantity: count, + }), + headers: { + 'Content-Type': 'application/json', + }, + }); + + // stripe payment is async, we can't directly the charge result in update calling + await this.waitUntilLicenseUpdated(license.key, count); + } + + async waitUntilLicenseUpdated(key: string, memberRequired: number) { + let tried = 0; + while (tried++ < 10) { + try { + const res = await this.fetch<{ quota: any; endAt: number }>( + `/api/team/licenses/${key}/health` + ); + + if (res.quota.memberLimit === memberRequired) { + break; + } + } catch (e) { + this.logger.error('Failed to check license health', e); + } + + await new Promise(resolve => setTimeout(resolve, tried * 2000)); + } + } + + @Cron(CronExpression.EVERY_HOUR) + async licensesHealthCheck() { + const licenses = await this.db.installedLicense.findMany({ + where: { + revalidatedAt: { + lte: new Date(Date.now() - 1000 * 60 * 60), + }, + }, + }); + + for (const license of licenses) { + await this.revalidateLicense(license); + } + } + + private async revalidateLicense(license: InstalledLicense) { + try { + const res = await this.fetch<{ quota: any; endAt: number }>( + `/api/team/licenses/${license.key}/health` + ); + + await this.db.installedLicense.update({ + where: { + key: license.key, + }, + data: { + revalidatedAt: new Date(), + }, + }); + + await this.quota.switchWorkspaceQuota( + license.workspaceId, + res.quota, + 'revalidate license', + new Date(res.endAt) + ); + } catch (e) { + this.logger.error('Failed to revalidate license', e); + + // only treat known error as invalid license response + if ( + e instanceof UserFriendlyError && + e.name !== 'internal_server_error' + ) { + await this.quota.switchWorkspaceQuotaToDefault( + license.workspaceId, + 'revalidate license' + ); + } + } + } + + private async fetch(path: string, init?: RequestInit) { + try { + const res = await fetch('https://app.affine.pro' + path, init); + + if (!res.ok) { + const body = (await res.json()) as UserFriendlyError; + throw new UserFriendlyError( + body.type as any, + body.name as any, + body.message, + body.data + ); + } + + return res.json() as T; + } catch (e) { + if (e instanceof UserFriendlyError) { + throw e; + } + + throw new InternalServerError( + e instanceof Error + ? e.message + : 'Failed to contact with https://app.affine.pro' + ); + } + } +} diff --git a/packages/backend/server/src/plugins/payment/manager/common.ts b/packages/backend/server/src/plugins/payment/manager/common.ts index c8d16e4191e01..13647bda417ac 100644 --- a/packages/backend/server/src/plugins/payment/manager/common.ts +++ b/packages/backend/server/src/plugins/payment/manager/common.ts @@ -226,7 +226,7 @@ export abstract class SubscriptionManager { protected async getCouponFromPromotionCode( userFacingPromotionCode: string, - customer: UserStripeCustomer + customer?: UserStripeCustomer ) { const list = await this.stripe.promotionCodes.list({ code: userFacingPromotionCode, @@ -245,11 +245,20 @@ export abstract class SubscriptionManager { // code.coupon.applies_to.products.forEach() // check if the code is bound to a specific customer - return !code.customer || - (typeof code.customer === 'string' - ? code.customer === customer.stripeCustomerId - : code.customer.id === customer.stripeCustomerId) - ? code.coupon.id - : null; + if (code.customer) { + if (!customer) { + return null; + } + + return ( + typeof code.customer === 'string' + ? code.customer === customer.stripeCustomerId + : code.customer.id === customer.stripeCustomerId + ) + ? code.coupon.id + : null; + } + + return code.coupon.id; } } diff --git a/packages/backend/server/src/plugins/payment/manager/index.ts b/packages/backend/server/src/plugins/payment/manager/index.ts index 21d7a26a20784..8a03bb3f398fb 100644 --- a/packages/backend/server/src/plugins/payment/manager/index.ts +++ b/packages/backend/server/src/plugins/payment/manager/index.ts @@ -1,3 +1,4 @@ export * from './common'; +export * from './selfhost'; export * from './user'; export * from './workspace'; diff --git a/packages/backend/server/src/plugins/payment/manager/selfhost.ts b/packages/backend/server/src/plugins/payment/manager/selfhost.ts new file mode 100644 index 0000000000000..56cc1b43fd397 --- /dev/null +++ b/packages/backend/server/src/plugins/payment/manager/selfhost.ts @@ -0,0 +1,205 @@ +import { randomUUID } from 'node:crypto'; + +import { Injectable } from '@nestjs/common'; +import { PrismaClient, UserStripeCustomer } from '@prisma/client'; +import { pick } from 'lodash-es'; +import Stripe from 'stripe'; +import { z } from 'zod'; + +import { URLHelper } from '../../../fundamentals'; +import { + KnownStripeInvoice, + KnownStripePrice, + KnownStripeSubscription, + SubscriptionPlan, + SubscriptionRecurring, + SubscriptionStatus, +} from '../types'; +import { + CheckoutParams, + Invoice, + Subscription, + SubscriptionManager, +} from './common'; + +export const SelfhostTeamCheckoutArgs = z.object({ + user: z + .object({ + id: z.string(), + email: z.string(), + }) + .optional(), + quantity: z.number(), +}); + +export const SelfhostTeamSubscriptionIdentity = z.object({ + plan: z.literal(SubscriptionPlan.SelfHostedTeam), + key: z.string(), +}); + +@Injectable() +export class SelfhostTeamSubscriptionManager extends SubscriptionManager { + constructor( + stripe: Stripe, + db: PrismaClient, + private readonly url: URLHelper + ) { + super(stripe, db); + } + + filterPrices( + prices: KnownStripePrice[], + _customer?: UserStripeCustomer + ): KnownStripePrice[] { + return prices.filter( + price => price.lookupKey.plan === SubscriptionPlan.SelfHostedTeam + ); + } + + async checkout( + { price }: KnownStripePrice, + params: z.infer, + args: z.infer + ) { + const { quantity } = args; + + const discounts = await (async () => { + if (params.coupon) { + const couponId = await this.getCouponFromPromotionCode(params.coupon); + if (couponId) { + return { discounts: [{ coupon: couponId }] }; + } + } + + return { allow_promotion_codes: true }; + })(); + + return this.stripe.checkout.sessions.create({ + line_items: [ + { + price: price.id, + quantity, + }, + ], + tax_id_collection: { + enabled: true, + }, + ...discounts, + mode: 'subscription', + success_url: this.url.link(params.successCallbackLink, { + session_id: '{CHECKOUT_SESSION_ID}', + }), + }); + } + + async saveStripeSubscription(subscription: KnownStripeSubscription) { + const { stripeSubscription } = subscription; + + const subscriptionData = this.transformSubscription(subscription); + + const existingSubscription = await this.db.subscription.findFirst({ + where: { + stripeSubscriptionId: stripeSubscription.id, + }, + }); + + if (!existingSubscription) { + const key = randomUUID(); + return this.db.subscription.create({ + data: { + targetId: key, + ...subscriptionData, + }, + }); + } else { + return this.db.subscription.update({ + where: { + stripeSubscriptionId: stripeSubscription.id, + }, + data: pick(subscriptionData, [ + 'status', + 'stripeScheduleId', + 'nextBillAt', + 'canceledAt', + ]), + }); + } + } + + async deleteStripeSubscription({ + stripeSubscription, + }: KnownStripeSubscription) { + const subscription = await this.db.subscription.findFirst({ + where: { stripeSubscriptionId: stripeSubscription.id }, + }); + + if (!subscription) { + return; + } + + await this.db.subscription.deleteMany({ + where: { stripeSubscriptionId: stripeSubscription.id }, + }); + + await this.db.license.deleteMany({ + where: { key: subscription.targetId }, + }); + } + + getSubscription(identity: z.infer) { + return this.db.subscription.findFirst({ + where: { + targetId: identity.key, + plan: identity.plan, + status: { + in: [SubscriptionStatus.Active, SubscriptionStatus.Trialing], + }, + }, + }); + } + + async cancelSubscription(subscription: Subscription) { + return await this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { + canceledAt: new Date(), + nextBillAt: null, + }, + }); + } + + resumeSubscription(subscription: Subscription): Promise { + return this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { + canceledAt: null, + nextBillAt: subscription.end, + }, + }); + } + + updateSubscriptionRecurring( + subscription: Subscription, + recurring: SubscriptionRecurring + ): Promise { + return this.db.subscription.update({ + where: { + // @ts-expect-error checked outside + stripeSubscriptionId: subscription.stripeSubscriptionId, + }, + data: { recurring }, + }); + } + + async saveInvoice(knownInvoice: KnownStripeInvoice): Promise { + const invoiceData = await this.transformInvoice(knownInvoice); + + return invoiceData; + } +} diff --git a/packages/backend/server/src/plugins/payment/manager/user.ts b/packages/backend/server/src/plugins/payment/manager/user.ts index 32b20c570d930..c661c9a08208e 100644 --- a/packages/backend/server/src/plugins/payment/manager/user.ts +++ b/packages/backend/server/src/plugins/payment/manager/user.ts @@ -174,9 +174,7 @@ export class UserSubscriptionManager extends SubscriptionManager { }, ...discounts, ...mode, - success_url: this.url.link(params.successCallbackLink, { - session_id: '{CHECKOUT_SESSION_ID}', - }), + success_url: this.url.link(params.successCallbackLink), customer: customer.stripeCustomerId, }); } @@ -195,6 +193,8 @@ export class UserSubscriptionManager extends SubscriptionManager { async saveStripeSubscription(subscription: KnownStripeSubscription) { const { userId, lookupKey, stripeSubscription } = subscription; + this.assertUserIdExists(userId); + // update features first, features modify are idempotent // so there is no need to skip if a subscription already exists. // TODO(@forehalo): @@ -247,6 +247,8 @@ export class UserSubscriptionManager extends SubscriptionManager { lookupKey, stripeSubscription, }: KnownStripeSubscription) { + this.assertUserIdExists(userId); + const deleted = await this.db.subscription.deleteMany({ where: { stripeSubscriptionId: stripeSubscription.id, @@ -371,6 +373,7 @@ export class UserSubscriptionManager extends SubscriptionManager { async saveInvoice(knownInvoice: KnownStripeInvoice) { const { userId, lookupKey, stripeInvoice } = knownInvoice; + this.assertUserIdExists(userId); const invoiceData = await this.transformInvoice(knownInvoice); @@ -413,6 +416,8 @@ export class UserSubscriptionManager extends SubscriptionManager { async saveLifetimeSubscription( knownInvoice: KnownStripeInvoice ): Promise { + this.assertUserIdExists(knownInvoice.userId); + // cancel previous non-lifetime subscription const prevSubscription = await this.db.subscription.findUnique({ where: { @@ -478,6 +483,8 @@ export class UserSubscriptionManager extends SubscriptionManager { async saveOnetimePaymentSubscription( knownInvoice: KnownStripeInvoice ): Promise { + this.assertUserIdExists(knownInvoice.userId); + // TODO(@forehalo): identify whether the invoice has already been redeemed. const { userId, lookupKey } = knownInvoice; const existingSubscription = await this.db.subscription.findUnique({ @@ -666,4 +673,12 @@ export class UserSubscriptionManager extends SubscriptionManager { onetime: false, }; } + + private assertUserIdExists( + userId: string | undefined + ): asserts userId is string { + if (!userId) { + throw new Error('user should exists for stripe subscription or invoice.'); + } + } } diff --git a/packages/backend/server/src/plugins/payment/resolver.ts b/packages/backend/server/src/plugins/payment/resolver.ts index d71399fba3de6..2eae27c068313 100644 --- a/packages/backend/server/src/plugins/payment/resolver.ts +++ b/packages/backend/server/src/plugins/payment/resolver.ts @@ -16,6 +16,7 @@ import type { User } from '@prisma/client'; import { PrismaClient } from '@prisma/client'; import { GraphQLJSONObject } from 'graphql-scalars'; import { groupBy } from 'lodash-es'; +import Stripe from 'stripe'; import { z } from 'zod'; import { CurrentUser, Public } from '../../core/auth'; @@ -24,6 +25,7 @@ import { UserType } from '../../core/user'; import { WorkspaceType } from '../../core/workspaces'; import { AccessDenied, + AuthenticationRequired, FailedToCheckout, WorkspaceIdRequiredToUpdateTeamSubscription, } from '../../fundamentals'; @@ -193,7 +195,7 @@ class CreateCheckoutSessionInput implements z.infer { idempotencyKey?: string; @Field(() => GraphQLJSONObject, { nullable: true }) - args!: { workspaceId?: string }; + args!: { workspaceId?: string; quantity?: number }; } @Resolver(() => SubscriptionType) @@ -261,19 +263,33 @@ export class SubscriptionResolver { }, [] as SubscriptionPrice[]); } + @Public() @Mutation(() => String, { description: 'Create a subscription checkout link of stripe', }) async createCheckoutSession( - @CurrentUser() user: CurrentUser, + @CurrentUser() user: CurrentUser | null, @Args({ name: 'input', type: () => CreateCheckoutSessionInput }) input: CreateCheckoutSessionInput ) { - const session = await this.service.checkout(input, { - plan: input.plan as any, - user, - workspaceId: input.args?.workspaceId, - }); + let session: Stripe.Checkout.Session; + + if (input.plan === SubscriptionPlan.SelfHostedTeam) { + session = await this.service.checkout(input, { + plan: input.plan as any, + quantity: input.args.quantity ?? 10, + }); + } else { + if (!user) { + throw new AuthenticationRequired(); + } + + session = await this.service.checkout(input, { + plan: input.plan as any, + user, + workspaceId: input.args?.workspaceId, + }); + } if (!session.url) { throw new FailedToCheckout(); diff --git a/packages/backend/server/src/plugins/payment/service.ts b/packages/backend/server/src/plugins/payment/service.ts index 3b57a972690e4..36f28444ad963 100644 --- a/packages/backend/server/src/plugins/payment/service.ts +++ b/packages/backend/server/src/plugins/payment/service.ts @@ -38,6 +38,11 @@ import { WorkspaceSubscriptionIdentity, WorkspaceSubscriptionManager, } from './manager'; +import { + SelfhostTeamCheckoutArgs, + SelfhostTeamSubscriptionIdentity, + SelfhostTeamSubscriptionManager, +} from './manager/selfhost'; import { ScheduleManager } from './schedule'; import { decodeLookupKey, @@ -58,11 +63,13 @@ import { export const CheckoutExtraArgs = z.union([ UserSubscriptionCheckoutArgs, WorkspaceSubscriptionCheckoutArgs, + SelfhostTeamCheckoutArgs, ]); export const SubscriptionIdentity = z.union([ UserSubscriptionIdentity, WorkspaceSubscriptionIdentity, + SelfhostTeamSubscriptionIdentity, ]); export { CheckoutParams }; @@ -80,6 +87,7 @@ export class SubscriptionService implements OnApplicationBootstrap { private readonly user: UserService, private readonly userManager: UserSubscriptionManager, private readonly workspaceManager: WorkspaceSubscriptionManager, + private readonly selfhostManager: SelfhostTeamSubscriptionManager, private readonly mutex: Mutex ) {} @@ -99,6 +107,8 @@ export class SubscriptionService implements OnApplicationBootstrap { case SubscriptionPlan.Pro: case SubscriptionPlan.AI: return this.userManager; + case SubscriptionPlan.SelfHostedTeam: + return this.selfhostManager; default: throw new UnsupportedSubscriptionPlan({ plan }); } @@ -129,7 +139,7 @@ export class SubscriptionService implements OnApplicationBootstrap { if ( this.config.deploy && this.config.affine.canary && - !this.feature.isStaff(args.user.email) + (!args.user || !this.feature.isStaff(args.user.email)) ) { throw new ActionForbidden(); } @@ -303,6 +313,53 @@ export class SubscriptionService implements OnApplicationBootstrap { return newSubscription; } + async updateSubscriptionQuantity( + identity: z.infer, + count: number + ) { + this.assertSubscriptionIdentity(identity); + + const subscription = await this.select(identity.plan).getSubscription( + identity + ); + + if (!subscription) { + throw new SubscriptionNotExists({ plan: identity.plan }); + } + + if (!subscription.stripeSubscriptionId) { + throw new CantUpdateOnetimePaymentSubscription(); + } + + const stripeSubscription = await this.stripe.subscriptions.retrieve( + subscription.stripeSubscriptionId + ); + + const lookupKey = + retriveLookupKeyFromStripeSubscription(stripeSubscription); + + await this.stripe.subscriptions.update(stripeSubscription.id, { + items: [ + { + id: stripeSubscription.items.data[0].id, + quantity: count, + }, + ], + payment_behavior: 'pending_if_incomplete', + proration_behavior: + lookupKey?.recurring === SubscriptionRecurring.Yearly + ? 'always_invoice' + : 'none', + }); + + if (subscription.stripeScheduleId) { + const schedule = await this.scheduleManager.fromSchedule( + subscription.stripeScheduleId + ); + await schedule.updateQuantity(count); + } + } + async createCustomerPortal(id: string) { const user = await this.db.userStripeCustomer.findUnique({ where: { @@ -428,15 +485,18 @@ export class SubscriptionService implements OnApplicationBootstrap { private async retrieveUserFromCustomer( customer: string | Stripe.Customer | Stripe.DeletedCustomer - ) { + ): Promise<{ id?: string; email: string } | null> { const userStripeCustomer = await this.db.userStripeCustomer.findUnique({ where: { stripeCustomerId: typeof customer === 'string' ? customer : customer.id, }, + select: { + user: true, + }, }); if (userStripeCustomer) { - return userStripeCustomer.userId; + return userStripeCustomer.user; } if (typeof customer === 'string') { @@ -450,7 +510,10 @@ export class SubscriptionService implements OnApplicationBootstrap { const user = await this.user.findUserByEmail(customer.email); if (!user) { - return null; + return { + id: undefined, + email: customer.email, + }; } await this.db.userStripeCustomer.create({ @@ -460,7 +523,7 @@ export class SubscriptionService implements OnApplicationBootstrap { }, }); - return user.id; + return user; } private async listStripePrices(): Promise { @@ -517,14 +580,8 @@ export class SubscriptionService implements OnApplicationBootstrap { const user = await this.user.findUserByEmail(invoice.customer_email); - // TODO(@forehalo): the email may actually not appear to be AFFiNE user - // There is coming feature that allow anonymous user with only email provided to buy selfhost licenses - if (!user) { - return null; - } - return { - userId: user.id, + userId: user?.id, stripeInvoice: invoice, lookupKey, metadata: invoice.subscription_details?.metadata ?? {}, @@ -540,14 +597,10 @@ export class SubscriptionService implements OnApplicationBootstrap { return null; } - const userId = await this.retrieveUserFromCustomer(subscription.customer); - - if (!userId) { - return null; - } + const user = await this.retrieveUserFromCustomer(subscription.customer); return { - userId, + userId: user?.id, lookupKey, stripeSubscription: subscription, quantity: subscription.items.data[0]?.quantity ?? 1, diff --git a/packages/backend/server/src/plugins/payment/types.ts b/packages/backend/server/src/plugins/payment/types.ts index 8a241fbf135e6..6b020d685a9ca 100644 --- a/packages/backend/server/src/plugins/payment/types.ts +++ b/packages/backend/server/src/plugins/payment/types.ts @@ -16,6 +16,7 @@ export enum SubscriptionPlan { Team = 'team', Enterprise = 'enterprise', SelfHosted = 'selfhosted', + SelfHostedTeam = 'selfhostedteam', } export enum SubscriptionVariant { @@ -95,7 +96,7 @@ export interface KnownStripeInvoice { /** * User in AFFiNE system. */ - userId: string; + userId?: string; /** * The lookup key of the price that the invoice is for. @@ -117,7 +118,7 @@ export interface KnownStripeSubscription { /** * User in AFFiNE system. */ - userId: string; + userId?: string; /** * The lookup key of the price that the invoice is for. diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index bd3651b7af7dd..74549be4cc87d 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -209,7 +209,7 @@ type EditorType { name: String! } -union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInSpaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedSubscriptionPlanDataType | VersionRejectedDataType +union ErrorDataUnion = AlreadyInSpaceDataType | BlobNotFoundDataType | CopilotMessageNotFoundDataType | CopilotPromptNotFoundDataType | CopilotProviderSideErrorDataType | DocAccessDeniedDataType | DocHistoryNotFoundDataType | DocNotFoundDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType | InvalidLicenseUpdateParamsDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType | NotInSpaceDataType | RuntimeConfigNotFoundDataType | SameSubscriptionRecurringDataType | SpaceAccessDeniedDataType | SpaceNotFoundDataType | SpaceOwnerNotFoundDataType | SubscriptionAlreadyExistsDataType | SubscriptionNotExistsDataType | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedSubscriptionPlanDataType | VersionRejectedDataType | WorkspaceMembersExceedLimitToDowngradeDataType enum ErrorNames { ACCESS_DENIED @@ -251,10 +251,13 @@ enum ErrorNames { INVALID_EMAIL INVALID_EMAIL_TOKEN INVALID_HISTORY_TIMESTAMP + INVALID_LICENSE_TO_ACTIVATE + INVALID_LICENSE_UPDATE_PARAMS INVALID_OAUTH_CALLBACK_STATE INVALID_PASSWORD_LENGTH INVALID_RUNTIME_CONFIG_TYPE INVALID_SUBSCRIPTION_PARAMETERS + LICENSE_NOT_FOUND LINK_EXPIRED MAILER_SERVICE_IS_NOT_CONFIGURED MEMBER_QUOTA_EXCEEDED @@ -288,6 +291,7 @@ enum ErrorNames { VERSION_REJECTED WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION + WORKSPACE_MEMBERS_EXCEED_LIMIT_TO_DOWNGRADE WRONG_SIGN_IN_CREDENTIALS WRONG_SIGN_IN_METHOD } @@ -330,6 +334,10 @@ type InvalidHistoryTimestampDataType { timestamp: String! } +type InvalidLicenseUpdateParamsDataType { + reason: String! +} + type InvalidPasswordLengthDataType { max: Int! min: Int! @@ -757,6 +765,7 @@ enum SubscriptionPlan { Free Pro SelfHosted + SelfHostedTeam Team } @@ -902,6 +911,10 @@ type WorkspaceBlobSizes { size: SafeInt! } +type WorkspaceMembersExceedLimitToDowngradeDataType { + limit: Int! +} + type WorkspacePage { id: String! mode: PublicPageMode! diff --git a/packages/frontend/graphql/src/schema.ts b/packages/frontend/graphql/src/schema.ts index 30d609acb1785..430a6d39350b1 100644 --- a/packages/frontend/graphql/src/schema.ts +++ b/packages/frontend/graphql/src/schema.ts @@ -272,6 +272,7 @@ export type ErrorDataUnion = | DocNotFoundDataType | InvalidEmailDataType | InvalidHistoryTimestampDataType + | InvalidLicenseUpdateParamsDataType | InvalidPasswordLengthDataType | InvalidRuntimeConfigTypeDataType | MissingOauthQueryParameterDataType @@ -286,7 +287,8 @@ export type ErrorDataUnion = | SubscriptionPlanNotFoundDataType | UnknownOauthProviderDataType | UnsupportedSubscriptionPlanDataType - | VersionRejectedDataType; + | VersionRejectedDataType + | WorkspaceMembersExceedLimitToDowngradeDataType; export enum ErrorNames { ACCESS_DENIED = 'ACCESS_DENIED', @@ -328,10 +330,13 @@ export enum ErrorNames { INVALID_EMAIL = 'INVALID_EMAIL', INVALID_EMAIL_TOKEN = 'INVALID_EMAIL_TOKEN', INVALID_HISTORY_TIMESTAMP = 'INVALID_HISTORY_TIMESTAMP', + INVALID_LICENSE_TO_ACTIVATE = 'INVALID_LICENSE_TO_ACTIVATE', + INVALID_LICENSE_UPDATE_PARAMS = 'INVALID_LICENSE_UPDATE_PARAMS', INVALID_OAUTH_CALLBACK_STATE = 'INVALID_OAUTH_CALLBACK_STATE', INVALID_PASSWORD_LENGTH = 'INVALID_PASSWORD_LENGTH', INVALID_RUNTIME_CONFIG_TYPE = 'INVALID_RUNTIME_CONFIG_TYPE', INVALID_SUBSCRIPTION_PARAMETERS = 'INVALID_SUBSCRIPTION_PARAMETERS', + LICENSE_NOT_FOUND = 'LICENSE_NOT_FOUND', LINK_EXPIRED = 'LINK_EXPIRED', MAILER_SERVICE_IS_NOT_CONFIGURED = 'MAILER_SERVICE_IS_NOT_CONFIGURED', MEMBER_QUOTA_EXCEEDED = 'MEMBER_QUOTA_EXCEEDED', @@ -365,6 +370,7 @@ export enum ErrorNames { VERSION_REJECTED = 'VERSION_REJECTED', WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION = 'WORKSPACE_ID_REQUIRED_FOR_TEAM_SUBSCRIPTION', WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION = 'WORKSPACE_ID_REQUIRED_TO_UPDATE_TEAM_SUBSCRIPTION', + WORKSPACE_MEMBERS_EXCEED_LIMIT_TO_DOWNGRADE = 'WORKSPACE_MEMBERS_EXCEED_LIMIT_TO_DOWNGRADE', WRONG_SIGN_IN_CREDENTIALS = 'WRONG_SIGN_IN_CREDENTIALS', WRONG_SIGN_IN_METHOD = 'WRONG_SIGN_IN_METHOD', } @@ -407,6 +413,11 @@ export interface InvalidHistoryTimestampDataType { timestamp: Scalars['String']['output']; } +export interface InvalidLicenseUpdateParamsDataType { + __typename?: 'InvalidLicenseUpdateParamsDataType'; + reason: Scalars['String']['output']; +} + export interface InvalidPasswordLengthDataType { __typename?: 'InvalidPasswordLengthDataType'; max: Scalars['Int']['output']; @@ -1066,6 +1077,7 @@ export enum SubscriptionPlan { Free = 'Free', Pro = 'Pro', SelfHosted = 'SelfHosted', + SelfHostedTeam = 'SelfHostedTeam', Team = 'Team', } @@ -1223,6 +1235,11 @@ export interface WorkspaceBlobSizes { size: Scalars['SafeInt']['output']; } +export interface WorkspaceMembersExceedLimitToDowngradeDataType { + __typename?: 'WorkspaceMembersExceedLimitToDowngradeDataType'; + limit: Scalars['Int']['output']; +} + export interface WorkspacePage { __typename?: 'WorkspacePage'; id: Scalars['String']['output'];