-
Notifications
You must be signed in to change notification settings - Fork 2
/
disco.go
175 lines (147 loc) · 3.83 KB
/
disco.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
package disco
import (
"context"
"errors"
"net"
"sync"
"time"
"github.com/forestgiant/disco/multicast"
"github.com/forestgiant/disco/node"
)
// Disco represents a list of discovered devices
type Disco struct {
mu sync.Mutex // protects members
members []*node.Node // stores all nodes registered
discoveredChan chan *node.Node // node.Serve() sends nodes to this chan
}
// Members returns all node in members slice
func (d *Disco) Members() []*node.Node {
d.mu.Lock()
defer d.mu.Unlock()
return d.members
}
// Discover listens for multicast sends and registers any nodes it finds
func (d *Disco) Discover(ctx context.Context, multicastAddress string) (<-chan *node.Node, error) {
if multicastAddress == "" {
return nil, errors.New("Address is blank")
}
ip, _, err := net.SplitHostPort(multicastAddress)
if err != nil {
return nil, err
}
if !net.ParseIP(ip).IsMulticast() {
return nil, errors.New("multicastAddress is not valid")
}
if d.discoveredChan == nil {
d.discoveredChan = make(chan *node.Node)
}
results := make(chan *node.Node)
m := &multicast.Multicast{Address: multicastAddress}
respChan, err := m.Listen(ctx)
if err != nil {
return nil, err
}
if d.members == nil {
d.members = []*node.Node{}
}
go func() {
for {
select {
case resp := <-respChan:
rn, err := node.DecodeNode(resp.Payload)
if err != nil {
continue
}
// If we receive a node registration then register or keep alive
if rn.Action == node.RegisterAction {
rn.SrcIP = resp.SrcIP // set the source address
if d.addToMembers(rn) {
d.register(results, rn)
} else {
if index := d.indexOfMember(rn); index != -1 {
d.mu.Lock()
d.members[index].KeepRegistered()
d.mu.Unlock()
}
}
} else {
// If we receive an explicit deregister multicast remove it and send to results
if d.deregister(rn) {
results <- rn
}
}
case <-ctx.Done():
return
}
}
}()
return results, nil
}
// register adds newly discovered nodes to the d.members slice and sends the node
// over the result chan. Then it creates a new goroutine for each node that checks
// if it can read on it's registerCh. If it can't within rn.SendInterval * 3 it deregisters
func (d *Disco) register(results chan *node.Node, rn *node.Node) {
// If it's new to the members send it as a result
rn.Action = node.RegisterAction
d.mu.Lock()
d.members = append(d.members, rn)
d.mu.Unlock()
go func() {
for {
rn.Mutex.Lock()
rn.Heartbeat = time.NewTimer(rn.SendInterval * 3)
rn.Mutex.Unlock()
select {
case <-rn.RegisterCh():
rn.Heartbeat.Stop()
continue
case <-rn.Heartbeat.C:
rn.Heartbeat.Stop()
// Deregister if it times out
rn.Action = node.DeregisterAction
d.deregister(rn)
results <- rn
return
}
}
}()
results <- rn
}
// deregister takes a node and removes it from the d.members slice
func (d *Disco) deregister(n *node.Node) bool {
d.mu.Lock()
defer d.mu.Unlock()
// Remove node from regsistered
for i := len(d.members) - 1; i >= 0; i-- {
m := d.members[i]
// make sure the node we sent matches
if m.Equal(n) {
// stop heartbeat and remove from member slice
m.Mutex.Lock()
m.Heartbeat.Stop()
m.Mutex.Unlock()
d.members = append(d.members[:i], d.members[i+1:]...)
return true
}
}
return false
}
// Check if the members slice already has the node if it doesn't add it
func (d *Disco) addToMembers(n *node.Node) bool {
for _, m := range d.Members() {
if m.Equal(n) {
return false // node is already a member
}
}
return true
}
// indexOfMember checks if a node is in the d.members slice
// and returns it's index, if it isn't there it returns -1
func (d *Disco) indexOfMember(n *node.Node) int {
for i, a := range d.Members() {
if a.Equal(n) {
return i
}
}
return -1
}