-
Notifications
You must be signed in to change notification settings - Fork 73
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
641 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import hmac | ||
import logging | ||
import threading | ||
from collections import deque | ||
from functools import partial | ||
from threading import RLock | ||
|
||
from django.conf import settings | ||
|
||
from judge.balancer.bridge_handler import BridgeHandler | ||
from judge.balancer.judge_handler import JudgeHandler | ||
from judge.bridge.server import Server | ||
|
||
|
||
logger = logging.getLogger('judge.balancer') | ||
|
||
|
||
class JudgeBalancer: | ||
def __init__(self, config): | ||
self.executors = {} | ||
self.config = config | ||
self.judges = set() | ||
self.queue = deque() | ||
self.lock = RLock() | ||
self.judge_to_bridge = {} | ||
self.bridge_to_judge = {} | ||
|
||
self.judge_server = Server( | ||
settings.BALANCER_JUDGE_ADDRESS, | ||
partial(JudgeHandler, balancer=self), | ||
) | ||
|
||
self.bridges = [] | ||
for bridge in config['bridges']: | ||
bridge_id = len(self.bridges) | ||
self.bridges.append(BridgeHandler(balancer=self, bridge_id=bridge_id, **bridge)) | ||
|
||
def run(self): | ||
threading.Thread(target=self.judge_server.serve_forever).start() | ||
for bridge in self.bridges: | ||
bridge.listen() | ||
|
||
def shutdown(self): | ||
self.judge_server.shutdown() | ||
for bridge in self.bridges: | ||
bridge.shutdown() | ||
|
||
def get_paired_bridge(self, judge_name): | ||
return self.judge_to_bridge.get(judge_name) | ||
|
||
def reset_bridge(self, bridge_id): | ||
with self.lock: | ||
if bridge_id in self.bridge_to_judge: | ||
judge = self.bridge_to_judge[bridge_id] | ||
del self.judge_to_bridge[judge.name] | ||
del self.bridge_to_judge[bridge_id] | ||
|
||
def _try_judge(self): | ||
with self.lock: | ||
available = [judge for judge in self.judges if not judge.working] | ||
while available and self.queue: | ||
judge = available.pop() | ||
bridge_id, packet = self.queue.popleft() | ||
self.judge_to_bridge[judge.name] = bridge_id | ||
self.bridge_to_judge[bridge_id] = judge | ||
|
||
packet['storage-namespace'] = self.config['bridges'][bridge_id].get('storage_namespace') | ||
judge.submit(packet) | ||
|
||
def free_judge(self, judge): | ||
with self.lock: | ||
bridge_id = self.judge_to_bridge[judge.name] | ||
del self.judge_to_bridge[judge.name] | ||
del self.bridge_to_judge[bridge_id] | ||
|
||
self._try_judge() | ||
|
||
def authenticate_judge(self, judge_id, key, client_address): | ||
judge_config = ([judge for judge in self.config['judges'] if judge['id'] == judge_id] or [None])[0] | ||
if judge_config is None: | ||
return False | ||
|
||
if not hmac.compare_digest(judge_config.get('key'), key): | ||
logger.warning('Judge authentication failure: %s', client_address) | ||
return False | ||
|
||
return True | ||
|
||
def register_judge(self, judge): | ||
with self.lock: | ||
# Disconnect all judges with the same name, see <https://github.com/DMOJ/online-judge/issues/828> | ||
self.disconnect(judge, force=True) | ||
self.judges.add(judge) | ||
self._try_judge() | ||
|
||
def disconnect(self, judge_id, force=False): | ||
with self.lock: | ||
for judge in self.judges: | ||
if judge.name == judge_id: | ||
judge.disconnect(force=force) | ||
|
||
def remove_judge(self, judge): | ||
with self.lock: | ||
bridge_id = self.judge_to_bridge.get(judge.name) | ||
if bridge_id is not None: | ||
del self.judge_to_bridge[judge.name] | ||
del self.bridge_to_judge[bridge_id] | ||
self.judges.discard(judge) | ||
|
||
def set_runtime_versions(self, executors): | ||
self.executors = executors | ||
for bridge in self.bridges: | ||
bridge.executors_packet(executors) | ||
|
||
def get_runtime_versions(self): | ||
return self.executors | ||
|
||
def queue_submission(self, bridge_id: int, packet: dict): | ||
with self.lock: | ||
self.queue.append((bridge_id, packet)) | ||
self._try_judge() | ||
|
||
def abort_submission(self, bridge_id): | ||
try: | ||
judge = self.bridge_to_judge[bridge_id] | ||
judge.abort() | ||
except KeyError: | ||
pass | ||
|
||
def forward_packet_to_bridge(self, judge_name, packet: dict): | ||
try: | ||
bridge_id = self.judge_to_bridge[judge_name] | ||
self.bridges[bridge_id].send_packet(packet) | ||
except KeyError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
import errno | ||
import json | ||
import logging | ||
import socket | ||
import ssl | ||
import struct | ||
import threading | ||
import time | ||
import zlib | ||
from typing import Optional | ||
|
||
from judge.balancer import sysinfo | ||
|
||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class JudgeAuthenticationFailed(Exception): | ||
pass | ||
|
||
|
||
class BridgeHandler: | ||
SIZE_PACK = struct.Struct('!I') | ||
|
||
ssl_context: Optional[ssl.SSLContext] | ||
|
||
def __init__( | ||
self, | ||
host: str, | ||
port: int, | ||
id: str, | ||
key: str, | ||
balancer, | ||
bridge_id: int, | ||
secure: bool = False, | ||
no_cert_check: bool = False, | ||
cert_store: Optional[str] = None, | ||
**kwargs, | ||
): | ||
self.host = host | ||
self.port = port | ||
self.balancer = balancer | ||
self.name = id | ||
self.key = key | ||
self.bridge_id = bridge_id | ||
self._closed = False | ||
|
||
log.info('Preparing to connect to [%s]:%s as: %s', host, port, id) | ||
if secure: | ||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) | ||
self.ssl_context.options |= ssl.OP_NO_SSLv2 | ||
self.ssl_context.options |= ssl.OP_NO_SSLv3 | ||
|
||
if not no_cert_check: | ||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED | ||
self.ssl_context.check_hostname = True | ||
|
||
if cert_store is None: | ||
self.ssl_context.load_default_certs() | ||
else: | ||
self.ssl_context.load_verify_locations(cafile=cert_store) | ||
log.info('Configured to use TLS.') | ||
else: | ||
self.ssl_context = None | ||
log.info('TLS not enabled.') | ||
|
||
self.secure = secure | ||
self.no_cert_check = no_cert_check | ||
self.cert_store = cert_store | ||
|
||
self._lock = threading.RLock() | ||
self.shutdown_requested = False | ||
|
||
# Exponential backoff: starting at 4 seconds, max 60 seconds. | ||
# If it fails to connect for something like 7 hours, it could RecursionError. | ||
self.fallback = 4 | ||
|
||
self.conn = None | ||
self._do_reconnect() | ||
|
||
def _connect(self): | ||
problems = [] # should be handled by bridged's monitor | ||
versions = self.balancer.get_runtime_versions() | ||
|
||
log.info('Opening connection to: [%s]:%s', self.host, self.port) | ||
|
||
while True: | ||
try: | ||
self.conn = socket.create_connection((self.host, self.port), timeout=5) | ||
except OSError as e: | ||
if e.errno != errno.EINTR: | ||
raise | ||
else: | ||
break | ||
|
||
self.conn.settimeout(300) | ||
self.conn.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | ||
|
||
if self.ssl_context: | ||
log.info('Starting TLS on: [%s]:%s', self.host, self.port) | ||
self.conn = self.ssl_context.wrap_socket(self.conn, server_hostname=self.host) | ||
|
||
log.info('Starting handshake with: [%s]:%s', self.host, self.port) | ||
self.input = self.conn.makefile('rb') | ||
self.handshake(problems, versions, self.name, self.key) | ||
log.info('Judge "%s" online: [%s]:%s', self.name, self.host, self.port) | ||
|
||
def _reconnect(self): | ||
if self.shutdown_requested: | ||
log.info('Shutdown requested, not reconnecting.') | ||
return | ||
|
||
log.warning('Attempting reconnection in %.0fs: [%s]:%s', self.fallback, self.host, self.port) | ||
|
||
if self.conn is not None: | ||
log.info('Dropping old connection.') | ||
self.conn.close() | ||
time.sleep(self.fallback) | ||
self.fallback = min(self.fallback * 1.5, 60) # Limit fallback to one minute. | ||
self._do_reconnect() | ||
|
||
def _do_reconnect(self): | ||
try: | ||
self._connect() | ||
except JudgeAuthenticationFailed: | ||
log.error('Authentication as "%s" failed on: [%s]:%s', self.name, self.host, self.port) | ||
self._reconnect() | ||
except socket.error: | ||
log.exception('Connection failed due to socket error: [%s]:%s', self.host, self.port) | ||
self._reconnect() | ||
|
||
def _read_forever(self): | ||
try: | ||
while True: | ||
packet = self._read_single() | ||
if packet is None: | ||
break | ||
self._receive_packet(packet) | ||
except Exception: | ||
self.balancer.abort_submission(self.bridge_id) | ||
self.balancer.reset_bridge(self.bridge_id) | ||
self._reconnect() | ||
|
||
def _read_single(self) -> Optional[dict]: | ||
if self.shutdown_requested: | ||
return None | ||
|
||
try: | ||
data = self.input.read(BridgeHandler.SIZE_PACK.size) | ||
except socket.error: | ||
self._reconnect() | ||
return self._read_single() | ||
if not data: | ||
self._reconnect() | ||
return self._read_single() | ||
size = BridgeHandler.SIZE_PACK.unpack(data)[0] | ||
try: | ||
packet = zlib.decompress(self.input.read(size)) | ||
except zlib.error: | ||
self._reconnect() | ||
return self._read_single() | ||
else: | ||
return json.loads(packet.decode('utf-8', 'strict')) | ||
|
||
def listen(self): | ||
threading.Thread(target=self._read_forever).start() | ||
|
||
def shutdown(self): | ||
self.shutdown_requested = True | ||
self._close() | ||
|
||
def _close(self): | ||
if self.conn and not self._closed: | ||
try: | ||
# May already be closed despite self._closed == False if a network error occurred and `close` is being | ||
# called as part of cleanup. | ||
self.conn.shutdown(socket.SHUT_RDWR) | ||
except socket.error: | ||
pass | ||
self._closed = True | ||
|
||
def __del__(self): | ||
self.shutdown() | ||
|
||
def send_packet(self, packet: dict): | ||
for k, v in packet.items(): | ||
if isinstance(v, bytes): | ||
# Make sure we don't have any garbage utf-8 from e.g. weird compilers | ||
# *cough* fpc *cough* that could cause this routine to crash | ||
# We cannot use utf8text because it may not be text. | ||
packet[k] = v.decode('utf-8', 'replace') | ||
|
||
raw = zlib.compress(json.dumps(packet).encode('utf-8')) | ||
with self._lock: | ||
try: | ||
assert self.conn is not None | ||
self.conn.sendall(BridgeHandler.SIZE_PACK.pack(len(raw)) + raw) | ||
except Exception: # connection reset by peer | ||
self.balancer.abort_submission(self.bridge_id) | ||
self.balancer.reset_bridge(self.bridge_id) | ||
self._reconnect() | ||
|
||
def _receive_packet(self, packet: dict): | ||
name = packet['name'] | ||
if name == 'ping': | ||
self.ping_packet(packet['when']) | ||
elif name == 'submission-request': | ||
self.submission_acknowledged_packet(packet['submission-id']) | ||
self.balancer.queue_submission(self.bridge_id, packet) | ||
elif name == 'terminate-submission': | ||
self.balancer.abort_submission(self.bridge_id) | ||
elif name == 'disconnect': | ||
self.balancer.abort_submission(self.bridge_id) | ||
self._close() | ||
else: | ||
log.error('Unknown packet %s, payload %s', name, packet) | ||
|
||
def handshake(self, problems: str, runtimes, id: str, key: str): | ||
self.send_packet({'name': 'handshake', 'problems': problems, 'executors': runtimes, 'id': id, 'key': key}) | ||
log.info('Awaiting handshake response: [%s]:%s', self.host, self.port) | ||
try: | ||
data = self.input.read(BridgeHandler.SIZE_PACK.size) | ||
size = BridgeHandler.SIZE_PACK.unpack(data)[0] | ||
packet = zlib.decompress(self.input.read(size)).decode('utf-8', 'strict') | ||
resp = json.loads(packet) | ||
except Exception: | ||
log.exception('Cannot understand handshake response: [%s]:%s', self.host, self.port) | ||
raise JudgeAuthenticationFailed() | ||
else: | ||
if resp['name'] != 'handshake-success': | ||
log.error('Handshake failed.') | ||
raise JudgeAuthenticationFailed() | ||
|
||
def ping_packet(self, when: float): | ||
data = {'name': 'ping-response', 'when': when, 'time': time.time()} | ||
for fn in sysinfo.report_callbacks: | ||
key, value = fn() | ||
data[key] = value | ||
self.send_packet(data) | ||
|
||
def submission_acknowledged_packet(self, sub_id: int): | ||
self.send_packet({'name': 'submission-acknowledged', 'submission-id': sub_id}) | ||
|
||
def executors_packet(self, executors): | ||
self.send_packet({'name': 'executors', 'executors': executors}) |
Oops, something went wrong.