Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
cmd/tier: add switch and whoami (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmizerany authored Nov 16, 2022
1 parent 1771f35 commit 9cd6360
Show file tree
Hide file tree
Showing 17 changed files with 501 additions and 66 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
./tier/cmd/tier/tier
dist/
*.prof
tier.state
29 changes: 29 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,25 @@ func NewHandler(c *control.Client, logf func(string, ...any)) *Handler {
return &Handler{c: c, Logf: logf, helper: func() {}}
}

func isInvalidAccount(err error) bool {
var e *stripe.Error
return errors.As(err, &e) && e.Code == "account_invalid"
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
defer func() {
// h.logf("%s %s %s %s", r.RemoteAddr, r.Method, r.URL, err)
}()
bw := &byteCountResponseWriter{ResponseWriter: w}
err = h.serve(bw, r)
if isInvalidAccount(err) {
trweb.WriteError(w, &trweb.HTTPError{
Status: 401,
Code: "account_invalid",
})
return
}
if trweb.WriteError(w, lookupErr(err)) || trweb.WriteError(w, err) {
return
}
Expand All @@ -92,6 +104,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
case "/v1/whoami":
return h.serveWhoAmI(w, r)
case "/v1/whois":
return h.serveWhoIs(w, r)
case "/v1/limits":
Expand Down Expand Up @@ -170,6 +184,21 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) error {
})
}

func (h *Handler) serveWhoAmI(w http.ResponseWriter, r *http.Request) error {
who, err := h.c.WhoAmI(r.Context())
if err != nil {
return err
}
return httpJSON(w, apitypes.WhoAmIResponse{
ProviderID: who.ProviderID,
Email: who.Email,
Created: who.Created(),
KeySource: who.KeySource,
Isolated: who.Isolated,
URL: who.URL(),
})
}

func (h *Handler) servePhase(w http.ResponseWriter, r *http.Request) error {
org := r.FormValue("org")
ps, err := h.c.LookupPhases(r.Context(), org)
Expand Down
40 changes: 29 additions & 11 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"tier.run/fetch/fetchtest"
"tier.run/refs"
"tier.run/stripe/stroke"
"tier.run/trweb"
)

var (
Expand Down Expand Up @@ -136,7 +135,7 @@ func TestAPISubscribe(t *testing.T) {
diff.Test(t, t.Errorf, got, want, ignore)
}

whoIs("org:test", &trweb.HTTPError{
whoIs("org:test", &apitypes.Error{
Status: 400,
Code: "org_not_found",
Message: "org not found",
Expand All @@ -146,7 +145,7 @@ func TestAPISubscribe(t *testing.T) {

report("org:test", "feature:t", 9, nil)
report("org:test", "feature:t", 1, nil)
report("org:test", "feature:x", 1, &trweb.HTTPError{
report("org:test", "feature:x", 1, &apitypes.Error{
Status: 400,
Code: "invalid_request",
Message: "feature not reportable",
Expand All @@ -165,13 +164,13 @@ func TestAPISubscribe(t *testing.T) {
},
})

report("org:test", "feature:nope", 9, &trweb.HTTPError{
report("org:test", "feature:nope", 9, &apitypes.Error{
Status: 400,
Code: "feature_not_found",
Message: "feature not found",
})

report("org:nope", "feature:t", 9, &trweb.HTTPError{
report("org:nope", "feature:t", 9, &apitypes.Error{
Status: 400,
Code: "org_not_found",
Message: "org not found",
Expand All @@ -182,7 +181,7 @@ func TestAPISubscribe(t *testing.T) {
Plans: mpps("plan:test@0"),
})

sub("org:test", []string{"plan:test@0", "feature:nope@0"}, &trweb.HTTPError{
sub("org:test", []string{"plan:test@0", "feature:nope@0"}, &apitypes.Error{
Status: 400,
Code: "feature_not_found",
Message: "feature not found",
Expand All @@ -195,14 +194,14 @@ func TestPhaseBadOrg(t *testing.T) {
ctx := context.Background()
c, _ := newTestClient(t)

_, err := fetch.OK[struct{}, *trweb.HTTPError](ctx, c, "GET", "/v1/phase?org=org:nope", nil)
diff.Test(t, t.Errorf, err, &trweb.HTTPError{
_, err := fetch.OK[struct{}, *apitypes.Error](ctx, c, "GET", "/v1/phase?org=org:nope", nil)
diff.Test(t, t.Errorf, err, &apitypes.Error{
Status: 404,
Code: "not_found",
Message: "Not Found",
})
_, err = fetch.OK[struct{}, *trweb.HTTPError](ctx, c, "GET", "/v1/phase", nil)
diff.Test(t, t.Errorf, err, &trweb.HTTPError{
_, err = fetch.OK[struct{}, *apitypes.Error](ctx, c, "GET", "/v1/phase", nil)
diff.Test(t, t.Errorf, err, &apitypes.Error{
Status: 400,
Code: "invalid_request",
Message: `org must be prefixed with "org:"`,
Expand Down Expand Up @@ -247,7 +246,7 @@ func TestPhaseFragments(t *testing.T) {
t.Fatal(err)
}

got, err := fetch.OK[apitypes.PhaseResponse, *trweb.HTTPError](ctx, c, "GET", "/v1/phase?org=org:test", nil)
got, err := fetch.OK[apitypes.PhaseResponse, *apitypes.Error](ctx, c, "GET", "/v1/phase?org=org:test", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -309,6 +308,25 @@ func TestTierPull(t *testing.T) {
diff.Test(t, t.Errorf, got, want)
}

func TestWhoAmI(t *testing.T) {
t.Parallel()

ctx := context.Background()
c, _ := newTestClient(t)
tc := &tier.Client{HTTPClient: c}
a, err := tc.WhoAmI(ctx)
if err != nil {
t.Fatal(err)
}

if a.ProviderID == "" {
t.Error("unexpected empty provider id")
}
if a.URL == "" {
t.Error("unexpected empty url")
}
}

func TestTierReport(t *testing.T) {
t.Parallel()

Expand Down
21 changes: 21 additions & 0 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
package apitypes

import (
"fmt"
"time"

"tier.run/refs"
)

type Error struct {
Status int `json:"status"`
Code string `json:"code"` // (e.g. "invalid_request")
Message string `json:"message"`
}

func (e *Error) Error() string {
return fmt.Sprintf("httpError{status:%d code:%q message:%q}",
e.Status, e.Code, e.Message)
}

type Phase struct {
Effective time.Time
Features []string
Expand Down Expand Up @@ -60,3 +72,12 @@ type PushResult struct {
type PushResponse struct {
Results []PushResult `json:"results,omitempty"`
}

type WhoAmIResponse struct {
ProviderID string `json:"id"`
Email string `json:"email"`
Created time.Time `json:"created"`
KeySource string `json:"key_source"`
Isolated bool `json:"isolated"`
URL string `json:"url"`
}
40 changes: 21 additions & 19 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"tier.run/api/apitypes"
"tier.run/fetch"
"tier.run/refs"
"tier.run/trweb"
)

const Inf = 1<<63 - 1
Expand All @@ -31,37 +30,37 @@ func (c *Client) client() *http.Client {

// Pull fetches the complete pricing model from Stripe.
func (c *Client) Push(ctx context.Context, m apitypes.Model) (apitypes.PushResponse, error) {
return fetch.OK[apitypes.PushResponse, *trweb.HTTPError](ctx, c.client(), "POST", "/v1/push", m)
return fetch.OK[apitypes.PushResponse, *apitypes.Error](ctx, c.client(), "POST", "/v1/push", m)
}

func (c *Client) PushJSON(ctx context.Context, m []byte) (apitypes.PushResponse, error) {
return fetch.OK[apitypes.PushResponse, *trweb.HTTPError](ctx, c.client(), "POST", "/v1/push", json.RawMessage(m))
return fetch.OK[apitypes.PushResponse, *apitypes.Error](ctx, c.client(), "POST", "/v1/push", json.RawMessage(m))
}

// Pull fetches the complete pricing model from Stripe.
func (c *Client) Pull(ctx context.Context) (apitypes.Model, error) {
return fetch.OK[apitypes.Model, *trweb.HTTPError](ctx, c.client(), "GET", "/v1/pull", nil)
return fetch.OK[apitypes.Model, *apitypes.Error](ctx, c.client(), "GET", "/v1/pull", nil)
}

// PullJSON fetches the complete pricing model from Stripe and returns the raw
// JSON response.
func (c *Client) PullJSON(ctx context.Context) ([]byte, error) {
return fetch.OK[[]byte, *trweb.HTTPError](ctx, c.client(), "GET", "/v1/pull", nil)
return fetch.OK[[]byte, *apitypes.Error](ctx, c.client(), "GET", "/v1/pull", nil)
}

// WhoIS reports the Stripe ID for the given organization.
func (c *Client) WhoIs(ctx context.Context, org string) (apitypes.WhoIsResponse, error) {
return fetch.OK[apitypes.WhoIsResponse, *trweb.HTTPError](ctx, c.client(), "GET", "/v1/whois?org="+org, nil)
return fetch.OK[apitypes.WhoIsResponse, *apitypes.Error](ctx, c.client(), "GET", "/v1/whois?org="+org, nil)
}

// LookupPhase reports information about the current phase the provided org is scheduled in.
func (c *Client) LookupPhase(ctx context.Context, org string) (apitypes.PhaseResponse, error) {
return fetch.OK[apitypes.PhaseResponse, *trweb.HTTPError](ctx, c.client(), "GET", "/v1/phase?org="+org, nil)
return fetch.OK[apitypes.PhaseResponse, *apitypes.Error](ctx, c.client(), "GET", "/v1/phase?org="+org, nil)
}

// LookupLimits reports the current usage and limits for the provided org.
func (c *Client) LookupLimits(ctx context.Context, org string) (apitypes.UsageResponse, error) {
return fetch.OK[apitypes.UsageResponse, *trweb.HTTPError](ctx, c.client(), "GET", "/v1/limits?org="+org, nil)
return fetch.OK[apitypes.UsageResponse, *apitypes.Error](ctx, c.client(), "GET", "/v1/limits?org="+org, nil)
}

// LookupLimit reports the current usage and limits for the provided org and
Expand Down Expand Up @@ -121,17 +120,16 @@ func (c Answer) ReportN(n int) error {
//
// If reporting consumption is not required, it can be used in the form:
//
// if c.Can(ctx, "org:acme", "feature:convert").OK() { ... }
// if c.Can(ctx, "org:acme", "feature:convert").OK() { ... }
//
// reporting usage post consumption looks like:
//
// ans := c.Can(ctx, "org:acme", "feature:convert")
// if !ans.OK() {
// return ""
// }
// defer ans.Report() // or ReportN
// return convert(temp)
//
// ans := c.Can(ctx, "org:acme", "feature:convert")
// if !ans.OK() {
// return ""
// }
// defer ans.Report() // or ReportN
// return convert(temp)
func (c *Client) Can(ctx context.Context, org, feature string) Answer {
limit, used, err := c.LookupLimit(ctx, org, feature)
if err != nil {
Expand All @@ -156,7 +154,7 @@ func (c *Client) Report(ctx context.Context, org, feature string, n int) error {
if err != nil {
return err
}
_, err = fetch.OK[struct{}, *trweb.HTTPError](ctx, c.client(), "POST", "/v1/report", apitypes.ReportRequest{
_, err = fetch.OK[struct{}, *apitypes.Error](ctx, c.client(), "POST", "/v1/report", apitypes.ReportRequest{
Org: org,
Feature: fn,
N: n,
Expand All @@ -167,7 +165,7 @@ func (c *Client) Report(ctx context.Context, org, feature string, n int) error {

// ReportUsage reports usage based on the provided ReportRequest fields.
func (c *Client) ReportUsage(ctx context.Context, r apitypes.ReportRequest) error {
_, err := fetch.OK[struct{}, *trweb.HTTPError](ctx, c.client(), "POST", "/v1/report", r)
_, err := fetch.OK[struct{}, *apitypes.Error](ctx, c.client(), "POST", "/v1/report", r)
return err
}

Expand All @@ -177,9 +175,13 @@ func (c *Client) ReportUsage(ctx context.Context, r apitypes.ReportRequest) erro
// Any in-progress scheduled is overwritten and the customer is billed with
// prorations immediately.
func (c *Client) Subscribe(ctx context.Context, org string, featuresAndPlans ...string) error {
_, err := fetch.OK[struct{}, *trweb.HTTPError](ctx, c.client(), "POST", "/v1/subscribe", apitypes.SubscribeRequest{
_, err := fetch.OK[struct{}, *apitypes.Error](ctx, c.client(), "POST", "/v1/subscribe", apitypes.SubscribeRequest{
Org: org,
Phases: []apitypes.Phase{{Features: featuresAndPlans}},
})
return err
}

func (c *Client) WhoAmI(ctx context.Context) (apitypes.WhoAmIResponse, error) {
return fetch.OK[apitypes.WhoAmIResponse, *apitypes.Error](ctx, c.client(), "GET", "/v1/whoami", nil)
}
12 changes: 5 additions & 7 deletions cmd/tier/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,18 @@ func fetchProfile(ctx context.Context, pollURL string) (*profile.Profile, error)
// getKey returns the API from the environment variable STRIPE_API_KEY, or the
// live mode key in config.json, or the test mode key in config.json, in that
// order. It returns an error if no key is found.
func getKey() (string, error) {
func getKey() (key, source string, err error) {
if envAPIKey != "" {
return envAPIKey, nil
return envAPIKey, "STRIPE_API_KEY", nil
}

p, err := profile.Load("tier")
if err != nil {
return "", err
return "", "", err
}

if *flagLive {
return p.LiveAPIKey, nil
return p.LiveAPIKey, profile.ConfigPath, nil
}
return p.TestAPIKey, nil
return p.TestAPIKey, profile.ConfigPath, nil
}

//lint:ignore U1000 this type is used as a type parameter, but staticcheck seems to not be able to detect that yet. Remove this comment when staticcheck will stop complaining.
Expand Down
Loading

0 comments on commit 9cd6360

Please sign in to comment.