diff --git a/experimental/incremental/doc.go b/experimental/incremental/doc.go new file mode 100644 index 0000000..8665abc --- /dev/null +++ b/experimental/incremental/doc.go @@ -0,0 +1,63 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +package incremental implements a query-oriented incremental compilation +framework. + +The primary type of this package is [Executor], which executes [Query] values +and caches their results. Queries can themselves depend on other queries, and +can request that those dependencies be executed (in parallel) using [Resolve]. + +Queries are intended to be relatively fine-grained. For example, there might be +a query that represents "compile a module" that contains a list of file names as +input. It would then depend on the AST queries for each of those files, from +which it would compute lists of imports, and depend on queries based on those +inputs. + +# Implementing a Query + +Each query must provide a key that uniquely identifies it, and a function for +actually computing it. Queries can partially succeed: instead of a query +returning (T, error), it only returns a T, and errors are flagged to the [Task] +argument. + +If a query cannot proceed, it can call [Task.Fail], which will mark the query +as failed and exit the goroutine dedicated to running that query. No queries +that depend on it will be executed. Non-fatal errors can be recorded with +[Task.Error]. + +This means that generally queries do not need to worry about propagating errors +correctly; this happens automatically in the framework. The entry-point for +query execution, [Run], will return all errors that partially-succeeding or +failing queries return. + +Why can queries partially succeed? Consider a parsing operation. This may +generate diagnostics that we want to bubble up to the caller, but whether or +not the presence of errors is actually fatal depends on what the caller wants +to do with the query result. Thus, queries should generally not fail unless +one of their dependencies produced an error they cannot handle. + +Queries can inspect errors generated by their direct dependencies, but not by +those dependencies' dependencies. ([Run], however, returns all transitive errors). + +# Invalidating Queries + +[Executor] supports invalidating queries by key, which will cause all queries +that depended on that query to be discarded and require recomputing. This can be +used e.g. to mark a file as changed and require that everything that that file +depended on is recomputed. See [Executor.Evict]. +*/ +//nolint:dupword // "that that" is grammatical! +package incremental diff --git a/experimental/incremental/executor.go b/experimental/incremental/executor.go new file mode 100644 index 0000000..97d3dfd --- /dev/null +++ b/experimental/incremental/executor.go @@ -0,0 +1,192 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package incremental + +import ( + "context" + "fmt" + "runtime" + "slices" + "sync" + + "golang.org/x/sync/semaphore" +) + +// Executor is a caching executor for incremental queries. +// +// See [New], [Run], and [Invalidate]. +type Executor struct { + dirty sync.RWMutex + + // TODO: Evaluate alternatives. sync.Map is pretty bad at having predictable + // performance, and we may want to add eviction to keep memoization costs + // in a long-running process (like, say, a language server) down. + // See https://github.com/dgraph-io/ristretto as a potential alternative. + tasks sync.Map // [any, *task] + + sema *semaphore.Weighted +} + +// ExecutorOption is an option func for [New]. +type ExecutorOption func(*Executor) + +// New constructs a new executor with the given maximum parallelism. +func New(options ...ExecutorOption) *Executor { + exec := &Executor{ + sema: semaphore.NewWeighted(int64(runtime.GOMAXPROCS(0))), + } + + for _, opt := range options { + opt(exec) + } + + return exec +} + +// WithParallelism sets the maximum number of queries that can execute in +// parallel. Defaults to GOMAXPROCS if not set explicitly. +func WithParallelism(n int64) ExecutorOption { + return func(e *Executor) { e.sema = semaphore.NewWeighted(n) } +} + +// Keys returns a snapshot of the keys of which queries are present (and +// memoized) in an Executor. +// +// The returned slice is sorted. +func (e *Executor) Keys() (keys []string) { + e.tasks.Range(func(k, t any) bool { + task := t.(*task) //nolint:errcheck // All values in this map are tasks. + result := task.result.Load() + if result == nil || !closed(result.done) { + return true + } + keys = append(keys, fmt.Sprintf("%#v", k)) + return true + }) + + slices.Sort(keys) + return +} + +var runExecutorKey byte + +// Run executes a set of queries on this executor in parallel. +// +// This function only returns an error if ctx expires during execution, +// in which case it returns nil and [context.Cause]. +// +// Errors that occur during each query are contained within the returned results. +// Unlike [Resolve], these contain the *transitive* errors for each query! +// +// Implementations of [Query].Execute MUST NOT UNDER ANY CIRCUMSTANCES call +// Run. This will result in potential resource starvation or deadlocking, and +// defeats other correctness verification (such as cycle detection). Instead, +// use [Resolve], which takes a [Task] instead of an [Executor]. +// +// Note: this function really wants to be a method of [Executor], but it isn't +// because it's generic. +func Run[T any](ctx context.Context, e *Executor, queries ...Query[T]) (results []Result[T], expired error) { + e.dirty.RLock() + defer e.dirty.RUnlock() + + // Verify we haven't reëntrantly called Run. + if callers, ok := ctx.Value(&runExecutorKey).(*[]*Executor); ok { + if slices.Contains(*callers, e) { + panic("protocompile/incremental: reentrant call to Run") + } + *callers = append(*callers, e) + } else { + ctx = context.WithValue(ctx, &runExecutorKey, &[]*Executor{e}) + } + ctx, cancel := context.WithCancelCause(ctx) + + // Need to acquire a hold on the global semaphore to represent the root + // task we're about to execute. + if e.sema.Acquire(ctx, 1) != nil { + return nil, context.Cause(ctx) + } + defer e.sema.Release(1) + + root := Task{ + ctx: ctx, + cancel: cancel, + exec: e, + result: &result{done: make(chan struct{})}, + } + + results, expired = Resolve(root, queries...) + if expired != nil { + return nil, expired + } + + // Now, for each result, we need to walk their dependencies and collect + // their dependencies' non-fatal errors. + for i, query := range queries { + task := e.getTask(query.Key()) + for dep := range task.deps { + r := &results[i] + r.NonFatal = append(r.NonFatal, dep.result.Load().NonFatal...) + } + } + + return results, nil +} + +// Evict marks query keys as invalid, requiring those queries, and their +// dependencies, to be recomputed. keys that are not cached are ignored. +// +// This function cannot execute in parallel with calls to [Run], and will take +// an exclusive lock (note that [Run] calls themselves can be run in parallel). +func (e *Executor) Evict(keys ...any) { + var queue []*task + for _, key := range keys { + if t, ok := e.tasks.Load(key); ok { + queue = append(queue, t.(*task)) + } else { + return + } + } + if len(queue) == 0 { + return + } + + e.dirty.Lock() + defer e.dirty.Unlock() + + for len(queue) > 0 { + next := queue[0] + queue = queue[1:] + + next.downstream.Range(func(k, _ any) bool { + queue = append(queue, k.(*task)) + return true + }) + + // Clear everything. We don't need to synchronize here because we have + // unique ownership of the task. + *next = task{} + } +} + +// getTask returns (and creates if necessary) a task pointer for the given key. +func (e *Executor) getTask(key any) *task { + // Avoid allocating a new task object in the common case. + if t, ok := e.tasks.Load(key); ok { + return t.(*task) + } + + t, _ := e.tasks.LoadOrStore(key, new(task)) + return t.(*task) +} diff --git a/experimental/incremental/executor_test.go b/experimental/incremental/executor_test.go new file mode 100644 index 0000000..a5b0acf --- /dev/null +++ b/experimental/incremental/executor_test.go @@ -0,0 +1,226 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package incremental_test + +import ( + "context" + "fmt" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/bufbuild/protocompile/experimental/incremental" +) + +type Root struct{} + +func (r Root) Key() any { + return r +} + +func (Root) Execute(_ incremental.Task) (struct{}, error) { + time.Sleep(100 * time.Millisecond) + return struct{}{}, nil +} + +type ParseInt struct { + Input string +} + +func (i ParseInt) Key() any { + return i +} + +func (i ParseInt) Execute(t incremental.Task) (int, error) { + // This tests that a thundering stampede of queries all waiting on the same + // query (as in a diamond-shaped graph) do not cause any issues. + _, err := incremental.Resolve(t, Root{}) + if err != nil { + return 0, err + } + + v, err := strconv.Atoi(i.Input) + if err != nil { + t.NonFatal(err) + } + if v < 0 { + return 0, fmt.Errorf("negative value: %v", v) + } + return v, nil +} + +type Sum struct { + Input string +} + +func (s Sum) Key() any { + return s +} + +func (s Sum) Execute(t incremental.Task) (int, error) { + var queries []incremental.Query[int] //nolint:prealloc + for _, s := range strings.Split(s.Input, ",") { + queries = append(queries, ParseInt{s}) + } + + ints, err := incremental.Resolve(t, queries...) + if err != nil { + return 0, err + } + + var v int + for _, i := range ints { + if i.Fatal != nil { + return 0, i.Fatal + } + + v += i.Value + } + return v, nil +} + +type Cyclic struct { + Mod, Step int +} + +func (c Cyclic) Key() any { + return c +} + +func (c Cyclic) Execute(t incremental.Task) (int, error) { + next, err := incremental.Resolve(t, Cyclic{ + Mod: c.Mod, + Step: (c.Step + 1) % c.Mod, + }) + if err != nil { + return 0, err + } + if next[0].Fatal != nil { + return 0, next[0].Fatal + } + + return next[0].Value * next[0].Value, nil +} + +func TestSum(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + ctx := context.Background() + exec := incremental.New( + incremental.WithParallelism(4), + ) + + result, err := incremental.Run(ctx, exec, Sum{"1,2,2,3,4"}) + require.NoError(t, err) + assert.Equal(12, result[0].Value) + assert.Empty(result[0].NonFatal) + assert.Equal([]string{ + `incremental_test.ParseInt{Input:"1"}`, + `incremental_test.ParseInt{Input:"2"}`, + `incremental_test.ParseInt{Input:"3"}`, + `incremental_test.ParseInt{Input:"4"}`, + `incremental_test.Root{}`, + `incremental_test.Sum{Input:"1,2,2,3,4"}`, + }, exec.Keys()) + + result, err = incremental.Run(ctx, exec, Sum{"1,2,2,oops,4"}) + require.NoError(t, err) + assert.Equal(9, result[0].Value) + assert.Len(result[0].NonFatal, 1) + assert.Equal([]string{ + `incremental_test.ParseInt{Input:"1"}`, + `incremental_test.ParseInt{Input:"2"}`, + `incremental_test.ParseInt{Input:"3"}`, + `incremental_test.ParseInt{Input:"4"}`, + `incremental_test.ParseInt{Input:"oops"}`, + `incremental_test.Root{}`, + `incremental_test.Sum{Input:"1,2,2,3,4"}`, + `incremental_test.Sum{Input:"1,2,2,oops,4"}`, + }, exec.Keys()) + + exec.Evict(ParseInt{"4"}) + assert.Equal([]string{ + `incremental_test.ParseInt{Input:"1"}`, + `incremental_test.ParseInt{Input:"2"}`, + `incremental_test.ParseInt{Input:"3"}`, + `incremental_test.ParseInt{Input:"oops"}`, + `incremental_test.Root{}`, + }, exec.Keys()) + + result, err = incremental.Run(ctx, exec, Sum{"1,2,2,3,4"}) + require.NoError(t, err) + assert.Equal(12, result[0].Value) + assert.Empty(result[0].NonFatal) + assert.Equal([]string{ + `incremental_test.ParseInt{Input:"1"}`, + `incremental_test.ParseInt{Input:"2"}`, + `incremental_test.ParseInt{Input:"3"}`, + `incremental_test.ParseInt{Input:"4"}`, + `incremental_test.ParseInt{Input:"oops"}`, + `incremental_test.Root{}`, + `incremental_test.Sum{Input:"1,2,2,3,4"}`, + }, exec.Keys()) +} + +func TestFatal(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + ctx := context.Background() + exec := incremental.New( + incremental.WithParallelism(4), + ) + + result, err := incremental.Run(ctx, exec, Sum{"1,2,-3,-4"}) + require.NoError(t, err) + // NOTE: This error is deterministic, because it's chosen by Sum.Execute. + assert.Equal("negative value: -3", result[0].Fatal.Error()) + assert.Equal([]string{ + `incremental_test.ParseInt{Input:"-3"}`, + `incremental_test.ParseInt{Input:"-4"}`, + `incremental_test.ParseInt{Input:"1"}`, + `incremental_test.ParseInt{Input:"2"}`, + `incremental_test.Root{}`, + `incremental_test.Sum{Input:"1,2,-3,-4"}`, + }, exec.Keys()) +} + +func TestCyclic(t *testing.T) { + t.Parallel() + assert := assert.New(t) + + ctx := context.Background() + exec := incremental.New( + incremental.WithParallelism(4), + ) + + result, err := incremental.Run(ctx, exec, Cyclic{Mod: 5, Step: 3}) + require.NoError(t, err) + assert.Equal( + `cycle detected: `+ + `incremental_test.Cyclic{Mod:5, Step:3} -> `+ + `incremental_test.Cyclic{Mod:5, Step:4} -> `+ + `incremental_test.Cyclic{Mod:5, Step:0} -> `+ + `incremental_test.Cyclic{Mod:5, Step:1} -> `+ + `incremental_test.Cyclic{Mod:5, Step:2} -> `+ + `incremental_test.Cyclic{Mod:5, Step:3}`, + result[0].Fatal.Error(), + ) +} diff --git a/experimental/incremental/query.go b/experimental/incremental/query.go new file mode 100644 index 0000000..c0cd7fa --- /dev/null +++ b/experimental/incremental/query.go @@ -0,0 +1,143 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package incremental + +import ( + "fmt" + "strings" +) + +// Query represents an incremental compilation query. +// +// Types which implement Query can be executed by an [Executor], which +// automatically caches the results of a query. +type Query[T any] interface { + // Returns a unique key representing this query. + // + // This should be a comparable struct type unique to the query type. Failure + // to do so may result in different queries with the same key, which may + // result in incorrect results or panics. + Key() any + + // Executes this query. This function will only be called if the result of + // this query is not already in the [Executor]'s cache. + // + // The error return should only be used to signal if the query failed. For + // non-fatal errors, you should record that information with [Task.NonFatal]. + // + // Implementations of this function MUST NOT call [Run] on the executor that + // is executing them. This will defeat correctness detection, and lead to + // resource starvation (and potentially deadlocks). + // + // Panicking in execute is not interpreted as a fatal error that should be + // memoized; instead, it is treated as cancellation of the context that + // was passed to [Run]. + Execute(Task) (value T, fatal error) +} + +// ErrCycle is returned by [Resolve] if a cycle occurs during query execution. +type ErrCycle struct { + // The offending cycle. The first and last queries will have the same URL. + // + // To inspect the concrete types of the cycle members, use [DowncastQuery], + // which will automatically unwrap any calls to [AnyQuery]. + Cycle []*AnyQuery +} + +// Error implements [error]. +func (e *ErrCycle) Error() string { + var buf strings.Builder + buf.WriteString("cycle detected: ") + for i, q := range e.Cycle { + if i != 0 { + buf.WriteString(" -> ") + } + fmt.Fprintf(&buf, "%#v", q.Key()) + } + return buf.String() +} + +// ErrPanic is returned by [Run] if any of the queries it executes panic. +// This error is used to cancel the [context.Context] that governs the call to +// [Run]. +type ErrPanic struct { + Query *AnyQuery // The query that panicked. + Panic any // The actual value passed to panic(). + Backtrace string // A backtrace for the panic. +} + +// Error implements [error]. +func (e *ErrPanic) Error() string { + return fmt.Sprintf( + "call to Query.Execute (key: %#v) panicked: %v\n%s", + e.Query.Key(), e.Panic, e.Backtrace, + ) +} + +// AnyQuery is a [Query] that has been type-erased. +type AnyQuery struct { + actual, key any + execute func(Task) (any, error) +} + +// AsAny type-erases a [Query]. +// +// This is intended to be combined with [Resolve], for cases where queries +// of different types want to be run in parallel. +// +// Panics if q is nil. +func AsAny[T any](q Query[T]) *AnyQuery { + if q, ok := any(q).(*AnyQuery); ok { + return q + } + + return &AnyQuery{ + actual: q, + key: q.Key(), + execute: func(t Task) (any, error) { return q.Execute(t) }, + } +} + +// Underlying returns the original, non-AnyQuery query this query was +// constructed with. +func (q *AnyQuery) Underlying() any { + return q.actual +} + +// Key implements [Query]. +func (q *AnyQuery) Key() any { return q.key } + +// Execute implements [Query]. +func (q *AnyQuery) Execute(t Task) (any, error) { return q.execute(t) } + +// AsTyped undoes the effect of [AsAny]. +// +// For some Query[any] values, you may be able to use ordinary Go type +// assertions, if the underlying type actually implements Query[any]. However, +// to downcast to a concrete Query[T] type, you must use this function. +func AsTyped[Q Query[T], T any](q Query[any]) (downcast Q, ok bool) { + if downcast, ok := q.(Q); ok { + return downcast, true + } + + qAny, ok := q.(*AnyQuery) + if !ok { + var zero Q + return zero, false + } + + downcast, ok = qAny.actual.(Q) + return downcast, ok +} diff --git a/experimental/incremental/task.go b/experimental/incremental/task.go new file mode 100644 index 0000000..fd727ed --- /dev/null +++ b/experimental/incremental/task.go @@ -0,0 +1,325 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package incremental + +import ( + "context" + "runtime/debug" + "slices" + "sync" + "sync/atomic" + + "github.com/bufbuild/protocompile/internal/iter" +) + +// Task represents a query that is currently being executed. +// +// Values of type Task are passed to [Query]. The main use of a Task is to +// be passed to [Resolve] to resolve dependencies. +type Task struct { + // We need all of the contexts for a call to [Run] to be the same, so to + // avoid user implementations of Query making this mistake (or inserting + // inappropriate timeouts), we pass the context as part of the task context. + ctx context.Context //nolint:containedctx + cancel func(error) + + exec *Executor + task *task + result *result + + // Intrusive linked list node for cycle detection. + path path +} + +// Context returns the cancellation context for this task. +func (t *Task) Context() context.Context { + t.checkDone() + return t.ctx +} + +// Error adds errors to the current query, which will be propagated to all +// queries which depend on it. +// +// This will not cause the query to fail; instead, [Query].Execute should +// return false for the ok value to signal failure. +func (t *Task) NonFatal(errs ...error) { + t.checkDone() + t.result.NonFatal = append(t.result.NonFatal, errs...) +} + +// Resolve executes a set of queries in parallel. Each query is run on its own +// goroutine. +// +// If the context passed [Executor] expires, this will return [context.Cause]. +// The caller must propagate this error to ensure the whole query graph exits +// quickly. Failure to propagate the error will result in incorrect query +// results. +// +// If a cycle is detected for a given query, the query will automatically fail +// and produce an [ErrCycle] for its fatal error. If the call to [Query].Execute +// returns an error, that will be placed into the fatal error value, instead. +// +// Callers should make sure to check each result's fatal error before using +// its value. +// +// Non-fatal errors for each result are only those that occurred as a direct +// result of query execution, and *does not* contain that query's transitive +// errors. This is unlike the behavior of [Run], which only collects errors at +// the very end. This ensures that errors are not duplicated, something that +// is not possible to do mid-query. +// +// Note: this function really wants to be a method of [Task], but it isn't +// because it's generic. +func Resolve[T any](caller Task, queries ...Query[T]) (results []Result[T], expired error) { + caller.checkDone() + + results = make([]Result[T], len(queries)) + deps := make([]*task, len(queries)) + + var wg sync.WaitGroup + wg.Add(len(queries)) + + // TODO: A potential optimization is to make the current goroutine + // execute the zeroth query, which saves on allocating a fresh g for + // *every* query. + anyAsync := false + for i, q := range queries { + i := i // Avoid pesky capture-by-ref of loop induction variables. + + q := AsAny(q) // This will also cache the result of q.Key() for us. + deps[i] = caller.exec.getTask(q.Key()) + + async := deps[i].start(caller, q, func(r *result) { + if r != nil { + if r.Value != nil { + // This type assertion will always succeed, unless the user has + // distinct queries with the same key, which is a sufficiently + // unrecoverable condition that a panic is acceptable. + results[i].Value = r.Value.(T) //nolint:errcheck + } + + results[i].NonFatal = r.NonFatal + results[i].Fatal = r.Fatal + } + + wg.Done() + }) + + anyAsync = anyAsync || async + } + + // Update dependency links for each of our dependencies. This occurs in a + // defer block so that it happens regardless of panicking. + defer func() { + if caller.task == nil { + return + } + for _, dep := range deps { + if dep == nil { + continue + } + + if caller.task.deps == nil { + caller.task.deps = map[*task]struct{}{} + } + + caller.task.deps[dep] = struct{}{} + for dep := range dep.deps { + caller.task.deps[dep] = struct{}{} + } + if caller.task != nil { + dep.downstream.Store(caller.task, struct{}{}) + } + } + }() + + if anyAsync { + // Release our current hold on the global semaphore, since we're about to + // go to sleep. This avoids potential resource starvation for deeply-nested + // queries on low parallelism settings. + caller.exec.sema.Release(1) + wg.Wait() + + // Reacquire from the global semaphore before returning, so + // execution of the calling task may resume. + if caller.exec.sema.Acquire(caller.ctx, 1) != nil { + return nil, context.Cause(caller.ctx) + } + } + + return results, nil +} + +// checkDone returns an error if this task is completed. This is to avoid shenanigans with +// tasks that escape their scope. +func (t *Task) checkDone() { + if closed(t.result.done) { + panic("protocompile/incremental: use of Task after the associated Query.Execute call returned") + } +} + +// task is book-keeping information for a memoized Task in an Executor. +type task struct { + deps map[*task]struct{} // Transitive. + + // TODO: See the comment on Executor.tasks. + downstream sync.Map // [*task, struct{}] + + // If this task has not been started yet, this is nil. + // Otherwise, if it is complete, result.done will be closed. + result atomic.Pointer[result] +} + +// Result is the Result of executing a query on an [Executor], either via +// [Run] or [Resolve]. +type Result[T any] struct { + Value T + NonFatal []error + Fatal error +} + +// result is a Result[any] with a completion channel appended to it. +type result struct { + Result[any] + done chan struct{} +} + +// path is a linked list node for tracking cycles in query dependencies. +type path struct { + Query *AnyQuery + Prev *path +} + +// Walk returns an iterator over the linked list. +func (p *path) Walk() iter.Seq[*path] { + return func(yield func(*path) bool) { + for node := p; node.Query != nil; node = node.Prev { + if !yield(node) { + return + } + } + } +} + +// start executes a query in the context of some task and records the result by +// calling done. +func (t *task) start(caller Task, q *AnyQuery, done func(*result)) (async bool) { + // Common case for cached values; no need to spawn a separate goroutine. + r := t.result.Load() + if r != nil && closed(r.done) { + done(r) + return false + } + + // Complete the rest of the computation asynchronously. + go func() { + done(t.run(caller, q)) + }() + return true +} + +// run actually executes the query passed to start. It is called on its own +// goroutine. +func (t *task) run(caller Task, q *AnyQuery) (output *result) { + output = t.result.Load() + + defer func() { + if panicked := recover(); panicked != nil { + output = nil + caller.cancel(&ErrPanic{ + Query: q, + Panic: panicked, + Backtrace: string(debug.Stack()), + }) + } + + if output != nil && !closed(output.done) { + close(output.done) + } + }() + + // Check for a potential cycle. + var cycle *ErrCycle + caller.path.Walk()(func(node *path) bool { + if node.Query.Key() == q.Key() { + cycle = new(ErrCycle) + + // Re-walk the list to collect the cycle itself. + caller.path.Walk()(func(node2 *path) bool { + cycle.Cycle = append(cycle.Cycle, node2.Query) + return node2 != node + }) + + // Reverse the list so that dependency arrows point to the + // right (i.e., Cycle[n] depends on Cycle[n+1]). + slices.Reverse(cycle.Cycle) + + // Insert a copy of the current query to complete the cycle. + cycle.Cycle = append(cycle.Cycle, AsAny(q)) + return false + } + return true + }) + if cycle != nil { + output.Fatal = cycle + return output + } + + // Try to become the leader (the task responsible for computing the result). + output = &result{done: make(chan struct{})} + if !t.result.CompareAndSwap(nil, output) { + // We failed to become the executor, so we're gonna go to sleep + // until it's done. + select { + case <-t.result.Load().done: + case <-caller.ctx.Done(): + } + + // Reload the result pointer. This is needed if the leader panics, + // because the result will be set to nil. + return t.result.Load() + } + + callee := Task{ + ctx: caller.ctx, + exec: caller.exec, + task: t, + result: output, + path: path{ + Query: q, + Prev: &caller.path, + }, + } + + if callee.exec.sema.Acquire(caller.ctx, 1) != nil { + return nil + } + defer callee.exec.sema.Release(1) + + output.Value, output.Fatal = q.Execute(callee) + return output +} + +// closed checks if ch is closed. This may return false negatives, in that it +// may return false for a channel which is closed immediately after this +// function returns. +func closed[T any](ch <-chan T) bool { + select { + case _, ok := <-ch: + return !ok + default: + return false + } +} diff --git a/go.work.sum b/go.work.sum index 11789fb..a8a6c9e 100644 --- a/go.work.sum +++ b/go.work.sum @@ -21,6 +21,8 @@ cloud.google.com/go/compute v1.6.1/go.mod h1:g85FgpzFvNULZ+S8AYq87axRKuf2Kh7deLq github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/bmatcuk/doublestar/v4 v4.7.1 h1:fdDeAqgT47acgwd9bd9HxJRDmc9UAmPpc+2m0CXv75Q= +github.com/bmatcuk/doublestar/v4 v4.7.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/bufbuild/protocompile v0.2.1-0.20230123224550-da57cd758c2f/go.mod h1:tleDrpPTlLUVmgnEoN6qBliKWqJaZFJXqZdFjTd+ocU= github.com/bufbuild/protocompile v0.13.0/go.mod h1:dr++fGGeMPWHv7jPeT06ZKukm45NJscd7rUxQVzEKRk= github.com/bufbuild/protovalidate-go v0.6.3 h1:wxQyzW035zM16Binbaz/nWAzS12dRIXhZdSUWRY7Fv0= @@ -144,9 +146,10 @@ golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 h1:zf5N6UOrA487eEFacMePxjXAJctxKmyjKUsjA11Uzuk= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= @@ -166,6 +169,7 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -261,6 +265,8 @@ google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs=