diff --git a/ch_tools/monrun_checks/dns.py b/ch_tools/monrun_checks/dns.py index d753a368..46746da3 100644 --- a/ch_tools/monrun_checks/dns.py +++ b/ch_tools/monrun_checks/dns.py @@ -1,7 +1,8 @@ import json import socket from functools import lru_cache -from typing import List, Tuple +from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network +from typing import List, Optional, Tuple, Union import click import dns.resolver @@ -37,19 +38,26 @@ def __init__(self, fqdn: str, private: bool, strict: bool): @click.option( "--imdsv2", "imdsv2", is_flag=True, help="Use imdsv2 token for non gcp hosts" ) +@click.option( + "--nameserver", + "nameserver", + type=str, + help="Custom nameserver to query records from", +) def dns_command( cluster: bool, private: bool, ipv4: bool, ipv6: bool, imdsv2: bool, + nameserver: Optional[str] = None, ) -> Result: """ Check presence and correctness of DNS records. """ err = [] for record in _get_host_dns(cluster, private): - err.extend(_check_fqdn(record, ipv4, ipv6, imdsv2)) + err.extend(_check_fqdn(record, ipv4, ipv6, imdsv2, nameserver)) if not err: return Result(OK) @@ -57,14 +65,25 @@ def dns_command( return Result(CRIT, ", ".join(err)) -def _check_fqdn(target: _TargetRecord, ipv4: bool, ipv6: bool, imdsv2: bool) -> list: +def _check_fqdn( + target: _TargetRecord, + ipv4: bool, + ipv6: bool, + imdsv2: bool, + nameserver: Optional[str], +) -> list: err = [] resolver = dns.resolver.Resolver() + if nameserver: + resolver.nameservers = [nameserver] def _compare(record_type: str, ip_type: str) -> Tuple[bool, set, set]: try: actual_addr = set( - map(lambda a: a.to_text(), resolver.resolve(target.fqdn, record_type)) + map( + lambda a: ip_address(a.to_text()), + resolver.resolve(target.fqdn, record_type), + ) ) except dns.resolver.NXDOMAIN: actual_addr = set() @@ -91,7 +110,7 @@ def _compare(record_type: str, ip_type: str) -> Tuple[bool, set, set]: @lru_cache(maxsize=None) -def _get_host_ip(addr_type: str, imdsv2: bool) -> str: +def _get_host_ip(addr_type: str, imdsv2: bool) -> Union[IPv4Address, IPv6Address]: # pylint: disable=missing-timeout if _is_gcp(): resp = requests.get( @@ -103,7 +122,7 @@ def _get_host_ip(addr_type: str, imdsv2: bool) -> str: headers["X-aws-ec2-metadata-token"] = _get_imdsv2_token() resp = requests.get(IP_METADATA_PATHS[addr_type], headers=headers) resp.raise_for_status() - return resp.text.strip() + return ip_network(resp.text.strip())[0] @lru_cache(maxsize=None)