Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nevillelyh committed Dec 2, 2024
1 parent cdf9985 commit da71619
Show file tree
Hide file tree
Showing 7 changed files with 547 additions and 22 deletions.
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.7
github.com/peterbourgon/ff/v4 v4.0.0-alpha.4
github.com/replicate/go v0.0.0-20241101110715-45e9ae8c2040
github.com/stretchr/testify v1.9.0
go.uber.org/automaxprocs v1.6.0
gopkg.in/yaml.v3 v3.0.1
)
Expand All @@ -15,10 +16,12 @@ require (
github.com/cilium/ebpf v0.9.1 // indirect
github.com/containerd/cgroups/v3 v3.0.1 // indirect
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/opencontainers/runtime-spec v1.0.2 // indirect
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
Expand Down
62 changes: 40 additions & 22 deletions internal/server/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,21 @@ func (r *Runner) wait() {
if err != nil {
logs := r.rotateLogs()
log.Errorw("python runner excited with error", "pid", r.cmd.Process.Pid, "error", err, "logs", logs)
for pid, pr := range r.pending {
resp := PredictionResponse{
Input: pr.request.Input,
Id: pid,
CreatedAt: pr.request.CreatedAt,
// Runner crashed without writing a response JSON, so we don't actually know
StartedAt: pr.request.CreatedAt,
CompletedAt: util.NowIso(),
Logs: util.JoinLogs(pr.logs),
// Not per prediction logs, but could be useful nonetheless
Error: logs,
Status: "failed",
}
sendResponse(pr, resp)
}
if r.status == StatusStarting {
r.status = StatusSetupFailed
r.setupResult.CompletedAt = util.NowIso()
Expand Down Expand Up @@ -240,7 +255,7 @@ func (r *Runner) updateSetupResult() {

func (r *Runner) handleResponses() {
log := logger.Sugar()
completed := make(map[string]bool)
pids := make(map[string]bool)
for _, entry := range must.Get(os.ReadDir(r.workingDir)) {
m := RESPONSE_REGEX.FindStringSubmatch(entry.Name())
if m == nil {
Expand Down Expand Up @@ -272,30 +287,22 @@ func (r *Runner) handleResponses() {
resp.CreatedAt = pr.request.CreatedAt

r.mu.Lock()
resp.Logs = strings.Join(pr.logs, "\n")
if resp.Logs != "" {
resp.Logs += "\n"
}
resp.Logs = util.JoinLogs(pr.logs)
r.mu.Unlock()

// FIXME: webhook interval
if pr.request.Webhook != "" {
webhook(pr.request.Webhook, resp)
}

if _, ok := PredictionCompletedStatuses[resp.Status]; ok {
completed[pid] = true
completed := sendResponse(pr, resp)
if completed {
log.Infow("prediction completed", "id", pr.request.Id, "status", resp.Status)
if pr.c != nil {
pr.c <- resp
}
}
pids[pid] = completed
}
r.mu.Lock()
defer r.mu.Unlock()
for pid, _ := range completed {
delete(r.pending, pid)
must.Do(os.Remove(path.Join(r.workingDir, fmt.Sprintf(RESPONSE_FMT, pid))))
for pid, completed := range pids {
if completed {
delete(r.pending, pid)
must.Do(os.Remove(path.Join(r.workingDir, fmt.Sprintf(RESPONSE_FMT, pid))))
}
}
}

Expand All @@ -310,8 +317,22 @@ func (r *Runner) readJson(filename string, v any) error {
return json.Unmarshal(bs, v)
}

func sendResponse(pr *PendingPrediction, resp PredictionResponse) bool {
_, completed := PredictionCompletedStatuses[resp.Status]
if pr.request.Webhook != "" {
// Async prediction
// FIXME: webhook interval
webhook(pr.request.Webhook, resp)
} else if pr.c != nil && completed {
// Blocking prediction
pr.c <- resp
}
return completed
}

func webhook(url string, response PredictionResponse) {
log := logger.Sugar()
log.Infow("sending webhook", "url", url, "response", response)
body := bytes.NewBuffer(must.Get(json.Marshal(response)))
req := must.Get(http.NewRequest("POST", url, body))
req.Header.Add("Content-Type", "application/json")
Expand Down Expand Up @@ -346,10 +367,7 @@ func (r *Runner) log(line string) {
}

func (r *Runner) rotateLogs() string {
logs := strings.Join(r.logs, "\n")
if logs != "" {
logs += "\n"
}
logs := util.JoinLogs(r.logs)
r.logs = make([]string, 0)
return logs
}
Expand Down
151 changes: 151 additions & 0 deletions internal/tests/async_prediction_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package tests

import (
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestAsyncPredictionSucceeded(t *testing.T) {
e := NewCogTest(t, "sleep")
assert.NoError(t, e.Start())
e.StartWebhook()

hc := e.WaitForSetup()
assert.Equal(t, "READY", hc.Status)
assert.Equal(t, "succeeded", hc.Setup.Status)

e.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
for {
if len(e.WebhookRequests()) == 2 {
break
}
time.Sleep(100 * time.Millisecond)
}
wr := e.WebhookRequests()
for _, r := range wr {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/webhook", r.Path)
}
assert.Equal(t, "starting", wr[0].Response.Status)
assert.Equal(t, nil, wr[0].Response.Output)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs)

assert.Equal(t, "succeeded", wr[1].Response.Status)
assert.Equal(t, "*bar*", wr[1].Response.Output)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", wr[1].Response.Logs)

e.Shutdown()
assert.NoError(t, e.Cleanup())
}

func TestAsyncPredictionWithIdSucceeded(t *testing.T) {
e := NewCogTest(t, "sleep")
assert.NoError(t, e.Start())
e.StartWebhook()

hc := e.WaitForSetup()
assert.Equal(t, "READY", hc.Status)
assert.Equal(t, "succeeded", hc.Setup.Status)

e.AsyncPredictionWithId("p01", map[string]any{"i": 1, "s": "bar"})
for {
if len(e.WebhookRequests()) == 2 {
break
}
time.Sleep(100 * time.Millisecond)
}
wr := e.WebhookRequests()
for _, r := range wr {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/webhook", r.Path)
}

assert.Equal(t, "starting", wr[0].Response.Status)
assert.Equal(t, nil, wr[0].Response.Output)
assert.Equal(t, "p01", wr[0].Response.Id)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs)

assert.Equal(t, "succeeded", wr[1].Response.Status)
assert.Equal(t, "*bar*", wr[1].Response.Output)
assert.Equal(t, "p01", wr[1].Response.Id)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\ncompleted prediction\n", wr[1].Response.Logs)

e.Shutdown()
assert.NoError(t, e.Cleanup())
}

func TestAsyncPredictionFailure(t *testing.T) {
e := NewCogTest(t, "sleep")
e.AppendEnvs("PREDICTION_FAILURE=1")
assert.NoError(t, e.Start())
e.StartWebhook()

hc := e.WaitForSetup()
assert.Equal(t, "READY", hc.Status)
assert.Equal(t, "succeeded", hc.Setup.Status)

e.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
for {
if len(e.WebhookRequests()) == 2 {
break
}
time.Sleep(100 * time.Millisecond)
}
wr := e.WebhookRequests()
for _, r := range wr {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/webhook", r.Path)
}

assert.Equal(t, "starting", wr[0].Response.Status)
assert.Equal(t, nil, wr[0].Response.Output)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs)

assert.Equal(t, "failed", wr[1].Response.Status)
assert.Equal(t, nil, wr[1].Response.Output)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\nprediction failed\n", wr[1].Response.Logs)

e.Shutdown()
assert.NoError(t, e.Cleanup())
}

func TestAsyncPredictionCrash(t *testing.T) {
e := NewCogTest(t, "sleep")
e.AppendArgs("--await-explicit-shutdown")
e.AppendEnvs("PREDICTION_CRASH=1")
assert.NoError(t, e.Start())
e.StartWebhook()

hc := e.WaitForSetup()
assert.Equal(t, "READY", hc.Status)
assert.Equal(t, "succeeded", hc.Setup.Status)

e.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
for {
if len(e.WebhookRequests()) == 2 {
break
}
time.Sleep(100 * time.Millisecond)
}
wr := e.WebhookRequests()
for _, r := range wr {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/webhook", r.Path)
}

assert.Equal(t, "starting", wr[0].Response.Status)
assert.Equal(t, nil, wr[0].Response.Output)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\n", wr[0].Response.Logs)

assert.Equal(t, "failed", wr[1].Response.Status)
assert.Equal(t, nil, wr[1].Response.Output)
assert.Equal(t, "starting prediction\nprediction in progress 1/1\nprediction crashed\n", wr[1].Response.Logs)

assert.Equal(t, "DEFUNCT", e.HealthCheck().Status)

e.Shutdown()
assert.NoError(t, e.Cleanup())
}
Loading

0 comments on commit da71619

Please sign in to comment.