diff --git a/app/billing_client.go b/app/billing_client.go new file mode 100644 index 00000000000..299c4f6fd35 --- /dev/null +++ b/app/billing_client.go @@ -0,0 +1,310 @@ +package app + +import ( + "context" + "errors" + "io" + + pb "go.viam.com/api/app/v1" + "go.viam.com/utils/rpc" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// UsageCostType specifies the type of usage cost. +type UsageCostType int32 + +const ( + // UsageCostTypeUnspecified is an unspecified usage cost type. + UsageCostTypeUnspecified UsageCostType = iota + // UsageCostTypeDataUpload represents the usage cost from data upload. + UsageCostTypeDataUpload + // UsageCostTypeDataEgress represents the usage cost from data egress. + UsageCostTypeDataEgress + // UsageCostTypeRemoteControl represents the usage cost from remote control. + UsageCostTypeRemoteControl + // UsageCostTypeStandardCompute represents the usage cost from standard compute. + UsageCostTypeStandardCompute + // UsageCostTypeCloudStorage represents the usage cost from cloud storage. + UsageCostTypeCloudStorage + // UsageCostTypeBinaryDataCloudStorage represents the usage cost from binary data cloud storage. + UsageCostTypeBinaryDataCloudStorage + // UsageCostTypeOtherCloudStorage represents the usage cost from other cloud storage. + UsageCostTypeOtherCloudStorage + // UsageCostTypePerMachine represents the usage cost per machine. + UsageCostTypePerMachine +) + +func usageCostTypeFromProto(costType pb.UsageCostType) UsageCostType { + switch costType { + case pb.UsageCostType_USAGE_COST_TYPE_UNSPECIFIED: + return UsageCostTypeUnspecified + case pb.UsageCostType_USAGE_COST_TYPE_DATA_UPLOAD: + return UsageCostTypeDataUpload + case pb.UsageCostType_USAGE_COST_TYPE_DATA_EGRESS: + return UsageCostTypeDataEgress + case pb.UsageCostType_USAGE_COST_TYPE_REMOTE_CONTROL: + return UsageCostTypeRemoteControl + case pb.UsageCostType_USAGE_COST_TYPE_STANDARD_COMPUTE: + return UsageCostTypeStandardCompute + case pb.UsageCostType_USAGE_COST_TYPE_CLOUD_STORAGE: + return UsageCostTypeCloudStorage + case pb.UsageCostType_USAGE_COST_TYPE_BINARY_DATA_CLOUD_STORAGE: + return UsageCostTypeBinaryDataCloudStorage + case pb.UsageCostType_USAGE_COST_TYPE_OTHER_CLOUD_STORAGE: + return UsageCostTypeOtherCloudStorage + case pb.UsageCostType_USAGE_COST_TYPE_PER_MACHINE: + return UsageCostTypePerMachine + default: + return UsageCostTypeUnspecified + } +} + +// UsageCost contains the cost and cost type. +type UsageCost struct { + ResourceType UsageCostType + Cost float64 +} + +func usageCostFromProto(cost *pb.UsageCost) *UsageCost { + return &UsageCost{ + ResourceType: usageCostTypeFromProto(cost.ResourceType), + Cost: cost.Cost, + } +} + +// ResourceUsageCosts holds the usage costs with discount information. +type ResourceUsageCosts struct { + UsageCosts []*UsageCost + Discount float64 + TotalWithDiscount float64 + TotalWithoutDiscount float64 +} + +func resourceUsageCostsFromProto(costs *pb.ResourceUsageCosts) *ResourceUsageCosts { + var usageCosts []*UsageCost + for _, cost := range costs.UsageCosts { + usageCosts = append(usageCosts, usageCostFromProto(cost)) + } + return &ResourceUsageCosts{ + UsageCosts: usageCosts, + Discount: costs.Discount, + TotalWithDiscount: costs.TotalWithDiscount, + TotalWithoutDiscount: costs.TotalWithoutDiscount, + } +} + +// SourceType is the type of source from which a cost is coming from. +type SourceType int32 + +const ( + // SourceTypeUnspecified represents an unspecified source type. + SourceTypeUnspecified SourceType = iota + // SourceTypeOrg represents an organization. + SourceTypeOrg + // SourceTypeFragment represents a fragment. + SourceTypeFragment +) + +func sourceTypeFromProto(sourceType pb.SourceType) SourceType { + switch sourceType { + case pb.SourceType_SOURCE_TYPE_UNSPECIFIED: + return SourceTypeUnspecified + case pb.SourceType_SOURCE_TYPE_ORG: + return SourceTypeOrg + case pb.SourceType_SOURCE_TYPE_FRAGMENT: + return SourceTypeFragment + default: + return SourceTypeUnspecified + } +} + +// ResourceUsageCostsBySource contains the resource usage costs of a source. +type ResourceUsageCostsBySource struct { + SourceType SourceType + ResourceUsageCosts *ResourceUsageCosts + TierName string +} + +func resourceUsageCostsBySourceFromProto(costs *pb.ResourceUsageCostsBySource) *ResourceUsageCostsBySource { + return &ResourceUsageCostsBySource{ + SourceType: sourceTypeFromProto(costs.SourceType), + ResourceUsageCosts: resourceUsageCostsFromProto(costs.ResourceUsageCosts), + TierName: costs.TierName, + } +} + +// GetCurrentMonthUsageResponse contains the current month usage information. +type GetCurrentMonthUsageResponse struct { + StartDate *timestamppb.Timestamp + EndDate *timestamppb.Timestamp + ResourceUsageCostsBySource []*ResourceUsageCostsBySource + Subtotal float64 +} + +func getCurrentMonthUsageResponseFromProto(response *pb.GetCurrentMonthUsageResponse) *GetCurrentMonthUsageResponse { + var costs []*ResourceUsageCostsBySource + for _, cost := range response.ResourceUsageCostsBySource { + costs = append(costs, resourceUsageCostsBySourceFromProto(cost)) + } + return &GetCurrentMonthUsageResponse{ + StartDate: response.StartDate, + EndDate: response.EndDate, + ResourceUsageCostsBySource: costs, + Subtotal: response.Subtotal, + } +} + +// PaymentMethodType is the type of payment method. +type PaymentMethodType int32 + +const ( + // PaymentMethodTypeUnspecified represents an unspecified payment method. + PaymentMethodTypeUnspecified PaymentMethodType = iota + // PaymentMethodtypeCard represents a payment by card. + PaymentMethodtypeCard +) + +func paymentMethodTypeFromProto(methodType pb.PaymentMethodType) PaymentMethodType { + switch methodType { + case pb.PaymentMethodType_PAYMENT_METHOD_TYPE_UNSPECIFIED: + return PaymentMethodTypeUnspecified + case pb.PaymentMethodType_PAYMENT_METHOD_TYPE_CARD: + return PaymentMethodtypeCard + default: + return PaymentMethodTypeUnspecified + } +} + +// PaymentMethodCard holds the information of a card used for payment. +type PaymentMethodCard struct { + Brand string + LastFourDigits string +} + +func paymentMethodCardFromProto(card *pb.PaymentMethodCard) *PaymentMethodCard { + return &PaymentMethodCard{ + Brand: card.Brand, + LastFourDigits: card.LastFourDigits, + } +} + +// GetOrgBillingInformationResponse contains the information of an organization's billing information. +type GetOrgBillingInformationResponse struct { + Type PaymentMethodType + BillingEmail string + // defined if type is PaymentMethodTypeCard + Method *PaymentMethodCard + // only return for billing dashboard admin users + BillingTier *string +} + +func getOrgBillingInformationResponseFromProto(resp *pb.GetOrgBillingInformationResponse) *GetOrgBillingInformationResponse { + return &GetOrgBillingInformationResponse{ + Type: paymentMethodTypeFromProto(resp.Type), + BillingEmail: resp.BillingEmail, + Method: paymentMethodCardFromProto(resp.Method), + BillingTier: resp.BillingTier, + } +} + +// InvoiceSummary holds the information of an invoice summary. +type InvoiceSummary struct { + ID string + InvoiceDate *timestamppb.Timestamp + InvoiceAmount float64 + Status string + DueDate *timestamppb.Timestamp + PaidDate *timestamppb.Timestamp +} + +func invoiceSummaryFromProto(summary *pb.InvoiceSummary) *InvoiceSummary { + return &InvoiceSummary{ + ID: summary.Id, + InvoiceDate: summary.InvoiceDate, + InvoiceAmount: summary.InvoiceAmount, + Status: summary.Status, + DueDate: summary.DueDate, + PaidDate: summary.PaidDate, + } +} + +// BillingClient is a gRPC client for method calls to the Billing API. +type BillingClient struct { + client pb.BillingServiceClient +} + +// NewBillingClient constructs a new BillingClient using the connection passed in by the Viam client. +func NewBillingClient(conn rpc.ClientConn) *BillingClient { + return &BillingClient{client: pb.NewBillingServiceClient(conn)} +} + +// GetCurrentMonthUsage gets the data usage information for the current month for an organization. +func (c *BillingClient) GetCurrentMonthUsage(ctx context.Context, orgID string) (*GetCurrentMonthUsageResponse, error) { + resp, err := c.client.GetCurrentMonthUsage(ctx, &pb.GetCurrentMonthUsageRequest{ + OrgId: orgID, + }) + if err != nil { + return nil, err + } + return getCurrentMonthUsageResponseFromProto(resp), nil +} + +// GetOrgBillingInformation gets the billing information of an organization. +func (c *BillingClient) GetOrgBillingInformation(ctx context.Context, orgID string) (*GetOrgBillingInformationResponse, error) { + resp, err := c.client.GetOrgBillingInformation(ctx, &pb.GetOrgBillingInformationRequest{ + OrgId: orgID, + }) + if err != nil { + return nil, err + } + return getOrgBillingInformationResponseFromProto(resp), nil +} + +// GetInvoicesSummary returns the outstanding balance and the invoice summaries of an organization. +func (c *BillingClient) GetInvoicesSummary(ctx context.Context, orgID string) (float64, []*InvoiceSummary, error) { + resp, err := c.client.GetInvoicesSummary(ctx, &pb.GetInvoicesSummaryRequest{ + OrgId: orgID, + }) + if err != nil { + return 0, nil, err + } + var invoices []*InvoiceSummary + for _, invoice := range resp.Invoices { + invoices = append(invoices, invoiceSummaryFromProto(invoice)) + } + return resp.OutstandingBalance, invoices, nil +} + +// GetInvoicePDF gets the invoice PDF data. +func (c *BillingClient) GetInvoicePDF(ctx context.Context, id, orgID string) ([]byte, error) { + stream, err := c.client.GetInvoicePdf(ctx, &pb.GetInvoicePdfRequest{ + Id: id, + OrgId: orgID, + }) + if err != nil { + return nil, err + } + + var data []byte + for { + resp, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return data, err + } + data = append(data, resp.Chunk...) + } + + return data, nil +} + +// SendPaymentRequiredEmail sends an email about payment requirement. +func (c *BillingClient) SendPaymentRequiredEmail(ctx context.Context, customerOrgID, billingOwnerOrgID string) error { + _, err := c.client.SendPaymentRequiredEmail(ctx, &pb.SendPaymentRequiredEmailRequest{ + CustomerOrgId: customerOrgID, + BillingOwnerOrgId: billingOwnerOrgID, + }) + return err +} diff --git a/app/billing_client_test.go b/app/billing_client_test.go new file mode 100644 index 00000000000..e8af00302a0 --- /dev/null +++ b/app/billing_client_test.go @@ -0,0 +1,265 @@ +package app + +import ( + "bytes" + "context" + "io" + "testing" + + pb "go.viam.com/api/app/v1" + "go.viam.com/test" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/timestamppb" + + "go.viam.com/rdk/testutils/inject" +) + +const ( + subtotal = 37 + sourceType = SourceTypeOrg + usageCostType = UsageCostTypeCloudStorage + cost float64 = 20 + discount float64 = 9 + totalWithDiscount = cost - discount + totalWithoutDiscount float64 = cost + email = "email" + paymentMethodType = PaymentMethodtypeCard + brand = "brand" + digits = "1234" + invoiceID = "invoice_id" + invoiceAmount float64 = 100.12 + status = "status" + balance float64 = 73.21 + billingOwnerOrgID = "billing_owner_organization_id" +) + +var ( + start = timestamppb.Timestamp{Seconds: 92, Nanos: 0} + end = timestamppb.Timestamp{Seconds: 99, Nanos: 999} + tier = "tier" + getCurrentMonthUsageResponse = GetCurrentMonthUsageResponse{ + StartDate: &start, + EndDate: &end, + ResourceUsageCostsBySource: []*ResourceUsageCostsBySource{ + { + SourceType: sourceType, + ResourceUsageCosts: &ResourceUsageCosts{ + UsageCosts: []*UsageCost{ + { + ResourceType: usageCostType, + Cost: cost, + }, + }, + Discount: discount, + TotalWithDiscount: totalWithDiscount, + TotalWithoutDiscount: totalWithoutDiscount, + }, + TierName: tier, + }, + }, + Subtotal: subtotal, + } + getOrgBillingInformationResponse = GetOrgBillingInformationResponse{ + Type: paymentMethodType, + BillingEmail: email, + Method: &PaymentMethodCard{ + Brand: brand, + LastFourDigits: digits, + }, + BillingTier: &tier, + } + invoiceDate = timestamppb.Timestamp{Seconds: 287, Nanos: 0} + dueDate = timestamppb.Timestamp{Seconds: 1241, Nanos: 40} + paidDate = timestamppb.Timestamp{Seconds: 827, Nanos: 62} + invoiceSummary = InvoiceSummary{ + ID: invoiceID, + InvoiceDate: &invoiceDate, + InvoiceAmount: invoiceAmount, + Status: status, + DueDate: &dueDate, + PaidDate: &paidDate, + } + chunk1 = []byte{4, 8} + chunk2 = []byte("chunk1") + chunk3 = []byte("chunk2") + chunks = [][]byte{chunk1, chunk2, chunk3} + chunkCount = len(chunks) +) + +func sourceTypeToProto(sourceType SourceType) pb.SourceType { + switch sourceType { + case SourceTypeUnspecified: + return pb.SourceType_SOURCE_TYPE_UNSPECIFIED + case SourceTypeOrg: + return pb.SourceType_SOURCE_TYPE_ORG + case SourceTypeFragment: + return pb.SourceType_SOURCE_TYPE_FRAGMENT + default: + return pb.SourceType_SOURCE_TYPE_UNSPECIFIED + } +} + +func usageCostTypeToProto(costType UsageCostType) pb.UsageCostType { + switch costType { + case UsageCostTypeUnspecified: + return pb.UsageCostType_USAGE_COST_TYPE_UNSPECIFIED + case UsageCostTypeDataUpload: + return pb.UsageCostType_USAGE_COST_TYPE_DATA_UPLOAD + case UsageCostTypeDataEgress: + return pb.UsageCostType_USAGE_COST_TYPE_DATA_EGRESS + case UsageCostTypeRemoteControl: + return pb.UsageCostType_USAGE_COST_TYPE_REMOTE_CONTROL + case UsageCostTypeStandardCompute: + return pb.UsageCostType_USAGE_COST_TYPE_STANDARD_COMPUTE + case UsageCostTypeCloudStorage: + return pb.UsageCostType_USAGE_COST_TYPE_CLOUD_STORAGE + case UsageCostTypeBinaryDataCloudStorage: + return pb.UsageCostType_USAGE_COST_TYPE_BINARY_DATA_CLOUD_STORAGE + case UsageCostTypeOtherCloudStorage: + return pb.UsageCostType_USAGE_COST_TYPE_OTHER_CLOUD_STORAGE + case UsageCostTypePerMachine: + return pb.UsageCostType_USAGE_COST_TYPE_PER_MACHINE + default: + return pb.UsageCostType_USAGE_COST_TYPE_UNSPECIFIED + } +} + +func paymentMethodTypeToProto(methodType PaymentMethodType) pb.PaymentMethodType { + switch methodType { + case PaymentMethodTypeUnspecified: + return pb.PaymentMethodType_PAYMENT_METHOD_TYPE_UNSPECIFIED + case PaymentMethodtypeCard: + return pb.PaymentMethodType_PAYMENT_METHOD_TYPE_CARD + default: + return pb.PaymentMethodType_PAYMENT_METHOD_TYPE_UNSPECIFIED + } +} + +func createBillingGrpcClient() *inject.BillingServiceClient { + return &inject.BillingServiceClient{} +} + +func TestBillingClient(t *testing.T) { + grpcClient := createBillingGrpcClient() + client := BillingClient{client: grpcClient} + + t.Run("GetCurrentMonthUsage", func(t *testing.T) { + pbResponse := pb.GetCurrentMonthUsageResponse{ + StartDate: getCurrentMonthUsageResponse.StartDate, + EndDate: getCurrentMonthUsageResponse.EndDate, + ResourceUsageCostsBySource: []*pb.ResourceUsageCostsBySource{ + { + SourceType: sourceTypeToProto(sourceType), + ResourceUsageCosts: &pb.ResourceUsageCosts{ + UsageCosts: []*pb.UsageCost{ + { + ResourceType: usageCostTypeToProto(usageCostType), + Cost: cost, + }, + }, + Discount: discount, + TotalWithDiscount: totalWithDiscount, + TotalWithoutDiscount: totalWithoutDiscount, + }, + TierName: tier, + }, + }, + Subtotal: getCurrentMonthUsageResponse.Subtotal, + } + grpcClient.GetCurrentMonthUsageFunc = func( + ctx context.Context, in *pb.GetCurrentMonthUsageRequest, opts ...grpc.CallOption, + ) (*pb.GetCurrentMonthUsageResponse, error) { + test.That(t, in.OrgId, test.ShouldEqual, organizationID) + return &pbResponse, nil + } + resp, err := client.GetCurrentMonthUsage(context.Background(), organizationID) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp, test.ShouldResemble, &getCurrentMonthUsageResponse) + }) + + t.Run("GetOrgBillingInformation", func(t *testing.T) { + pbResponse := pb.GetOrgBillingInformationResponse{ + Type: paymentMethodTypeToProto(getOrgBillingInformationResponse.Type), + BillingEmail: getOrgBillingInformationResponse.BillingEmail, + Method: &pb.PaymentMethodCard{ + Brand: getOrgBillingInformationResponse.Method.Brand, + LastFourDigits: getOrgBillingInformationResponse.Method.LastFourDigits, + }, + BillingTier: getOrgBillingInformationResponse.BillingTier, + } + grpcClient.GetOrgBillingInformationFunc = func( + ctx context.Context, in *pb.GetOrgBillingInformationRequest, opts ...grpc.CallOption, + ) (*pb.GetOrgBillingInformationResponse, error) { + test.That(t, in.OrgId, test.ShouldEqual, organizationID) + return &pbResponse, nil + } + resp, err := client.GetOrgBillingInformation(context.Background(), organizationID) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp, test.ShouldResemble, &getOrgBillingInformationResponse) + }) + + t.Run("GetInvoicesSummary", func(t *testing.T) { + expectedInvoices := []*InvoiceSummary{&invoiceSummary} + grpcClient.GetInvoicesSummaryFunc = func( + ctx context.Context, in *pb.GetInvoicesSummaryRequest, opts ...grpc.CallOption, + ) (*pb.GetInvoicesSummaryResponse, error) { + test.That(t, in.OrgId, test.ShouldEqual, organizationID) + return &pb.GetInvoicesSummaryResponse{ + OutstandingBalance: balance, + Invoices: []*pb.InvoiceSummary{ + { + Id: invoiceSummary.ID, + InvoiceDate: invoiceSummary.InvoiceDate, + InvoiceAmount: invoiceSummary.InvoiceAmount, + Status: invoiceSummary.Status, + DueDate: invoiceSummary.DueDate, + PaidDate: invoiceSummary.PaidDate, + }, + }, + }, nil + } + outstandingBalance, invoices, err := client.GetInvoicesSummary(context.Background(), organizationID) + test.That(t, err, test.ShouldBeNil) + test.That(t, outstandingBalance, test.ShouldResemble, balance) + test.That(t, invoices, test.ShouldResemble, expectedInvoices) + }) + + t.Run("GetInvoicePDF", func(t *testing.T) { + expectedData := bytes.Join(chunks, nil) + var count int + mockStream := &inject.BillingServiceGetInvoicePdfClient{ + RecvFunc: func() (*pb.GetInvoicePdfResponse, error) { + if count >= chunkCount { + return nil, io.EOF + } + chunk := chunks[count] + count++ + return &pb.GetInvoicePdfResponse{ + Chunk: chunk, + }, nil + }, + } + grpcClient.GetInvoicePdfFunc = func( + ctx context.Context, in *pb.GetInvoicePdfRequest, opts ...grpc.CallOption, + ) (pb.BillingService_GetInvoicePdfClient, error) { + test.That(t, in.Id, test.ShouldEqual, invoiceID) + test.That(t, in.OrgId, test.ShouldEqual, organizationID) + return mockStream, nil + } + data, err := client.GetInvoicePDF(context.Background(), invoiceID, organizationID) + test.That(t, err, test.ShouldBeNil) + test.That(t, data, test.ShouldResemble, expectedData) + }) + + t.Run("SendPaymentRequiredEmail", func(t *testing.T) { + grpcClient.SendPaymentRequiredEmailFunc = func( + ctx context.Context, in *pb.SendPaymentRequiredEmailRequest, opts ...grpc.CallOption, + ) (*pb.SendPaymentRequiredEmailResponse, error) { + test.That(t, in.CustomerOrgId, test.ShouldEqual, organizationID) + test.That(t, in.BillingOwnerOrgId, test.ShouldEqual, billingOwnerOrgID) + return &pb.SendPaymentRequiredEmailResponse{}, nil + } + err := client.SendPaymentRequiredEmail(context.Background(), organizationID, billingOwnerOrgID) + test.That(t, err, test.ShouldBeNil) + }) +} diff --git a/app/viam_client.go b/app/viam_client.go index 6a8e3075012..74e155ed46a 100644 --- a/app/viam_client.go +++ b/app/viam_client.go @@ -14,8 +14,9 @@ import ( // ViamClient is a gRPC client for method calls to Viam app. type ViamClient struct { - conn rpc.ClientConn - dataClient *DataClient + conn rpc.ClientConn + billingClient *BillingClient + dataClient *DataClient } // Options has the options necessary to connect through gRPC. @@ -63,6 +64,16 @@ func CreateViamClientWithAPIKey( return CreateViamClientWithOptions(ctx, options, logger) } +// Billingclient initializes and returns a Billingclient instance used to make app method calls. +// To use Billingclient, you must first instantiate a ViamClient. +func (c *ViamClient) Billingclient() *BillingClient { + if c.billingClient != nil { + return c.billingClient + } + c.billingClient = NewBillingClient(c.conn) + return c.billingClient +} + // DataClient initializes and returns a DataClient instance used to make data method calls. // To use DataClient, you must first instantiate a ViamClient. func (c *ViamClient) DataClient() *DataClient { diff --git a/app/viam_client_test.go b/app/viam_client_test.go index 08b93ad1c31..6bd4eaa574f 100644 --- a/app/viam_client_test.go +++ b/app/viam_client_test.go @@ -5,7 +5,8 @@ import ( "testing" "github.com/viamrobotics/webrtc/v3" - pb "go.viam.com/api/app/data/v1" + datapb "go.viam.com/api/app/data/v1" + apppb "go.viam.com/api/app/v1" "go.viam.com/test" "go.viam.com/utils" "go.viam.com/utils/rpc" @@ -119,7 +120,7 @@ func TestCreateViamClientWithAPIKeyTests(t *testing.T) { } } -func TestNewDataClient(t *testing.T) { +func TestNewAppClients(t *testing.T) { originalDialDirectGRPC := dialDirectGRPC dialDirectGRPC = mockDialDirectGRPC defer func() { dialDirectGRPC = originalDialDirectGRPC }() @@ -135,10 +136,20 @@ func TestNewDataClient(t *testing.T) { test.That(t, err, test.ShouldBeNil) defer client.Close() + billingClient := client.Billingclient() + test.That(t, billingClient, test.ShouldNotBeNil) + test.That(t, billingClient, test.ShouldHaveSameTypeAs, &BillingClient{}) + test.That(t, billingClient.client, test.ShouldImplement, (*apppb.BillingServiceClient)(nil)) + + // Testing that a second call to Billingclient() returns the same instance + billingClient2 := client.Billingclient() + test.That(t, billingClient2, test.ShouldNotBeNil) + test.That(t, billingClient, test.ShouldEqual, billingClient2) + dataClient := client.DataClient() test.That(t, dataClient, test.ShouldNotBeNil) test.That(t, dataClient, test.ShouldHaveSameTypeAs, &DataClient{}) - test.That(t, dataClient.client, test.ShouldImplement, (*pb.DataServiceClient)(nil)) + test.That(t, dataClient.client, test.ShouldImplement, (*datapb.DataServiceClient)(nil)) // Testing that a second call to DataClient() returns the same instance dataClient2 := client.DataClient() diff --git a/testutils/inject/billing_service_client.go b/testutils/inject/billing_service_client.go new file mode 100644 index 00000000000..d7982b29749 --- /dev/null +++ b/testutils/inject/billing_service_client.go @@ -0,0 +1,87 @@ +package inject + +import ( + "context" + + billingpb "go.viam.com/api/app/v1" + "google.golang.org/grpc" +) + +// BillingServiceClient represents a fake instance of a billing service client. +type BillingServiceClient struct { + billingpb.BillingServiceClient + GetCurrentMonthUsageFunc func(ctx context.Context, in *billingpb.GetCurrentMonthUsageRequest, + opts ...grpc.CallOption) (*billingpb.GetCurrentMonthUsageResponse, error) + GetOrgBillingInformationFunc func(ctx context.Context, in *billingpb.GetOrgBillingInformationRequest, + opts ...grpc.CallOption) (*billingpb.GetOrgBillingInformationResponse, error) + GetInvoicesSummaryFunc func(ctx context.Context, in *billingpb.GetInvoicesSummaryRequest, + opts ...grpc.CallOption) (*billingpb.GetInvoicesSummaryResponse, error) + GetInvoicePdfFunc func(ctx context.Context, in *billingpb.GetInvoicePdfRequest, + opts ...grpc.CallOption) (billingpb.BillingService_GetInvoicePdfClient, error) + SendPaymentRequiredEmailFunc func(ctx context.Context, in *billingpb.SendPaymentRequiredEmailRequest, + opts ...grpc.CallOption) (*billingpb.SendPaymentRequiredEmailResponse, error) +} + +// GetCurrentMonthUsage calls the injected GetCurrentMonthUsageFunc or the real version. +func (bsc *BillingServiceClient) GetCurrentMonthUsage(ctx context.Context, in *billingpb.GetCurrentMonthUsageRequest, + opts ...grpc.CallOption, +) (*billingpb.GetCurrentMonthUsageResponse, error) { + if bsc.GetCurrentMonthUsageFunc == nil { + return bsc.BillingServiceClient.GetCurrentMonthUsage(ctx, in, opts...) + } + return bsc.GetCurrentMonthUsageFunc(ctx, in, opts...) +} + +// GetOrgBillingInformation calls the injected GetOrgBillingInformationFunc or the real version. +func (bsc *BillingServiceClient) GetOrgBillingInformation(ctx context.Context, in *billingpb.GetOrgBillingInformationRequest, + opts ...grpc.CallOption, +) (*billingpb.GetOrgBillingInformationResponse, error) { + if bsc.GetOrgBillingInformationFunc == nil { + return bsc.BillingServiceClient.GetOrgBillingInformation(ctx, in, opts...) + } + return bsc.GetOrgBillingInformationFunc(ctx, in, opts...) +} + +// GetInvoicesSummary calls the injected GetInvoicesSummaryFunc or the real version. +func (bsc *BillingServiceClient) GetInvoicesSummary(ctx context.Context, in *billingpb.GetInvoicesSummaryRequest, + opts ...grpc.CallOption, +) (*billingpb.GetInvoicesSummaryResponse, error) { + if bsc.GetInvoicesSummaryFunc == nil { + return bsc.BillingServiceClient.GetInvoicesSummary(ctx, in, opts...) + } + return bsc.GetInvoicesSummaryFunc(ctx, in, opts...) +} + +// GetInvoicePdf calls the injected GetInvoicePdfFunc or the real version. +func (bsc *BillingServiceClient) GetInvoicePdf(ctx context.Context, in *billingpb.GetInvoicePdfRequest, + opts ...grpc.CallOption, +) (billingpb.BillingService_GetInvoicePdfClient, error) { + if bsc.GetInvoicePdfFunc == nil { + return bsc.BillingServiceClient.GetInvoicePdf(ctx, in, opts...) + } + return bsc.GetInvoicePdfFunc(ctx, in, opts...) +} + +// BillingServiceGetInvoicePdfClient represents a fake instance of a proto BillingService_GetInvoicePdfClient. +type BillingServiceGetInvoicePdfClient struct { + billingpb.BillingService_GetInvoicePdfClient + RecvFunc func() (*billingpb.GetInvoicePdfResponse, error) +} + +// Recv calls the injected RecvFunc or the real version. +func (c *BillingServiceGetInvoicePdfClient) Recv() (*billingpb.GetInvoicePdfResponse, error) { + if c.RecvFunc == nil { + return c.BillingService_GetInvoicePdfClient.Recv() + } + return c.RecvFunc() +} + +// SendPaymentRequiredEmail calls the injected SendPaymentRequiredEmailFunc or the real version. +func (bsc *BillingServiceClient) SendPaymentRequiredEmail(ctx context.Context, in *billingpb.SendPaymentRequiredEmailRequest, + opts ...grpc.CallOption, +) (*billingpb.SendPaymentRequiredEmailResponse, error) { + if bsc.SendPaymentRequiredEmailFunc == nil { + return bsc.BillingServiceClient.SendPaymentRequiredEmail(ctx, in, opts...) + } + return bsc.SendPaymentRequiredEmailFunc(ctx, in, opts...) +}