diff --git a/README.md b/README.md index 50971ae..2b8a2fb 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Made in 🇩🇰 by [maragu](https://www.maragu.dk/), maker of [online Go course - Messages are sent to and received from the queue, and are guaranteed to not be redelivered before a timeout occurs. - Support for multiple queues in one table. - Message timeouts can be extended, to support e.g. long-running tasks. +- A job runner abstraction is provided on top of the queue, for your background tasks. - A simple HTTP handler is provided for your convenience. - No non-test dependencies. Bring your own SQLite driver. @@ -45,6 +46,8 @@ func main() { if err != nil { log.Fatalln(err) } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) if err := goqite.Setup(context.Background(), db); err != nil { log.Fatalln(err) diff --git a/docs/examples/jobs/main.go b/docs/examples/jobs/main.go new file mode 100644 index 0000000..537d75a --- /dev/null +++ b/docs/examples/jobs/main.go @@ -0,0 +1,62 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "time" + + _ "github.com/mattn/go-sqlite3" + + "github.com/maragudk/goqite" + "github.com/maragudk/goqite/jobs" +) + +func main() { + log := slog.Default() + + // Setup the db and goqite schema. + db, err := sql.Open("sqlite3", ":memory:?_journal=WAL&_timeout=5000&_fk=true") + if err != nil { + log.Info("Error opening db", "error", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + if err := goqite.Setup(context.Background(), db); err != nil { + log.Info("Error in setup", "error", err) + } + + // Make a new queue for the jobs. You can have as many of these as you like, just name them differently. + q := goqite.New(goqite.NewOpts{ + DB: db, + Name: "jobs", + }) + + // Make a job runner with a job limit of 1 and a short message poll interval. + r := jobs.NewRunner(jobs.NewRunnerOpts{ + Limit: 1, + Log: slog.Default(), + PollInterval: 10 * time.Millisecond, + Queue: q, + }) + + // Register our "print" job. + r.Register("print", func(ctx context.Context, m []byte) error { + fmt.Println(string(m)) + return nil + }) + + // Create a "print" job with a message. + if err := jobs.Create(context.Background(), q, "print", []byte("Yo")); err != nil { + log.Info("Error creating job", "error", err) + } + + // Stop the job runner after a timeout. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + // Start the job runner and see the job run. + r.Start(ctx) +} diff --git a/docs/example.go b/docs/examples/queue/main.go similarity index 97% rename from docs/example.go rename to docs/examples/queue/main.go index 4a017b8..85137d8 100644 --- a/docs/example.go +++ b/docs/examples/queue/main.go @@ -21,6 +21,8 @@ func main() { if err != nil { log.Fatalln(err) } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) if err := goqite.Setup(context.Background(), db); err != nil { log.Fatalln(err) diff --git a/docs/index.html b/docs/index.html index e0f7f85..30bf403 100644 --- a/docs/index.html +++ b/docs/index.html @@ -24,7 +24,9 @@

goqite

$ go get github.com/maragudk/goqite

See goqite on Github

-

Example

+

Examples

+ +

Queue

package main
 
@@ -49,6 +51,8 @@ 

Example

if err != nil { log.Fatalln(err) } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) if err := goqite.Setup(context.Background(), db); err != nil { log.Fatalln(err) @@ -92,11 +96,77 @@

Example

log.Fatalln(err) } } +
+ +

Jobs

+ +
package main
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+	"log/slog"
+	"time"
+
+	_ "github.com/mattn/go-sqlite3"
+
+	"github.com/maragudk/goqite"
+	"github.com/maragudk/goqite/jobs"
+)
+
+func main() {
+	log := slog.Default()
+
+	// Setup the db and goqite schema.
+	db, err := sql.Open("sqlite3", ":memory:?_journal=WAL&_timeout=5000&_fk=true")
+	if err != nil {
+		log.Info("Error opening db", "error", err)
+	}
+	db.SetMaxOpenConns(1)
+	db.SetMaxIdleConns(1)
+
+	if err := goqite.Setup(context.Background(), db); err != nil {
+		log.Info("Error in setup", "error", err)
+	}
+
+	// Make a new queue for the jobs. You can have as many of these as you like, just name them differently.
+	q := goqite.New(goqite.NewOpts{
+		DB:   db,
+		Name: "jobs",
+	})
+
+	// Make a job runner with a job limit of 1 and a short message poll interval.
+	r := jobs.NewRunner(jobs.NewRunnerOpts{
+		Limit:        1,
+		Log:          slog.Default(),
+		PollInterval: 10 * time.Millisecond,
+		Queue:        q,
+	})
+
+	// Register our "print" job.
+	r.Register("print", func(ctx context.Context, m []byte) error {
+		fmt.Println(string(m))
+		return nil
+	})
+
+	// Create a "print" job with a message.
+	if err := jobs.Create(context.Background(), q, "print", []byte("Yo")); err != nil {
+		log.Info("Error creating job", "error", err)
+	}
+
+	// Stop the job runner after a timeout.
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
+	defer cancel()
+
+	// Start the job runner and see the job run.
+	r.Start(ctx)
+}
 
- + diff --git a/goqite.go b/goqite.go index 05714e8..27fb730 100644 --- a/goqite.go +++ b/goqite.go @@ -7,8 +7,9 @@ import ( "database/sql" _ "embed" "errors" - "fmt" "time" + + internalsql "github.com/maragudk/goqite/internal/sql" ) //go:embed schema.sql @@ -18,13 +19,8 @@ var schema string // zeros removed. const rfc3339Milli = "2006-01-02T15:04:05.000Z07:00" -type logger interface { - Println(v ...any) -} - type NewOpts struct { DB *sql.DB - Log logger MaxReceive int // Max receive count for messages before they cannot be received anymore. Name string Timeout time.Duration // Default timeout for messages before they can be re-received. @@ -44,10 +40,6 @@ func New(opts NewOpts) *Queue { panic("name cannot be empty") } - if opts.Log == nil { - opts.Log = &discardLogger{} - } - if opts.MaxReceive < 0 { panic("max receive cannot be negative") } @@ -67,7 +59,6 @@ func New(opts NewOpts) *Queue { return &Queue{ db: opts.DB, name: opts.Name, - log: opts.Log, maxReceive: opts.MaxReceive, timeout: opts.Timeout, } @@ -75,7 +66,6 @@ func New(opts NewOpts) *Queue { type Queue struct { db *sql.DB - log logger maxReceive int name string timeout time.Duration @@ -91,7 +81,7 @@ type Message struct { // Send a Message to the queue with an optional delay. func (q *Queue) Send(ctx context.Context, m Message) error { - return q.inTx(func(tx *sql.Tx) error { + return internalsql.InTx(q.db, func(tx *sql.Tx) error { return q.SendTx(ctx, tx, m) }) } @@ -114,7 +104,7 @@ func (q *Queue) SendTx(ctx context.Context, tx *sql.Tx, m Message) error { // Receive a Message from the queue, or nil if there is none. func (q *Queue) Receive(ctx context.Context) (*Message, error) { var m *Message - err := q.inTx(func(tx *sql.Tx) error { + err := internalsql.InTx(q.db, func(tx *sql.Tx) error { var err error m, err = q.ReceiveTx(ctx, tx) return err @@ -154,8 +144,8 @@ func (q *Queue) ReceiveTx(ctx context.Context, tx *sql.Tx) (*Message, error) { return &m, nil } -// ReceiveAndWait for a Message from the queue or the context is cancelled. -// If the context is cancelled, the error will be non-nil. See context.Context.Err. +// ReceiveAndWait for a Message from the queue, polling at the given interval, until the context is cancelled. +// If the context is cancelled, the error will be non-nil. See [context.Context.Err]. func (q *Queue) ReceiveAndWait(ctx context.Context, interval time.Duration) (*Message, error) { ticker := time.NewTicker(interval) defer ticker.Stop() @@ -178,7 +168,7 @@ func (q *Queue) ReceiveAndWait(ctx context.Context, interval time.Duration) (*Me // Extend a Message timeout by the given delay from now. func (q *Queue) Extend(ctx context.Context, id ID, delay time.Duration) error { - return q.inTx(func(tx *sql.Tx) error { + return internalsql.InTx(q.db, func(tx *sql.Tx) error { return q.ExtendTx(ctx, tx, id, delay) }) } @@ -197,7 +187,7 @@ func (q *Queue) ExtendTx(ctx context.Context, tx *sql.Tx, id ID, delay time.Dura // Delete a Message from the queue by id. func (q *Queue) Delete(ctx context.Context, id ID) error { - return q.inTx(func(tx *sql.Tx) error { + return internalsql.InTx(q.db, func(tx *sql.Tx) error { return q.DeleteTx(ctx, tx, id) }) } @@ -208,41 +198,6 @@ func (q *Queue) DeleteTx(ctx context.Context, tx *sql.Tx, id ID) error { return err } -func (q *Queue) inTx(cb func(*sql.Tx) error) (err error) { - tx, txErr := q.db.Begin() - if txErr != nil { - return fmt.Errorf("cannot start tx: %w", txErr) - } - - defer func() { - if rec := recover(); rec != nil { - err = rollback(tx, nil) - panic(rec) - } - }() - - if err := cb(tx); err != nil { - return rollback(tx, err) - } - - if txErr := tx.Commit(); txErr != nil { - return fmt.Errorf("cannot commit tx: %w", txErr) - } - - return nil -} - -func rollback(tx *sql.Tx, err error) error { - if txErr := tx.Rollback(); txErr != nil { - return fmt.Errorf("cannot roll back tx after error (tx error: %v), original error: %w", txErr, err) - } - return err -} - -type discardLogger struct{} - -func (l *discardLogger) Println(v ...any) {} - // Setup the queue in the database. func Setup(ctx context.Context, db *sql.DB) error { _, err := db.ExecContext(ctx, schema) diff --git a/internal/sql/tx.go b/internal/sql/tx.go new file mode 100644 index 0000000..216a92a --- /dev/null +++ b/internal/sql/tx.go @@ -0,0 +1,37 @@ +package sql + +import ( + "database/sql" + "fmt" +) + +func InTx(db *sql.DB, cb func(*sql.Tx) error) (err error) { + tx, txErr := db.Begin() + if txErr != nil { + return fmt.Errorf("cannot start tx: %w", txErr) + } + + defer func() { + if rec := recover(); rec != nil { + err = rollback(tx, nil) + panic(rec) + } + }() + + if err := cb(tx); err != nil { + return rollback(tx, err) + } + + if txErr := tx.Commit(); txErr != nil { + return fmt.Errorf("cannot commit tx: %w", txErr) + } + + return nil +} + +func rollback(tx *sql.Tx, err error) error { + if txErr := tx.Rollback(); txErr != nil { + return fmt.Errorf("cannot roll back tx after error (tx error: %v), original error: %w", txErr, err) + } + return err +} diff --git a/internal/testing/schema.sql b/internal/testing/schema.sql new file mode 100644 index 0000000..0e47ffb --- /dev/null +++ b/internal/testing/schema.sql @@ -0,0 +1,15 @@ +create table goqite ( + id text primary key default ('m_' || lower(hex(randomblob(16)))), + created text not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + updated text not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + queue text not null, + body blob not null, + timeout text not null default (strftime('%Y-%m-%dT%H:%M:%fZ')), + received integer not null default 0 +) strict; + +create trigger goqite_updated_timestamp after update on goqite begin + update goqite set updated = strftime('%Y-%m-%dT%H:%M:%fZ') where id = old.id; +end; + +create index goqite_queue_created_idx on goqite (queue, created); diff --git a/internal/testing/testing.go b/internal/testing/testing.go new file mode 100644 index 0000000..f10976a --- /dev/null +++ b/internal/testing/testing.go @@ -0,0 +1,82 @@ +package testing + +import ( + "database/sql" + _ "embed" + "fmt" + "os" + "testing" + + _ "github.com/mattn/go-sqlite3" + + "github.com/maragudk/goqite" +) + +//go:embed schema.sql +var schema string + +func NewDB(t testing.TB, path string) *sql.DB { + t.Helper() + + // Check if file exists already + exists := false + if _, err := os.Stat(path); err == nil { + exists = true + } + + if path != ":memory:" && !exists { + t.Cleanup(func() { + for _, p := range []string{path, path + "-shm", path + "-wal"} { + if err := os.Remove(p); err != nil { + t.Fatal(err) + } + } + }) + } + + db, err := sql.Open("sqlite3", path+"?_journal=WAL&_timeout=5000&_fk=true") + if err != nil { + t.Fatal(err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + if !exists { + _, err = db.Exec(schema) + if err != nil { + t.Fatal(err) + } + } + + return db +} + +func NewQ(t testing.TB, opts goqite.NewOpts, path string) *goqite.Queue { + t.Helper() + + opts.DB = NewDB(t, path) + + if opts.Name == "" { + opts.Name = "test" + } + + return goqite.New(opts) +} + +type Logger func(msg string, args ...any) + +func (f Logger) Info(msg string, args ...any) { + f(msg, args...) +} + +func NewLogger(t *testing.T) Logger { + t.Helper() + + return Logger(func(msg string, args ...any) { + logArgs := []any{msg} + for i := 0; i < len(args); i += 2 { + logArgs = append(logArgs, fmt.Sprintf("%v=%v", args[i], args[i+1])) + } + t.Log(logArgs...) + }) +} diff --git a/jobs/runner.go b/jobs/runner.go new file mode 100644 index 0000000..6e3715f --- /dev/null +++ b/jobs/runner.go @@ -0,0 +1,215 @@ +// Package jobs provides a [Runner] which can run registered job [Func]s by name, when a message for it is received +// on the underlying queue. +// +// It provides: +// - Limit on how many jobs can be run simultaneously +// - Automatic message timeout extension while the job is running +// - Graceful shutdown +package jobs + +import ( + "bytes" + "context" + "encoding/gob" + "errors" + "fmt" + "runtime" + "sort" + "sync" + "time" + + "github.com/maragudk/goqite" +) + +type NewRunnerOpts struct { + Limit int + Log logger + PollInterval time.Duration + Queue *goqite.Queue +} + +func NewRunner(opts NewRunnerOpts) *Runner { + if opts.Log == nil { + opts.Log = &discardLogger{} + } + + if opts.Limit == 0 { + opts.Limit = runtime.GOMAXPROCS(0) + } + + if opts.PollInterval == 0 { + opts.PollInterval = 100 * time.Millisecond + } + + return &Runner{ + jobCountLimit: opts.Limit, + jobs: make(map[string]Func), + log: opts.Log, + pollInterval: opts.PollInterval, + queue: opts.Queue, + } +} + +type Runner struct { + jobCount int + jobCountLimit int + jobCountLock sync.RWMutex + jobs map[string]Func + log logger + pollInterval time.Duration + queue *goqite.Queue +} + +type message struct { + Name string + Message []byte +} + +// Start the Runner, blocking until the given context is cancelled. +// When the context is cancelled, waits for the jobs to finish. +func (r *Runner) Start(ctx context.Context) { + var names []string + for k := range r.jobs { + names = append(names, k) + } + sort.Strings(names) + + r.log.Info("Starting", "jobs", names) + + var wg sync.WaitGroup + + for { + select { + case <-ctx.Done(): + r.log.Info("Stopping") + wg.Wait() + r.log.Info("Stopped") + return + default: + r.receiveAndRun(ctx, &wg) + } + } +} + +func (r *Runner) receiveAndRun(ctx context.Context, wg *sync.WaitGroup) { + r.jobCountLock.RLock() + if r.jobCount == r.jobCountLimit { + r.jobCountLock.RUnlock() + // This is to avoid a busy loop + time.Sleep(r.pollInterval) + return + } else { + r.jobCountLock.RUnlock() + } + + m, err := r.queue.ReceiveAndWait(ctx, r.pollInterval) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return + } + r.log.Info("Error receiving job", "error", err) + // Sleep a bit to not hammer the queue if there's an error with it + time.Sleep(time.Second) + return + } + + if m == nil { + return + } + + var jm message + if err := gob.NewDecoder(bytes.NewReader(m.Body)).Decode(&jm); err != nil { + r.log.Info("Error decoding job message body", "error", err) + return + } + + job, ok := r.jobs[jm.Name] + if !ok { + panic(fmt.Sprintf(`job "%v" not registered`, jm.Name)) + } + + r.jobCountLock.Lock() + r.jobCount++ + r.jobCountLock.Unlock() + + wg.Add(1) + go func() { + defer wg.Done() + + defer func() { + r.jobCountLock.Lock() + r.jobCount-- + r.jobCountLock.Unlock() + }() + + defer func() { + if rec := recover(); rec != nil { + r.log.Info("Recovered from panic in job", "error", rec) + } + }() + + jobCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Extend the job message while the job is running + done := make(chan struct{}, 1) + defer func() { + done <- struct{}{} + }() + + go func() { + for { + select { + case <-done: + return + default: + if err := r.queue.Extend(jobCtx, m.ID, 5*time.Second); err != nil { + r.log.Info("Error extending message timeout", "error", err) + } + time.Sleep(3 * time.Second) + } + } + }() + + before := time.Now() + if err := job(jobCtx, jm.Message); err != nil { + r.log.Info("Error running job", "name", jm.Name, "error", err) + return + } + duration := time.Since(before) + r.log.Info("Ran job", "name", jm.Name, "duration", duration) + + deleteCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := r.queue.Delete(deleteCtx, m.ID); err != nil { + r.log.Info("Error deleting job from queue", "error", err) + } + }() +} + +// Func is a job to be done. It gets the message m from the queue. +type Func func(ctx context.Context, m []byte) error + +func (r *Runner) Register(name string, job Func) { + if _, ok := r.jobs[name]; ok { + panic(fmt.Sprintf(`job "%v" already registered`, name)) + } + r.jobs[name] = job +} + +func Create(ctx context.Context, q *goqite.Queue, name string, m []byte) error { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(message{Name: name, Message: m}); err != nil { + return err + } + return q.Send(ctx, goqite.Message{Body: buf.Bytes()}) +} + +// logger matches the info level method from the slog.Logger. +type logger interface { + Info(msg string, args ...any) +} + +type discardLogger struct{} + +func (d *discardLogger) Info(msg string, args ...any) {} diff --git a/jobs/runner_test.go b/jobs/runner_test.go new file mode 100644 index 0000000..f03323e --- /dev/null +++ b/jobs/runner_test.go @@ -0,0 +1,180 @@ +package jobs_test + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "testing" + "time" + + "github.com/maragudk/is" + _ "github.com/mattn/go-sqlite3" + + "github.com/maragudk/goqite" + internaltesting "github.com/maragudk/goqite/internal/testing" + "github.com/maragudk/goqite/jobs" +) + +func TestRunner_Register(t *testing.T) { + t.Run("can register a new job", func(t *testing.T) { + r := jobs.NewRunner(jobs.NewRunnerOpts{}) + r.Register("test", func(ctx context.Context, m []byte) error { + return nil + }) + }) + + t.Run("panics if the same job is registered twice", func(t *testing.T) { + r := jobs.NewRunner(jobs.NewRunnerOpts{}) + r.Register("test", func(ctx context.Context, m []byte) error { + return nil + }) + defer func() { + r := recover() + if r == nil { + t.Fatal("did not panic") + } + is.Equal(t, `job "test" already registered`, r) + }() + r.Register("test", func(ctx context.Context, m []byte) error { + return nil + }) + }) +} + +func TestRunner_Start(t *testing.T) { + t.Run("can run a named job", func(t *testing.T) { + q, r := newRunner(t) + + var ran bool + ctx, cancel := context.WithCancel(context.Background()) + r.Register("test", func(ctx context.Context, m []byte) error { + ran = true + is.Equal(t, "yo", string(m)) + cancel() + return nil + }) + + err := jobs.Create(ctx, q, "test", []byte("yo")) + is.NotError(t, err) + + r.Start(ctx) + is.True(t, ran) + }) + + t.Run("doesn't run a different job", func(t *testing.T) { + q, r := newRunner(t) + + var ranTest, ranDifferentTest bool + ctx, cancel := context.WithCancel(context.Background()) + r.Register("test", func(ctx context.Context, m []byte) error { + ranTest = true + return nil + }) + r.Register("different-test", func(ctx context.Context, m []byte) error { + ranDifferentTest = true + cancel() + return nil + }) + + err := jobs.Create(ctx, q, "different-test", []byte("yo")) + is.NotError(t, err) + + r.Start(ctx) + is.True(t, !ranTest) + is.True(t, ranDifferentTest) + }) + + t.Run("panics if the job is not registered", func(t *testing.T) { + q, r := newRunner(t) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := jobs.Create(ctx, q, "test", []byte("yo")) + is.NotError(t, err) + + defer func() { + r := recover() + if r == nil { + t.Fatal("did not panic") + } + is.Equal(t, `job "test" not registered`, r) + }() + r.Start(ctx) + }) + + t.Run("does not panic if job panics", func(t *testing.T) { + q, r := newRunner(t) + + ctx, cancel := context.WithCancel(context.Background()) + + r.Register("test", func(ctx context.Context, m []byte) error { + cancel() + panic("test panic") + }) + + err := jobs.Create(ctx, q, "test", []byte("yo")) + is.NotError(t, err) + + r.Start(ctx) + }) +} + +func ExampleRunner_Start() { + log := slog.Default() + + // Setup the db and goqite schema. + db, err := sql.Open("sqlite3", ":memory:?_journal=WAL&_timeout=5000&_fk=true") + if err != nil { + log.Info("Error opening db", "error", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + if err := goqite.Setup(context.Background(), db); err != nil { + log.Info("Error in setup", "error", err) + } + + // Make a new queue for the jobs. You can have as many of these as you like, just name them differently. + q := goqite.New(goqite.NewOpts{ + DB: db, + Name: "jobs", + }) + + // Make a job runner with a job limit of 1 and a short message poll interval. + r := jobs.NewRunner(jobs.NewRunnerOpts{ + Limit: 1, + Log: slog.Default(), + PollInterval: 10 * time.Millisecond, + Queue: q, + }) + + // Register our "print" job. + r.Register("print", func(ctx context.Context, m []byte) error { + fmt.Println(string(m)) + return nil + }) + + // Create a "print" job with a message. + if err := jobs.Create(context.Background(), q, "print", []byte("Yo")); err != nil { + log.Info("Error creating job", "error", err) + } + + // Stop the job runner after a timeout. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + // Start the job runner and see the job run. + r.Start(ctx) + + // Output: Yo +} + +func newRunner(t *testing.T) (*goqite.Queue, *jobs.Runner) { + t.Helper() + + q := internaltesting.NewQ(t, goqite.NewOpts{}, ":memory:") + r := jobs.NewRunner(jobs.NewRunnerOpts{Log: internaltesting.NewLogger(t), Queue: q}) + return q, r +}