-
Notifications
You must be signed in to change notification settings - Fork 33
/
dns-update-proxy
executable file
·366 lines (312 loc) · 14.5 KB
/
dns-update-proxy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
#!/usr/bin/env python3
import os, sys, logging, contextlib, asyncio, socket, signal
import stat, tempfile, re, string, base64, json, inspect, secrets
import unicodedata, pathlib as pl, subprocess as sp
import libnacl, libnacl.public
err_fmt = lambda err: '[{}] {}'.format(err.__class__.__name__, err)
class LogMessage:
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), **log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
log = get_logger('dup.main')
class adict(dict):
def __init__(self, *args, **kws):
super().__init__(*args, **kws)
self.__dict__ = self
b64_decode = lambda s: ( base64.urlsafe_b64decode
if '-' in s or '_' in s else base64.standard_b64decode )(s)
str_norm = lambda v: unicodedata.normalize('NFKC', v.strip()).casefold()
def str_part(s, sep, default=None):
'Examples: str_part("user@host", "<@", "root"), str_part("host:port", ":>")'
c = sep.strip('<>')
if sep.strip(c) == '<': return (default, s) if c not in s else s.split(c, 1)
else: return (s, default) if c not in s else s.rsplit(c, 1)
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
async def aio_await_wrap(res):
'Wraps coroutine, callable that creates one or any other awaitable.'
if not inspect.isawaitable(res) and callable(res): res = res()
if inspect.isawaitable(res): res = await res
return res
async def aio_task_cancel(task_list):
'Cancel and await a task or a list of such, which can have empty values mixed-in.'
if inspect.isawaitable(task_list): task_list = [task_list]
task_list = list(filter(None, task_list))
for task in task_list:
with contextlib.suppress(asyncio.CancelledError): task.cancel()
for task in task_list:
with contextlib.suppress(asyncio.CancelledError): await task
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
@contextlib.contextmanager
def safe_replacement(path, *open_args, mode=None, **open_kws):
path = str(path)
if mode is None:
try: mode = stat.S_IMODE(os.lstat(path).st_mode)
except OSError: pass
open_kws.update( delete=False,
dir=os.path.dirname(path), prefix=os.path.basename(path)+'.' )
if not open_args: open_kws['mode'] = 'w'
with tempfile.NamedTemporaryFile(*open_args, **open_kws) as tmp:
try:
if mode is not None: os.fchmod(tmp.fileno(), mode)
yield tmp
if not tmp.closed:
os.fsync(tmp.fileno())
tmp.close()
os.rename(tmp.name, path)
finally:
try: os.unlink(tmp.name)
except OSError: pass
class StacklessContext:
'''Like AsyncContextStack, but for tracking tasks that
can finish at any point without leaving stack frames.'''
def __init__(self, log=None): self.tasks, self.log = dict(), log
async def __aenter__(self): return self
async def __aexit__(self, *err):
if self.tasks:
task_list, self.tasks = self.tasks.values(), None
await aio_task_cancel(task_list)
async def close(self): await self.__aexit__(None, None, None)
def add_task(self, coro, run_after=None):
'Start task eating its own tail, with an optional success-only callback'
task_id = None
async def _task_wrapper(coro=coro):
try:
await aio_await_wrap(coro)
if run_after:
coro = run_after()
await aio_await_wrap(coro)
except asyncio.CancelledError: pass
except Exception as err:
if self.log: self.log.exception('Background task failed: {} - {}', coro, err_fmt(err))
finally:
assert task_id is not None, task_id
if self.tasks: self.tasks.pop(task_id, None)
task = asyncio.create_task(_task_wrapper())
task_id = id(task)
self.tasks[task_id] = task
return task
add = add_task
class DUPServerProtocol:
transport = None
def __init__(self):
self.updates = asyncio.Queue()
self.log = get_logger('dup.udp')
def connection_made(self, transport):
self.log.debug('Connection made')
self.transport = transport
def datagram_received(self, data, addr):
self.log.debug('Received {:,d}B from {!r}', len(data), addr)
self.updates.put_nowait((data, addr))
def connection_lost(self, err):
self.log.debug('Connection lost: {}', err or 'done')
self.updates.put_nowait(None)
class DUPServerError(Exception): pass
def parse_update_blob(box_key, pk_whitelist, req, req_addr):
n = libnacl.crypto_box_PUBLICKEYBYTES
pk, req = req[:n], req[n:]
if pk_whitelist and not any(
secrets.compare_digest(pk, pk_chk) for pk_chk in pk_whitelist ): return
try: update = libnacl.public.Box(box_key, pk).decrypt(req)
except libnacl.CryptError: return
return update.decode().split(':', 1)
async def dup_listen(box_key, pk_whitelist, sock_af, sock_p, host, port, api_conf):
update_ts, loop, dup = dict(), asyncio.get_running_loop(), asyncio.current_task()
for sig in 'INT TERM'.split(): loop.add_signal_handler(getattr(signal, f'SIG{sig}'), dup.cancel)
async with contextlib.AsyncExitStack() as ctx:
transport, proto = await loop.create_datagram_endpoint(
lambda: DUPServerProtocol(), local_addr=(host, port), family=sock_af, proto=sock_p )
ctx.callback(transport.close)
tasks = await ctx.enter_async_context(StacklessContext(log))
if api_conf.get('gandi'):
import aiohttp
api_conf.http = await ctx.enter_async_context(aiohttp.ClientSession(
headers = {'X-Api-Key': api_conf.gandi},
timeout=aiohttp.ClientTimeout(connect=40.0, sock_connect=30.0) ))
while True:
req, req_addr = await proto.updates.get()
update = parse_update_blob(box_key, pk_whitelist, req, req_addr)
if not update:
log.debug('Skipping update-blob [{}B] from {} - auth/crypto error', len(req), req_addr)
continue
domain, addr = update
if api_conf.update_filter:
for d in api_conf.update_filter:
if domain == d or domain.endswith(f'.{d}'): break
else: continue
ts, ts_chk = loop.time(), update_ts.get(domain)
if ts_chk and ts_chk + api_conf.update_timeout > ts: continue
update_ts[domain] = ts
tasks.add(dns_update(api_conf, domain, addr))
async def dns_update(api_conf, domain, addr):
if api_conf.check:
try:
dns_addr_list = await asyncio.get_running_loop()\
.getaddrinfo(domain, None, family=socket.AF_INET)
except (socket.gaierror, socket.error) as err: dns_addr_list = None
for gai_tuple in dns_addr_list or list():
if gai_tuple[4][0] == addr: return
if api_conf.get('zone_file'): await dns_update_local(api_conf, domain, addr)
if api_conf.get('gandi'): await dns_update_gandi(api_conf, domain, addr)
async def dns_update_gandi(api_conf, domain, addr):
suffix, loop = f'.{api_conf.gandi_zone}', asyncio.get_running_loop()
if not domain.endswith(suffix): return
rec, rec_t = domain[:-len(suffix)], 'AAAA' if ':' in addr else 'A'
rec_url = ( 'https://dns.api.gandi.net/api/'
f'v5/domains/{api_conf.gandi_zone}/records/{rec}/{rec_t}' )
rec_addr = dict(rrset_values=[addr], rrset_ttl=10 * 60)
for n, delay in enumerate(api_conf.delays + [0], 1):
deadline = loop.time() + delay
log.debug('[{}] Gandi zone update (attempt={}): {} {}', rec, n, rec_url, rec_addr)
res = await api_conf.http.put(rec_url, json=rec_addr)
async with res as res:
if res.status in [200, 201]:
log.debug('[{}] Gandi zone update success (attempt={})', rec, n)
return
res_body = await res.read()
log_func = log.info if not delay else log.error
log_func(
'[{}] Failed to set record {!r} to {}={!r} (attempt={}): {} {} - {}',
rec, rec, rec_t, addr, n, res.status, res.reason, res_body )
if delay: await asyncio.sleep(max(0, deadline - loop.time()))
async def dns_update_local(api_conf, domain, addr):
suffix = f'.{api_conf.zone_name}'
if not domain.endswith(suffix): return
rec, rec_t = domain[:-len(suffix)], 'AAAA' if ':' in addr else 'A'
zone, replaced = list(), None
line_re = re.compile(
r'(?i)^(\s*' + re.escape(rec)
+ r'\s+IN\s+' + rec_t + r'\s+)(\S+)\s*$' )
for line in api_conf.zone_file.read_text().splitlines():
m = line_re.search(line)
if m:
if str_norm(m.group(2)) == str_norm(addr): replaced = False
if replaced is None: line, replaced = m.group(1) + addr, True
elif replaced is True: line = None
if line is not None: zone.append(line.rstrip())
log.debug('[{}] Zone line replace-status: {!r}', rec, replaced)
if replaced is False: return # no change needed
if replaced is None:
while zone and not zone[-1].strip(): zone.pop()
zone.append(f'{rec} IN {rec_t} {addr}')
if zone and zone[-1] != '': zone.append('')
with safe_replacement(api_conf.zone_file) as dst: dst.write('\n'.join(zone))
log.debug('[{}] Zone file updated ({} lines)', rec, len(zone))
if api_conf.zone_cmd:
log.debug('[{}] Running zone-update cmd: {}', rec, ' '.join(api_conf.zone_cmd))
sp.run(api_conf.zone_cmd)
def main(args=None):
import argparse
parser = argparse.ArgumentParser(
description='Script to update DNS information via received UDP packets.')
group = parser.add_argument_group('Update checks')
group.add_argument('-d', '--update', metavar='domain', action='append',
help='Domain(s) to update subdomains in. Requests to update other ones are ignored.')
group.add_argument('-c', '--check', action='store_true',
help='Check if received domain name already resolves to IP before updating it.'
' Any IP in getaddrinfo() counts, any system-wide options for that apply.')
group = parser.add_argument_group('Bind/auth options')
group.add_argument('bind', nargs='?', default='::',
help='Host or address (to be resolved via gai) to listen on.'
' Can include port, which will override -p/--port option, if specified here.'
' Default is to use "::" wildcard IPv4/IPv6 binding.')
group.add_argument('-p', '--port',
default=1234, type=int, metavar='port',
help='Local UDP port to listen on for info from clients (default: %(default)s).'
' Can also be specified in the "bind" argument, which overrides this option.')
group.add_argument('-k', '--key-file', required=True, metavar='path',
help='Path to file with base64-encrypted ed25519 key to use for incoming packets.')
group.add_argument('-a', '--auth-key', metavar='key/path',
help='Space-separated list base64-encoded client public keys'
' or path to a file with such, if it starts with "." or "/".'
' If specified, only these clients will be authorized for name-update requests.')
group = parser.add_argument_group('DNS update - local zone file')
group.add_argument('--zone-file', metavar='path',
help='Path to BIND-style zone file in which'
' to lookup/add/update corresponding A record.')
group.add_argument('--zone-name', metavar='domain',
help='Zone name/suffix to assume/strip for names in --zone-file.')
group.add_argument('--zone-update-cmd', metavar='cmd',
help='Command to run after --zone-file is updated, split on spaces.')
group = parser.add_argument_group('DNS update - gandi.net')
group.add_argument('--gandi-api-key-file', metavar='path',
help='Path to gandi.net LiveDNS API key (X-Api-Key header).')
group.add_argument('--gandi-zone', metavar='domain',
help='Top-level domain to operate on, corresponding to zone name.')
group = parser.add_argument_group('Misc options')
group.add_argument('-t', '--update-timeout', type=float, metavar='s', default=1800,
help='Min delay before checking/updating same domain again. Default: %(default)ss')
group.add_argument('-r', '--retry', metavar='n:sec', default='5:600',
help='DNS update retry options, in "attempts:timespan" format. Default: %(default)s')
group.add_argument('--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
box_key = b64_decode(pl.Path(opts.key_file).read_text().strip())
pk_whitelist = (opts.auth_key or '').strip() or set()
if pk_whitelist:
if pk_whitelist[0] in './': pk_whitelist = pl.Path(pk_whitelist).read_text()
pk_whitelist = set(map(b64_decode, pk_whitelist.split()))
retry_n, retry_timeout = opts.retry.split(':', 1)
api_conf = adict( check=opts.check,
delays=retries_within_timeout(int(retry_n), float(retry_timeout)),
update_timeout=opts.update_timeout,
update_filter=list(d.strip('.') for d in (opts.update or list())) )
if not (opts.gandi_api_key_file or opts.zone_file):
parser.error('No DNS update APIs enabled, exiting')
if opts.gandi_api_key_file:
if not opts.gandi_zone:
parser.error('gandi: --gandi-zone must be specified')
api_conf.update( gandi_zone=opts.gandi_zone,
gandi=pl.Path(opts.gandi_api_key_file).read_text().strip() )
if opts.zone_file:
zone_file = pl.Path(opts.zone_file)
if not zone_file.exists():
parser.error(f'Specified --zone-file does not exists: {str(zone_file)!r}')
zone_cmd = opts.zone_update_cmd
if zone_cmd: zone_cmd = zone_cmd.split()
api_conf.update( zone_file=zone_file,
zone_name=opts.zone_name, zone_cmd=zone_cmd )
host, port, family = opts.bind, opts.port, 0
if host.count(':') > 1: host, port = str_part(host, ']:>', port)
else: host, port = str_part(host, ':>', port)
if '[' in host: family = socket.AF_INET6
host, port = host.strip('[]'), int(port)
try:
addrinfo = socket.getaddrinfo( host, str(port),
family=family, type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP )
if not addrinfo: raise socket.gaierror(f'No addrinfo for host: {host}')
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters'
' via getaddrinfo: {!r} - {}'.format((host, port), err_fmt(err)) )
sock_af, sock_t, sock_p, _, sock_addr = addrinfo[0]
log.debug(
'Resolved host:port {!r}:{!r} to endpoint: {} (family: {}, type: {}, proto: {})',
host, port, sock_addr,
*(sockopt_resolve(pre, n) for pre, n in [
('af_', sock_af), ('sock_', sock_t), ('ipproto_', sock_p) ]) )
host, port = sock_addr[:2]
asyncio.run(dup_listen( box_key,
pk_whitelist, sock_af, sock_p, host, port, api_conf ))
if __name__ == '__main__': sys.exit(main())