-
Notifications
You must be signed in to change notification settings - Fork 2
/
waiter.go
115 lines (96 loc) · 2.18 KB
/
waiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package async
import (
"context"
)
type Waiter[T any] interface {
// Add add a task
Add(task Task[T])
// Wait wail for all tasks to completed
Wait(context.Context) ([]T, []error, error)
// WaitAny wait for any task to completed without error, can cancel other tasks
WaitAny(context.Context) (T, []error, error)
// WaitN wait for N tasks to completed without error
WaitN(context.Context, int) ([]T, []error, error)
}
type waiter[T any] struct {
tasks []Task[T]
}
func (a *waiter[T]) Add(task Task[T]) {
a.tasks = append(a.tasks, task)
}
func (a *waiter[T]) Wait(ctx context.Context) ([]T, []error, error) {
wait := make(chan Result[T])
for _, task := range a.tasks {
go func(task func(context.Context) (T, error)) {
r, err := task(ctx)
wait <- Result[T]{
Data: r,
Error: err,
}
}(task)
}
var r Result[T]
var taskErrs []error
var items []T
tt := len(a.tasks)
for i := 0; i < tt; i++ {
select {
case r = <-wait:
if r.Error != nil {
taskErrs = append(taskErrs, r.Error)
} else {
items = append(items, r.Data)
}
case <-ctx.Done():
return items, taskErrs, ctx.Err()
}
}
if len(items) == tt {
return items, taskErrs, nil
}
return items, taskErrs, ErrTooLessDone
}
func (a *waiter[T]) WaitN(ctx context.Context, n int) ([]T, []error, error) {
wait := make(chan Result[T])
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
for _, task := range a.tasks {
go func(task func(context.Context) (T, error)) {
r, err := task(cancelCtx)
wait <- Result[T]{
Data: r,
Error: err,
}
}(task)
}
var r Result[T]
var taskErrs []error
var items []T
tt := len(a.tasks)
var done int
for i := 0; i < tt; i++ {
select {
case r = <-wait:
if r.Error != nil {
taskErrs = append(taskErrs, r.Error)
} else {
items = append(items, r.Data)
done++
if done == n {
return items, taskErrs, nil
}
}
case <-ctx.Done():
return items, taskErrs, ctx.Err()
}
}
return items, taskErrs, ErrTooLessDone
}
func (a *waiter[T]) WaitAny(ctx context.Context) (T, []error, error) {
var t T
result, taskErrs, err := a.WaitN(ctx, 1)
if len(result) == 1 {
t = result[0]
}
return t, taskErrs, err
}