From 6cdac13ef274dae0ebeb052d32c977367c4d20e1 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Mon, 2 Dec 2024 11:19:15 -0500 Subject: [PATCH] Add tests --- go.mod | 3 + internal/server/runner.go | 62 ++++--- internal/tests/async_prediction_test.go | 151 +++++++++++++++++ internal/tests/cog_test.go | 209 ++++++++++++++++++++++++ internal/tests/prediction_test.go | 80 +++++++++ internal/tests/setup_test.go | 56 +++++++ internal/util/util.go | 8 + 7 files changed, 547 insertions(+), 22 deletions(-) create mode 100644 internal/tests/async_prediction_test.go create mode 100644 internal/tests/cog_test.go create mode 100644 internal/tests/prediction_test.go create mode 100644 internal/tests/setup_test.go diff --git a/go.mod b/go.mod index 122ca3a..51d3bc0 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gabriel-vasile/mimetype v1.4.7 github.com/peterbourgon/ff/v4 v4.0.0-alpha.4 github.com/replicate/go v0.0.0-20241101110715-45e9ae8c2040 + github.com/stretchr/testify v1.9.0 go.uber.org/automaxprocs v1.6.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -15,10 +16,12 @@ require ( github.com/cilium/ebpf v0.9.1 // indirect github.com/containerd/cgroups/v3 v3.0.1 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/opencontainers/runtime-spec v1.0.2 // indirect github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.8.1 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/internal/server/runner.go b/internal/server/runner.go index cf12063..091d954 100644 --- a/internal/server/runner.go +++ b/internal/server/runner.go @@ -166,6 +166,21 @@ func (r *Runner) wait() { if err != nil { logs := r.rotateLogs() log.Errorw("python runner excited with error", "pid", r.cmd.Process.Pid, "error", err, "logs", logs) + for pid, pr := range r.pending { + resp := PredictionResponse{ + Input: pr.request.Input, + Id: pid, + CreatedAt: pr.request.CreatedAt, + // Runner crashed without writing a response JSON, so we don't actually know + StartedAt: pr.request.CreatedAt, + CompletedAt: util.NowIso(), + Logs: util.JoinLogs(pr.logs), + // Not per prediction logs, but could be useful nonetheless + Error: logs, + Status: "failed", + } + sendResponse(pr, resp) + } if r.status == StatusStarting { r.status = StatusSetupFailed r.setupResult.CompletedAt = util.NowIso() @@ -240,7 +255,7 @@ func (r *Runner) updateSetupResult() { func (r *Runner) handleResponses() { log := logger.Sugar() - completed := make(map[string]bool) + pids := make(map[string]bool) for _, entry := range must.Get(os.ReadDir(r.workingDir)) { m := RESPONSE_REGEX.FindStringSubmatch(entry.Name()) if m == nil { @@ -272,30 +287,22 @@ func (r *Runner) handleResponses() { resp.CreatedAt = pr.request.CreatedAt r.mu.Lock() - resp.Logs = strings.Join(pr.logs, "\n") - if resp.Logs != "" { - resp.Logs += "\n" - } + resp.Logs = util.JoinLogs(pr.logs) r.mu.Unlock() - // FIXME: webhook interval - if pr.request.Webhook != "" { - webhook(pr.request.Webhook, resp) - } - - if _, ok := PredictionCompletedStatuses[resp.Status]; ok { - completed[pid] = true + completed := sendResponse(pr, resp) + if completed { log.Infow("prediction completed", "id", pr.request.Id, "status", resp.Status) - if pr.c != nil { - pr.c <- resp - } } + pids[pid] = completed } r.mu.Lock() defer r.mu.Unlock() - for pid, _ := range completed { - delete(r.pending, pid) - must.Do(os.Remove(path.Join(r.workingDir, fmt.Sprintf(RESPONSE_FMT, pid)))) + for pid, completed := range pids { + if completed { + delete(r.pending, pid) + must.Do(os.Remove(path.Join(r.workingDir, fmt.Sprintf(RESPONSE_FMT, pid)))) + } } } @@ -310,8 +317,22 @@ func (r *Runner) readJson(filename string, v any) error { return json.Unmarshal(bs, v) } +func sendResponse(pr *PendingPrediction, resp PredictionResponse) bool { + _, completed := PredictionCompletedStatuses[resp.Status] + if pr.request.Webhook != "" { + // Async prediction + // FIXME: webhook interval + webhook(pr.request.Webhook, resp) + } else if pr.c != nil && completed { + // Blocking prediction + pr.c <- resp + } + return completed +} + func webhook(url string, response PredictionResponse) { log := logger.Sugar() + log.Infow("sending webhook", "url", url, "response", response) body := bytes.NewBuffer(must.Get(json.Marshal(response))) req := must.Get(http.NewRequest("POST", url, body)) req.Header.Add("Content-Type", "application/json") @@ -346,10 +367,7 @@ func (r *Runner) log(line string) { } func (r *Runner) rotateLogs() string { - logs := strings.Join(r.logs, "\n") - if logs != "" { - logs += "\n" - } + logs := util.JoinLogs(r.logs) r.logs = make([]string, 0) return logs } diff --git a/internal/tests/async_prediction_test.go b/internal/tests/async_prediction_test.go new file mode 100644 index 0000000..2a9e952 --- /dev/null +++ b/internal/tests/async_prediction_test.go @@ -0,0 +1,151 @@ +package tests + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAsyncPredictionSucceeded(t *testing.T) { + e := NewCogTest(t, "sleep") + assert.NoError(t, e.Start()) + e.StartWebhook() + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + e.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) + for { + if len(e.WebhookRequests()) == 2 { + break + } + time.Sleep(100 * time.Millisecond) + } + wr := e.WebhookRequests() + for _, r := range wr { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/webhook", r.Path) + } + assert.Equal(t, "starting", wr[0].Response.Status) + assert.Equal(t, nil, wr[0].Response.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs) + + assert.Equal(t, "succeeded", wr[1].Response.Status) + assert.Equal(t, "*bar*", wr[1].Response.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", wr[1].Response.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestAsyncPredictionWithIdSucceeded(t *testing.T) { + e := NewCogTest(t, "sleep") + assert.NoError(t, e.Start()) + e.StartWebhook() + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + e.AsyncPredictionWithId("p01", map[string]any{"i": 1, "s": "bar"}) + for { + if len(e.WebhookRequests()) == 2 { + break + } + time.Sleep(100 * time.Millisecond) + } + wr := e.WebhookRequests() + for _, r := range wr { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/webhook", r.Path) + } + + assert.Equal(t, "starting", wr[0].Response.Status) + assert.Equal(t, nil, wr[0].Response.Output) + assert.Equal(t, "p01", wr[0].Response.Id) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs) + + assert.Equal(t, "succeeded", wr[1].Response.Status) + assert.Equal(t, "*bar*", wr[1].Response.Output) + assert.Equal(t, "p01", wr[1].Response.Id) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", wr[1].Response.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestAsyncPredictionFailure(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendEnvs("PREDICTION_FAILURE=1") + assert.NoError(t, e.Start()) + e.StartWebhook() + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + e.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) + for { + if len(e.WebhookRequests()) == 2 { + break + } + time.Sleep(100 * time.Millisecond) + } + wr := e.WebhookRequests() + for _, r := range wr { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/webhook", r.Path) + } + + assert.Equal(t, "starting", wr[0].Response.Status) + assert.Equal(t, nil, wr[0].Response.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs) + + assert.Equal(t, "failed", wr[1].Response.Status) + assert.Equal(t, nil, wr[1].Response.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\nprediction failed\n", wr[1].Response.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestAsyncPredictionCrash(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendArgs("--await-explicit-shutdown") + e.AppendEnvs("PREDICTION_CRASH=1") + assert.NoError(t, e.Start()) + e.StartWebhook() + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + e.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) + for { + if len(e.WebhookRequests()) == 2 { + break + } + time.Sleep(100 * time.Millisecond) + } + wr := e.WebhookRequests() + for _, r := range wr { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "/webhook", r.Path) + } + + assert.Equal(t, "starting", wr[0].Response.Status) + assert.Equal(t, nil, wr[0].Response.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs) + + assert.Equal(t, "failed", wr[1].Response.Status) + assert.Equal(t, nil, wr[1].Response.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\nprediction crashed\n", wr[1].Response.Logs) + + assert.Equal(t, "DEFUNCT", e.HealthCheck().Status) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} diff --git a/internal/tests/cog_test.go b/internal/tests/cog_test.go new file mode 100644 index 0000000..6e12ecb --- /dev/null +++ b/internal/tests/cog_test.go @@ -0,0 +1,209 @@ +package tests + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "path" + "runtime" + "testing" + "time" + + "github.com/replicate/go/logging" + + "github.com/replicate/go/must" + "github.com/stretchr/testify/assert" + + "github.com/replicate/cog-runtime/internal/server" +) + +var ( + _, b, _, _ = runtime.Caller(0) + basePath = path.Dir(path.Dir(path.Dir(b))) + logger = logging.New("cog-test") +) + +type WebhookRequest struct { + Method string + Path string + Response server.PredictionResponse +} + +type WebhookHandler struct { + webhookRequests []WebhookRequest +} + +func (h *WebhookHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var resp server.PredictionResponse + must.Do(json.Unmarshal(must.Get(io.ReadAll(r.Body)), &resp)) + req := WebhookRequest{ + Method: r.Method, + Path: r.URL.Path, + Response: resp, + } + log := logger.Sugar() + log.Infow("webhook", "request", req) + h.webhookRequests = append(h.webhookRequests, req) +} + +var _ = (http.Handler)((*WebhookHandler)(nil)) + +type CogTest struct { + t *testing.T + module string + extraArgs []string + extraEnvs []string + serverPort int + webhookPort int + cmd *exec.Cmd + webhookServer *http.Server +} + +func NewCogTest(t *testing.T, module string) *CogTest { + t.Parallel() + return &CogTest{ + t: t, + module: module, + } +} + +func (e *CogTest) AppendArgs(args ...string) { + e.extraArgs = append(e.extraArgs, args...) +} + +func (e *CogTest) AppendEnvs(envs ...string) { + e.extraEnvs = append(e.extraEnvs, envs...) +} + +func (e *CogTest) StartWebhook() { + e.webhookPort = getFreePort() + e.webhookServer = &http.Server{ + Addr: fmt.Sprintf(":%d", e.webhookPort), + Handler: &WebhookHandler{}, + } + go func() { + err := e.webhookServer.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + panic(err) + } + }() +} + +func (e *CogTest) Start() error { + pathEnv := path.Join(basePath, "python", ".venv", "bin") + pythonPathEnv := path.Join(basePath, "python") + e.serverPort = getFreePort() + args := []string{ + "run", path.Join(basePath, "cmd", "cog-server", "main.go"), + "--module-name", fmt.Sprintf("tests.runners.%s", e.module), + "--class-name", "Predictor", + "--port", fmt.Sprintf("%d", e.serverPort), + } + args = append(args, e.extraArgs...) + e.cmd = exec.Command("go", args...) + e.cmd.Env = os.Environ() + e.cmd.Env = append(e.cmd.Env, + fmt.Sprintf("PATH=%s:%s", pathEnv, os.Getenv("PATH")), + fmt.Sprintf("PYTHONPATH=%s", pythonPathEnv), + ) + e.cmd.Env = append(e.cmd.Env, e.extraEnvs...) + e.cmd.Stdout = os.Stdout + e.cmd.Stderr = os.Stderr + return e.cmd.Start() +} + +func (e *CogTest) Cleanup() error { + if e.webhookServer != nil { + must.Do(e.webhookServer.Shutdown(context.Background())) + } + return e.cmd.Wait() +} + +func (e *CogTest) WebhookRequests() []WebhookRequest { + return e.webhookServer.Handler.(*WebhookHandler).webhookRequests +} + +func (e *CogTest) Url(path string) string { + return fmt.Sprintf("http://localhost:%d%s", e.serverPort, path) +} + +func (e *CogTest) HealthCheck() server.HealthCheck { + url := fmt.Sprintf("http://localhost:%d/health-check", e.serverPort) + for { + resp, err := http.DefaultClient.Get(url) + if err == nil { + var hc server.HealthCheck + must.Do(json.Unmarshal(must.Get(io.ReadAll(resp.Body)), &hc)) + return hc + } + time.Sleep(100 * time.Millisecond) + } +} + +func (e *CogTest) WaitForSetup() server.HealthCheck { + for { + hc := e.HealthCheck() + if hc.Status != "STARTING" { + return hc + } + time.Sleep(100 * time.Millisecond) + } +} + +func (e *CogTest) Prediction(input map[string]any) server.PredictionResponse { + return e.prediction(http.MethodPost, e.Url("/predictions"), input) +} + +func (e *CogTest) PredictionWithId(pid string, input map[string]any) server.PredictionResponse { + return e.prediction(http.MethodPut, e.Url(fmt.Sprintf("/predictions/%s", pid)), input) +} + +func (e *CogTest) prediction(method string, url string, input map[string]any) server.PredictionResponse { + req := server.PredictionRequest{Input: input} + data := bytes.NewReader(must.Get(json.Marshal(req))) + r := must.Get(http.NewRequest(method, url, data)) + r.Header.Set("Content-Type", "application/json") + resp := must.Get(http.DefaultClient.Do(r)) + assert.Equal(e.t, http.StatusOK, resp.StatusCode) + var pr server.PredictionResponse + must.Do(json.Unmarshal(must.Get(io.ReadAll(resp.Body)), &pr)) + return pr +} + +func (e *CogTest) AsyncPrediction(input map[string]any) { + e.asyncPrediction(http.MethodPost, e.Url("/predictions"), input) +} + +func (e *CogTest) AsyncPredictionWithId(pid string, input map[string]any) { + e.asyncPrediction(http.MethodPut, e.Url(fmt.Sprintf("/predictions/%s", pid)), input) +} + +func (e *CogTest) asyncPrediction(method string, url string, input map[string]any) { + req := server.PredictionRequest{Input: input, Webhook: fmt.Sprintf("http://localhost:%d/webhook", e.webhookPort)} + data := bytes.NewReader(must.Get(json.Marshal(req))) + r := must.Get(http.NewRequest(method, url, data)) + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Prefer", "respond-async") + resp := must.Get(http.DefaultClient.Do(r)) + assert.Equal(e.t, http.StatusOK, resp.StatusCode) +} + +func (e *CogTest) Shutdown() { + url := fmt.Sprintf("http://localhost:%d/shutdown", e.serverPort) + resp := must.Get(http.DefaultClient.Post(url, "", nil)) + assert.Equal(e.t, http.StatusOK, resp.StatusCode) +} + +func getFreePort() int { + a := must.Get(net.ResolveTCPAddr("tcp", "localhost:0")) + l := must.Get(net.ListenTCP("tcp", a)) + defer l.Close() + return l.Addr().(*net.TCPAddr).Port +} diff --git a/internal/tests/prediction_test.go b/internal/tests/prediction_test.go new file mode 100644 index 0000000..0cdba5b --- /dev/null +++ b/internal/tests/prediction_test.go @@ -0,0 +1,80 @@ +package tests + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPredictionSucceeded(t *testing.T) { + e := NewCogTest(t, "sleep") + assert.NoError(t, e.Start()) + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + resp := e.Prediction(map[string]any{"i": 1, "s": "bar"}) + assert.Equal(t, "succeeded", resp.Status) + assert.Equal(t, "*bar*", resp.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", resp.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestPredictionWithIdSucceeded(t *testing.T) { + e := NewCogTest(t, "sleep") + assert.NoError(t, e.Start()) + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + resp := e.PredictionWithId("p01", map[string]any{"i": 1, "s": "bar"}) + assert.Equal(t, "succeeded", resp.Status) + assert.Equal(t, "*bar*", resp.Output) + assert.Equal(t, "p01", resp.Id) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", resp.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestPredictionFailure(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendEnvs("PREDICTION_FAILURE=1") + assert.NoError(t, e.Start()) + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + resp := e.Prediction(map[string]any{"i": 1, "s": "bar"}) + assert.Equal(t, "failed", resp.Status) + assert.Equal(t, nil, resp.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\nprediction failed\n", resp.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestPredictionCrash(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendArgs("--await-explicit-shutdown") + e.AppendEnvs("PREDICTION_CRASH=1") + assert.NoError(t, e.Start()) + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + + resp := e.Prediction(map[string]any{"i": 1, "s": "bar"}) + assert.Equal(t, "failed", resp.Status) + assert.Equal(t, nil, resp.Output) + assert.Equal(t, "starting prediction\nprediction in progress 1/1\nprediction crashed\n", resp.Logs) + assert.Equal(t, "DEFUNCT", e.HealthCheck().Status) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} diff --git a/internal/tests/setup_test.go b/internal/tests/setup_test.go new file mode 100644 index 0000000..df136dd --- /dev/null +++ b/internal/tests/setup_test.go @@ -0,0 +1,56 @@ +package tests + +import ( + "net/http" + "testing" + + "github.com/replicate/go/must" + + "github.com/stretchr/testify/assert" +) + +func TestSetupSucceeded(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendEnvs("SETUP_SLEEP=1") + assert.NoError(t, e.Start()) + assert.Equal(t, "STARTING", e.HealthCheck().Status) + + hc := e.WaitForSetup() + assert.Equal(t, "READY", hc.Status) + assert.Equal(t, "succeeded", hc.Setup.Status) + assert.Equal(t, "starting setup\nsetup in progress 1/1\ncompleted setup\n", hc.Setup.Logs) + assert.Equal(t, http.StatusOK, must.Get(http.DefaultClient.Get(e.Url("/openapi.json"))).StatusCode) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestSetupFailure(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendArgs("--await-explicit-shutdown") + e.AppendEnvs("SETUP_FAILURE=1") + assert.NoError(t, e.Start()) + + hc := e.WaitForSetup() + assert.Equal(t, "SETUP_FAILED", hc.Status) + assert.Equal(t, "failed", hc.Setup.Status) + assert.Equal(t, "starting setup\nsetup failed\n", hc.Setup.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} + +func TestSetupCrash(t *testing.T) { + e := NewCogTest(t, "sleep") + e.AppendArgs("--await-explicit-shutdown") + e.AppendEnvs("SETUP_CRASH=1") + assert.NoError(t, e.Start()) + + hc := e.WaitForSetup() + assert.Equal(t, "SETUP_FAILED", hc.Status) + assert.Equal(t, "failed", hc.Setup.Status) + assert.Equal(t, "starting setup\nsetup crashed\n", hc.Setup.Logs) + + e.Shutdown() + assert.NoError(t, e.Cleanup()) +} diff --git a/internal/util/util.go b/internal/util/util.go index 94277c3..5fccfcc 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -50,3 +50,11 @@ func NowIso() string { // Python: datetime.now(tz=timezone.utc).isoformat() return time.Now().UTC().Format("2006-01-02T15:04:05.999999-07:00") } + +func JoinLogs(logs []string) string { + r := strings.Join(logs, "\n") + if r != "" { + r += "\n" + } + return r +}