Initial commit
This commit is contained in:
72
.venv/Lib/site-packages/dns/__init__.py
Normal file
72
.venv/Lib/site-packages/dns/__init__.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""dnspython DNS toolkit"""
|
||||
|
||||
__all__ = [
|
||||
"asyncbackend",
|
||||
"asyncquery",
|
||||
"asyncresolver",
|
||||
"btree",
|
||||
"btreezone",
|
||||
"dnssec",
|
||||
"dnssecalgs",
|
||||
"dnssectypes",
|
||||
"e164",
|
||||
"edns",
|
||||
"entropy",
|
||||
"exception",
|
||||
"flags",
|
||||
"immutable",
|
||||
"inet",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"message",
|
||||
"name",
|
||||
"namedict",
|
||||
"node",
|
||||
"opcode",
|
||||
"query",
|
||||
"quic",
|
||||
"rcode",
|
||||
"rdata",
|
||||
"rdataclass",
|
||||
"rdataset",
|
||||
"rdatatype",
|
||||
"renderer",
|
||||
"resolver",
|
||||
"reversename",
|
||||
"rrset",
|
||||
"serial",
|
||||
"set",
|
||||
"tokenizer",
|
||||
"transaction",
|
||||
"tsig",
|
||||
"tsigkeyring",
|
||||
"ttl",
|
||||
"rdtypes",
|
||||
"update",
|
||||
"version",
|
||||
"versioned",
|
||||
"wire",
|
||||
"xfr",
|
||||
"zone",
|
||||
"zonetypes",
|
||||
"zonefile",
|
||||
]
|
||||
|
||||
from dns.version import version as __version__ # noqa
|
||||
BIN
.venv/Lib/site-packages/dns/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/_ddr.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/_ddr.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/_no_ssl.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/_no_ssl.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/btree.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/btree.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/dnssec.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/dnssec.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/e164.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/e164.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/edns.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/edns.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/entropy.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/entropy.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/enum.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/enum.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/flags.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/flags.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/grange.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/grange.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/inet.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/inet.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/ipv4.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/ipv4.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/ipv6.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/ipv6.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/message.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/message.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/name.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/name.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/namedict.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/namedict.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/node.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/node.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/opcode.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/opcode.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/query.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/query.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/rcode.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/rcode.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/rdata.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/rdata.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/rdataset.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/rdataset.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/renderer.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/renderer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/resolver.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/resolver.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/rrset.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/rrset.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/serial.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/serial.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/set.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/set.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/tsig.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/tsig.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/ttl.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/ttl.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/update.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/update.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/version.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/version.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/wire.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/wire.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/xfr.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/xfr.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/zone.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/zone.cpython-312.pyc
Normal file
Binary file not shown.
BIN
.venv/Lib/site-packages/dns/__pycache__/zonefile.cpython-312.pyc
Normal file
BIN
.venv/Lib/site-packages/dns/__pycache__/zonefile.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
100
.venv/Lib/site-packages/dns/_asyncbackend.py
Normal file
100
.venv/Lib/site-packages/dns/_asyncbackend.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
|
||||
# but it is only for sync use.
|
||||
|
||||
|
||||
class NullContext:
|
||||
def __init__(self, enter_result=None):
|
||||
self.enter_result = enter_result
|
||||
|
||||
def __enter__(self):
|
||||
return self.enter_result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.enter_result
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
# These are declared here so backends can import them without creating
|
||||
# circular dependencies with dns.asyncbackend.
|
||||
|
||||
|
||||
class Socket: # pragma: no cover
|
||||
def __init__(self, family: int, type: int):
|
||||
self.family = family
|
||||
self.type = type
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def getpeername(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def getsockname(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.close()
|
||||
|
||||
|
||||
class DatagramSocket(Socket): # pragma: no cover
|
||||
async def sendto(self, what, destination, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamSocket(Socket): # pragma: no cover
|
||||
async def sendall(self, what, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NullTransport:
|
||||
async def connect_tcp(self, host, port, timeout, local_address):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Backend: # pragma: no cover
|
||||
def name(self) -> str:
|
||||
return "unknown"
|
||||
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def datagram_connection_required(self):
|
||||
return False
|
||||
|
||||
async def sleep(self, interval):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_transport_class(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def wait_for(self, awaitable, timeout):
|
||||
raise NotImplementedError
|
||||
276
.venv/Lib/site-packages/dns/_asyncio_backend.py
Normal file
276
.venv/Lib/site-packages/dns/_asyncio_backend.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
"""asyncio library query support"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import dns._asyncbackend
|
||||
import dns._features
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
|
||||
_is_win32 = sys.platform == "win32"
|
||||
|
||||
|
||||
def _get_running_loop():
|
||||
try:
|
||||
return asyncio.get_running_loop()
|
||||
except AttributeError: # pragma: no cover
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
|
||||
class _DatagramProtocol:
|
||||
def __init__(self):
|
||||
self.transport = None
|
||||
self.recvfrom = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
if self.recvfrom and not self.recvfrom.done():
|
||||
self.recvfrom.set_result((data, addr))
|
||||
|
||||
def error_received(self, exc): # pragma: no cover
|
||||
if self.recvfrom and not self.recvfrom.done():
|
||||
self.recvfrom.set_exception(exc)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.recvfrom and not self.recvfrom.done():
|
||||
if exc is None:
|
||||
# EOF we triggered. Is there a better way to do this?
|
||||
try:
|
||||
raise EOFError("EOF")
|
||||
except EOFError as e:
|
||||
self.recvfrom.set_exception(e)
|
||||
else:
|
||||
self.recvfrom.set_exception(exc)
|
||||
|
||||
def close(self):
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
async def _maybe_wait_for(awaitable, timeout):
|
||||
if timeout is not None:
|
||||
try:
|
||||
return await asyncio.wait_for(awaitable, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise dns.exception.Timeout(timeout=timeout)
|
||||
else:
|
||||
return await awaitable
|
||||
|
||||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, family, transport, protocol):
|
||||
super().__init__(family, socket.SOCK_DGRAM)
|
||||
self.transport = transport
|
||||
self.protocol = protocol
|
||||
|
||||
async def sendto(self, what, destination, timeout): # pragma: no cover
|
||||
# no timeout for asyncio sendto
|
||||
self.transport.sendto(what, destination)
|
||||
return len(what)
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
# ignore size as there's no way I know to tell protocol about it
|
||||
done = _get_running_loop().create_future()
|
||||
try:
|
||||
assert self.protocol.recvfrom is None
|
||||
self.protocol.recvfrom = done
|
||||
await _maybe_wait_for(done, timeout)
|
||||
return done.result()
|
||||
finally:
|
||||
self.protocol.recvfrom = None
|
||||
|
||||
async def close(self):
|
||||
self.protocol.close()
|
||||
|
||||
async def getpeername(self):
|
||||
return self.transport.get_extra_info("peername")
|
||||
|
||||
async def getsockname(self):
|
||||
return self.transport.get_extra_info("sockname")
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, af, reader, writer):
|
||||
super().__init__(af, socket.SOCK_STREAM)
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
async def sendall(self, what, timeout):
|
||||
self.writer.write(what)
|
||||
return await _maybe_wait_for(self.writer.drain(), timeout)
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
return await _maybe_wait_for(self.reader.read(size), timeout)
|
||||
|
||||
async def close(self):
|
||||
self.writer.close()
|
||||
|
||||
async def getpeername(self):
|
||||
return self.writer.get_extra_info("peername")
|
||||
|
||||
async def getsockname(self):
|
||||
return self.writer.get_extra_info("sockname")
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
return self.writer.get_extra_info("peercert")
|
||||
|
||||
|
||||
if dns._features.have("doh"):
|
||||
import anyio
|
||||
import httpcore
|
||||
import httpcore._backends.anyio
|
||||
import httpx
|
||||
|
||||
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
||||
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream # pyright: ignore
|
||||
|
||||
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
||||
|
||||
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
||||
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||
super().__init__()
|
||||
self._local_port = local_port
|
||||
self._resolver = resolver
|
||||
self._bootstrap_address = bootstrap_address
|
||||
self._family = family
|
||||
if local_port != 0:
|
||||
raise NotImplementedError(
|
||||
"the asyncio transport for HTTPX cannot set the local port"
|
||||
)
|
||||
|
||||
async def connect_tcp(
|
||||
self, host, port, timeout=None, local_address=None, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
addresses = []
|
||||
_, expiration = _compute_times(timeout)
|
||||
if dns.inet.is_address(host):
|
||||
addresses.append(host)
|
||||
elif self._bootstrap_address is not None:
|
||||
addresses.append(self._bootstrap_address)
|
||||
else:
|
||||
timeout = _remaining(expiration)
|
||||
family = self._family
|
||||
if local_address:
|
||||
family = dns.inet.af_for_address(local_address)
|
||||
answers = await self._resolver.resolve_name(
|
||||
host, family=family, lifetime=timeout
|
||||
)
|
||||
addresses = answers.addresses()
|
||||
for address in addresses:
|
||||
try:
|
||||
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||
timeout = _remaining(attempt_expiration)
|
||||
with anyio.fail_after(timeout):
|
||||
stream = await anyio.connect_tcp(
|
||||
remote_host=address,
|
||||
remote_port=port,
|
||||
local_host=local_address,
|
||||
)
|
||||
return _CoreAnyIOStream(stream)
|
||||
except Exception:
|
||||
pass
|
||||
raise httpcore.ConnectError
|
||||
|
||||
async def connect_unix_socket(
|
||||
self, path, timeout=None, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, seconds): # pylint: disable=signature-differs
|
||||
await anyio.sleep(seconds)
|
||||
|
||||
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
local_port=0,
|
||||
bootstrap_address=None,
|
||||
resolver=None,
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None and bootstrap_address is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
resolver = dns.asyncresolver.Resolver()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pool._network_backend = _NetworkBackend(
|
||||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
else:
|
||||
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||
|
||||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
return "asyncio"
|
||||
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
loop = _get_running_loop()
|
||||
if socktype == socket.SOCK_DGRAM:
|
||||
if _is_win32 and source is None:
|
||||
# Win32 wants explicit binding before recvfrom(). This is the
|
||||
# proper fix for [#637].
|
||||
source = (dns.inet.any_for_af(af), 0)
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
_DatagramProtocol, # pyright: ignore
|
||||
source,
|
||||
family=af,
|
||||
proto=proto,
|
||||
remote_addr=destination,
|
||||
)
|
||||
return DatagramSocket(af, transport, protocol)
|
||||
elif socktype == socket.SOCK_STREAM:
|
||||
if destination is None:
|
||||
# This shouldn't happen, but we check to make code analysis software
|
||||
# happier.
|
||||
raise ValueError("destination required for stream sockets")
|
||||
(r, w) = await _maybe_wait_for(
|
||||
asyncio.open_connection(
|
||||
destination[0],
|
||||
destination[1],
|
||||
ssl=ssl_context,
|
||||
family=af,
|
||||
proto=proto,
|
||||
local_addr=source,
|
||||
server_hostname=server_hostname,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
return StreamSocket(af, r, w)
|
||||
raise NotImplementedError(
|
||||
"unsupported socket " + f"type {socktype}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def sleep(self, interval):
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def datagram_connection_required(self):
|
||||
return False
|
||||
|
||||
def get_transport_class(self):
|
||||
return _HTTPTransport
|
||||
|
||||
async def wait_for(self, awaitable, timeout):
|
||||
return await _maybe_wait_for(awaitable, timeout)
|
||||
154
.venv/Lib/site-packages/dns/_ddr.py
Normal file
154
.venv/Lib/site-packages/dns/_ddr.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
#
|
||||
# Support for Discovery of Designated Resolvers
|
||||
|
||||
import socket
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.inet
|
||||
import dns.name
|
||||
import dns.nameserver
|
||||
import dns.query
|
||||
import dns.rdtypes.svcbbase
|
||||
|
||||
# The special name of the local resolver when using DDR
|
||||
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
|
||||
|
||||
|
||||
#
|
||||
# Processing is split up into I/O independent and I/O dependent parts to
|
||||
# make supporting sync and async versions easy.
|
||||
#
|
||||
|
||||
|
||||
class _SVCBInfo:
|
||||
def __init__(self, bootstrap_address, port, hostname, nameservers):
|
||||
self.bootstrap_address = bootstrap_address
|
||||
self.port = port
|
||||
self.hostname = hostname
|
||||
self.nameservers = nameservers
|
||||
|
||||
def ddr_check_certificate(self, cert):
|
||||
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
|
||||
for name, value in cert["subjectAltName"]:
|
||||
if name == "IP Address" and value == self.bootstrap_address:
|
||||
return True
|
||||
return False
|
||||
|
||||
def make_tls_context(self):
|
||||
ssl = dns.query.ssl
|
||||
ctx = ssl.create_default_context()
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
return ctx
|
||||
|
||||
def ddr_tls_check_sync(self, lifetime):
|
||||
ctx = self.make_tls_context()
|
||||
expiration = time.time() + lifetime
|
||||
with socket.create_connection(
|
||||
(self.bootstrap_address, self.port), lifetime
|
||||
) as s:
|
||||
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
|
||||
ts.settimeout(dns.query._remaining(expiration))
|
||||
ts.do_handshake()
|
||||
cert = ts.getpeercert()
|
||||
return self.ddr_check_certificate(cert)
|
||||
|
||||
async def ddr_tls_check_async(self, lifetime, backend=None):
|
||||
if backend is None:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
ctx = self.make_tls_context()
|
||||
expiration = time.time() + lifetime
|
||||
async with await backend.make_socket(
|
||||
dns.inet.af_for_address(self.bootstrap_address),
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
None,
|
||||
(self.bootstrap_address, self.port),
|
||||
lifetime,
|
||||
ctx,
|
||||
self.hostname,
|
||||
) as ts:
|
||||
cert = await ts.getpeercert(dns.query._remaining(expiration))
|
||||
return self.ddr_check_certificate(cert)
|
||||
|
||||
|
||||
def _extract_nameservers_from_svcb(answer):
|
||||
bootstrap_address = answer.nameserver
|
||||
if not dns.inet.is_address(bootstrap_address):
|
||||
return []
|
||||
infos = []
|
||||
for rr in answer.rrset.processing_order():
|
||||
nameservers = []
|
||||
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
|
||||
if param is None:
|
||||
continue
|
||||
alpns = set(param.ids)
|
||||
host = rr.target.to_text(omit_final_dot=True)
|
||||
port = None
|
||||
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
|
||||
if param is not None:
|
||||
port = param.port
|
||||
# For now we ignore address hints and address resolution and always use the
|
||||
# bootstrap address
|
||||
if b"h2" in alpns:
|
||||
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
|
||||
if param is None or not param.value.endswith(b"{?dns}"):
|
||||
continue
|
||||
path = param.value[:-6].decode()
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if port is None:
|
||||
port = 443
|
||||
url = f"https://{host}:{port}{path}"
|
||||
# check the URL
|
||||
try:
|
||||
urlparse(url)
|
||||
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
|
||||
except Exception:
|
||||
# continue processing other ALPN types
|
||||
pass
|
||||
if b"dot" in alpns:
|
||||
if port is None:
|
||||
port = 853
|
||||
nameservers.append(
|
||||
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
|
||||
)
|
||||
if b"doq" in alpns:
|
||||
if port is None:
|
||||
port = 853
|
||||
nameservers.append(
|
||||
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
|
||||
)
|
||||
if len(nameservers) > 0:
|
||||
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
|
||||
return infos
|
||||
|
||||
|
||||
def _get_nameservers_sync(answer, lifetime):
|
||||
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
|
||||
answer."""
|
||||
nameservers = []
|
||||
infos = _extract_nameservers_from_svcb(answer)
|
||||
for info in infos:
|
||||
try:
|
||||
if info.ddr_tls_check_sync(lifetime):
|
||||
nameservers.extend(info.nameservers)
|
||||
except Exception:
|
||||
pass
|
||||
return nameservers
|
||||
|
||||
|
||||
async def _get_nameservers_async(answer, lifetime):
|
||||
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
|
||||
answer."""
|
||||
nameservers = []
|
||||
infos = _extract_nameservers_from_svcb(answer)
|
||||
for info in infos:
|
||||
try:
|
||||
if await info.ddr_tls_check_async(lifetime):
|
||||
nameservers.extend(info.nameservers)
|
||||
except Exception:
|
||||
pass
|
||||
return nameservers
|
||||
95
.venv/Lib/site-packages/dns/_features.py
Normal file
95
.venv/Lib/site-packages/dns/_features.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import importlib.metadata
|
||||
import itertools
|
||||
import string
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def _tuple_from_text(version: str) -> Tuple:
|
||||
text_parts = version.split(".")
|
||||
int_parts = []
|
||||
for text_part in text_parts:
|
||||
digit_prefix = "".join(
|
||||
itertools.takewhile(lambda x: x in string.digits, text_part)
|
||||
)
|
||||
try:
|
||||
int_parts.append(int(digit_prefix))
|
||||
except Exception:
|
||||
break
|
||||
return tuple(int_parts)
|
||||
|
||||
|
||||
def _version_check(
|
||||
requirement: str,
|
||||
) -> bool:
|
||||
"""Is the requirement fulfilled?
|
||||
|
||||
The requirement must be of the form
|
||||
|
||||
package>=version
|
||||
"""
|
||||
package, minimum = requirement.split(">=")
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
# This shouldn't happen, but it apparently can.
|
||||
if version is None:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
t_version = _tuple_from_text(version)
|
||||
t_minimum = _tuple_from_text(minimum)
|
||||
if t_version < t_minimum:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_cache: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def have(feature: str) -> bool:
|
||||
"""Is *feature* available?
|
||||
|
||||
This tests if all optional packages needed for the
|
||||
feature are available and recent enough.
|
||||
|
||||
Returns ``True`` if the feature is available,
|
||||
and ``False`` if it is not or if metadata is
|
||||
missing.
|
||||
"""
|
||||
value = _cache.get(feature)
|
||||
if value is not None:
|
||||
return value
|
||||
requirements = _requirements.get(feature)
|
||||
if requirements is None:
|
||||
# we make a cache entry here for consistency not performance
|
||||
_cache[feature] = False
|
||||
return False
|
||||
ok = True
|
||||
for requirement in requirements:
|
||||
if not _version_check(requirement):
|
||||
ok = False
|
||||
break
|
||||
_cache[feature] = ok
|
||||
return ok
|
||||
|
||||
|
||||
def force(feature: str, enabled: bool) -> None:
|
||||
"""Force the status of *feature* to be *enabled*.
|
||||
|
||||
This method is provided as a workaround for any cases
|
||||
where importlib.metadata is ineffective, or for testing.
|
||||
"""
|
||||
_cache[feature] = enabled
|
||||
|
||||
|
||||
_requirements: Dict[str, List[str]] = {
|
||||
### BEGIN generated requirements
|
||||
"dnssec": ["cryptography>=45"],
|
||||
"doh": ["httpcore>=1.0.0", "httpx>=0.28.0", "h2>=4.2.0"],
|
||||
"doq": ["aioquic>=1.2.0"],
|
||||
"idna": ["idna>=3.10"],
|
||||
"trio": ["trio>=0.30"],
|
||||
"wmi": ["wmi>=1.5.1"],
|
||||
### END generated requirements
|
||||
}
|
||||
76
.venv/Lib/site-packages/dns/_immutable_ctx.py
Normal file
76
.venv/Lib/site-packages/dns/_immutable_ctx.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# This implementation of the immutable decorator requires python >=
|
||||
# 3.7, and is significantly more storage efficient when making classes
|
||||
# with slots immutable. It's also faster.
|
||||
|
||||
import contextvars
|
||||
import inspect
|
||||
|
||||
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
|
||||
|
||||
|
||||
class _Immutable:
|
||||
"""Immutable mixin class"""
|
||||
|
||||
# We set slots to the empty list to say "we don't have any attributes".
|
||||
# We do this so that if we're mixed in with a class with __slots__, we
|
||||
# don't cause a __dict__ to be added which would waste space.
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if _in__init__.get() is not self:
|
||||
raise TypeError("object doesn't support attribute assignment")
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
if _in__init__.get() is not self:
|
||||
raise TypeError("object doesn't support attribute assignment")
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
|
||||
def _immutable_init(f):
|
||||
def nf(*args, **kwargs):
|
||||
previous = _in__init__.set(args[0])
|
||||
try:
|
||||
# call the actual __init__
|
||||
f(*args, **kwargs)
|
||||
finally:
|
||||
_in__init__.reset(previous)
|
||||
|
||||
nf.__signature__ = inspect.signature(f) # pyright: ignore
|
||||
return nf
|
||||
|
||||
|
||||
def immutable(cls):
|
||||
if _Immutable in cls.__mro__:
|
||||
# Some ancestor already has the mixin, so just make sure we keep
|
||||
# following the __init__ protocol.
|
||||
cls.__init__ = _immutable_init(cls.__init__)
|
||||
if hasattr(cls, "__setstate__"):
|
||||
cls.__setstate__ = _immutable_init(cls.__setstate__)
|
||||
ncls = cls
|
||||
else:
|
||||
# Mixin the Immutable class and follow the __init__ protocol.
|
||||
class ncls(_Immutable, cls):
|
||||
# We have to do the __slots__ declaration here too!
|
||||
__slots__ = ()
|
||||
|
||||
@_immutable_init
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if hasattr(cls, "__setstate__"):
|
||||
|
||||
@_immutable_init
|
||||
def __setstate__(self, *args, **kwargs):
|
||||
super().__setstate__(*args, **kwargs)
|
||||
|
||||
# make ncls have the same name and module as cls
|
||||
ncls.__name__ = cls.__name__
|
||||
ncls.__qualname__ = cls.__qualname__
|
||||
ncls.__module__ = cls.__module__
|
||||
return ncls
|
||||
61
.venv/Lib/site-packages/dns/_no_ssl.py
Normal file
61
.venv/Lib/site-packages/dns/_no_ssl.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
CERT_NONE = 0
|
||||
|
||||
|
||||
class TLSVersion(enum.IntEnum):
|
||||
TLSv1_2 = 12
|
||||
|
||||
|
||||
class WantReadException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WantWriteException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SSLWantReadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SSLWantWriteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SSLContext:
|
||||
def __init__(self) -> None:
|
||||
self.minimum_version: Any = TLSVersion.TLSv1_2
|
||||
self.check_hostname: bool = False
|
||||
self.verify_mode: int = CERT_NONE
|
||||
|
||||
def wrap_socket(self, *args, **kwargs) -> "SSLSocket": # type: ignore
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
|
||||
def set_alpn_protocols(self, *args, **kwargs): # type: ignore
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
|
||||
|
||||
class SSLSocket:
|
||||
def pending(self) -> bool:
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
|
||||
def do_handshake(self) -> None:
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
|
||||
def settimeout(self, value: Any) -> None:
|
||||
pass
|
||||
|
||||
def getpeercert(self) -> Any:
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
|
||||
def create_default_context(*args, **kwargs) -> SSLContext: # type: ignore
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
19
.venv/Lib/site-packages/dns/_tls_util.py
Normal file
19
.venv/Lib/site-packages/dns/_tls_util.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def convert_verify_to_cafile_and_capath(
|
||||
verify: bool | str,
|
||||
) -> Tuple[str | None, str | None]:
|
||||
cafile: str | None = None
|
||||
capath: str | None = None
|
||||
if isinstance(verify, str):
|
||||
if os.path.isfile(verify):
|
||||
cafile = verify
|
||||
elif os.path.isdir(verify):
|
||||
capath = verify
|
||||
else:
|
||||
raise ValueError("invalid verify string")
|
||||
return cafile, capath
|
||||
255
.venv/Lib/site-packages/dns/_trio_backend.py
Normal file
255
.venv/Lib/site-packages/dns/_trio_backend.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
"""trio async I/O library query support"""
|
||||
|
||||
import socket
|
||||
|
||||
import trio
|
||||
import trio.socket # type: ignore
|
||||
|
||||
import dns._asyncbackend
|
||||
import dns._features
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
|
||||
if not dns._features.have("trio"):
|
||||
raise ImportError("trio not found or too old")
|
||||
|
||||
|
||||
def _maybe_timeout(timeout):
|
||||
if timeout is not None:
|
||||
return trio.move_on_after(timeout)
|
||||
else:
|
||||
return dns._asyncbackend.NullContext()
|
||||
|
||||
|
||||
# for brevity
|
||||
_lltuple = dns.inet.low_level_address_tuple
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, sock):
|
||||
super().__init__(sock.family, socket.SOCK_DGRAM)
|
||||
self.socket = sock
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
if destination is None:
|
||||
return await self.socket.send(what)
|
||||
else:
|
||||
return await self.socket.sendto(what, destination)
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.socket.recvfrom(size)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
self.socket.close()
|
||||
|
||||
async def getpeername(self):
|
||||
return self.socket.getpeername()
|
||||
|
||||
async def getsockname(self):
|
||||
return self.socket.getsockname()
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, family, stream, tls=False):
|
||||
super().__init__(family, socket.SOCK_STREAM)
|
||||
self.stream = stream
|
||||
self.tls = tls
|
||||
|
||||
async def sendall(self, what, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.stream.send_all(what)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.stream.receive_some(size)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
await self.stream.aclose()
|
||||
|
||||
async def getpeername(self):
|
||||
if self.tls:
|
||||
return self.stream.transport_stream.socket.getpeername()
|
||||
else:
|
||||
return self.stream.socket.getpeername()
|
||||
|
||||
async def getsockname(self):
|
||||
if self.tls:
|
||||
return self.stream.transport_stream.socket.getsockname()
|
||||
else:
|
||||
return self.stream.socket.getsockname()
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
if self.tls:
|
||||
with _maybe_timeout(timeout):
|
||||
await self.stream.do_handshake()
|
||||
return self.stream.getpeercert()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if dns._features.have("doh"):
|
||||
import httpcore
|
||||
import httpcore._backends.trio
|
||||
import httpx
|
||||
|
||||
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
||||
_CoreTrioStream = httpcore._backends.trio.TrioStream
|
||||
|
||||
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
||||
|
||||
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
||||
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||
super().__init__()
|
||||
self._local_port = local_port
|
||||
self._resolver = resolver
|
||||
self._bootstrap_address = bootstrap_address
|
||||
self._family = family
|
||||
|
||||
async def connect_tcp(
|
||||
self, host, port, timeout=None, local_address=None, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
addresses = []
|
||||
_, expiration = _compute_times(timeout)
|
||||
if dns.inet.is_address(host):
|
||||
addresses.append(host)
|
||||
elif self._bootstrap_address is not None:
|
||||
addresses.append(self._bootstrap_address)
|
||||
else:
|
||||
timeout = _remaining(expiration)
|
||||
family = self._family
|
||||
if local_address:
|
||||
family = dns.inet.af_for_address(local_address)
|
||||
answers = await self._resolver.resolve_name(
|
||||
host, family=family, lifetime=timeout
|
||||
)
|
||||
addresses = answers.addresses()
|
||||
for address in addresses:
|
||||
try:
|
||||
af = dns.inet.af_for_address(address)
|
||||
if local_address is not None or self._local_port != 0:
|
||||
source = (local_address, self._local_port)
|
||||
else:
|
||||
source = None
|
||||
destination = (address, port)
|
||||
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||
timeout = _remaining(attempt_expiration)
|
||||
sock = await Backend().make_socket(
|
||||
af, socket.SOCK_STREAM, 0, source, destination, timeout
|
||||
)
|
||||
assert isinstance(sock, StreamSocket)
|
||||
return _CoreTrioStream(sock.stream)
|
||||
except Exception:
|
||||
continue
|
||||
raise httpcore.ConnectError
|
||||
|
||||
async def connect_unix_socket(
|
||||
self, path, timeout=None, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, seconds): # pylint: disable=signature-differs
|
||||
await trio.sleep(seconds)
|
||||
|
||||
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
local_port=0,
|
||||
bootstrap_address=None,
|
||||
resolver=None,
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None and bootstrap_address is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
resolver = dns.asyncresolver.Resolver()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pool._network_backend = _NetworkBackend(
|
||||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
else:
|
||||
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||
|
||||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
return "trio"
|
||||
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
s = trio.socket.socket(af, socktype, proto)
|
||||
stream = None
|
||||
try:
|
||||
if source:
|
||||
await s.bind(_lltuple(source, af))
|
||||
if socktype == socket.SOCK_STREAM or destination is not None:
|
||||
connected = False
|
||||
with _maybe_timeout(timeout):
|
||||
assert destination is not None
|
||||
await s.connect(_lltuple(destination, af))
|
||||
connected = True
|
||||
if not connected:
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # lgtm[py/unreachable-statement]
|
||||
except Exception: # pragma: no cover
|
||||
s.close()
|
||||
raise
|
||||
if socktype == socket.SOCK_DGRAM:
|
||||
return DatagramSocket(s)
|
||||
elif socktype == socket.SOCK_STREAM:
|
||||
stream = trio.SocketStream(s)
|
||||
tls = False
|
||||
if ssl_context:
|
||||
tls = True
|
||||
try:
|
||||
stream = trio.SSLStream(
|
||||
stream, ssl_context, server_hostname=server_hostname
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
await stream.aclose()
|
||||
raise
|
||||
return StreamSocket(af, stream, tls)
|
||||
raise NotImplementedError(
|
||||
"unsupported socket " + f"type {socktype}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def sleep(self, interval):
|
||||
await trio.sleep(interval)
|
||||
|
||||
def get_transport_class(self):
|
||||
return _HTTPTransport
|
||||
|
||||
async def wait_for(self, awaitable, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await awaitable
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
101
.venv/Lib/site-packages/dns/asyncbackend.py
Normal file
101
.venv/Lib/site-packages/dns/asyncbackend.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import dns.exception
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
|
||||
Backend,
|
||||
DatagramSocket,
|
||||
Socket,
|
||||
StreamSocket,
|
||||
)
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
_default_backend = None
|
||||
|
||||
_backends: Dict[str, Backend] = {}
|
||||
|
||||
# Allow sniffio import to be disabled for testing purposes
|
||||
_no_sniffio = False
|
||||
|
||||
|
||||
class AsyncLibraryNotFoundError(dns.exception.DNSException):
|
||||
pass
|
||||
|
||||
|
||||
def get_backend(name: str) -> Backend:
|
||||
"""Get the specified asynchronous backend.
|
||||
|
||||
*name*, a ``str``, the name of the backend. Currently the "trio"
|
||||
and "asyncio" backends are available.
|
||||
|
||||
Raises NotImplementedError if an unknown backend name is specified.
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
backend = _backends.get(name)
|
||||
if backend:
|
||||
return backend
|
||||
if name == "trio":
|
||||
import dns._trio_backend
|
||||
|
||||
backend = dns._trio_backend.Backend()
|
||||
elif name == "asyncio":
|
||||
import dns._asyncio_backend
|
||||
|
||||
backend = dns._asyncio_backend.Backend()
|
||||
else:
|
||||
raise NotImplementedError(f"unimplemented async backend {name}")
|
||||
_backends[name] = backend
|
||||
return backend
|
||||
|
||||
|
||||
def sniff() -> str:
|
||||
"""Attempt to determine the in-use asynchronous I/O library by using
|
||||
the ``sniffio`` module if it is available.
|
||||
|
||||
Returns the name of the library, or raises AsyncLibraryNotFoundError
|
||||
if the library cannot be determined.
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
try:
|
||||
if _no_sniffio:
|
||||
raise ImportError
|
||||
import sniffio
|
||||
|
||||
try:
|
||||
return sniffio.current_async_library()
|
||||
except sniffio.AsyncLibraryNotFoundError:
|
||||
raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
|
||||
except ImportError:
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return "asyncio"
|
||||
except RuntimeError:
|
||||
raise AsyncLibraryNotFoundError("no async library detected")
|
||||
|
||||
|
||||
def get_default_backend() -> Backend:
|
||||
"""Get the default backend, initializing it if necessary."""
|
||||
if _default_backend:
|
||||
return _default_backend
|
||||
|
||||
return set_default_backend(sniff())
|
||||
|
||||
|
||||
def set_default_backend(name: str) -> Backend:
|
||||
"""Set the default backend.
|
||||
|
||||
It's not normally necessary to call this method, as
|
||||
``get_default_backend()`` will initialize the backend
|
||||
appropriately in many cases. If ``sniffio`` is not installed, or
|
||||
in testing situations, this function allows the backend to be set
|
||||
explicitly.
|
||||
"""
|
||||
global _default_backend
|
||||
_default_backend = get_backend(name)
|
||||
return _default_backend
|
||||
953
.venv/Lib/site-packages/dns/asyncquery.py
Normal file
953
.venv/Lib/site-packages/dns/asyncquery.py
Normal file
@@ -0,0 +1,953 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""Talk to a DNS server."""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Tuple, cast
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
import dns.message
|
||||
import dns.name
|
||||
import dns.quic
|
||||
import dns.rdatatype
|
||||
import dns.transaction
|
||||
import dns.tsig
|
||||
import dns.xfr
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.query import (
|
||||
BadResponse,
|
||||
HTTPVersion,
|
||||
NoDOH,
|
||||
NoDOQ,
|
||||
UDPMode,
|
||||
_check_status,
|
||||
_compute_times,
|
||||
_matches_destination,
|
||||
_remaining,
|
||||
have_doh,
|
||||
make_ssl_context,
|
||||
)
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
import dns._no_ssl as ssl # type: ignore
|
||||
|
||||
if have_doh:
|
||||
import httpx
|
||||
|
||||
# for brevity
|
||||
_lltuple = dns.inet.low_level_address_tuple
|
||||
|
||||
|
||||
def _source_tuple(af, address, port):
|
||||
# Make a high level source tuple, or return None if address and port
|
||||
# are both None
|
||||
if address or port:
|
||||
if address is None:
|
||||
if af == socket.AF_INET:
|
||||
address = "0.0.0.0"
|
||||
elif af == socket.AF_INET6:
|
||||
address = "::"
|
||||
else:
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
return (address, port)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _timeout(expiration, now=None):
|
||||
if expiration is not None:
|
||||
if not now:
|
||||
now = time.time()
|
||||
return max(expiration - now, 0)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def send_udp(
|
||||
sock: dns.asyncbackend.DatagramSocket,
|
||||
what: dns.message.Message | bytes,
|
||||
destination: Any,
|
||||
expiration: float | None = None,
|
||||
) -> Tuple[int, float]:
|
||||
"""Send a DNS message to the specified UDP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.DatagramSocket``.
|
||||
|
||||
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
|
||||
|
||||
*destination*, a destination tuple appropriate for the address family
|
||||
of the socket, specifying where to send the query.
|
||||
|
||||
*expiration*, a ``float`` or ``None``, the absolute time at which
|
||||
a timeout exception should be raised. If ``None``, no timeout will
|
||||
occur. The expiration value is meaningless for the asyncio backend, as
|
||||
asyncio's transport sendto() never blocks.
|
||||
|
||||
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
|
||||
"""
|
||||
|
||||
if isinstance(what, dns.message.Message):
|
||||
what = what.to_wire()
|
||||
sent_time = time.time()
|
||||
n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
|
||||
return (n, sent_time)
|
||||
|
||||
|
||||
async def receive_udp(
|
||||
sock: dns.asyncbackend.DatagramSocket,
|
||||
destination: Any | None = None,
|
||||
expiration: float | None = None,
|
||||
ignore_unexpected: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
|
||||
request_mac: bytes | None = b"",
|
||||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
ignore_errors: bool = False,
|
||||
query: dns.message.Message | None = None,
|
||||
) -> Any:
|
||||
"""Read a DNS message from a UDP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.DatagramSocket``.
|
||||
|
||||
See :py:func:`dns.query.receive_udp()` for the documentation of the other
|
||||
parameters, and exceptions.
|
||||
|
||||
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
|
||||
received time, and the address where the message arrived from.
|
||||
"""
|
||||
|
||||
wire = b""
|
||||
while True:
|
||||
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
|
||||
if not _matches_destination(
|
||||
sock.family, from_address, destination, ignore_unexpected
|
||||
):
|
||||
continue
|
||||
received_time = time.time()
|
||||
try:
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
raise_on_truncation=raise_on_truncation,
|
||||
)
|
||||
except dns.message.Truncated as e:
|
||||
# See the comment in query.py for details.
|
||||
if (
|
||||
ignore_errors
|
||||
and query is not None
|
||||
and not query.is_response(e.message())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
if ignore_errors:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
if ignore_errors and query is not None and not query.is_response(r):
|
||||
continue
|
||||
return (r, received_time, from_address)
|
||||
|
||||
|
||||
async def udp(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 53,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
ignore_unexpected: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
sock: dns.asyncbackend.DatagramSocket | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via UDP.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
|
||||
the socket to use for the query. If ``None``, the default, a
|
||||
socket is created. Note that if a socket is provided, the
|
||||
*source*, *source_port*, and *backend* are ignored.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.udp()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
wire = q.to_wire()
|
||||
(begin_time, expiration) = _compute_times(timeout)
|
||||
af = dns.inet.af_for_address(where)
|
||||
destination = _lltuple((where, port), af)
|
||||
if sock:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
if backend.datagram_connection_required():
|
||||
dtuple = (where, port)
|
||||
else:
|
||||
dtuple = None
|
||||
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
|
||||
async with cm as s:
|
||||
await send_udp(s, wire, destination, expiration) # pyright: ignore
|
||||
(r, received_time, _) = await receive_udp(
|
||||
s, # pyright: ignore
|
||||
destination,
|
||||
expiration,
|
||||
ignore_unexpected,
|
||||
one_rr_per_rrset,
|
||||
q.keyring,
|
||||
q.mac,
|
||||
ignore_trailing,
|
||||
raise_on_truncation,
|
||||
ignore_errors,
|
||||
q,
|
||||
)
|
||||
r.time = received_time - begin_time
|
||||
# We don't need to check q.is_response() if we are in ignore_errors mode
|
||||
# as receive_udp() will have checked it.
|
||||
if not (ignore_errors or q.is_response(r)):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def udp_with_fallback(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 53,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
ignore_unexpected: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
udp_sock: dns.asyncbackend.DatagramSocket | None = None,
|
||||
tcp_sock: dns.asyncbackend.StreamSocket | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> Tuple[dns.message.Message, bool]:
|
||||
"""Return the response to the query, trying UDP first and falling back
|
||||
to TCP if UDP results in a truncated response.
|
||||
|
||||
*udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
|
||||
the socket to use for the UDP query. If ``None``, the default, a
|
||||
socket is created. Note that if a socket is provided the *source*,
|
||||
*source_port*, and *backend* are ignored for the UDP query.
|
||||
|
||||
*tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
|
||||
socket to use for the TCP query. If ``None``, the default, a
|
||||
socket is created. Note that if a socket is provided *where*,
|
||||
*source*, *source_port*, and *backend* are ignored for the TCP query.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.udp_with_fallback()` for the documentation
|
||||
of the other parameters, exceptions, and return type of this
|
||||
method.
|
||||
"""
|
||||
try:
|
||||
response = await udp(
|
||||
q,
|
||||
where,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
ignore_unexpected,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
True,
|
||||
udp_sock,
|
||||
backend,
|
||||
ignore_errors,
|
||||
)
|
||||
return (response, False)
|
||||
except dns.message.Truncated:
|
||||
response = await tcp(
|
||||
q,
|
||||
where,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
tcp_sock,
|
||||
backend,
|
||||
)
|
||||
return (response, True)
|
||||
|
||||
|
||||
async def send_tcp(
|
||||
sock: dns.asyncbackend.StreamSocket,
|
||||
what: dns.message.Message | bytes,
|
||||
expiration: float | None = None,
|
||||
) -> Tuple[int, float]:
|
||||
"""Send a DNS message to the specified TCP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.StreamSocket``.
|
||||
|
||||
See :py:func:`dns.query.send_tcp()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if isinstance(what, dns.message.Message):
|
||||
tcpmsg = what.to_wire(prepend_length=True)
|
||||
else:
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = len(what).to_bytes(2, "big") + what
|
||||
sent_time = time.time()
|
||||
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
|
||||
return (len(tcpmsg), sent_time)
|
||||
|
||||
|
||||
async def _read_exactly(sock, count, expiration):
|
||||
"""Read the specified number of bytes from stream. Keep trying until we
|
||||
either get the desired amount, or we hit EOF.
|
||||
"""
|
||||
s = b""
|
||||
while count > 0:
|
||||
n = await sock.recv(count, _timeout(expiration))
|
||||
if n == b"":
|
||||
raise EOFError("EOF")
|
||||
count = count - len(n)
|
||||
s = s + n
|
||||
return s
|
||||
|
||||
|
||||
async def receive_tcp(
|
||||
sock: dns.asyncbackend.StreamSocket,
|
||||
expiration: float | None = None,
|
||||
one_rr_per_rrset: bool = False,
|
||||
keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
|
||||
request_mac: bytes | None = b"",
|
||||
ignore_trailing: bool = False,
|
||||
) -> Tuple[dns.message.Message, float]:
|
||||
"""Read a DNS message from a TCP socket.
|
||||
|
||||
*sock*, a ``dns.asyncbackend.StreamSocket``.
|
||||
|
||||
See :py:func:`dns.query.receive_tcp()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
ldata = await _read_exactly(sock, 2, expiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
wire = await _read_exactly(sock, l, expiration)
|
||||
received_time = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
return (r, received_time)
|
||||
|
||||
|
||||
async def tcp(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 53,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
sock: dns.asyncbackend.StreamSocket | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via TCP.
|
||||
|
||||
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
|
||||
socket to use for the query. If ``None``, the default, a socket
|
||||
is created. Note that if a socket is provided
|
||||
*where*, *port*, *source*, *source_port*, and *backend* are ignored.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.tcp()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
wire = q.to_wire()
|
||||
(begin_time, expiration) = _compute_times(timeout)
|
||||
if sock:
|
||||
# Verify that the socket is connected, as if it's not connected,
|
||||
# it's not writable, and the polling in send_tcp() will time out or
|
||||
# hang forever.
|
||||
await sock.getpeername()
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
# These are simple (address, port) pairs, not family-dependent tuples
|
||||
# you pass to low-level socket code.
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
cm = await backend.make_socket(
|
||||
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
|
||||
)
|
||||
async with cm as s:
|
||||
await send_tcp(s, wire, expiration) # pyright: ignore
|
||||
(r, received_time) = await receive_tcp(
|
||||
s, # pyright: ignore
|
||||
expiration,
|
||||
one_rr_per_rrset,
|
||||
q.keyring,
|
||||
q.mac,
|
||||
ignore_trailing,
|
||||
)
|
||||
r.time = received_time - begin_time
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def tls(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 853,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
sock: dns.asyncbackend.StreamSocket | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
server_hostname: str | None = None,
|
||||
verify: bool | str = True,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via TLS.
|
||||
|
||||
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
|
||||
to use for the query. If ``None``, the default, a socket is
|
||||
created. Note that if a socket is provided, it must be a
|
||||
connected SSL stream socket, and *where*, *port*,
|
||||
*source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
|
||||
are ignored.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.tls()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
(begin_time, expiration) = _compute_times(timeout)
|
||||
if sock:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
if ssl_context is None:
|
||||
ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
cm = await backend.make_socket(
|
||||
af,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
stuple,
|
||||
dtuple,
|
||||
timeout,
|
||||
ssl_context,
|
||||
server_hostname,
|
||||
)
|
||||
async with cm as s:
|
||||
timeout = _timeout(expiration)
|
||||
response = await tcp(
|
||||
q,
|
||||
where,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
s,
|
||||
backend,
|
||||
)
|
||||
end_time = time.time()
|
||||
response.time = end_time - begin_time
|
||||
return response
|
||||
|
||||
|
||||
def _maybe_get_resolver(
|
||||
resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore
|
||||
) -> "dns.asyncresolver.Resolver": # pyright: ignore
|
||||
# We need a separate method for this to avoid overriding the global
|
||||
# variable "dns" with the as-yet undefined local variable "dns"
|
||||
# in https().
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
resolver = dns.asyncresolver.Resolver()
|
||||
return resolver
|
||||
|
||||
|
||||
async def https(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 443,
|
||||
source: str | None = None,
|
||||
source_port: int = 0, # pylint: disable=W0613
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None,
|
||||
path: str = "/dns-query",
|
||||
post: bool = True,
|
||||
verify: bool | str | ssl.SSLContext = True,
|
||||
bootstrap_address: str | None = None,
|
||||
resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore
|
||||
family: int = socket.AF_UNSPEC,
|
||||
http_version: HTTPVersion = HTTPVersion.DEFAULT,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
|
||||
the query.
|
||||
|
||||
Unlike the other dnspython async functions, a backend cannot be provided
|
||||
in this function because httpx always auto-detects the async backend.
|
||||
|
||||
See :py:func:`dns.query.https()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
try:
|
||||
af = dns.inet.af_for_address(where)
|
||||
except ValueError:
|
||||
af = None
|
||||
# we bind url and then override as pyright can't figure out all paths bind.
|
||||
url = where
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = f"https://{where}:{port}{path}"
|
||||
elif af == socket.AF_INET6:
|
||||
url = f"https://[{where}]:{port}{path}"
|
||||
|
||||
extensions = {}
|
||||
if bootstrap_address is None:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname is None:
|
||||
raise ValueError("no hostname in URL")
|
||||
if dns.inet.is_address(parsed.hostname):
|
||||
bootstrap_address = parsed.hostname
|
||||
extensions["sni_hostname"] = parsed.hostname
|
||||
if parsed.port is not None:
|
||||
port = parsed.port
|
||||
|
||||
if http_version == HTTPVersion.H3 or (
|
||||
http_version == HTTPVersion.DEFAULT and not have_doh
|
||||
):
|
||||
if bootstrap_address is None:
|
||||
resolver = _maybe_get_resolver(resolver)
|
||||
assert parsed.hostname is not None # pyright: ignore
|
||||
answers = await resolver.resolve_name( # pyright: ignore
|
||||
parsed.hostname, family # pyright: ignore
|
||||
)
|
||||
bootstrap_address = random.choice(list(answers.addresses()))
|
||||
if client and not isinstance(
|
||||
client, dns.quic.AsyncQuicConnection
|
||||
): # pyright: ignore
|
||||
raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.")
|
||||
assert client is None or isinstance(client, dns.quic.AsyncQuicConnection)
|
||||
return await _http3(
|
||||
q,
|
||||
bootstrap_address,
|
||||
url,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
verify=verify,
|
||||
post=post,
|
||||
connection=client,
|
||||
)
|
||||
|
||||
if not have_doh:
|
||||
raise NoDOH # pragma: no cover
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore
|
||||
raise ValueError("client parameter must be an httpx.AsyncClient")
|
||||
# pylint: enable=possibly-used-before-assignment
|
||||
|
||||
wire = q.to_wire()
|
||||
headers = {"accept": "application/dns-message"}
|
||||
|
||||
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
|
||||
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
|
||||
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
|
||||
if source is None:
|
||||
local_address = None
|
||||
local_port = 0
|
||||
else:
|
||||
local_address = source
|
||||
local_port = source_port
|
||||
|
||||
if client:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
||||
else:
|
||||
transport = backend.get_transport_class()(
|
||||
local_address=local_address,
|
||||
http1=h1,
|
||||
http2=h2,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
cm = httpx.AsyncClient( # pyright: ignore
|
||||
http1=h1, http2=h2, verify=verify, transport=transport # type: ignore
|
||||
)
|
||||
|
||||
async with cm as the_client:
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
if post:
|
||||
headers.update(
|
||||
{
|
||||
"content-type": "application/dns-message",
|
||||
"content-length": str(len(wire)),
|
||||
}
|
||||
)
|
||||
response = await backend.wait_for(
|
||||
the_client.post( # pyright: ignore
|
||||
url,
|
||||
headers=headers,
|
||||
content=wire,
|
||||
extensions=extensions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = await backend.wait_for(
|
||||
the_client.get( # pyright: ignore
|
||||
url,
|
||||
headers=headers,
|
||||
params={"dns": twire},
|
||||
extensions=extensions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
# status codes
|
||||
if response.status_code < 200 or response.status_code > 299:
|
||||
raise ValueError(
|
||||
f"{where} responded with status code {response.status_code}"
|
||||
f"\nResponse body: {response.content!r}"
|
||||
)
|
||||
r = dns.message.from_wire(
|
||||
response.content,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = response.elapsed.total_seconds()
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def _http3(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
url: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 443,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
verify: bool | str | ssl.SSLContext = True,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
post: bool = True,
|
||||
connection: dns.quic.AsyncQuicConnection | None = None,
|
||||
) -> dns.message.Message:
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
|
||||
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
hostname = url_parts.hostname
|
||||
assert hostname is not None
|
||||
if url_parts.port is not None:
|
||||
port = url_parts.port
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.AsyncQuicConnection
|
||||
if connection:
|
||||
cfactory = dns.quic.null_factory
|
||||
mfactory = dns.quic.null_factory
|
||||
else:
|
||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(
|
||||
context, verify_mode=verify, server_name=hostname, h3=True
|
||||
) as the_manager:
|
||||
if connection:
|
||||
the_connection = connection
|
||||
else:
|
||||
the_connection = the_manager.connect( # pyright: ignore
|
||||
where, port, source, source_port
|
||||
)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
stream = await the_connection.make_stream(timeout) # pyright: ignore
|
||||
async with stream:
|
||||
# note that send_h3() does not need await
|
||||
stream.send_h3(url, wire, post)
|
||||
wire = await stream.receive(_remaining(expiration))
|
||||
_check_status(stream.headers(), where, wire)
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def quic(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: float | None = None,
|
||||
port: int = 853,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
connection: dns.quic.AsyncQuicConnection | None = None,
|
||||
verify: bool | str = True,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
hostname: str | None = None,
|
||||
server_hostname: str | None = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending an asynchronous query via
|
||||
DNS-over-QUIC.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.quic()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
|
||||
|
||||
if server_hostname is not None and hostname is None:
|
||||
hostname = server_hostname
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.AsyncQuicConnection
|
||||
if connection:
|
||||
cfactory = dns.quic.null_factory
|
||||
mfactory = dns.quic.null_factory
|
||||
the_connection = connection
|
||||
else:
|
||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(
|
||||
context,
|
||||
verify_mode=verify,
|
||||
server_name=server_hostname,
|
||||
) as the_manager:
|
||||
if not connection:
|
||||
the_connection = the_manager.connect( # pyright: ignore
|
||||
where, port, source, source_port
|
||||
)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
stream = await the_connection.make_stream(timeout) # pyright: ignore
|
||||
async with stream:
|
||||
await stream.send(wire, True)
|
||||
wire = await stream.receive(_remaining(expiration))
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def _inbound_xfr(
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
s: dns.asyncbackend.Socket,
|
||||
query: dns.message.Message,
|
||||
serial: int | None,
|
||||
timeout: float | None,
|
||||
expiration: float,
|
||||
) -> Any:
|
||||
"""Given a socket, does the zone transfer."""
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
is_udp = s.type == socket.SOCK_DGRAM
|
||||
if is_udp:
|
||||
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
|
||||
await udp_sock.sendto(wire, None, _timeout(expiration))
|
||||
else:
|
||||
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
await tcp_sock.sendall(tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
r: dns.message.Message | None = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
timeout = _timeout(mexpiration)
|
||||
(rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore
|
||||
else:
|
||||
ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
done = inbound.process_message(r)
|
||||
yield r
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if query.keyring and r is not None and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
|
||||
|
||||
async def inbound_xfr(
|
||||
where: str,
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
query: dns.message.Message | None = None,
|
||||
port: int = 53,
|
||||
timeout: float | None = None,
|
||||
lifetime: float | None = None,
|
||||
source: str | None = None,
|
||||
source_port: int = 0,
|
||||
udp_mode: UDPMode = UDPMode.NEVER,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
) -> None:
|
||||
"""Conduct an inbound transfer and apply it via a transaction from the
|
||||
txn_manager.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.inbound_xfr()` for the documentation of
|
||||
the other parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
if query is None:
|
||||
(query, serial) = dns.xfr.make_query(txn_manager)
|
||||
else:
|
||||
serial = dns.xfr.extract_serial_from_query(query)
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
|
||||
s = await backend.make_socket(
|
||||
af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
|
||||
)
|
||||
async with s:
|
||||
try:
|
||||
async for _ in _inbound_xfr( # pyright: ignore
|
||||
txn_manager,
|
||||
s,
|
||||
query,
|
||||
serial,
|
||||
timeout,
|
||||
expiration, # pyright: ignore
|
||||
):
|
||||
pass
|
||||
return
|
||||
except dns.xfr.UseTCP:
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
|
||||
s = await backend.make_socket(
|
||||
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
|
||||
)
|
||||
async with s:
|
||||
async for _ in _inbound_xfr( # pyright: ignore
|
||||
txn_manager, s, query, serial, timeout, expiration # pyright: ignore
|
||||
):
|
||||
pass
|
||||
478
.venv/Lib/site-packages/dns/asyncresolver.py
Normal file
478
.venv/Lib/site-packages/dns/asyncresolver.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""Asynchronous DNS stub resolver."""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import dns._ddr
|
||||
import dns.asyncbackend
|
||||
import dns.asyncquery
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
import dns.name
|
||||
import dns.nameserver
|
||||
import dns.query
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.resolver # lgtm[py/import-and-import-from]
|
||||
import dns.reversename
|
||||
|
||||
# import some resolver symbols for brevity
|
||||
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
|
||||
|
||||
# for indentation purposes below
|
||||
_udp = dns.asyncquery.udp
|
||||
_tcp = dns.asyncquery.tcp
|
||||
|
||||
|
||||
class Resolver(dns.resolver.BaseResolver):
|
||||
"""Asynchronous DNS stub resolver."""
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
qname: dns.name.Name | str,
|
||||
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
|
||||
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: str | None = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: float | None = None,
|
||||
search: bool | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Query nameservers asynchronously to find the answer to the question.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.resolve()` for the
|
||||
documentation of the other parameters, exceptions, and return
|
||||
type of this method.
|
||||
"""
|
||||
|
||||
resolution = dns.resolver._Resolution(
|
||||
self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
|
||||
)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
start = time.time()
|
||||
while True:
|
||||
(request, answer) = resolution.next_request()
|
||||
# Note we need to say "if answer is not None" and not just
|
||||
# "if answer" because answer implements __len__, and python
|
||||
# will call that. We want to return if we have an answer
|
||||
# object, including in cases where its length is 0.
|
||||
if answer is not None:
|
||||
# cache hit!
|
||||
return answer
|
||||
assert request is not None # needed for type checking
|
||||
done = False
|
||||
while not done:
|
||||
(nameserver, tcp, backoff) = resolution.next_nameserver()
|
||||
if backoff:
|
||||
await backend.sleep(backoff)
|
||||
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
||||
try:
|
||||
response = await nameserver.async_query(
|
||||
request,
|
||||
timeout=timeout,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
max_size=tcp,
|
||||
backend=backend,
|
||||
)
|
||||
except Exception as ex:
|
||||
(_, done) = resolution.query_result(None, ex)
|
||||
continue
|
||||
(answer, done) = resolution.query_result(response, None)
|
||||
# Note we need to say "if answer is not None" and not just
|
||||
# "if answer" because answer implements __len__, and python
|
||||
# will call that. We want to return if we have an answer
|
||||
# object, including in cases where its length is 0.
|
||||
if answer is not None:
|
||||
return answer
|
||||
|
||||
async def resolve_address(
|
||||
self, ipaddr: str, *args: Any, **kwargs: Any
|
||||
) -> dns.resolver.Answer:
|
||||
"""Use an asynchronous resolver to run a reverse query for PTR
|
||||
records.
|
||||
|
||||
This utilizes the resolve() method to perform a PTR lookup on the
|
||||
specified IP address.
|
||||
|
||||
*ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get
|
||||
the PTR record for.
|
||||
|
||||
All other arguments that can be passed to the resolve() function
|
||||
except for rdtype and rdclass are also supported by this
|
||||
function.
|
||||
|
||||
"""
|
||||
# We make a modified kwargs for type checking happiness, as otherwise
|
||||
# we get a legit warning about possibly having rdtype and rdclass
|
||||
# in the kwargs more than once.
|
||||
modified_kwargs: Dict[str, Any] = {}
|
||||
modified_kwargs.update(kwargs)
|
||||
modified_kwargs["rdtype"] = dns.rdatatype.PTR
|
||||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||
return await self.resolve(
|
||||
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
||||
)
|
||||
|
||||
async def resolve_name(
|
||||
self,
|
||||
name: dns.name.Name | str,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
**kwargs: Any,
|
||||
) -> dns.resolver.HostAnswers:
|
||||
"""Use an asynchronous resolver to query for address records.
|
||||
|
||||
This utilizes the resolve() method to perform A and/or AAAA lookups on
|
||||
the specified name.
|
||||
|
||||
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
|
||||
|
||||
*family*, an ``int``, the address family. If socket.AF_UNSPEC
|
||||
(the default), both A and AAAA records will be retrieved.
|
||||
|
||||
All other arguments that can be passed to the resolve() function
|
||||
except for rdtype and rdclass are also supported by this
|
||||
function.
|
||||
"""
|
||||
# We make a modified kwargs for type checking happiness, as otherwise
|
||||
# we get a legit warning about possibly having rdtype and rdclass
|
||||
# in the kwargs more than once.
|
||||
modified_kwargs: Dict[str, Any] = {}
|
||||
modified_kwargs.update(kwargs)
|
||||
modified_kwargs.pop("rdtype", None)
|
||||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||
|
||||
if family == socket.AF_INET:
|
||||
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
|
||||
return dns.resolver.HostAnswers.make(v4=v4)
|
||||
elif family == socket.AF_INET6:
|
||||
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
|
||||
return dns.resolver.HostAnswers.make(v6=v6)
|
||||
elif family != socket.AF_UNSPEC:
|
||||
raise NotImplementedError(f"unknown address family {family}")
|
||||
|
||||
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
|
||||
lifetime = modified_kwargs.pop("lifetime", None)
|
||||
start = time.time()
|
||||
v6 = await self.resolve(
|
||||
name,
|
||||
dns.rdatatype.AAAA,
|
||||
raise_on_no_answer=False,
|
||||
lifetime=self._compute_timeout(start, lifetime),
|
||||
**modified_kwargs,
|
||||
)
|
||||
# Note that setting name ensures we query the same name
|
||||
# for A as we did for AAAA. (This is just in case search lists
|
||||
# are active by default in the resolver configuration and
|
||||
# we might be talking to a server that says NXDOMAIN when it
|
||||
# wants to say NOERROR no data.
|
||||
name = v6.qname
|
||||
v4 = await self.resolve(
|
||||
name,
|
||||
dns.rdatatype.A,
|
||||
raise_on_no_answer=False,
|
||||
lifetime=self._compute_timeout(start, lifetime),
|
||||
**modified_kwargs,
|
||||
)
|
||||
answers = dns.resolver.HostAnswers.make(
|
||||
v6=v6, v4=v4, add_empty=not raise_on_no_answer
|
||||
)
|
||||
if not answers:
|
||||
raise NoAnswer(response=v6.response)
|
||||
return answers
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
async def canonical_name(self, name: dns.name.Name | str) -> dns.name.Name:
|
||||
"""Determine the canonical name of *name*.
|
||||
|
||||
The canonical name is the name the resolver uses for queries
|
||||
after all CNAME and DNAME renamings have been applied.
|
||||
|
||||
*name*, a ``dns.name.Name`` or ``str``, the query name.
|
||||
|
||||
This method can raise any exception that ``resolve()`` can
|
||||
raise, other than ``dns.resolver.NoAnswer`` and
|
||||
``dns.resolver.NXDOMAIN``.
|
||||
|
||||
Returns a ``dns.name.Name``.
|
||||
"""
|
||||
try:
|
||||
answer = await self.resolve(name, raise_on_no_answer=False)
|
||||
canonical_name = answer.canonical_name
|
||||
except dns.resolver.NXDOMAIN as e:
|
||||
canonical_name = e.canonical_name
|
||||
return canonical_name
|
||||
|
||||
async def try_ddr(self, lifetime: float = 5.0) -> None:
|
||||
"""Try to update the resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
|
||||
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
|
||||
is 5 seconds.
|
||||
|
||||
If the SVCB query is successful and results in a non-empty list of nameservers,
|
||||
then the resolver's nameservers are set to the returned servers in priority
|
||||
order.
|
||||
|
||||
The current implementation does not use any address hints from the SVCB record,
|
||||
nor does it resolve addresses for the SCVB target name, rather it assumes that
|
||||
the bootstrap nameserver will always be one of the addresses and uses it.
|
||||
A future revision to the code may offer fuller support. The code verifies that
|
||||
the bootstrap nameserver is in the Subject Alternative Name field of the
|
||||
TLS certficate.
|
||||
"""
|
||||
try:
|
||||
expiration = time.time() + lifetime
|
||||
answer = await self.resolve(
|
||||
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
|
||||
)
|
||||
timeout = dns.query._remaining(expiration)
|
||||
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
|
||||
if len(nameservers) > 0:
|
||||
self.nameservers = nameservers
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
default_resolver = None
|
||||
|
||||
|
||||
def get_default_resolver() -> Resolver:
|
||||
"""Get the default asynchronous resolver, initializing it if necessary."""
|
||||
if default_resolver is None:
|
||||
reset_default_resolver()
|
||||
assert default_resolver is not None
|
||||
return default_resolver
|
||||
|
||||
|
||||
def reset_default_resolver() -> None:
|
||||
"""Re-initialize default asynchronous resolver.
|
||||
|
||||
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
|
||||
systems) will be re-read immediately.
|
||||
"""
|
||||
|
||||
global default_resolver
|
||||
default_resolver = Resolver()
|
||||
|
||||
|
||||
async def resolve(
|
||||
qname: dns.name.Name | str,
|
||||
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
|
||||
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: str | None = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: float | None = None,
|
||||
search: bool | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Query nameservers asynchronously to find the answer to the question.
|
||||
|
||||
This is a convenience function that uses the default resolver
|
||||
object to make the query.
|
||||
|
||||
See :py:func:`dns.asyncresolver.Resolver.resolve` for more
|
||||
information on the parameters.
|
||||
"""
|
||||
|
||||
return await get_default_resolver().resolve(
|
||||
qname,
|
||||
rdtype,
|
||||
rdclass,
|
||||
tcp,
|
||||
source,
|
||||
raise_on_no_answer,
|
||||
source_port,
|
||||
lifetime,
|
||||
search,
|
||||
backend,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_address(
|
||||
ipaddr: str, *args: Any, **kwargs: Any
|
||||
) -> dns.resolver.Answer:
|
||||
"""Use a resolver to run a reverse query for PTR records.
|
||||
|
||||
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
|
||||
information on the parameters.
|
||||
"""
|
||||
|
||||
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
||||
|
||||
|
||||
async def resolve_name(
|
||||
name: dns.name.Name | str, family: int = socket.AF_UNSPEC, **kwargs: Any
|
||||
) -> dns.resolver.HostAnswers:
|
||||
"""Use a resolver to asynchronously query for address records.
|
||||
|
||||
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
|
||||
information on the parameters.
|
||||
"""
|
||||
|
||||
return await get_default_resolver().resolve_name(name, family, **kwargs)
|
||||
|
||||
|
||||
async def canonical_name(name: dns.name.Name | str) -> dns.name.Name:
|
||||
"""Determine the canonical name of *name*.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.canonical_name` for more
|
||||
information on the parameters and possible exceptions.
|
||||
"""
|
||||
|
||||
return await get_default_resolver().canonical_name(name)
|
||||
|
||||
|
||||
async def try_ddr(timeout: float = 5.0) -> None:
|
||||
"""Try to update the default resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
|
||||
"""
|
||||
return await get_default_resolver().try_ddr(timeout)
|
||||
|
||||
|
||||
async def zone_for_name(
|
||||
name: dns.name.Name | str,
|
||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
resolver: Resolver | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
) -> dns.name.Name:
|
||||
"""Find the name of the zone which contains the specified name.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
|
||||
information on the parameters and possible exceptions.
|
||||
"""
|
||||
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name, dns.name.root)
|
||||
if resolver is None:
|
||||
resolver = get_default_resolver()
|
||||
if not name.is_absolute():
|
||||
raise NotAbsolute(name)
|
||||
while True:
|
||||
try:
|
||||
answer = await resolver.resolve(
|
||||
name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
|
||||
)
|
||||
assert answer.rrset is not None
|
||||
if answer.rrset.name == name:
|
||||
return name
|
||||
# otherwise we were CNAMEd or DNAMEd and need to look higher
|
||||
except (NXDOMAIN, NoAnswer):
|
||||
pass
|
||||
try:
|
||||
name = name.parent()
|
||||
except dns.name.NoParent: # pragma: no cover
|
||||
raise NoRootSOA
|
||||
|
||||
|
||||
async def make_resolver_at(
|
||||
where: dns.name.Name | str,
|
||||
port: int = 53,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
resolver: Resolver | None = None,
|
||||
) -> Resolver:
|
||||
"""Make a stub resolver using the specified destination as the full resolver.
|
||||
|
||||
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
|
||||
full resolver.
|
||||
|
||||
*port*, an ``int``, the port to use. If not specified, the default is 53.
|
||||
|
||||
*family*, an ``int``, the address family to use. This parameter is used if
|
||||
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
|
||||
the first address returned by ``resolve_name()`` will be used, otherwise the
|
||||
first address of the specified family will be used.
|
||||
|
||||
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
|
||||
resolution of hostnames. If not specified, the default resolver will be used.
|
||||
|
||||
Returns a ``dns.resolver.Resolver`` or raises an exception.
|
||||
"""
|
||||
if resolver is None:
|
||||
resolver = get_default_resolver()
|
||||
nameservers: List[str | dns.nameserver.Nameserver] = []
|
||||
if isinstance(where, str) and dns.inet.is_address(where):
|
||||
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
|
||||
else:
|
||||
answers = await resolver.resolve_name(where, family)
|
||||
for address in answers.addresses():
|
||||
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
|
||||
res = Resolver(configure=False)
|
||||
res.nameservers = nameservers
|
||||
return res
|
||||
|
||||
|
||||
async def resolve_at(
|
||||
where: dns.name.Name | str,
|
||||
qname: dns.name.Name | str,
|
||||
rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.A,
|
||||
rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: str | None = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: float | None = None,
|
||||
search: bool | None = None,
|
||||
backend: dns.asyncbackend.Backend | None = None,
|
||||
port: int = 53,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
resolver: Resolver | None = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Query nameservers to find the answer to the question.
|
||||
|
||||
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
|
||||
to make a resolver, and then uses it to resolve the query.
|
||||
|
||||
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
|
||||
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
|
||||
resolver parameters *where*, *port*, *family*, and *resolver*.
|
||||
|
||||
If making more than one query, it is more efficient to call
|
||||
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
|
||||
instead of calling ``resolve_at()`` multiple times.
|
||||
"""
|
||||
res = await make_resolver_at(where, port, family, resolver)
|
||||
return await res.resolve(
|
||||
qname,
|
||||
rdtype,
|
||||
rdclass,
|
||||
tcp,
|
||||
source,
|
||||
raise_on_no_answer,
|
||||
source_port,
|
||||
lifetime,
|
||||
search,
|
||||
backend,
|
||||
)
|
||||
850
.venv/Lib/site-packages/dns/btree.py
Normal file
850
.venv/Lib/site-packages/dns/btree.py
Normal file
@@ -0,0 +1,850 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
"""
|
||||
A BTree in the style of Cormen, Leiserson, and Rivest's "Algorithms" book, with
|
||||
copy-on-write node updates, cursors, and optional space optimization for mostly-in-order
|
||||
insertion.
|
||||
"""
|
||||
|
||||
from collections.abc import MutableMapping, MutableSet
|
||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, cast
|
||||
|
||||
DEFAULT_T = 127
|
||||
|
||||
KT = TypeVar("KT") # the type of a key in Element
|
||||
|
||||
|
||||
class Element(Generic[KT]):
|
||||
"""All items stored in the BTree are Elements."""
|
||||
|
||||
def key(self) -> KT:
|
||||
"""The key for this element; the returned type must implement comparison."""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
ET = TypeVar("ET", bound=Element) # the type of a value in a _KV
|
||||
|
||||
|
||||
def _MIN(t: int) -> int:
|
||||
"""The minimum number of keys in a non-root node for a BTree with the specified
|
||||
``t``
|
||||
"""
|
||||
return t - 1
|
||||
|
||||
|
||||
def _MAX(t: int) -> int:
|
||||
"""The maximum number of keys in node for a BTree with the specified ``t``"""
|
||||
return 2 * t - 1
|
||||
|
||||
|
||||
class _Creator:
|
||||
"""A _Creator class instance is used as a unique id for the BTree which created
|
||||
a node.
|
||||
|
||||
We use a dedicated creator rather than just a BTree reference to avoid circularity
|
||||
that would complicate GC.
|
||||
"""
|
||||
|
||||
def __str__(self): # pragma: no cover
|
||||
return f"{id(self):x}"
|
||||
|
||||
|
||||
class _Node(Generic[KT, ET]):
|
||||
"""A Node in the BTree.
|
||||
|
||||
A Node (leaf or internal) of the BTree.
|
||||
"""
|
||||
|
||||
__slots__ = ["t", "creator", "is_leaf", "elts", "children"]
|
||||
|
||||
def __init__(self, t: int, creator: _Creator, is_leaf: bool):
|
||||
assert t >= 3
|
||||
self.t = t
|
||||
self.creator = creator
|
||||
self.is_leaf = is_leaf
|
||||
self.elts: list[ET] = []
|
||||
self.children: list[_Node[KT, ET]] = []
|
||||
|
||||
def is_maximal(self) -> bool:
|
||||
"""Does this node have the maximal number of keys?"""
|
||||
assert len(self.elts) <= _MAX(self.t)
|
||||
return len(self.elts) == _MAX(self.t)
|
||||
|
||||
def is_minimal(self) -> bool:
|
||||
"""Does this node have the minimal number of keys?"""
|
||||
assert len(self.elts) >= _MIN(self.t)
|
||||
return len(self.elts) == _MIN(self.t)
|
||||
|
||||
def search_in_node(self, key: KT) -> tuple[int, bool]:
|
||||
"""Get the index of the ``Element`` matching ``key`` or the index of its
|
||||
least successor.
|
||||
|
||||
Returns a tuple of the index and an ``equal`` boolean that is ``True`` iff.
|
||||
the key was found.
|
||||
"""
|
||||
l = len(self.elts)
|
||||
if l > 0 and key > self.elts[l - 1].key():
|
||||
# This is optimizing near in-order insertion.
|
||||
return l, False
|
||||
l = 0
|
||||
i = len(self.elts)
|
||||
r = i - 1
|
||||
equal = False
|
||||
while l <= r:
|
||||
m = (l + r) // 2
|
||||
k = self.elts[m].key()
|
||||
if key == k:
|
||||
i = m
|
||||
equal = True
|
||||
break
|
||||
elif key < k:
|
||||
i = m
|
||||
r = m - 1
|
||||
else:
|
||||
l = m + 1
|
||||
return i, equal
|
||||
|
||||
def maybe_cow_child(self, index: int) -> "_Node[KT, ET]":
|
||||
assert not self.is_leaf
|
||||
child = self.children[index]
|
||||
cloned = child.maybe_cow(self.creator)
|
||||
if cloned:
|
||||
self.children[index] = cloned
|
||||
return cloned
|
||||
else:
|
||||
return child
|
||||
|
||||
def _get_node(self, key: KT) -> Tuple[Optional["_Node[KT, ET]"], int]:
|
||||
"""Get the node associated with key and its index, doing
|
||||
copy-on-write if we have to descend.
|
||||
|
||||
Returns a tuple of the node and the index, or the tuple ``(None, 0)``
|
||||
if the key was not found.
|
||||
"""
|
||||
i, equal = self.search_in_node(key)
|
||||
if equal:
|
||||
return (self, i)
|
||||
elif self.is_leaf:
|
||||
return (None, 0)
|
||||
else:
|
||||
child = self.maybe_cow_child(i)
|
||||
return child._get_node(key)
|
||||
|
||||
def get(self, key: KT) -> ET | None:
|
||||
"""Get the element associated with *key* or return ``None``"""
|
||||
i, equal = self.search_in_node(key)
|
||||
if equal:
|
||||
return self.elts[i]
|
||||
elif self.is_leaf:
|
||||
return None
|
||||
else:
|
||||
return self.children[i].get(key)
|
||||
|
||||
def optimize_in_order_insertion(self, index: int) -> None:
|
||||
"""Try to minimize the number of Nodes in a BTree where the insertion
|
||||
is done in-order or close to it, by stealing as much as we can from our
|
||||
right sibling.
|
||||
|
||||
If we don't do this, then an in-order insertion will produce a BTree
|
||||
where most of the nodes are minimal.
|
||||
"""
|
||||
if index == 0:
|
||||
return
|
||||
left = self.children[index - 1]
|
||||
if len(left.elts) == _MAX(self.t):
|
||||
return
|
||||
left = self.maybe_cow_child(index - 1)
|
||||
while len(left.elts) < _MAX(self.t):
|
||||
if not left.try_right_steal(self, index - 1):
|
||||
break
|
||||
|
||||
def insert_nonfull(self, element: ET, in_order: bool) -> ET | None:
|
||||
assert not self.is_maximal()
|
||||
while True:
|
||||
key = element.key()
|
||||
i, equal = self.search_in_node(key)
|
||||
if equal:
|
||||
# replace
|
||||
old = self.elts[i]
|
||||
self.elts[i] = element
|
||||
return old
|
||||
elif self.is_leaf:
|
||||
self.elts.insert(i, element)
|
||||
return None
|
||||
else:
|
||||
child = self.maybe_cow_child(i)
|
||||
if child.is_maximal():
|
||||
self.adopt(*child.split())
|
||||
# Splitting might result in our target moving to us, so
|
||||
# search again.
|
||||
continue
|
||||
oelt = child.insert_nonfull(element, in_order)
|
||||
if in_order:
|
||||
self.optimize_in_order_insertion(i)
|
||||
return oelt
|
||||
|
||||
def split(self) -> tuple["_Node[KT, ET]", ET, "_Node[KT, ET]"]:
|
||||
"""Split a maximal node into two minimal ones and a central element."""
|
||||
assert self.is_maximal()
|
||||
right = self.__class__(self.t, self.creator, self.is_leaf)
|
||||
right.elts = list(self.elts[_MIN(self.t) + 1 :])
|
||||
middle = self.elts[_MIN(self.t)]
|
||||
self.elts = list(self.elts[: _MIN(self.t)])
|
||||
if not self.is_leaf:
|
||||
right.children = list(self.children[_MIN(self.t) + 1 :])
|
||||
self.children = list(self.children[: _MIN(self.t) + 1])
|
||||
return self, middle, right
|
||||
|
||||
def try_left_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
|
||||
"""Try to steal from this Node's left sibling for balancing purposes.
|
||||
|
||||
Returns ``True`` if the theft was successful, or ``False`` if not.
|
||||
"""
|
||||
if index != 0:
|
||||
left = parent.children[index - 1]
|
||||
if not left.is_minimal():
|
||||
left = parent.maybe_cow_child(index - 1)
|
||||
elt = parent.elts[index - 1]
|
||||
parent.elts[index - 1] = left.elts.pop()
|
||||
self.elts.insert(0, elt)
|
||||
if not left.is_leaf:
|
||||
assert not self.is_leaf
|
||||
child = left.children.pop()
|
||||
self.children.insert(0, child)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_right_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
|
||||
"""Try to steal from this Node's right sibling for balancing purposes.
|
||||
|
||||
Returns ``True`` if the theft was successful, or ``False`` if not.
|
||||
"""
|
||||
if index + 1 < len(parent.children):
|
||||
right = parent.children[index + 1]
|
||||
if not right.is_minimal():
|
||||
right = parent.maybe_cow_child(index + 1)
|
||||
elt = parent.elts[index]
|
||||
parent.elts[index] = right.elts.pop(0)
|
||||
self.elts.append(elt)
|
||||
if not right.is_leaf:
|
||||
assert not self.is_leaf
|
||||
child = right.children.pop(0)
|
||||
self.children.append(child)
|
||||
return True
|
||||
return False
|
||||
|
||||
def adopt(self, left: "_Node[KT, ET]", middle: ET, right: "_Node[KT, ET]") -> None:
|
||||
"""Adopt left, middle, and right into our Node (which must not be maximal,
|
||||
and which must not be a leaf). In the case were we are not the new root,
|
||||
then the left child must already be in the Node."""
|
||||
assert not self.is_maximal()
|
||||
assert not self.is_leaf
|
||||
key = middle.key()
|
||||
i, equal = self.search_in_node(key)
|
||||
assert not equal
|
||||
self.elts.insert(i, middle)
|
||||
if len(self.children) == 0:
|
||||
# We are the new root
|
||||
self.children = [left, right]
|
||||
else:
|
||||
assert self.children[i] == left
|
||||
self.children.insert(i + 1, right)
|
||||
|
||||
def merge(self, parent: "_Node[KT, ET]", index: int) -> None:
|
||||
"""Merge this node's parent and its right sibling into this node."""
|
||||
right = parent.children.pop(index + 1)
|
||||
self.elts.append(parent.elts.pop(index))
|
||||
self.elts.extend(right.elts)
|
||||
if not self.is_leaf:
|
||||
self.children.extend(right.children)
|
||||
|
||||
def minimum(self) -> ET:
|
||||
"""The least element in this subtree."""
|
||||
if self.is_leaf:
|
||||
return self.elts[0]
|
||||
else:
|
||||
return self.children[0].minimum()
|
||||
|
||||
def maximum(self) -> ET:
|
||||
"""The greatest element in this subtree."""
|
||||
if self.is_leaf:
|
||||
return self.elts[-1]
|
||||
else:
|
||||
return self.children[-1].maximum()
|
||||
|
||||
def balance(self, parent: "_Node[KT, ET]", index: int) -> None:
|
||||
"""This Node is minimal, and we want to make it non-minimal so we can delete.
|
||||
We try to steal from our siblings, and if that doesn't work we will merge
|
||||
with one of them."""
|
||||
assert not parent.is_leaf
|
||||
if self.try_left_steal(parent, index):
|
||||
return
|
||||
if self.try_right_steal(parent, index):
|
||||
return
|
||||
# Stealing didn't work, so both siblings must be minimal.
|
||||
if index == 0:
|
||||
# We are the left-most node so merge with our right sibling.
|
||||
self.merge(parent, index)
|
||||
else:
|
||||
# Have our left sibling merge with us. This lets us only have "merge right"
|
||||
# code.
|
||||
left = parent.maybe_cow_child(index - 1)
|
||||
left.merge(parent, index - 1)
|
||||
|
||||
def delete(
|
||||
self, key: KT, parent: Optional["_Node[KT, ET]"], exact: ET | None
|
||||
) -> ET | None:
|
||||
"""Delete an element matching *key* if it exists. If *exact* is not ``None``
|
||||
then it must be an exact match with that element. The Node must not be
|
||||
minimal unless it is the root."""
|
||||
assert parent is None or not self.is_minimal()
|
||||
i, equal = self.search_in_node(key)
|
||||
original_key = None
|
||||
if equal:
|
||||
# Note we use "is" here as we meant "exactly this object".
|
||||
if exact is not None and self.elts[i] is not exact:
|
||||
raise ValueError("exact delete did not match existing elt")
|
||||
if self.is_leaf:
|
||||
return self.elts.pop(i)
|
||||
# Note we need to ensure exact is None going forward as we've
|
||||
# already checked exactness and are about to change our target key
|
||||
# to the least successor.
|
||||
exact = None
|
||||
original_key = key
|
||||
least_successor = self.children[i + 1].minimum()
|
||||
key = least_successor.key()
|
||||
i = i + 1
|
||||
if self.is_leaf:
|
||||
# No match
|
||||
if exact is not None:
|
||||
raise ValueError("exact delete had no match")
|
||||
return None
|
||||
# recursively delete in the appropriate child
|
||||
child = self.maybe_cow_child(i)
|
||||
if child.is_minimal():
|
||||
child.balance(self, i)
|
||||
# Things may have moved.
|
||||
i, equal = self.search_in_node(key)
|
||||
assert not equal
|
||||
child = self.children[i]
|
||||
assert not child.is_minimal()
|
||||
elt = child.delete(key, self, exact)
|
||||
if original_key is not None:
|
||||
node, i = self._get_node(original_key)
|
||||
assert node is not None
|
||||
assert elt is not None
|
||||
oelt = node.elts[i]
|
||||
node.elts[i] = elt
|
||||
elt = oelt
|
||||
return elt
|
||||
|
||||
def visit_in_order(self, visit: Callable[[ET], None]) -> None:
|
||||
"""Call *visit* on all of the elements in order."""
|
||||
for i, elt in enumerate(self.elts):
|
||||
if not self.is_leaf:
|
||||
self.children[i].visit_in_order(visit)
|
||||
visit(elt)
|
||||
if not self.is_leaf:
|
||||
self.children[-1].visit_in_order(visit)
|
||||
|
||||
def _visit_preorder_by_node(self, visit: Callable[["_Node[KT, ET]"], None]) -> None:
|
||||
"""Visit nodes in preorder. This method is only used for testing."""
|
||||
visit(self)
|
||||
if not self.is_leaf:
|
||||
for child in self.children:
|
||||
child._visit_preorder_by_node(visit)
|
||||
|
||||
def maybe_cow(self, creator: _Creator) -> Optional["_Node[KT, ET]"]:
|
||||
"""Return a clone of this Node if it was not created by *creator*, or ``None``
|
||||
otherwise (i.e. copy for copy-on-write if we haven't already copied it)."""
|
||||
if self.creator is not creator:
|
||||
return self.clone(creator)
|
||||
else:
|
||||
return None
|
||||
|
||||
def clone(self, creator: _Creator) -> "_Node[KT, ET]":
|
||||
"""Make a shallow-copy duplicate of this node."""
|
||||
cloned = self.__class__(self.t, creator, self.is_leaf)
|
||||
cloned.elts.extend(self.elts)
|
||||
if not self.is_leaf:
|
||||
cloned.children.extend(self.children)
|
||||
return cloned
|
||||
|
||||
def __str__(self): # pragma: no cover
|
||||
if not self.is_leaf:
|
||||
children = " " + " ".join([f"{id(c):x}" for c in self.children])
|
||||
else:
|
||||
children = ""
|
||||
return f"{id(self):x} {self.creator} {self.elts}{children}"
|
||||
|
||||
|
||||
class Cursor(Generic[KT, ET]):
|
||||
"""A seekable cursor for a BTree.
|
||||
|
||||
If you are going to use a cursor on a mutable BTree, you should use it
|
||||
in a ``with`` block so that any mutations of the BTree automatically park
|
||||
the cursor.
|
||||
"""
|
||||
|
||||
def __init__(self, btree: "BTree[KT, ET]"):
|
||||
self.btree = btree
|
||||
self.current_node: _Node | None = None
|
||||
# The current index is the element index within the current node, or
|
||||
# if there is no current node then it is 0 on the left boundary and 1
|
||||
# on the right boundary.
|
||||
self.current_index: int = 0
|
||||
self.recurse = False
|
||||
self.increasing = True
|
||||
self.parents: list[tuple[_Node, int]] = []
|
||||
self.parked = False
|
||||
self.parking_key: KT | None = None
|
||||
self.parking_key_read = False
|
||||
|
||||
def _seek_least(self) -> None:
|
||||
# seek to the least value in the subtree beneath the current index of the
|
||||
# current node
|
||||
assert self.current_node is not None
|
||||
while not self.current_node.is_leaf:
|
||||
self.parents.append((self.current_node, self.current_index))
|
||||
self.current_node = self.current_node.children[self.current_index]
|
||||
assert self.current_node is not None
|
||||
self.current_index = 0
|
||||
|
||||
def _seek_greatest(self) -> None:
|
||||
# seek to the greatest value in the subtree beneath the current index of the
|
||||
# current node
|
||||
assert self.current_node is not None
|
||||
while not self.current_node.is_leaf:
|
||||
self.parents.append((self.current_node, self.current_index))
|
||||
self.current_node = self.current_node.children[self.current_index]
|
||||
assert self.current_node is not None
|
||||
self.current_index = len(self.current_node.elts)
|
||||
|
||||
def park(self):
|
||||
"""Park the cursor.
|
||||
|
||||
A cursor must be "parked" before mutating the BTree to avoid undefined behavior.
|
||||
Cursors created in a ``with`` block register with their BTree and will park
|
||||
automatically. Note that a parked cursor may not observe some changes made when
|
||||
it is parked; for example a cursor being iterated with next() will not see items
|
||||
inserted before its current position.
|
||||
"""
|
||||
if not self.parked:
|
||||
self.parked = True
|
||||
|
||||
def _maybe_unpark(self):
|
||||
if self.parked:
|
||||
if self.parking_key is not None:
|
||||
# remember our increasing hint, as seeking might change it
|
||||
increasing = self.increasing
|
||||
if self.parking_key_read:
|
||||
# We've already returned the parking key, so we want to be before it
|
||||
# if decreasing and after it if increasing.
|
||||
before = not self.increasing
|
||||
else:
|
||||
# We haven't returned the parking key, so we've parked right
|
||||
# after seeking or are on a boundary. Either way, the before
|
||||
# hint we want is the value of self.increasing.
|
||||
before = self.increasing
|
||||
self.seek(self.parking_key, before)
|
||||
self.increasing = increasing # might have been altered by seek()
|
||||
self.parked = False
|
||||
self.parking_key = None
|
||||
|
||||
def prev(self) -> ET | None:
|
||||
"""Get the previous element, or return None if on the left boundary."""
|
||||
self._maybe_unpark()
|
||||
self.parking_key = None
|
||||
if self.current_node is None:
|
||||
# on a boundary
|
||||
if self.current_index == 0:
|
||||
# left boundary, there is no prev
|
||||
return None
|
||||
else:
|
||||
assert self.current_index == 1
|
||||
# right boundary; seek to the actual boundary
|
||||
# so we can do a prev()
|
||||
self.current_node = self.btree.root
|
||||
self.current_index = len(self.btree.root.elts)
|
||||
self._seek_greatest()
|
||||
while True:
|
||||
if self.recurse:
|
||||
if not self.increasing:
|
||||
# We only want to recurse if we are continuing in the decreasing
|
||||
# direction.
|
||||
self._seek_greatest()
|
||||
self.recurse = False
|
||||
self.increasing = False
|
||||
self.current_index -= 1
|
||||
if self.current_index >= 0:
|
||||
elt = self.current_node.elts[self.current_index]
|
||||
if not self.current_node.is_leaf:
|
||||
self.recurse = True
|
||||
self.parking_key = elt.key()
|
||||
self.parking_key_read = True
|
||||
return elt
|
||||
else:
|
||||
if len(self.parents) > 0:
|
||||
self.current_node, self.current_index = self.parents.pop()
|
||||
else:
|
||||
self.current_node = None
|
||||
self.current_index = 0
|
||||
return None
|
||||
|
||||
def next(self) -> ET | None:
|
||||
"""Get the next element, or return None if on the right boundary."""
|
||||
self._maybe_unpark()
|
||||
self.parking_key = None
|
||||
if self.current_node is None:
|
||||
# on a boundary
|
||||
if self.current_index == 1:
|
||||
# right boundary, there is no next
|
||||
return None
|
||||
else:
|
||||
assert self.current_index == 0
|
||||
# left boundary; seek to the actual boundary
|
||||
# so we can do a next()
|
||||
self.current_node = self.btree.root
|
||||
self.current_index = 0
|
||||
self._seek_least()
|
||||
while True:
|
||||
if self.recurse:
|
||||
if self.increasing:
|
||||
# We only want to recurse if we are continuing in the increasing
|
||||
# direction.
|
||||
self._seek_least()
|
||||
self.recurse = False
|
||||
self.increasing = True
|
||||
if self.current_index < len(self.current_node.elts):
|
||||
elt = self.current_node.elts[self.current_index]
|
||||
self.current_index += 1
|
||||
if not self.current_node.is_leaf:
|
||||
self.recurse = True
|
||||
self.parking_key = elt.key()
|
||||
self.parking_key_read = True
|
||||
return elt
|
||||
else:
|
||||
if len(self.parents) > 0:
|
||||
self.current_node, self.current_index = self.parents.pop()
|
||||
else:
|
||||
self.current_node = None
|
||||
self.current_index = 1
|
||||
return None
|
||||
|
||||
def _adjust_for_before(self, before: bool, i: int) -> None:
|
||||
if before:
|
||||
self.current_index = i
|
||||
else:
|
||||
self.current_index = i + 1
|
||||
|
||||
def seek(self, key: KT, before: bool = True) -> None:
|
||||
"""Seek to the specified key.
|
||||
|
||||
If *before* is ``True`` (the default) then the cursor is positioned just
|
||||
before *key* if it exists, or before its least successor if it doesn't. A
|
||||
subsequent next() will retrieve this value. If *before* is ``False``, then
|
||||
the cursor is positioned just after *key* if it exists, or its greatest
|
||||
precessessor if it doesn't. A subsequent prev() will return this value.
|
||||
"""
|
||||
self.current_node = self.btree.root
|
||||
assert self.current_node is not None
|
||||
self.recurse = False
|
||||
self.parents = []
|
||||
self.increasing = before
|
||||
self.parked = False
|
||||
self.parking_key = key
|
||||
self.parking_key_read = False
|
||||
while not self.current_node.is_leaf:
|
||||
i, equal = self.current_node.search_in_node(key)
|
||||
if equal:
|
||||
self._adjust_for_before(before, i)
|
||||
if before:
|
||||
self._seek_greatest()
|
||||
else:
|
||||
self._seek_least()
|
||||
return
|
||||
self.parents.append((self.current_node, i))
|
||||
self.current_node = self.current_node.children[i]
|
||||
assert self.current_node is not None
|
||||
i, equal = self.current_node.search_in_node(key)
|
||||
if equal:
|
||||
self._adjust_for_before(before, i)
|
||||
else:
|
||||
self.current_index = i
|
||||
|
||||
def seek_first(self) -> None:
|
||||
"""Seek to the left boundary (i.e. just before the least element).
|
||||
|
||||
A subsequent next() will return the least element if the BTree isn't empty."""
|
||||
self.current_node = None
|
||||
self.current_index = 0
|
||||
self.recurse = False
|
||||
self.increasing = True
|
||||
self.parents = []
|
||||
self.parked = False
|
||||
self.parking_key = None
|
||||
|
||||
def seek_last(self) -> None:
|
||||
"""Seek to the right boundary (i.e. just after the greatest element).
|
||||
|
||||
A subsequent prev() will return the greatest element if the BTree isn't empty.
|
||||
"""
|
||||
self.current_node = None
|
||||
self.current_index = 1
|
||||
self.recurse = False
|
||||
self.increasing = False
|
||||
self.parents = []
|
||||
self.parked = False
|
||||
self.parking_key = None
|
||||
|
||||
def __enter__(self):
|
||||
self.btree.register_cursor(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.btree.deregister_cursor(self)
|
||||
return False
|
||||
|
||||
|
||||
class Immutable(Exception):
|
||||
"""The BTree is immutable."""
|
||||
|
||||
|
||||
class BTree(Generic[KT, ET]):
|
||||
"""An in-memory BTree with copy-on-write and cursors."""
|
||||
|
||||
def __init__(self, *, t: int = DEFAULT_T, original: Optional["BTree"] = None):
|
||||
"""Create a BTree.
|
||||
|
||||
If *original* is not ``None``, then the BTree is shallow-cloned from
|
||||
*original* using copy-on-write. Otherwise a new BTree with the specified
|
||||
*t* value is created.
|
||||
|
||||
The BTree is not thread-safe.
|
||||
"""
|
||||
# We don't use a reference to ourselves as a creator as we don't want
|
||||
# to prevent GC of old btrees.
|
||||
self.creator = _Creator()
|
||||
self._immutable = False
|
||||
self.t: int
|
||||
self.root: _Node
|
||||
self.size: int
|
||||
self.cursors: set[Cursor] = set()
|
||||
if original is not None:
|
||||
if not original._immutable:
|
||||
raise ValueError("original BTree is not immutable")
|
||||
self.t = original.t
|
||||
self.root = original.root
|
||||
self.size = original.size
|
||||
else:
|
||||
if t < 3:
|
||||
raise ValueError("t must be >= 3")
|
||||
self.t = t
|
||||
self.root = _Node(self.t, self.creator, True)
|
||||
self.size = 0
|
||||
|
||||
def make_immutable(self):
|
||||
"""Make the BTree immutable.
|
||||
|
||||
Attempts to alter the BTree after making it immutable will raise an
|
||||
Immutable exception. This operation cannot be undone.
|
||||
"""
|
||||
if not self._immutable:
|
||||
self._immutable = True
|
||||
|
||||
def _check_mutable_and_park(self) -> None:
|
||||
if self._immutable:
|
||||
raise Immutable
|
||||
for cursor in self.cursors:
|
||||
cursor.park()
|
||||
|
||||
# Note that we don't use insert() and delete() but rather insert_element() and
|
||||
# delete_key() so that BTreeDict can be a proper MutableMapping and supply the
|
||||
# rest of the standard mapping API.
|
||||
|
||||
def insert_element(self, elt: ET, in_order: bool = False) -> ET | None:
|
||||
"""Insert the element into the BTree.
|
||||
|
||||
If *in_order* is ``True``, then extra work will be done to make left siblings
|
||||
full, which optimizes storage space when the the elements are inserted in-order
|
||||
or close to it.
|
||||
|
||||
Returns the previously existing element at the element's key or ``None``.
|
||||
"""
|
||||
self._check_mutable_and_park()
|
||||
cloned = self.root.maybe_cow(self.creator)
|
||||
if cloned:
|
||||
self.root = cloned
|
||||
if self.root.is_maximal():
|
||||
old_root = self.root
|
||||
self.root = _Node(self.t, self.creator, False)
|
||||
self.root.adopt(*old_root.split())
|
||||
oelt = self.root.insert_nonfull(elt, in_order)
|
||||
if oelt is None:
|
||||
# We did not replace, so something was added.
|
||||
self.size += 1
|
||||
return oelt
|
||||
|
||||
def get_element(self, key: KT) -> ET | None:
|
||||
"""Get the element matching *key* from the BTree, or return ``None`` if it
|
||||
does not exist.
|
||||
"""
|
||||
return self.root.get(key)
|
||||
|
||||
def _delete(self, key: KT, exact: ET | None) -> ET | None:
|
||||
self._check_mutable_and_park()
|
||||
cloned = self.root.maybe_cow(self.creator)
|
||||
if cloned:
|
||||
self.root = cloned
|
||||
elt = self.root.delete(key, None, exact)
|
||||
if elt is not None:
|
||||
# We deleted something
|
||||
self.size -= 1
|
||||
if len(self.root.elts) == 0:
|
||||
# The root is now empty. If there is a child, then collapse this root
|
||||
# level and make the child the new root.
|
||||
if not self.root.is_leaf:
|
||||
assert len(self.root.children) == 1
|
||||
self.root = self.root.children[0]
|
||||
return elt
|
||||
|
||||
def delete_key(self, key: KT) -> ET | None:
|
||||
"""Delete the element matching *key* from the BTree.
|
||||
|
||||
Returns the matching element or ``None`` if it does not exist.
|
||||
"""
|
||||
return self._delete(key, None)
|
||||
|
||||
def delete_exact(self, element: ET) -> ET | None:
|
||||
"""Delete *element* from the BTree.
|
||||
|
||||
Returns the matching element or ``None`` if it was not in the BTree.
|
||||
"""
|
||||
delt = self._delete(element.key(), element)
|
||||
assert delt is element
|
||||
return delt
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def visit_in_order(self, visit: Callable[[ET], None]) -> None:
|
||||
"""Call *visit*(element) on all elements in the tree in sorted order."""
|
||||
self.root.visit_in_order(visit)
|
||||
|
||||
def _visit_preorder_by_node(self, visit: Callable[[_Node], None]) -> None:
|
||||
self.root._visit_preorder_by_node(visit)
|
||||
|
||||
def cursor(self) -> Cursor[KT, ET]:
|
||||
"""Create a cursor."""
|
||||
return Cursor(self)
|
||||
|
||||
def register_cursor(self, cursor: Cursor) -> None:
|
||||
"""Register a cursor for the automatic parking service."""
|
||||
self.cursors.add(cursor)
|
||||
|
||||
def deregister_cursor(self, cursor: Cursor) -> None:
|
||||
"""Deregister a cursor from the automatic parking service."""
|
||||
self.cursors.discard(cursor)
|
||||
|
||||
def __copy__(self):
|
||||
return self.__class__(original=self)
|
||||
|
||||
def __iter__(self):
|
||||
with self.cursor() as cursor:
|
||||
while True:
|
||||
elt = cursor.next()
|
||||
if elt is None:
|
||||
break
|
||||
yield elt.key()
|
||||
|
||||
|
||||
VT = TypeVar("VT") # the type of a value in a BTreeDict
|
||||
|
||||
|
||||
class KV(Element, Generic[KT, VT]):
|
||||
"""The BTree element type used in a ``BTreeDict``."""
|
||||
|
||||
def __init__(self, key: KT, value: VT):
|
||||
self._key = key
|
||||
self._value = value
|
||||
|
||||
def key(self) -> KT:
|
||||
return self._key
|
||||
|
||||
def value(self) -> VT:
|
||||
return self._value
|
||||
|
||||
def __str__(self): # pragma: no cover
|
||||
return f"KV({self._key}, {self._value})"
|
||||
|
||||
def __repr__(self): # pragma: no cover
|
||||
return f"KV({self._key}, {self._value})"
|
||||
|
||||
|
||||
class BTreeDict(Generic[KT, VT], BTree[KT, KV[KT, VT]], MutableMapping[KT, VT]):
|
||||
"""A MutableMapping implemented with a BTree.
|
||||
|
||||
Unlike a normal Python dict, the BTreeDict may be mutated while iterating.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
t: int = DEFAULT_T,
|
||||
original: BTree | None = None,
|
||||
in_order: bool = False,
|
||||
):
|
||||
super().__init__(t=t, original=original)
|
||||
self.in_order = in_order
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
elt = self.get_element(key)
|
||||
if elt is None:
|
||||
raise KeyError
|
||||
else:
|
||||
return cast(KV, elt).value()
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
elt = KV(key, value)
|
||||
self.insert_element(elt, self.in_order)
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
if self.delete_key(key) is None:
|
||||
raise KeyError
|
||||
|
||||
|
||||
class Member(Element, Generic[KT]):
|
||||
"""The BTree element type used in a ``BTreeSet``."""
|
||||
|
||||
def __init__(self, key: KT):
|
||||
self._key = key
|
||||
|
||||
def key(self) -> KT:
|
||||
return self._key
|
||||
|
||||
|
||||
class BTreeSet(BTree, Generic[KT], MutableSet[KT]):
|
||||
"""A MutableSet implemented with a BTree.
|
||||
|
||||
Unlike a normal Python set, the BTreeSet may be mutated while iterating.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
t: int = DEFAULT_T,
|
||||
original: BTree | None = None,
|
||||
in_order: bool = False,
|
||||
):
|
||||
super().__init__(t=t, original=original)
|
||||
self.in_order = in_order
|
||||
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
return self.get_element(key) is not None
|
||||
|
||||
def add(self, value: KT) -> None:
|
||||
elt = Member(value)
|
||||
self.insert_element(elt, self.in_order)
|
||||
|
||||
def discard(self, value: KT) -> None:
|
||||
self.delete_key(value)
|
||||
367
.venv/Lib/site-packages/dns/btreezone.py
Normal file
367
.venv/Lib/site-packages/dns/btreezone.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and
|
||||
# a separate per-version delegation index. These additions let us
|
||||
#
|
||||
# 1) Do efficient CoW versioning (useful for future online updates).
|
||||
# 2) Maintain sort order
|
||||
# 3) Allow delegations to be found easily
|
||||
# 4) Handle glue
|
||||
# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant. The ORIGIN
|
||||
# flag is set at the origin node, the DELEGATION FLAG is set at delegation
|
||||
# points, and the GLUE flag is set on nodes beneath delegation points.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, MutableMapping, Tuple, cast
|
||||
|
||||
import dns.btree
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.node
|
||||
import dns.rdataclass
|
||||
import dns.rdataset
|
||||
import dns.rdatatype
|
||||
import dns.versioned
|
||||
import dns.zone
|
||||
|
||||
|
||||
class NodeFlags(enum.IntFlag):
|
||||
ORIGIN = 0x01
|
||||
DELEGATION = 0x02
|
||||
GLUE = 0x04
|
||||
|
||||
|
||||
class Node(dns.node.Node):
|
||||
__slots__ = ["flags", "id"]
|
||||
|
||||
def __init__(self, flags: NodeFlags | None = None):
|
||||
super().__init__()
|
||||
if flags is None:
|
||||
# We allow optional flags rather than a default
|
||||
# as pyright doesn't like assigning a literal 0
|
||||
# to flags.
|
||||
flags = NodeFlags(0)
|
||||
self.flags = flags
|
||||
self.id = 0
|
||||
|
||||
def is_delegation(self):
|
||||
return (self.flags & NodeFlags.DELEGATION) != 0
|
||||
|
||||
def is_glue(self):
|
||||
return (self.flags & NodeFlags.GLUE) != 0
|
||||
|
||||
def is_origin(self):
|
||||
return (self.flags & NodeFlags.ORIGIN) != 0
|
||||
|
||||
def is_origin_or_glue(self):
|
||||
return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class ImmutableNode(Node):
|
||||
def __init__(self, node: Node):
|
||||
super().__init__()
|
||||
self.id = node.id
|
||||
self.rdatasets = tuple( # type: ignore
|
||||
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
|
||||
)
|
||||
self.flags = node.flags
|
||||
|
||||
def find_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
create: bool = False,
|
||||
) -> dns.rdataset.Rdataset:
|
||||
if create:
|
||||
raise TypeError("immutable")
|
||||
return super().find_rdataset(rdclass, rdtype, covers, False)
|
||||
|
||||
def get_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
create: bool = False,
|
||||
) -> dns.rdataset.Rdataset | None:
|
||||
if create:
|
||||
raise TypeError("immutable")
|
||||
return super().get_rdataset(rdclass, rdtype, covers, False)
|
||||
|
||||
def delete_rdataset(
|
||||
self,
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
|
||||
) -> None:
|
||||
raise TypeError("immutable")
|
||||
|
||||
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
|
||||
raise TypeError("immutable")
|
||||
|
||||
def is_immutable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class Delegations(dns.btree.BTreeSet[dns.name.Name]):
|
||||
def get_delegation(self, name: dns.name.Name) -> Tuple[dns.name.Name | None, bool]:
|
||||
"""Get the delegation applicable to *name*, if it exists.
|
||||
|
||||
If there delegation, then return a tuple consisting of the name of
|
||||
the delegation point, and a boolean which is `True` if the name is a proper
|
||||
subdomain of the delegation point, and `False` if it is equal to the delegation
|
||||
point.
|
||||
"""
|
||||
cursor = self.cursor()
|
||||
cursor.seek(name, before=False)
|
||||
prev = cursor.prev()
|
||||
if prev is None:
|
||||
return None, False
|
||||
cut = prev.key()
|
||||
reln, _, _ = name.fullcompare(cut)
|
||||
is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN
|
||||
if is_subdomain or reln == dns.name.NameRelation.EQUAL:
|
||||
return cut, is_subdomain
|
||||
else:
|
||||
return None, False
|
||||
|
||||
def is_glue(self, name: dns.name.Name) -> bool:
|
||||
"""Is *name* glue, i.e. is it beneath a delegation?"""
|
||||
cursor = self.cursor()
|
||||
cursor.seek(name, before=False)
|
||||
cut, is_subdomain = self.get_delegation(name)
|
||||
if cut is None:
|
||||
return False
|
||||
return is_subdomain
|
||||
|
||||
|
||||
class WritableVersion(dns.zone.WritableVersion):
|
||||
def __init__(self, zone: dns.zone.Zone, replacement: bool = False):
|
||||
super().__init__(zone, True)
|
||||
if not replacement:
|
||||
assert isinstance(zone, dns.versioned.Zone)
|
||||
version = zone._versions[-1]
|
||||
self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict(
|
||||
original=version.nodes # type: ignore
|
||||
)
|
||||
self.delegations = Delegations(original=version.delegations) # type: ignore
|
||||
else:
|
||||
self.delegations = Delegations()
|
||||
|
||||
def _is_origin(self, name: dns.name.Name) -> bool:
|
||||
# Assumes name has already been validated (and thus adjusted to the right
|
||||
# relativity too)
|
||||
if self.zone.relativize:
|
||||
return name == dns.name.empty
|
||||
else:
|
||||
return name == self.zone.origin
|
||||
|
||||
def _maybe_cow_with_name(
|
||||
self, name: dns.name.Name
|
||||
) -> Tuple[dns.node.Node, dns.name.Name]:
|
||||
(node, name) = super()._maybe_cow_with_name(name)
|
||||
node = cast(Node, node)
|
||||
if self._is_origin(name):
|
||||
node.flags |= NodeFlags.ORIGIN
|
||||
elif self.delegations.is_glue(name):
|
||||
node.flags |= NodeFlags.GLUE
|
||||
return (node, name)
|
||||
|
||||
def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None:
|
||||
cursor = self.nodes.cursor() # type: ignore
|
||||
cursor.seek(name, False)
|
||||
updates = []
|
||||
while True:
|
||||
elt = cursor.next()
|
||||
if elt is None:
|
||||
break
|
||||
ename = elt.key()
|
||||
if not ename.is_subdomain(name):
|
||||
break
|
||||
node = cast(dns.node.Node, elt.value())
|
||||
if ename not in self.changed:
|
||||
new_node = self.zone.node_factory()
|
||||
new_node.id = self.id # type: ignore
|
||||
new_node.rdatasets.extend(node.rdatasets)
|
||||
self.changed.add(ename)
|
||||
node = new_node
|
||||
assert isinstance(node, Node)
|
||||
if is_glue:
|
||||
node.flags |= NodeFlags.GLUE
|
||||
else:
|
||||
node.flags &= ~NodeFlags.GLUE
|
||||
# We don't update node here as any insertion could disturb the
|
||||
# btree and invalidate our cursor. We could use the cursor in a
|
||||
# with block and avoid this, but it would do a lot of parking and
|
||||
# unparking so the deferred update mode may still be better.
|
||||
updates.append((ename, node))
|
||||
for ename, node in updates:
|
||||
self.nodes[ename] = node
|
||||
|
||||
def delete_node(self, name: dns.name.Name) -> None:
|
||||
name = self._validate_name(name)
|
||||
node = self.nodes.get(name)
|
||||
if node is not None:
|
||||
if node.is_delegation(): # type: ignore
|
||||
self.delegations.discard(name)
|
||||
self.update_glue_flag(name, False)
|
||||
del self.nodes[name]
|
||||
self.changed.add(name)
|
||||
|
||||
def put_rdataset(
|
||||
self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
|
||||
) -> None:
|
||||
(node, name) = self._maybe_cow_with_name(name)
|
||||
if (
|
||||
rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue() # type: ignore
|
||||
):
|
||||
node.flags |= NodeFlags.DELEGATION # type: ignore
|
||||
if name not in self.delegations:
|
||||
self.delegations.add(name)
|
||||
self.update_glue_flag(name, True)
|
||||
node.replace_rdataset(rdataset)
|
||||
|
||||
def delete_rdataset(
|
||||
self,
|
||||
name: dns.name.Name,
|
||||
rdtype: dns.rdatatype.RdataType,
|
||||
covers: dns.rdatatype.RdataType,
|
||||
) -> None:
|
||||
(node, name) = self._maybe_cow_with_name(name)
|
||||
if rdtype == dns.rdatatype.NS and name in self.delegations: # type: ignore
|
||||
node.flags &= ~NodeFlags.DELEGATION # type: ignore
|
||||
self.delegations.discard(name) # type: ignore
|
||||
self.update_glue_flag(name, False)
|
||||
node.delete_rdataset(self.zone.rdclass, rdtype, covers)
|
||||
if len(node) == 0:
|
||||
del self.nodes[name]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Bounds:
|
||||
name: dns.name.Name
|
||||
left: dns.name.Name
|
||||
right: dns.name.Name | None
|
||||
closest_encloser: dns.name.Name
|
||||
is_equal: bool
|
||||
is_delegation: bool
|
||||
|
||||
def __str__(self):
|
||||
if self.is_equal:
|
||||
op = "="
|
||||
else:
|
||||
op = "<"
|
||||
if self.is_delegation:
|
||||
zonecut = " zonecut"
|
||||
else:
|
||||
zonecut = ""
|
||||
return (
|
||||
f"{self.left} {op} {self.name} < {self.right}{zonecut}; "
|
||||
f"{self.closest_encloser}"
|
||||
)
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class ImmutableVersion(dns.zone.Version):
|
||||
def __init__(self, version: dns.zone.Version):
|
||||
if not isinstance(version, WritableVersion):
|
||||
raise ValueError(
|
||||
"a dns.btreezone.ImmutableVersion requires a "
|
||||
"dns.btreezone.WritableVersion"
|
||||
)
|
||||
super().__init__(version.zone, True)
|
||||
self.id = version.id
|
||||
self.origin = version.origin
|
||||
for name in version.changed:
|
||||
node = version.nodes.get(name)
|
||||
if node:
|
||||
version.nodes[name] = ImmutableNode(node)
|
||||
# the cast below is for mypy
|
||||
self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes)
|
||||
self.nodes.make_immutable() # type: ignore
|
||||
self.delegations = version.delegations
|
||||
self.delegations.make_immutable()
|
||||
|
||||
def bounds(self, name: dns.name.Name | str) -> Bounds:
|
||||
"""Return the 'bounds' of *name* in its zone.
|
||||
|
||||
The bounds information is useful when making an authoritative response, as
|
||||
it can be used to determine whether the query name is at or beneath a delegation
|
||||
point. The other data in the ``Bounds`` object is useful for making on-the-fly
|
||||
DNSSEC signatures.
|
||||
|
||||
The left bound of *name* is *name* itself if it is in the zone, or the greatest
|
||||
predecessor which is in the zone.
|
||||
|
||||
The right bound of *name* is the least successor of *name*, or ``None`` if
|
||||
no name in the zone is greater than *name*.
|
||||
|
||||
The closest encloser of *name* is *name* itself, if *name* is in the zone;
|
||||
otherwise it is the name with the largest number of labels in common with
|
||||
*name* that is in the zone, either explicitly or by the implied existence
|
||||
of empty non-terminals.
|
||||
|
||||
The bounds *is_equal* field is ``True`` if and only if *name* is equal to
|
||||
its left bound.
|
||||
|
||||
The bounds *is_delegation* field is ``True`` if and only if the left bound is a
|
||||
delegation point.
|
||||
"""
|
||||
assert self.origin is not None
|
||||
# validate the origin because we may need to relativize
|
||||
origin = self.zone._validate_name(self.origin)
|
||||
name = self.zone._validate_name(name)
|
||||
cut, _ = self.delegations.get_delegation(name)
|
||||
if cut is not None:
|
||||
target = cut
|
||||
is_delegation = True
|
||||
else:
|
||||
target = name
|
||||
is_delegation = False
|
||||
c = cast(dns.btree.BTreeDict, self.nodes).cursor()
|
||||
c.seek(target, False)
|
||||
left = c.prev()
|
||||
assert left is not None
|
||||
c.next() # skip over left
|
||||
while True:
|
||||
right = c.next()
|
||||
if right is None or not right.value().is_glue():
|
||||
break
|
||||
left_comparison = left.key().fullcompare(name)
|
||||
if right is not None:
|
||||
right_key = right.key()
|
||||
right_comparison = right_key.fullcompare(name)
|
||||
else:
|
||||
right_comparison = (
|
||||
dns.name.NAMERELN_COMMONANCESTOR,
|
||||
-1,
|
||||
len(origin),
|
||||
)
|
||||
right_key = None
|
||||
closest_encloser = dns.name.Name(
|
||||
name[-max(left_comparison[2], right_comparison[2]) :]
|
||||
)
|
||||
return Bounds(
|
||||
name,
|
||||
left.key(),
|
||||
right_key,
|
||||
closest_encloser,
|
||||
left_comparison[0] == dns.name.NameRelation.EQUAL,
|
||||
is_delegation,
|
||||
)
|
||||
|
||||
|
||||
class Zone(dns.versioned.Zone):
|
||||
node_factory: Callable[[], dns.node.Node] = Node
|
||||
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast(
|
||||
Callable[[], MutableMapping[dns.name.Name, dns.node.Node]],
|
||||
dns.btree.BTreeDict[dns.name.Name, Node],
|
||||
)
|
||||
writable_version_factory: (
|
||||
Callable[[dns.zone.Zone, bool], dns.zone.Version] | None
|
||||
) = WritableVersion
|
||||
immutable_version_factory: Callable[[dns.zone.Version], dns.zone.Version] | None = (
|
||||
ImmutableVersion
|
||||
)
|
||||
1242
.venv/Lib/site-packages/dns/dnssec.py
Normal file
1242
.venv/Lib/site-packages/dns/dnssec.py
Normal file
File diff suppressed because it is too large
Load Diff
124
.venv/Lib/site-packages/dns/dnssecalgs/__init__.py
Normal file
124
.venv/Lib/site-packages/dns/dnssecalgs/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Dict, Tuple, Type
|
||||
|
||||
import dns._features
|
||||
import dns.name
|
||||
from dns.dnssecalgs.base import GenericPrivateKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.exception import UnsupportedAlgorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
# pyright: reportPossiblyUnboundVariable=false
|
||||
|
||||
if dns._features.have("dnssec"):
|
||||
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
|
||||
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
|
||||
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
|
||||
from dns.dnssecalgs.rsa import (
|
||||
PrivateRSAMD5,
|
||||
PrivateRSASHA1,
|
||||
PrivateRSASHA1NSEC3SHA1,
|
||||
PrivateRSASHA256,
|
||||
PrivateRSASHA512,
|
||||
)
|
||||
|
||||
_have_cryptography = True
|
||||
else:
|
||||
_have_cryptography = False
|
||||
|
||||
AlgorithmPrefix = bytes | dns.name.Name | None
|
||||
|
||||
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
|
||||
if _have_cryptography:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
algorithms.update(
|
||||
{
|
||||
(Algorithm.RSAMD5, None): PrivateRSAMD5,
|
||||
(Algorithm.DSA, None): PrivateDSA,
|
||||
(Algorithm.RSASHA1, None): PrivateRSASHA1,
|
||||
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
|
||||
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
|
||||
(Algorithm.RSASHA256, None): PrivateRSASHA256,
|
||||
(Algorithm.RSASHA512, None): PrivateRSASHA512,
|
||||
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
|
||||
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
|
||||
(Algorithm.ED25519, None): PrivateED25519,
|
||||
(Algorithm.ED448, None): PrivateED448,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_algorithm_cls(
|
||||
algorithm: int | str, prefix: AlgorithmPrefix = None
|
||||
) -> Type[GenericPrivateKey]:
|
||||
"""Get Private Key class from Algorithm.
|
||||
|
||||
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||
|
||||
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
|
||||
|
||||
Returns a ``dns.dnssecalgs.GenericPrivateKey``
|
||||
"""
|
||||
algorithm = Algorithm.make(algorithm)
|
||||
cls = algorithms.get((algorithm, prefix))
|
||||
if cls:
|
||||
return cls
|
||||
raise UnsupportedAlgorithm(
|
||||
f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
|
||||
)
|
||||
|
||||
|
||||
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
|
||||
"""Get Private Key class from DNSKEY.
|
||||
|
||||
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
|
||||
|
||||
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
|
||||
|
||||
Returns a ``dns.dnssecalgs.GenericPrivateKey``
|
||||
"""
|
||||
prefix: AlgorithmPrefix = None
|
||||
if dnskey.algorithm == Algorithm.PRIVATEDNS:
|
||||
prefix, _ = dns.name.from_wire(dnskey.key, 0)
|
||||
elif dnskey.algorithm == Algorithm.PRIVATEOID:
|
||||
length = int(dnskey.key[0])
|
||||
prefix = dnskey.key[0 : length + 1]
|
||||
return get_algorithm_cls(dnskey.algorithm, prefix)
|
||||
|
||||
|
||||
def register_algorithm_cls(
|
||||
algorithm: int | str,
|
||||
algorithm_cls: Type[GenericPrivateKey],
|
||||
name: dns.name.Name | str | None = None,
|
||||
oid: bytes | None = None,
|
||||
) -> None:
|
||||
"""Register Algorithm Private Key class.
|
||||
|
||||
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||
|
||||
*algorithm_cls*: A `GenericPrivateKey` class.
|
||||
|
||||
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
|
||||
|
||||
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
|
||||
|
||||
Raises ``ValueError`` if a name or oid is specified incorrectly.
|
||||
"""
|
||||
if not issubclass(algorithm_cls, GenericPrivateKey):
|
||||
raise TypeError("Invalid algorithm class")
|
||||
algorithm = Algorithm.make(algorithm)
|
||||
prefix: AlgorithmPrefix = None
|
||||
if algorithm == Algorithm.PRIVATEDNS:
|
||||
if name is None:
|
||||
raise ValueError("Name required for PRIVATEDNS algorithms")
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name)
|
||||
prefix = name
|
||||
elif algorithm == Algorithm.PRIVATEOID:
|
||||
if oid is None:
|
||||
raise ValueError("OID required for PRIVATEOID algorithms")
|
||||
prefix = bytes([len(oid)]) + oid
|
||||
elif name:
|
||||
raise ValueError("Name only supported for PRIVATEDNS algorithm")
|
||||
elif oid:
|
||||
raise ValueError("OID only supported for PRIVATEOID algorithm")
|
||||
algorithms[(algorithm, prefix)] = algorithm_cls
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
89
.venv/Lib/site-packages/dns/dnssecalgs/base.py
Normal file
89
.venv/Lib/site-packages/dns/dnssecalgs/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
|
||||
from typing import Any, Type
|
||||
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.exception import AlgorithmKeyMismatch
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
from dns.rdtypes.dnskeybase import Flag
|
||||
|
||||
|
||||
class GenericPublicKey(ABC):
|
||||
algorithm: Algorithm
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, key: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
"""Verify signed DNSSEC data"""
|
||||
|
||||
@abstractmethod
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode key as bytes for DNSKEY"""
|
||||
|
||||
@classmethod
|
||||
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
|
||||
if key.algorithm != cls.algorithm:
|
||||
raise AlgorithmKeyMismatch
|
||||
|
||||
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
|
||||
"""Return public key as DNSKEY"""
|
||||
return DNSKEY(
|
||||
rdclass=dns.rdataclass.IN,
|
||||
rdtype=dns.rdatatype.DNSKEY,
|
||||
flags=flags,
|
||||
protocol=protocol,
|
||||
algorithm=self.algorithm,
|
||||
key=self.encode_key_bytes(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
|
||||
"""Create public key from DNSKEY"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
|
||||
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
|
||||
in RFC 5280"""
|
||||
|
||||
@abstractmethod
|
||||
def to_pem(self) -> bytes:
|
||||
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
|
||||
in RFC 5280"""
|
||||
|
||||
|
||||
class GenericPrivateKey(ABC):
|
||||
public_cls: Type[GenericPublicKey]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, key: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign DNSSEC data"""
|
||||
|
||||
@abstractmethod
|
||||
def public_key(self) -> "GenericPublicKey":
|
||||
"""Return public key instance"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pem(
|
||||
cls, private_pem: bytes, password: bytes | None = None
|
||||
) -> "GenericPrivateKey":
|
||||
"""Create private key from PEM-encoded PKCS#8"""
|
||||
|
||||
@abstractmethod
|
||||
def to_pem(self, password: bytes | None = None) -> bytes:
|
||||
"""Return private key as PEM-encoded PKCS#8"""
|
||||
68
.venv/Lib/site-packages/dns/dnssecalgs/cryptography.py
Normal file
68
.venv/Lib/site-packages/dns/dnssecalgs/cryptography.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
|
||||
from dns.exception import AlgorithmKeyMismatch
|
||||
|
||||
|
||||
class CryptographyPublicKey(GenericPublicKey):
|
||||
key: Any = None
|
||||
key_cls: Any = None
|
||||
|
||||
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
|
||||
if self.key_cls is None:
|
||||
raise TypeError("Undefined private key class")
|
||||
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
|
||||
key, self.key_cls
|
||||
):
|
||||
raise AlgorithmKeyMismatch
|
||||
self.key = key
|
||||
|
||||
@classmethod
|
||||
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
|
||||
key = serialization.load_pem_public_key(public_pem)
|
||||
return cls(key=key)
|
||||
|
||||
def to_pem(self) -> bytes:
|
||||
return self.key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
|
||||
class CryptographyPrivateKey(GenericPrivateKey):
|
||||
key: Any = None
|
||||
key_cls: Any = None
|
||||
public_cls: Type[CryptographyPublicKey] # pyright: ignore
|
||||
|
||||
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
|
||||
if self.key_cls is None:
|
||||
raise TypeError("Undefined private key class")
|
||||
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
|
||||
key, self.key_cls
|
||||
):
|
||||
raise AlgorithmKeyMismatch
|
||||
self.key = key
|
||||
|
||||
def public_key(self) -> "CryptographyPublicKey":
|
||||
return self.public_cls(key=self.key.public_key())
|
||||
|
||||
@classmethod
|
||||
def from_pem(
|
||||
cls, private_pem: bytes, password: bytes | None = None
|
||||
) -> "GenericPrivateKey":
|
||||
key = serialization.load_pem_private_key(private_pem, password=password)
|
||||
return cls(key=key)
|
||||
|
||||
def to_pem(self, password: bytes | None = None) -> bytes:
|
||||
encryption_algorithm: serialization.KeySerializationEncryption
|
||||
if password:
|
||||
encryption_algorithm = serialization.BestAvailableEncryption(password)
|
||||
else:
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
return self.key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
||||
108
.venv/Lib/site-packages/dns/dnssecalgs/dsa.py
Normal file
108
.venv/Lib/site-packages/dns/dnssecalgs/dsa.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import struct
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa, utils
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicDSA(CryptographyPublicKey):
|
||||
key: dsa.DSAPublicKey
|
||||
key_cls = dsa.DSAPublicKey
|
||||
algorithm = Algorithm.DSA
|
||||
chosen_hash = hashes.SHA1()
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
sig_r = signature[1:21]
|
||||
sig_s = signature[21:]
|
||||
sig = utils.encode_dss_signature(
|
||||
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
|
||||
)
|
||||
self.key.verify(sig, data, self.chosen_hash)
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 2536, section 2."""
|
||||
pn = self.key.public_numbers()
|
||||
dsa_t = (self.key.key_size // 8 - 64) // 8
|
||||
if dsa_t > 8:
|
||||
raise ValueError("unsupported DSA key size")
|
||||
octets = 64 + dsa_t * 8
|
||||
res = struct.pack("!B", dsa_t)
|
||||
res += pn.parameter_numbers.q.to_bytes(20, "big")
|
||||
res += pn.parameter_numbers.p.to_bytes(octets, "big")
|
||||
res += pn.parameter_numbers.g.to_bytes(octets, "big")
|
||||
res += pn.y.to_bytes(octets, "big")
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
keyptr = key.key
|
||||
(t,) = struct.unpack("!B", keyptr[0:1])
|
||||
keyptr = keyptr[1:]
|
||||
octets = 64 + t * 8
|
||||
dsa_q = keyptr[0:20]
|
||||
keyptr = keyptr[20:]
|
||||
dsa_p = keyptr[0:octets]
|
||||
keyptr = keyptr[octets:]
|
||||
dsa_g = keyptr[0:octets]
|
||||
keyptr = keyptr[octets:]
|
||||
dsa_y = keyptr[0:octets]
|
||||
return cls(
|
||||
key=dsa.DSAPublicNumbers( # type: ignore
|
||||
int.from_bytes(dsa_y, "big"),
|
||||
dsa.DSAParameterNumbers(
|
||||
int.from_bytes(dsa_p, "big"),
|
||||
int.from_bytes(dsa_q, "big"),
|
||||
int.from_bytes(dsa_g, "big"),
|
||||
),
|
||||
).public_key(default_backend()),
|
||||
)
|
||||
|
||||
|
||||
class PrivateDSA(CryptographyPrivateKey):
|
||||
key: dsa.DSAPrivateKey
|
||||
key_cls = dsa.DSAPrivateKey
|
||||
public_cls = PublicDSA
|
||||
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 2536, section 3."""
|
||||
public_dsa_key = self.key.public_key()
|
||||
if public_dsa_key.key_size > 1024:
|
||||
raise ValueError("DSA key size overflow")
|
||||
der_signature = self.key.sign(
|
||||
data, self.public_cls.chosen_hash # pyright: ignore
|
||||
)
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
|
||||
octets = 20
|
||||
signature = (
|
||||
struct.pack("!B", dsa_t)
|
||||
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
|
||||
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
|
||||
)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls, key_size: int) -> "PrivateDSA":
|
||||
return cls(
|
||||
key=dsa.generate_private_key(key_size=key_size),
|
||||
)
|
||||
|
||||
|
||||
class PublicDSANSEC3SHA1(PublicDSA):
|
||||
algorithm = Algorithm.DSANSEC3SHA1
|
||||
|
||||
|
||||
class PrivateDSANSEC3SHA1(PrivateDSA):
|
||||
public_cls = PublicDSANSEC3SHA1
|
||||
100
.venv/Lib/site-packages/dns/dnssecalgs/ecdsa.py
Normal file
100
.venv/Lib/site-packages/dns/dnssecalgs/ecdsa.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, utils
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicECDSA(CryptographyPublicKey):
|
||||
key: ec.EllipticCurvePublicKey
|
||||
key_cls = ec.EllipticCurvePublicKey
|
||||
algorithm: Algorithm
|
||||
chosen_hash: hashes.HashAlgorithm
|
||||
curve: ec.EllipticCurve
|
||||
octets: int
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
sig_r = signature[0 : self.octets]
|
||||
sig_s = signature[self.octets :]
|
||||
sig = utils.encode_dss_signature(
|
||||
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
|
||||
)
|
||||
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 6605, section 4."""
|
||||
pn = self.key.public_numbers()
|
||||
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
ecdsa_x = key.key[0 : cls.octets]
|
||||
ecdsa_y = key.key[cls.octets : cls.octets * 2]
|
||||
return cls(
|
||||
key=ec.EllipticCurvePublicNumbers(
|
||||
curve=cls.curve,
|
||||
x=int.from_bytes(ecdsa_x, "big"),
|
||||
y=int.from_bytes(ecdsa_y, "big"),
|
||||
).public_key(default_backend()),
|
||||
)
|
||||
|
||||
|
||||
class PrivateECDSA(CryptographyPrivateKey):
|
||||
key: ec.EllipticCurvePrivateKey
|
||||
key_cls = ec.EllipticCurvePrivateKey
|
||||
public_cls = PublicECDSA
|
||||
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 6605, section 4."""
|
||||
algorithm = ec.ECDSA(
|
||||
self.public_cls.chosen_hash, # pyright: ignore
|
||||
deterministic_signing=deterministic,
|
||||
)
|
||||
der_signature = self.key.sign(data, algorithm)
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
signature = int.to_bytes(
|
||||
dsa_r, length=self.public_cls.octets, byteorder="big" # pyright: ignore
|
||||
) + int.to_bytes(
|
||||
dsa_s, length=self.public_cls.octets, byteorder="big" # pyright: ignore
|
||||
)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PrivateECDSA":
|
||||
return cls(
|
||||
key=ec.generate_private_key(
|
||||
curve=cls.public_cls.curve, backend=default_backend() # pyright: ignore
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PublicECDSAP256SHA256(PublicECDSA):
|
||||
algorithm = Algorithm.ECDSAP256SHA256
|
||||
chosen_hash = hashes.SHA256()
|
||||
curve = ec.SECP256R1()
|
||||
octets = 32
|
||||
|
||||
|
||||
class PrivateECDSAP256SHA256(PrivateECDSA):
|
||||
public_cls = PublicECDSAP256SHA256
|
||||
|
||||
|
||||
class PublicECDSAP384SHA384(PublicECDSA):
|
||||
algorithm = Algorithm.ECDSAP384SHA384
|
||||
chosen_hash = hashes.SHA384()
|
||||
curve = ec.SECP384R1()
|
||||
octets = 48
|
||||
|
||||
|
||||
class PrivateECDSAP384SHA384(PrivateECDSA):
|
||||
public_cls = PublicECDSAP384SHA384
|
||||
70
.venv/Lib/site-packages/dns/dnssecalgs/eddsa.py
Normal file
70
.venv/Lib/site-packages/dns/dnssecalgs/eddsa.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Type
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicEDDSA(CryptographyPublicKey):
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
self.key.verify(signature, data)
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 8080, section 3."""
|
||||
return self.key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
return cls(
|
||||
key=cls.key_cls.from_public_bytes(key.key),
|
||||
)
|
||||
|
||||
|
||||
class PrivateEDDSA(CryptographyPrivateKey):
|
||||
public_cls: Type[PublicEDDSA] # pyright: ignore
|
||||
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 8080, section 4."""
|
||||
signature = self.key.sign(data)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PrivateEDDSA":
|
||||
return cls(key=cls.key_cls.generate())
|
||||
|
||||
|
||||
class PublicED25519(PublicEDDSA):
|
||||
key: ed25519.Ed25519PublicKey
|
||||
key_cls = ed25519.Ed25519PublicKey
|
||||
algorithm = Algorithm.ED25519
|
||||
|
||||
|
||||
class PrivateED25519(PrivateEDDSA):
|
||||
key: ed25519.Ed25519PrivateKey
|
||||
key_cls = ed25519.Ed25519PrivateKey
|
||||
public_cls = PublicED25519
|
||||
|
||||
|
||||
class PublicED448(PublicEDDSA):
|
||||
key: ed448.Ed448PublicKey
|
||||
key_cls = ed448.Ed448PublicKey
|
||||
algorithm = Algorithm.ED448
|
||||
|
||||
|
||||
class PrivateED448(PrivateEDDSA):
|
||||
key: ed448.Ed448PrivateKey
|
||||
key_cls = ed448.Ed448PrivateKey
|
||||
public_cls = PublicED448
|
||||
126
.venv/Lib/site-packages/dns/dnssecalgs/rsa.py
Normal file
126
.venv/Lib/site-packages/dns/dnssecalgs/rsa.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import math
|
||||
import struct
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicRSA(CryptographyPublicKey):
|
||||
key: rsa.RSAPublicKey
|
||||
key_cls = rsa.RSAPublicKey
|
||||
algorithm: Algorithm
|
||||
chosen_hash: hashes.HashAlgorithm
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 3110, section 2."""
|
||||
pn = self.key.public_numbers()
|
||||
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
|
||||
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
|
||||
if _exp_len > 255:
|
||||
exp_header = b"\0" + struct.pack("!H", _exp_len)
|
||||
else:
|
||||
exp_header = struct.pack("!B", _exp_len)
|
||||
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
|
||||
raise ValueError("unsupported RSA key length")
|
||||
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
keyptr = key.key
|
||||
(bytes_,) = struct.unpack("!B", keyptr[0:1])
|
||||
keyptr = keyptr[1:]
|
||||
if bytes_ == 0:
|
||||
(bytes_,) = struct.unpack("!H", keyptr[0:2])
|
||||
keyptr = keyptr[2:]
|
||||
rsa_e = keyptr[0:bytes_]
|
||||
rsa_n = keyptr[bytes_:]
|
||||
return cls(
|
||||
key=rsa.RSAPublicNumbers(
|
||||
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
|
||||
).public_key(default_backend())
|
||||
)
|
||||
|
||||
|
||||
class PrivateRSA(CryptographyPrivateKey):
|
||||
key: rsa.RSAPrivateKey
|
||||
key_cls = rsa.RSAPrivateKey
|
||||
public_cls = PublicRSA
|
||||
default_public_exponent = 65537
|
||||
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 3110, section 3."""
|
||||
signature = self.key.sign(
|
||||
data, padding.PKCS1v15(), self.public_cls.chosen_hash # pyright: ignore
|
||||
)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls, key_size: int) -> "PrivateRSA":
|
||||
return cls(
|
||||
key=rsa.generate_private_key(
|
||||
public_exponent=cls.default_public_exponent,
|
||||
key_size=key_size,
|
||||
backend=default_backend(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PublicRSAMD5(PublicRSA):
|
||||
algorithm = Algorithm.RSAMD5
|
||||
chosen_hash = hashes.MD5()
|
||||
|
||||
|
||||
class PrivateRSAMD5(PrivateRSA):
|
||||
public_cls = PublicRSAMD5
|
||||
|
||||
|
||||
class PublicRSASHA1(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA1
|
||||
chosen_hash = hashes.SHA1()
|
||||
|
||||
|
||||
class PrivateRSASHA1(PrivateRSA):
|
||||
public_cls = PublicRSASHA1
|
||||
|
||||
|
||||
class PublicRSASHA1NSEC3SHA1(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA1NSEC3SHA1
|
||||
chosen_hash = hashes.SHA1()
|
||||
|
||||
|
||||
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
|
||||
public_cls = PublicRSASHA1NSEC3SHA1
|
||||
|
||||
|
||||
class PublicRSASHA256(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA256
|
||||
chosen_hash = hashes.SHA256()
|
||||
|
||||
|
||||
class PrivateRSASHA256(PrivateRSA):
|
||||
public_cls = PublicRSASHA256
|
||||
|
||||
|
||||
class PublicRSASHA512(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA512
|
||||
chosen_hash = hashes.SHA512()
|
||||
|
||||
|
||||
class PrivateRSASHA512(PrivateRSA):
|
||||
public_cls = PublicRSASHA512
|
||||
71
.venv/Lib/site-packages/dns/dnssectypes.py
Normal file
71
.venv/Lib/site-packages/dns/dnssectypes.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""Common DNSSEC-related types."""
|
||||
|
||||
# This is a separate file to avoid import circularity between dns.dnssec and
|
||||
# the implementations of the DS and DNSKEY types.
|
||||
|
||||
import dns.enum
|
||||
|
||||
|
||||
class Algorithm(dns.enum.IntEnum):
|
||||
RSAMD5 = 1
|
||||
DH = 2
|
||||
DSA = 3
|
||||
ECC = 4
|
||||
RSASHA1 = 5
|
||||
DSANSEC3SHA1 = 6
|
||||
RSASHA1NSEC3SHA1 = 7
|
||||
RSASHA256 = 8
|
||||
RSASHA512 = 10
|
||||
ECCGOST = 12
|
||||
ECDSAP256SHA256 = 13
|
||||
ECDSAP384SHA384 = 14
|
||||
ED25519 = 15
|
||||
ED448 = 16
|
||||
INDIRECT = 252
|
||||
PRIVATEDNS = 253
|
||||
PRIVATEOID = 254
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 255
|
||||
|
||||
|
||||
class DSDigest(dns.enum.IntEnum):
|
||||
"""DNSSEC Delegation Signer Digest Algorithm"""
|
||||
|
||||
NULL = 0
|
||||
SHA1 = 1
|
||||
SHA256 = 2
|
||||
GOST = 3
|
||||
SHA384 = 4
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 255
|
||||
|
||||
|
||||
class NSEC3Hash(dns.enum.IntEnum):
|
||||
"""NSEC3 hash algorithm"""
|
||||
|
||||
SHA1 = 1
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 255
|
||||
116
.venv/Lib/site-packages/dns/e164.py
Normal file
116
.venv/Lib/site-packages/dns/e164.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2006-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""DNS E.164 helpers."""
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import dns.exception
|
||||
import dns.name
|
||||
import dns.resolver
|
||||
|
||||
#: The public E.164 domain.
|
||||
public_enum_domain = dns.name.from_text("e164.arpa.")
|
||||
|
||||
|
||||
def from_e164(
|
||||
text: str, origin: dns.name.Name | None = public_enum_domain
|
||||
) -> dns.name.Name:
|
||||
"""Convert an E.164 number in textual form into a Name object whose
|
||||
value is the ENUM domain name for that number.
|
||||
|
||||
Non-digits in the text are ignored, i.e. "16505551212",
|
||||
"+1.650.555.1212" and "1 (650) 555-1212" are all the same.
|
||||
|
||||
*text*, a ``str``, is an E.164 number in textual form.
|
||||
|
||||
*origin*, a ``dns.name.Name``, the domain in which the number
|
||||
should be constructed. The default is ``e164.arpa.``.
|
||||
|
||||
Returns a ``dns.name.Name``.
|
||||
"""
|
||||
|
||||
parts = [d for d in text if d.isdigit()]
|
||||
parts.reverse()
|
||||
return dns.name.from_text(".".join(parts), origin=origin)
|
||||
|
||||
|
||||
def to_e164(
|
||||
name: dns.name.Name,
|
||||
origin: dns.name.Name | None = public_enum_domain,
|
||||
want_plus_prefix: bool = True,
|
||||
) -> str:
|
||||
"""Convert an ENUM domain name into an E.164 number.
|
||||
|
||||
Note that dnspython does not have any information about preferred
|
||||
number formats within national numbering plans, so all numbers are
|
||||
emitted as a simple string of digits, prefixed by a '+' (unless
|
||||
*want_plus_prefix* is ``False``).
|
||||
|
||||
*name* is a ``dns.name.Name``, the ENUM domain name.
|
||||
|
||||
*origin* is a ``dns.name.Name``, a domain containing the ENUM
|
||||
domain name. The name is relativized to this domain before being
|
||||
converted to text. If ``None``, no relativization is done.
|
||||
|
||||
*want_plus_prefix* is a ``bool``. If True, add a '+' to the beginning of
|
||||
the returned number.
|
||||
|
||||
Returns a ``str``.
|
||||
|
||||
"""
|
||||
if origin is not None:
|
||||
name = name.relativize(origin)
|
||||
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
|
||||
if len(dlabels) != len(name.labels):
|
||||
raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
|
||||
dlabels.reverse()
|
||||
text = b"".join(dlabels)
|
||||
if want_plus_prefix:
|
||||
text = b"+" + text
|
||||
return text.decode()
|
||||
|
||||
|
||||
def query(
|
||||
number: str,
|
||||
domains: Iterable[dns.name.Name | str],
|
||||
resolver: dns.resolver.Resolver | None = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Look for NAPTR RRs for the specified number in the specified domains.
|
||||
|
||||
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
|
||||
|
||||
*number*, a ``str`` is the number to look for.
|
||||
|
||||
*domains* is an iterable containing ``dns.name.Name`` values.
|
||||
|
||||
*resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If
|
||||
``None``, the default resolver is used.
|
||||
"""
|
||||
|
||||
if resolver is None:
|
||||
resolver = dns.resolver.get_default_resolver()
|
||||
e_nx = dns.resolver.NXDOMAIN()
|
||||
for domain in domains:
|
||||
if isinstance(domain, str):
|
||||
domain = dns.name.from_text(domain)
|
||||
qname = from_e164(number, domain)
|
||||
try:
|
||||
return resolver.resolve(qname, "NAPTR")
|
||||
except dns.resolver.NXDOMAIN as e:
|
||||
e_nx += e
|
||||
raise e_nx
|
||||
591
.venv/Lib/site-packages/dns/edns.py
Normal file
591
.venv/Lib/site-packages/dns/edns.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2009-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""EDNS Options"""
|
||||
|
||||
import binascii
|
||||
import math
|
||||
import socket
|
||||
import struct
|
||||
from typing import Any, Dict
|
||||
|
||||
import dns.enum
|
||||
import dns.inet
|
||||
import dns.ipv4
|
||||
import dns.ipv6
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.wire
|
||||
|
||||
|
||||
class OptionType(dns.enum.IntEnum):
|
||||
"""EDNS option type codes"""
|
||||
|
||||
#: NSID
|
||||
NSID = 3
|
||||
#: DAU
|
||||
DAU = 5
|
||||
#: DHU
|
||||
DHU = 6
|
||||
#: N3U
|
||||
N3U = 7
|
||||
#: ECS (client-subnet)
|
||||
ECS = 8
|
||||
#: EXPIRE
|
||||
EXPIRE = 9
|
||||
#: COOKIE
|
||||
COOKIE = 10
|
||||
#: KEEPALIVE
|
||||
KEEPALIVE = 11
|
||||
#: PADDING
|
||||
PADDING = 12
|
||||
#: CHAIN
|
||||
CHAIN = 13
|
||||
#: EDE (extended-dns-error)
|
||||
EDE = 15
|
||||
#: REPORTCHANNEL
|
||||
REPORTCHANNEL = 18
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 65535
|
||||
|
||||
|
||||
class Option:
|
||||
"""Base class for all EDNS option types."""
|
||||
|
||||
def __init__(self, otype: OptionType | str):
|
||||
"""Initialize an option.
|
||||
|
||||
*otype*, a ``dns.edns.OptionType``, is the option type.
|
||||
"""
|
||||
self.otype = OptionType.make(otype)
|
||||
|
||||
def to_wire(self, file: Any | None = None) -> bytes | None:
|
||||
"""Convert an option to wire format.
|
||||
|
||||
Returns a ``bytes`` or ``None``.
|
||||
|
||||
"""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def to_text(self) -> str:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def to_generic(self) -> "GenericOption":
|
||||
"""Creates a dns.edns.GenericOption equivalent of this rdata.
|
||||
|
||||
Returns a ``dns.edns.GenericOption``.
|
||||
"""
|
||||
wire = self.to_wire()
|
||||
assert wire is not None # for mypy
|
||||
return GenericOption(self.otype, wire)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
||||
*otype*, a ``dns.edns.OptionType``, is the option type.
|
||||
|
||||
*parser*, a ``dns.wire.Parser``, the parser, which should be
|
||||
restructed to the option length.
|
||||
|
||||
Returns a ``dns.edns.Option``.
|
||||
"""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def _cmp(self, other):
|
||||
"""Compare an EDNS option with another option of the same type.
|
||||
|
||||
Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*.
|
||||
"""
|
||||
wire = self.to_wire()
|
||||
owire = other.to_wire()
|
||||
if wire == owire:
|
||||
return 0
|
||||
if wire > owire:
|
||||
return 1
|
||||
return -1
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Option):
|
||||
return False
|
||||
if self.otype != other.otype:
|
||||
return False
|
||||
return self._cmp(other) == 0
|
||||
|
||||
def __ne__(self, other):
|
||||
if not isinstance(other, Option):
|
||||
return True
|
||||
if self.otype != other.otype:
|
||||
return True
|
||||
return self._cmp(other) != 0
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) < 0
|
||||
|
||||
def __le__(self, other):
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) <= 0
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) >= 0
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, Option) or self.otype != other.otype:
|
||||
return NotImplemented
|
||||
return self._cmp(other) > 0
|
||||
|
||||
def __str__(self):
|
||||
return self.to_text()
|
||||
|
||||
|
||||
class GenericOption(Option): # lgtm[py/missing-equals]
|
||||
"""Generic Option Class
|
||||
|
||||
This class is used for EDNS option types for which we have no better
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, otype: OptionType | str, data: bytes | str):
|
||||
super().__init__(otype)
|
||||
self.data = dns.rdata.Rdata._as_bytes(data, True)
|
||||
|
||||
def to_wire(self, file: Any | None = None) -> bytes | None:
|
||||
if file:
|
||||
file.write(self.data)
|
||||
return None
|
||||
else:
|
||||
return self.data
|
||||
|
||||
def to_text(self) -> str:
|
||||
return f"Generic {self.otype}"
|
||||
|
||||
def to_generic(self) -> "GenericOption":
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: OptionType | str, parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
return cls(otype, parser.get_remaining())
|
||||
|
||||
|
||||
class ECSOption(Option): # lgtm[py/missing-equals]
|
||||
"""EDNS Client Subnet (ECS, RFC7871)"""
|
||||
|
||||
def __init__(self, address: str, srclen: int | None = None, scopelen: int = 0):
|
||||
"""*address*, a ``str``, is the client address information.
|
||||
|
||||
*srclen*, an ``int``, the source prefix length, which is the
|
||||
leftmost number of bits of the address to be used for the
|
||||
lookup. The default is 24 for IPv4 and 56 for IPv6.
|
||||
|
||||
*scopelen*, an ``int``, the scope prefix length. This value
|
||||
must be 0 in queries, and should be set in responses.
|
||||
"""
|
||||
|
||||
super().__init__(OptionType.ECS)
|
||||
af = dns.inet.af_for_address(address)
|
||||
|
||||
if af == socket.AF_INET6:
|
||||
self.family = 2
|
||||
if srclen is None:
|
||||
srclen = 56
|
||||
address = dns.rdata.Rdata._as_ipv6_address(address)
|
||||
srclen = dns.rdata.Rdata._as_int(srclen, 0, 128)
|
||||
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 128)
|
||||
elif af == socket.AF_INET:
|
||||
self.family = 1
|
||||
if srclen is None:
|
||||
srclen = 24
|
||||
address = dns.rdata.Rdata._as_ipv4_address(address)
|
||||
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
|
||||
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
|
||||
else: # pragma: no cover (this will never happen)
|
||||
raise ValueError("Bad address family")
|
||||
|
||||
assert srclen is not None
|
||||
self.address = address
|
||||
self.srclen = srclen
|
||||
self.scopelen = scopelen
|
||||
|
||||
addrdata = dns.inet.inet_pton(af, address)
|
||||
nbytes = int(math.ceil(srclen / 8.0))
|
||||
|
||||
# Truncate to srclen and pad to the end of the last octet needed
|
||||
# See RFC section 6
|
||||
self.addrdata = addrdata[:nbytes]
|
||||
nbits = srclen % 8
|
||||
if nbits != 0:
|
||||
last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
|
||||
self.addrdata = self.addrdata[:-1] + last
|
||||
|
||||
def to_text(self) -> str:
|
||||
return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
|
||||
|
||||
@staticmethod
|
||||
def from_text(text: str) -> Option:
|
||||
"""Convert a string into a `dns.edns.ECSOption`
|
||||
|
||||
*text*, a `str`, the text form of the option.
|
||||
|
||||
Returns a `dns.edns.ECSOption`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> import dns.edns
|
||||
>>>
|
||||
>>> # basic example
|
||||
>>> dns.edns.ECSOption.from_text('1.2.3.4/24')
|
||||
>>>
|
||||
>>> # also understands scope
|
||||
>>> dns.edns.ECSOption.from_text('1.2.3.4/24/32')
|
||||
>>>
|
||||
>>> # IPv6
|
||||
>>> dns.edns.ECSOption.from_text('2001:4b98::1/64/64')
|
||||
>>>
|
||||
>>> # it understands results from `dns.edns.ECSOption.to_text()`
|
||||
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
|
||||
"""
|
||||
optional_prefix = "ECS"
|
||||
tokens = text.split()
|
||||
ecs_text = None
|
||||
if len(tokens) == 1:
|
||||
ecs_text = tokens[0]
|
||||
elif len(tokens) == 2:
|
||||
if tokens[0] != optional_prefix:
|
||||
raise ValueError(f'could not parse ECS from "{text}"')
|
||||
ecs_text = tokens[1]
|
||||
else:
|
||||
raise ValueError(f'could not parse ECS from "{text}"')
|
||||
n_slashes = ecs_text.count("/")
|
||||
if n_slashes == 1:
|
||||
address, tsrclen = ecs_text.split("/")
|
||||
tscope = "0"
|
||||
elif n_slashes == 2:
|
||||
address, tsrclen, tscope = ecs_text.split("/")
|
||||
else:
|
||||
raise ValueError(f'could not parse ECS from "{text}"')
|
||||
try:
|
||||
scope = int(tscope)
|
||||
except ValueError:
|
||||
raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
|
||||
try:
|
||||
srclen = int(tsrclen)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"invalid srclen " + f'"{tsrclen}": srclen must be an integer'
|
||||
)
|
||||
return ECSOption(address, srclen, scope)
|
||||
|
||||
def to_wire(self, file: Any | None = None) -> bytes | None:
|
||||
value = (
|
||||
struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
|
||||
)
|
||||
if file:
|
||||
file.write(value)
|
||||
return None
|
||||
else:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: OptionType | str, parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
family, src, scope = parser.get_struct("!HBB")
|
||||
addrlen = int(math.ceil(src / 8.0))
|
||||
prefix = parser.get_bytes(addrlen)
|
||||
if family == 1:
|
||||
pad = 4 - addrlen
|
||||
addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
|
||||
elif family == 2:
|
||||
pad = 16 - addrlen
|
||||
addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
|
||||
else:
|
||||
raise ValueError("unsupported family")
|
||||
|
||||
return cls(addr, src, scope)
|
||||
|
||||
|
||||
class EDECode(dns.enum.IntEnum):
|
||||
"""Extended DNS Error (EDE) codes"""
|
||||
|
||||
OTHER = 0
|
||||
UNSUPPORTED_DNSKEY_ALGORITHM = 1
|
||||
UNSUPPORTED_DS_DIGEST_TYPE = 2
|
||||
STALE_ANSWER = 3
|
||||
FORGED_ANSWER = 4
|
||||
DNSSEC_INDETERMINATE = 5
|
||||
DNSSEC_BOGUS = 6
|
||||
SIGNATURE_EXPIRED = 7
|
||||
SIGNATURE_NOT_YET_VALID = 8
|
||||
DNSKEY_MISSING = 9
|
||||
RRSIGS_MISSING = 10
|
||||
NO_ZONE_KEY_BIT_SET = 11
|
||||
NSEC_MISSING = 12
|
||||
CACHED_ERROR = 13
|
||||
NOT_READY = 14
|
||||
BLOCKED = 15
|
||||
CENSORED = 16
|
||||
FILTERED = 17
|
||||
PROHIBITED = 18
|
||||
STALE_NXDOMAIN_ANSWER = 19
|
||||
NOT_AUTHORITATIVE = 20
|
||||
NOT_SUPPORTED = 21
|
||||
NO_REACHABLE_AUTHORITY = 22
|
||||
NETWORK_ERROR = 23
|
||||
INVALID_DATA = 24
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return 65535
|
||||
|
||||
|
||||
class EDEOption(Option): # lgtm[py/missing-equals]
|
||||
"""Extended DNS Error (EDE, RFC8914)"""
|
||||
|
||||
_preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
|
||||
|
||||
def __init__(self, code: EDECode | str, text: str | None = None):
|
||||
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
|
||||
extended error.
|
||||
|
||||
*text*, a ``str`` or ``None``, specifying additional information about
|
||||
the error.
|
||||
"""
|
||||
|
||||
super().__init__(OptionType.EDE)
|
||||
|
||||
self.code = EDECode.make(code)
|
||||
if text is not None and not isinstance(text, str):
|
||||
raise ValueError("text must be string or None")
|
||||
self.text = text
|
||||
|
||||
def to_text(self) -> str:
|
||||
output = f"EDE {self.code}"
|
||||
if self.code in EDECode:
|
||||
desc = EDECode.to_text(self.code)
|
||||
desc = " ".join(
|
||||
word if word in self._preserve_case else word.title()
|
||||
for word in desc.split("_")
|
||||
)
|
||||
output += f" ({desc})"
|
||||
if self.text is not None:
|
||||
output += f": {self.text}"
|
||||
return output
|
||||
|
||||
def to_wire(self, file: Any | None = None) -> bytes | None:
|
||||
value = struct.pack("!H", self.code)
|
||||
if self.text is not None:
|
||||
value += self.text.encode("utf8")
|
||||
|
||||
if file:
|
||||
file.write(value)
|
||||
return None
|
||||
else:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: OptionType | str, parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
code = EDECode.make(parser.get_uint16())
|
||||
text = parser.get_remaining()
|
||||
|
||||
if text:
|
||||
if text[-1] == 0: # text MAY be null-terminated
|
||||
text = text[:-1]
|
||||
btext = text.decode("utf8")
|
||||
else:
|
||||
btext = None
|
||||
|
||||
return cls(code, btext)
|
||||
|
||||
|
||||
class NSIDOption(Option):
|
||||
def __init__(self, nsid: bytes):
|
||||
super().__init__(OptionType.NSID)
|
||||
self.nsid = nsid
|
||||
|
||||
def to_wire(self, file: Any = None) -> bytes | None:
|
||||
if file:
|
||||
file.write(self.nsid)
|
||||
return None
|
||||
else:
|
||||
return self.nsid
|
||||
|
||||
def to_text(self) -> str:
|
||||
if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
|
||||
# All ASCII printable, so it's probably a string.
|
||||
value = self.nsid.decode()
|
||||
else:
|
||||
value = binascii.hexlify(self.nsid).decode()
|
||||
return f"NSID {value}"
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: OptionType | str, parser: dns.wire.Parser
|
||||
) -> Option:
|
||||
return cls(parser.get_remaining())
|
||||
|
||||
|
||||
class CookieOption(Option):
|
||||
def __init__(self, client: bytes, server: bytes):
|
||||
super().__init__(OptionType.COOKIE)
|
||||
self.client = client
|
||||
self.server = server
|
||||
if len(client) != 8:
|
||||
raise ValueError("client cookie must be 8 bytes")
|
||||
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
|
||||
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
|
||||
|
||||
def to_wire(self, file: Any = None) -> bytes | None:
|
||||
if file:
|
||||
file.write(self.client)
|
||||
if len(self.server) > 0:
|
||||
file.write(self.server)
|
||||
return None
|
||||
else:
|
||||
return self.client + self.server
|
||||
|
||||
def to_text(self) -> str:
|
||||
client = binascii.hexlify(self.client).decode()
|
||||
if len(self.server) > 0:
|
||||
server = binascii.hexlify(self.server).decode()
|
||||
else:
|
||||
server = ""
|
||||
return f"COOKIE {client}{server}"
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: OptionType | str, parser: dns.wire.Parser
|
||||
) -> Option:
|
||||
return cls(parser.get_bytes(8), parser.get_remaining())
|
||||
|
||||
|
||||
class ReportChannelOption(Option):
|
||||
# RFC 9567
|
||||
def __init__(self, agent_domain: dns.name.Name):
|
||||
super().__init__(OptionType.REPORTCHANNEL)
|
||||
self.agent_domain = agent_domain
|
||||
|
||||
def to_wire(self, file: Any = None) -> bytes | None:
|
||||
return self.agent_domain.to_wire(file)
|
||||
|
||||
def to_text(self) -> str:
|
||||
return "REPORTCHANNEL " + self.agent_domain.to_text()
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: OptionType | str, parser: dns.wire.Parser
|
||||
) -> Option:
|
||||
return cls(parser.get_name())
|
||||
|
||||
|
||||
_type_to_class: Dict[OptionType, Any] = {
|
||||
OptionType.ECS: ECSOption,
|
||||
OptionType.EDE: EDEOption,
|
||||
OptionType.NSID: NSIDOption,
|
||||
OptionType.COOKIE: CookieOption,
|
||||
OptionType.REPORTCHANNEL: ReportChannelOption,
|
||||
}
|
||||
|
||||
|
||||
def get_option_class(otype: OptionType) -> Any:
|
||||
"""Return the class for the specified option type.
|
||||
|
||||
The GenericOption class is used if a more specific class is not
|
||||
known.
|
||||
"""
|
||||
|
||||
cls = _type_to_class.get(otype)
|
||||
if cls is None:
|
||||
cls = GenericOption
|
||||
return cls
|
||||
|
||||
|
||||
def option_from_wire_parser(
|
||||
otype: OptionType | str, parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
|
||||
*parser*, a ``dns.wire.Parser``, the parser, which should be
|
||||
restricted to the option length.
|
||||
|
||||
Returns an instance of a subclass of ``dns.edns.Option``.
|
||||
"""
|
||||
otype = OptionType.make(otype)
|
||||
cls = get_option_class(otype)
|
||||
return cls.from_wire_parser(otype, parser)
|
||||
|
||||
|
||||
def option_from_wire(
|
||||
otype: OptionType | str, wire: bytes, current: int, olen: int
|
||||
) -> Option:
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
|
||||
*wire*, a ``bytes``, is the wire-format message.
|
||||
|
||||
*current*, an ``int``, is the offset in *wire* of the beginning
|
||||
of the rdata.
|
||||
|
||||
*olen*, an ``int``, is the length of the wire-format option data
|
||||
|
||||
Returns an instance of a subclass of ``dns.edns.Option``.
|
||||
"""
|
||||
parser = dns.wire.Parser(wire, current)
|
||||
with parser.restrict_to(olen):
|
||||
return option_from_wire_parser(otype, parser)
|
||||
|
||||
|
||||
def register_type(implementation: Any, otype: OptionType) -> None:
|
||||
"""Register the implementation of an option type.
|
||||
|
||||
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
|
||||
|
||||
*otype*, an ``int``, is the option type.
|
||||
"""
|
||||
|
||||
_type_to_class[otype] = implementation
|
||||
|
||||
|
||||
### BEGIN generated OptionType constants
|
||||
|
||||
NSID = OptionType.NSID
|
||||
DAU = OptionType.DAU
|
||||
DHU = OptionType.DHU
|
||||
N3U = OptionType.N3U
|
||||
ECS = OptionType.ECS
|
||||
EXPIRE = OptionType.EXPIRE
|
||||
COOKIE = OptionType.COOKIE
|
||||
KEEPALIVE = OptionType.KEEPALIVE
|
||||
PADDING = OptionType.PADDING
|
||||
CHAIN = OptionType.CHAIN
|
||||
EDE = OptionType.EDE
|
||||
REPORTCHANNEL = OptionType.REPORTCHANNEL
|
||||
|
||||
### END generated OptionType constants
|
||||
130
.venv/Lib/site-packages/dns/entropy.py
Normal file
130
.venv/Lib/site-packages/dns/entropy.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2009-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
|
||||
class EntropyPool:
|
||||
# This is an entropy pool for Python implementations that do not
|
||||
# have a working SystemRandom. I'm not sure there are any, but
|
||||
# leaving this code doesn't hurt anything as the library code
|
||||
# is used if present.
|
||||
|
||||
def __init__(self, seed: bytes | None = None):
|
||||
self.pool_index = 0
|
||||
self.digest: bytearray | None = None
|
||||
self.next_byte = 0
|
||||
self.lock = threading.Lock()
|
||||
self.hash = hashlib.sha1()
|
||||
self.hash_len = 20
|
||||
self.pool = bytearray(b"\0" * self.hash_len)
|
||||
if seed is not None:
|
||||
self._stir(seed)
|
||||
self.seeded = True
|
||||
self.seed_pid = os.getpid()
|
||||
else:
|
||||
self.seeded = False
|
||||
self.seed_pid = 0
|
||||
|
||||
def _stir(self, entropy: bytes | bytearray) -> None:
|
||||
for c in entropy:
|
||||
if self.pool_index == self.hash_len:
|
||||
self.pool_index = 0
|
||||
b = c & 0xFF
|
||||
self.pool[self.pool_index] ^= b
|
||||
self.pool_index += 1
|
||||
|
||||
def stir(self, entropy: bytes | bytearray) -> None:
|
||||
with self.lock:
|
||||
self._stir(entropy)
|
||||
|
||||
def _maybe_seed(self) -> None:
|
||||
if not self.seeded or self.seed_pid != os.getpid():
|
||||
try:
|
||||
seed = os.urandom(16)
|
||||
except Exception: # pragma: no cover
|
||||
try:
|
||||
with open("/dev/urandom", "rb", 0) as r:
|
||||
seed = r.read(16)
|
||||
except Exception:
|
||||
seed = str(time.time()).encode()
|
||||
self.seeded = True
|
||||
self.seed_pid = os.getpid()
|
||||
self.digest = None
|
||||
seed = bytearray(seed)
|
||||
self._stir(seed)
|
||||
|
||||
def random_8(self) -> int:
|
||||
with self.lock:
|
||||
self._maybe_seed()
|
||||
if self.digest is None or self.next_byte == self.hash_len:
|
||||
self.hash.update(bytes(self.pool))
|
||||
self.digest = bytearray(self.hash.digest())
|
||||
self._stir(self.digest)
|
||||
self.next_byte = 0
|
||||
value = self.digest[self.next_byte]
|
||||
self.next_byte += 1
|
||||
return value
|
||||
|
||||
def random_16(self) -> int:
|
||||
return self.random_8() * 256 + self.random_8()
|
||||
|
||||
def random_32(self) -> int:
|
||||
return self.random_16() * 65536 + self.random_16()
|
||||
|
||||
def random_between(self, first: int, last: int) -> int:
|
||||
size = last - first + 1
|
||||
if size > 4294967296:
|
||||
raise ValueError("too big")
|
||||
if size > 65536:
|
||||
rand = self.random_32
|
||||
max = 4294967295
|
||||
elif size > 256:
|
||||
rand = self.random_16
|
||||
max = 65535
|
||||
else:
|
||||
rand = self.random_8
|
||||
max = 255
|
||||
return first + size * rand() // (max + 1)
|
||||
|
||||
|
||||
pool = EntropyPool()
|
||||
|
||||
system_random: Any | None
|
||||
try:
|
||||
system_random = random.SystemRandom()
|
||||
except Exception: # pragma: no cover
|
||||
system_random = None
|
||||
|
||||
|
||||
def random_16() -> int:
|
||||
if system_random is not None:
|
||||
return system_random.randrange(0, 65536)
|
||||
else:
|
||||
return pool.random_16()
|
||||
|
||||
|
||||
def between(first: int, last: int) -> int:
|
||||
if system_random is not None:
|
||||
return system_random.randrange(first, last + 1)
|
||||
else:
|
||||
return pool.random_between(first, last)
|
||||
113
.venv/Lib/site-packages/dns/enum.py
Normal file
113
.venv/Lib/site-packages/dns/enum.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import enum
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
|
||||
|
||||
|
||||
class IntEnum(enum.IntEnum):
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
cls._check_value(value)
|
||||
val = int.__new__(cls, value) # pyright: ignore
|
||||
val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
|
||||
val._value_ = value # pyright: ignore
|
||||
return val
|
||||
|
||||
@classmethod
|
||||
def _check_value(cls, value):
|
||||
max = cls._maximum()
|
||||
if not isinstance(value, int):
|
||||
raise TypeError
|
||||
if value < 0 or value > max:
|
||||
name = cls._short_name()
|
||||
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
|
||||
|
||||
@classmethod
|
||||
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
|
||||
text = text.upper()
|
||||
try:
|
||||
return cls[text]
|
||||
except KeyError:
|
||||
pass
|
||||
value = cls._extra_from_text(text)
|
||||
if value:
|
||||
return value
|
||||
prefix = cls._prefix()
|
||||
if text.startswith(prefix) and text[len(prefix) :].isdigit():
|
||||
value = int(text[len(prefix) :])
|
||||
cls._check_value(value)
|
||||
return cls(value)
|
||||
raise cls._unknown_exception_class()
|
||||
|
||||
@classmethod
|
||||
def to_text(cls: Type[TIntEnum], value: int) -> str:
|
||||
cls._check_value(value)
|
||||
try:
|
||||
text = cls(value).name
|
||||
except ValueError:
|
||||
text = None
|
||||
text = cls._extra_to_text(value, text)
|
||||
if text is None:
|
||||
text = f"{cls._prefix()}{value}"
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def make(cls: Type[TIntEnum], value: int | str) -> TIntEnum:
|
||||
"""Convert text or a value into an enumerated type, if possible.
|
||||
|
||||
*value*, the ``int`` or ``str`` to convert.
|
||||
|
||||
Raises a class-specific exception if a ``str`` is provided that
|
||||
cannot be converted.
|
||||
|
||||
Raises ``ValueError`` if the value is out of range.
|
||||
|
||||
Returns an enumeration from the calling class corresponding to the
|
||||
value, if one is defined, or an ``int`` otherwise.
|
||||
"""
|
||||
|
||||
if isinstance(value, str):
|
||||
return cls.from_text(value)
|
||||
cls._check_value(value)
|
||||
return cls(value)
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def _short_name(cls):
|
||||
return cls.__name__.lower()
|
||||
|
||||
@classmethod
|
||||
def _prefix(cls) -> str:
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _extra_from_text(cls, text: str) -> Any | None: # pylint: disable=W0613
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
|
||||
return current_text
|
||||
|
||||
@classmethod
|
||||
def _unknown_exception_class(cls) -> Type[Exception]:
|
||||
return ValueError
|
||||
169
.venv/Lib/site-packages/dns/exception.py
Normal file
169
.venv/Lib/site-packages/dns/exception.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""Common DNS Exceptions.
|
||||
|
||||
Dnspython modules may also define their own exceptions, which will
|
||||
always be subclasses of ``DNSException``.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Set
|
||||
|
||||
|
||||
class DNSException(Exception):
|
||||
"""Abstract base class shared by all dnspython exceptions.
|
||||
|
||||
It supports two basic modes of operation:
|
||||
|
||||
a) Old/compatible mode is used if ``__init__`` was called with
|
||||
empty *kwargs*. In compatible mode all *args* are passed
|
||||
to the standard Python Exception class as before and all *args* are
|
||||
printed by the standard ``__str__`` implementation. Class variable
|
||||
``msg`` (or doc string if ``msg`` is ``None``) is returned from ``str()``
|
||||
if *args* is empty.
|
||||
|
||||
b) New/parametrized mode is used if ``__init__`` was called with
|
||||
non-empty *kwargs*.
|
||||
In the new mode *args* must be empty and all kwargs must match
|
||||
those set in class variable ``supp_kwargs``. All kwargs are stored inside
|
||||
``self.kwargs`` and used in a new ``__str__`` implementation to construct
|
||||
a formatted message based on the ``fmt`` class variable, a ``string``.
|
||||
|
||||
In the simplest case it is enough to override the ``supp_kwargs``
|
||||
and ``fmt`` class variables to get nice parametrized messages.
|
||||
"""
|
||||
|
||||
msg: str | None = None # non-parametrized message
|
||||
supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
|
||||
fmt: str | None = None # message parametrized with results from _fmt_kwargs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._check_params(*args, **kwargs)
|
||||
if kwargs:
|
||||
# This call to a virtual method from __init__ is ok in our usage
|
||||
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
|
||||
self.msg = str(self)
|
||||
else:
|
||||
self.kwargs = dict() # defined but empty for old mode exceptions
|
||||
if self.msg is None:
|
||||
# doc string is better implicit message than empty string
|
||||
self.msg = self.__doc__
|
||||
if args:
|
||||
super().__init__(*args)
|
||||
else:
|
||||
super().__init__(self.msg)
|
||||
|
||||
def _check_params(self, *args, **kwargs):
|
||||
"""Old exceptions supported only args and not kwargs.
|
||||
|
||||
For sanity we do not allow to mix old and new behavior."""
|
||||
if args or kwargs:
|
||||
assert bool(args) != bool(
|
||||
kwargs
|
||||
), "keyword arguments are mutually exclusive with positional args"
|
||||
|
||||
def _check_kwargs(self, **kwargs):
|
||||
if kwargs:
|
||||
assert (
|
||||
set(kwargs.keys()) == self.supp_kwargs
|
||||
), f"following set of keyword args is required: {self.supp_kwargs}"
|
||||
return kwargs
|
||||
|
||||
def _fmt_kwargs(self, **kwargs):
|
||||
"""Format kwargs before printing them.
|
||||
|
||||
Resulting dictionary has to have keys necessary for str.format call
|
||||
on fmt class variable.
|
||||
"""
|
||||
fmtargs = {}
|
||||
for kw, data in kwargs.items():
|
||||
if isinstance(data, list | set):
|
||||
# convert list of <someobj> to list of str(<someobj>)
|
||||
fmtargs[kw] = list(map(str, data))
|
||||
if len(fmtargs[kw]) == 1:
|
||||
# remove list brackets [] from single-item lists
|
||||
fmtargs[kw] = fmtargs[kw].pop()
|
||||
else:
|
||||
fmtargs[kw] = data
|
||||
return fmtargs
|
||||
|
||||
def __str__(self):
|
||||
if self.kwargs and self.fmt:
|
||||
# provide custom message constructed from keyword arguments
|
||||
fmtargs = self._fmt_kwargs(**self.kwargs)
|
||||
return self.fmt.format(**fmtargs)
|
||||
else:
|
||||
# print *args directly in the same way as old DNSException
|
||||
return super().__str__()
|
||||
|
||||
|
||||
class FormError(DNSException):
|
||||
"""DNS message is malformed."""
|
||||
|
||||
|
||||
class SyntaxError(DNSException):
|
||||
"""Text input is malformed."""
|
||||
|
||||
|
||||
class UnexpectedEnd(SyntaxError):
|
||||
"""Text input ended unexpectedly."""
|
||||
|
||||
|
||||
class TooBig(DNSException):
|
||||
"""The DNS message is too big."""
|
||||
|
||||
|
||||
class Timeout(DNSException):
|
||||
"""The DNS operation timed out."""
|
||||
|
||||
supp_kwargs = {"timeout"}
|
||||
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
|
||||
|
||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||
# idna_exception
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class UnsupportedAlgorithm(DNSException):
|
||||
"""The DNSSEC algorithm is not supported."""
|
||||
|
||||
|
||||
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
|
||||
"""The DNSSEC algorithm is not supported for the given key type."""
|
||||
|
||||
|
||||
class ValidationFailure(DNSException):
|
||||
"""The DNSSEC signature is invalid."""
|
||||
|
||||
|
||||
class DeniedByPolicy(DNSException):
|
||||
"""Denied by DNSSEC policy."""
|
||||
|
||||
|
||||
class ExceptionWrapper:
|
||||
def __init__(self, exception_class):
|
||||
self.exception_class = exception_class
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None and not isinstance(exc_val, self.exception_class):
|
||||
raise self.exception_class(str(exc_val)) from exc_val
|
||||
return False
|
||||
123
.venv/Lib/site-packages/dns/flags.py
Normal file
123
.venv/Lib/site-packages/dns/flags.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2001-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""DNS Message Flags."""
|
||||
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
# Standard DNS flags
|
||||
|
||||
|
||||
class Flag(enum.IntFlag):
|
||||
#: Query Response
|
||||
QR = 0x8000
|
||||
#: Authoritative Answer
|
||||
AA = 0x0400
|
||||
#: Truncated Response
|
||||
TC = 0x0200
|
||||
#: Recursion Desired
|
||||
RD = 0x0100
|
||||
#: Recursion Available
|
||||
RA = 0x0080
|
||||
#: Authentic Data
|
||||
AD = 0x0020
|
||||
#: Checking Disabled
|
||||
CD = 0x0010
|
||||
|
||||
|
||||
# EDNS flags
|
||||
|
||||
|
||||
class EDNSFlag(enum.IntFlag):
|
||||
#: DNSSEC answer OK
|
||||
DO = 0x8000
|
||||
|
||||
|
||||
def _from_text(text: str, enum_class: Any) -> int:
|
||||
flags = 0
|
||||
tokens = text.split()
|
||||
for t in tokens:
|
||||
flags |= enum_class[t.upper()]
|
||||
return flags
|
||||
|
||||
|
||||
def _to_text(flags: int, enum_class: Any) -> str:
|
||||
text_flags = []
|
||||
for k, v in enum_class.__members__.items():
|
||||
if flags & v != 0:
|
||||
text_flags.append(k)
|
||||
return " ".join(text_flags)
|
||||
|
||||
|
||||
def from_text(text: str) -> int:
|
||||
"""Convert a space-separated list of flag text values into a flags
|
||||
value.
|
||||
|
||||
Returns an ``int``
|
||||
"""
|
||||
|
||||
return _from_text(text, Flag)
|
||||
|
||||
|
||||
def to_text(flags: int) -> str:
|
||||
"""Convert a flags value into a space-separated list of flag text
|
||||
values.
|
||||
|
||||
Returns a ``str``.
|
||||
"""
|
||||
|
||||
return _to_text(flags, Flag)
|
||||
|
||||
|
||||
def edns_from_text(text: str) -> int:
|
||||
"""Convert a space-separated list of EDNS flag text values into a EDNS
|
||||
flags value.
|
||||
|
||||
Returns an ``int``
|
||||
"""
|
||||
|
||||
return _from_text(text, EDNSFlag)
|
||||
|
||||
|
||||
def edns_to_text(flags: int) -> str:
|
||||
"""Convert an EDNS flags value into a space-separated list of EDNS flag
|
||||
text values.
|
||||
|
||||
Returns a ``str``.
|
||||
"""
|
||||
|
||||
return _to_text(flags, EDNSFlag)
|
||||
|
||||
|
||||
### BEGIN generated Flag constants
|
||||
|
||||
QR = Flag.QR
|
||||
AA = Flag.AA
|
||||
TC = Flag.TC
|
||||
RD = Flag.RD
|
||||
RA = Flag.RA
|
||||
AD = Flag.AD
|
||||
CD = Flag.CD
|
||||
|
||||
### END generated Flag constants
|
||||
|
||||
### BEGIN generated EDNSFlag constants
|
||||
|
||||
DO = EDNSFlag.DO
|
||||
|
||||
### END generated EDNSFlag constants
|
||||
72
.venv/Lib/site-packages/dns/grange.py
Normal file
72
.venv/Lib/site-packages/dns/grange.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2012-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""DNS GENERATE range conversion."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import dns.exception
|
||||
|
||||
|
||||
def from_text(text: str) -> Tuple[int, int, int]:
|
||||
"""Convert the text form of a range in a ``$GENERATE`` statement to an
|
||||
integer.
|
||||
|
||||
*text*, a ``str``, the textual range in ``$GENERATE`` form.
|
||||
|
||||
Returns a tuple of three ``int`` values ``(start, stop, step)``.
|
||||
"""
|
||||
|
||||
start = -1
|
||||
stop = -1
|
||||
step = 1
|
||||
cur = ""
|
||||
state = 0
|
||||
# state 0 1 2
|
||||
# x - y / z
|
||||
|
||||
if text and text[0] == "-":
|
||||
raise dns.exception.SyntaxError("Start cannot be a negative number")
|
||||
|
||||
for c in text:
|
||||
if c == "-" and state == 0:
|
||||
start = int(cur)
|
||||
cur = ""
|
||||
state = 1
|
||||
elif c == "/":
|
||||
stop = int(cur)
|
||||
cur = ""
|
||||
state = 2
|
||||
elif c.isdigit():
|
||||
cur += c
|
||||
else:
|
||||
raise dns.exception.SyntaxError(f"Could not parse {c}")
|
||||
|
||||
if state == 0:
|
||||
raise dns.exception.SyntaxError("no stop value specified")
|
||||
elif state == 1:
|
||||
stop = int(cur)
|
||||
else:
|
||||
assert state == 2
|
||||
step = int(cur)
|
||||
|
||||
assert step >= 1
|
||||
assert start >= 0
|
||||
if start > stop:
|
||||
raise dns.exception.SyntaxError("start must be <= stop")
|
||||
|
||||
return (start, stop, step)
|
||||
68
.venv/Lib/site-packages/dns/immutable.py
Normal file
68
.venv/Lib/site-packages/dns/immutable.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import collections.abc
|
||||
from typing import Any, Callable
|
||||
|
||||
from dns._immutable_ctx import immutable
|
||||
|
||||
|
||||
@immutable
|
||||
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
|
||||
def __init__(
|
||||
self,
|
||||
dictionary: Any,
|
||||
no_copy: bool = False,
|
||||
map_factory: Callable[[], collections.abc.MutableMapping] = dict,
|
||||
):
|
||||
"""Make an immutable dictionary from the specified dictionary.
|
||||
|
||||
If *no_copy* is `True`, then *dictionary* will be wrapped instead
|
||||
of copied. Only set this if you are sure there will be no external
|
||||
references to the dictionary.
|
||||
"""
|
||||
if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
|
||||
self._odict = dictionary
|
||||
else:
|
||||
self._odict = map_factory()
|
||||
self._odict.update(dictionary)
|
||||
self._hash = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._odict.__getitem__(key)
|
||||
|
||||
def __hash__(self): # pylint: disable=invalid-hash-returned
|
||||
if self._hash is None:
|
||||
h = 0
|
||||
for key in sorted(self._odict.keys()):
|
||||
h ^= hash(key)
|
||||
object.__setattr__(self, "_hash", h)
|
||||
# this does return an int, but pylint doesn't figure that out
|
||||
return self._hash
|
||||
|
||||
def __len__(self):
|
||||
return len(self._odict)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._odict)
|
||||
|
||||
|
||||
def constify(o: Any) -> Any:
|
||||
"""
|
||||
Convert mutable types to immutable types.
|
||||
"""
|
||||
if isinstance(o, bytearray):
|
||||
return bytes(o)
|
||||
if isinstance(o, tuple):
|
||||
try:
|
||||
hash(o)
|
||||
return o
|
||||
except Exception:
|
||||
return tuple(constify(elt) for elt in o)
|
||||
if isinstance(o, list):
|
||||
return tuple(constify(elt) for elt in o)
|
||||
if isinstance(o, dict):
|
||||
cdict = dict()
|
||||
for k, v in o.items():
|
||||
cdict[k] = constify(v)
|
||||
return Dict(cdict, True)
|
||||
return o
|
||||
195
.venv/Lib/site-packages/dns/inet.py
Normal file
195
.venv/Lib/site-packages/dns/inet.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""Generic Internet address helper functions."""
|
||||
|
||||
import socket
|
||||
from typing import Any, Tuple
|
||||
|
||||
import dns.ipv4
|
||||
import dns.ipv6
|
||||
|
||||
# We assume that AF_INET and AF_INET6 are always defined. We keep
|
||||
# these here for the benefit of any old code (unlikely though that
|
||||
# is!).
|
||||
AF_INET = socket.AF_INET
|
||||
AF_INET6 = socket.AF_INET6
|
||||
|
||||
|
||||
def inet_pton(family: int, text: str) -> bytes:
|
||||
"""Convert the textual form of a network address into its binary form.
|
||||
|
||||
*family* is an ``int``, the address family.
|
||||
|
||||
*text* is a ``str``, the textual address.
|
||||
|
||||
Raises ``NotImplementedError`` if the address family specified is not
|
||||
implemented.
|
||||
|
||||
Returns a ``bytes``.
|
||||
"""
|
||||
|
||||
if family == AF_INET:
|
||||
return dns.ipv4.inet_aton(text)
|
||||
elif family == AF_INET6:
|
||||
return dns.ipv6.inet_aton(text, True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def inet_ntop(family: int, address: bytes) -> str:
|
||||
"""Convert the binary form of a network address into its textual form.
|
||||
|
||||
*family* is an ``int``, the address family.
|
||||
|
||||
*address* is a ``bytes``, the network address in binary form.
|
||||
|
||||
Raises ``NotImplementedError`` if the address family specified is not
|
||||
implemented.
|
||||
|
||||
Returns a ``str``.
|
||||
"""
|
||||
|
||||
if family == AF_INET:
|
||||
return dns.ipv4.inet_ntoa(address)
|
||||
elif family == AF_INET6:
|
||||
return dns.ipv6.inet_ntoa(address)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def af_for_address(text: str) -> int:
|
||||
"""Determine the address family of a textual-form network address.
|
||||
|
||||
*text*, a ``str``, the textual address.
|
||||
|
||||
Raises ``ValueError`` if the address family cannot be determined
|
||||
from the input.
|
||||
|
||||
Returns an ``int``.
|
||||
"""
|
||||
|
||||
try:
|
||||
dns.ipv4.inet_aton(text)
|
||||
return AF_INET
|
||||
except Exception:
|
||||
try:
|
||||
dns.ipv6.inet_aton(text, True)
|
||||
return AF_INET6
|
||||
except Exception:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def is_multicast(text: str) -> bool:
|
||||
"""Is the textual-form network address a multicast address?
|
||||
|
||||
*text*, a ``str``, the textual address.
|
||||
|
||||
Raises ``ValueError`` if the address family cannot be determined
|
||||
from the input.
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
try:
|
||||
first = dns.ipv4.inet_aton(text)[0]
|
||||
return first >= 224 and first <= 239
|
||||
except Exception:
|
||||
try:
|
||||
first = dns.ipv6.inet_aton(text, True)[0]
|
||||
return first == 255
|
||||
except Exception:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def is_address(text: str) -> bool:
|
||||
"""Is the specified string an IPv4 or IPv6 address?
|
||||
|
||||
*text*, a ``str``, the textual address.
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
try:
|
||||
dns.ipv4.inet_aton(text)
|
||||
return True
|
||||
except Exception:
|
||||
try:
|
||||
dns.ipv6.inet_aton(text, True)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def low_level_address_tuple(high_tuple: Tuple[str, int], af: int | None = None) -> Any:
|
||||
"""Given a "high-level" address tuple, i.e.
|
||||
an (address, port) return the appropriate "low-level" address tuple
|
||||
suitable for use in socket calls.
|
||||
|
||||
If an *af* other than ``None`` is provided, it is assumed the
|
||||
address in the high-level tuple is valid and has that af. If af
|
||||
is ``None``, then af_for_address will be called.
|
||||
"""
|
||||
address, port = high_tuple
|
||||
if af is None:
|
||||
af = af_for_address(address)
|
||||
if af == AF_INET:
|
||||
return (address, port)
|
||||
elif af == AF_INET6:
|
||||
i = address.find("%")
|
||||
if i < 0:
|
||||
# no scope, shortcut!
|
||||
return (address, port, 0, 0)
|
||||
# try to avoid getaddrinfo()
|
||||
addrpart = address[:i]
|
||||
scope = address[i + 1 :]
|
||||
if scope.isdigit():
|
||||
return (addrpart, port, 0, int(scope))
|
||||
try:
|
||||
return (addrpart, port, 0, socket.if_nametoindex(scope))
|
||||
except AttributeError: # pragma: no cover (we can't really test this)
|
||||
ai_flags = socket.AI_NUMERICHOST
|
||||
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
|
||||
return tup
|
||||
else:
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
|
||||
|
||||
def any_for_af(af):
|
||||
"""Return the 'any' address for the specified address family."""
|
||||
if af == socket.AF_INET:
|
||||
return "0.0.0.0"
|
||||
elif af == socket.AF_INET6:
|
||||
return "::"
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
|
||||
|
||||
def canonicalize(text: str) -> str:
|
||||
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
|
||||
canonical text form. IPv6 addresses with scopes are rejected.
|
||||
|
||||
*text*, a ``str``, the address in textual form.
|
||||
|
||||
Raises ``ValueError`` if the text is not valid.
|
||||
"""
|
||||
try:
|
||||
return dns.ipv6.canonicalize(text)
|
||||
except Exception:
|
||||
try:
|
||||
return dns.ipv4.canonicalize(text)
|
||||
except Exception:
|
||||
raise ValueError
|
||||
76
.venv/Lib/site-packages/dns/ipv4.py
Normal file
76
.venv/Lib/site-packages/dns/ipv4.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""IPv4 helper functions."""
|
||||
|
||||
import struct
|
||||
|
||||
import dns.exception
|
||||
|
||||
|
||||
def inet_ntoa(address: bytes) -> str:
|
||||
"""Convert an IPv4 address in binary form to text form.
|
||||
|
||||
*address*, a ``bytes``, the IPv4 address in binary form.
|
||||
|
||||
Returns a ``str``.
|
||||
"""
|
||||
|
||||
if len(address) != 4:
|
||||
raise dns.exception.SyntaxError
|
||||
return f"{address[0]}.{address[1]}.{address[2]}.{address[3]}"
|
||||
|
||||
|
||||
def inet_aton(text: str | bytes) -> bytes:
|
||||
"""Convert an IPv4 address in text form to binary form.
|
||||
|
||||
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
|
||||
|
||||
Returns a ``bytes``.
|
||||
"""
|
||||
|
||||
if not isinstance(text, bytes):
|
||||
btext = text.encode()
|
||||
else:
|
||||
btext = text
|
||||
parts = btext.split(b".")
|
||||
if len(parts) != 4:
|
||||
raise dns.exception.SyntaxError
|
||||
for part in parts:
|
||||
if not part.isdigit():
|
||||
raise dns.exception.SyntaxError
|
||||
if len(part) > 1 and part[0] == ord("0"):
|
||||
# No leading zeros
|
||||
raise dns.exception.SyntaxError
|
||||
try:
|
||||
b = [int(part) for part in parts]
|
||||
return struct.pack("BBBB", *b)
|
||||
except Exception:
|
||||
raise dns.exception.SyntaxError
|
||||
|
||||
|
||||
def canonicalize(text: str | bytes) -> str:
|
||||
"""Verify that *address* is a valid text form IPv4 address and return its
|
||||
canonical text form.
|
||||
|
||||
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
|
||||
|
||||
Raises ``dns.exception.SyntaxError`` if the text is not valid.
|
||||
"""
|
||||
# Note that inet_aton() only accepts canonial form, but we still run through
|
||||
# inet_ntoa() to ensure the output is a str.
|
||||
return inet_ntoa(inet_aton(text))
|
||||
217
.venv/Lib/site-packages/dns/ipv6.py
Normal file
217
.venv/Lib/site-packages/dns/ipv6.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2017 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
"""IPv6 helper functions."""
|
||||
|
||||
import binascii
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import dns.exception
|
||||
import dns.ipv4
|
||||
|
||||
_leading_zero = re.compile(r"0+([0-9a-f]+)")
|
||||
|
||||
|
||||
def inet_ntoa(address: bytes) -> str:
|
||||
"""Convert an IPv6 address in binary form to text form.
|
||||
|
||||
*address*, a ``bytes``, the IPv6 address in binary form.
|
||||
|
||||
Raises ``ValueError`` if the address isn't 16 bytes long.
|
||||
Returns a ``str``.
|
||||
"""
|
||||
|
||||
if len(address) != 16:
|
||||
raise ValueError("IPv6 addresses are 16 bytes long")
|
||||
hex = binascii.hexlify(address)
|
||||
chunks = []
|
||||
i = 0
|
||||
l = len(hex)
|
||||
while i < l:
|
||||
chunk = hex[i : i + 4].decode()
|
||||
# strip leading zeros. we do this with an re instead of
|
||||
# with lstrip() because lstrip() didn't support chars until
|
||||
# python 2.2.2
|
||||
m = _leading_zero.match(chunk)
|
||||
if m is not None:
|
||||
chunk = m.group(1)
|
||||
chunks.append(chunk)
|
||||
i += 4
|
||||
#
|
||||
# Compress the longest subsequence of 0-value chunks to ::
|
||||
#
|
||||
best_start = 0
|
||||
best_len = 0
|
||||
start = -1
|
||||
last_was_zero = False
|
||||
for i in range(8):
|
||||
if chunks[i] != "0":
|
||||
if last_was_zero:
|
||||
end = i
|
||||
current_len = end - start
|
||||
if current_len > best_len:
|
||||
best_start = start
|
||||
best_len = current_len
|
||||
last_was_zero = False
|
||||
elif not last_was_zero:
|
||||
start = i
|
||||
last_was_zero = True
|
||||
if last_was_zero:
|
||||
end = 8
|
||||
current_len = end - start
|
||||
if current_len > best_len:
|
||||
best_start = start
|
||||
best_len = current_len
|
||||
if best_len > 1:
|
||||
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
|
||||
# We have an embedded IPv4 address
|
||||
if best_len == 6:
|
||||
prefix = "::"
|
||||
else:
|
||||
prefix = "::ffff:"
|
||||
thex = prefix + dns.ipv4.inet_ntoa(address[12:])
|
||||
else:
|
||||
thex = (
|
||||
":".join(chunks[:best_start])
|
||||
+ "::"
|
||||
+ ":".join(chunks[best_start + best_len :])
|
||||
)
|
||||
else:
|
||||
thex = ":".join(chunks)
|
||||
return thex
|
||||
|
||||
|
||||
_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
|
||||
_colon_colon_start = re.compile(rb"::.*")
|
||||
_colon_colon_end = re.compile(rb".*::$")
|
||||
|
||||
|
||||
def inet_aton(text: str | bytes, ignore_scope: bool = False) -> bytes:
|
||||
"""Convert an IPv6 address in text form to binary form.
|
||||
|
||||
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
|
||||
|
||||
*ignore_scope*, a ``bool``. If ``True``, a scope will be ignored.
|
||||
If ``False``, the default, it is an error for a scope to be present.
|
||||
|
||||
Returns a ``bytes``.
|
||||
"""
|
||||
|
||||
#
|
||||
# Our aim here is not something fast; we just want something that works.
|
||||
#
|
||||
if not isinstance(text, bytes):
|
||||
btext = text.encode()
|
||||
else:
|
||||
btext = text
|
||||
|
||||
if ignore_scope:
|
||||
parts = btext.split(b"%")
|
||||
l = len(parts)
|
||||
if l == 2:
|
||||
btext = parts[0]
|
||||
elif l > 2:
|
||||
raise dns.exception.SyntaxError
|
||||
|
||||
if btext == b"":
|
||||
raise dns.exception.SyntaxError
|
||||
elif btext.endswith(b":") and not btext.endswith(b"::"):
|
||||
raise dns.exception.SyntaxError
|
||||
elif btext.startswith(b":") and not btext.startswith(b"::"):
|
||||
raise dns.exception.SyntaxError
|
||||
elif btext == b"::":
|
||||
btext = b"0::"
|
||||
#
|
||||
# Get rid of the icky dot-quad syntax if we have it.
|
||||
#
|
||||
m = _v4_ending.match(btext)
|
||||
if m is not None:
|
||||
b = dns.ipv4.inet_aton(m.group(2))
|
||||
btext = (
|
||||
f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
|
||||
).encode()
|
||||
#
|
||||
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
|
||||
# turn '<whatever>::' into '<whatever>:'
|
||||
#
|
||||
m = _colon_colon_start.match(btext)
|
||||
if m is not None:
|
||||
btext = btext[1:]
|
||||
else:
|
||||
m = _colon_colon_end.match(btext)
|
||||
if m is not None:
|
||||
btext = btext[:-1]
|
||||
#
|
||||
# Now canonicalize into 8 chunks of 4 hex digits each
|
||||
#
|
||||
chunks = btext.split(b":")
|
||||
l = len(chunks)
|
||||
if l > 8:
|
||||
raise dns.exception.SyntaxError
|
||||
seen_empty = False
|
||||
canonical: List[bytes] = []
|
||||
for c in chunks:
|
||||
if c == b"":
|
||||
if seen_empty:
|
||||
raise dns.exception.SyntaxError
|
||||
seen_empty = True
|
||||
for _ in range(0, 8 - l + 1):
|
||||
canonical.append(b"0000")
|
||||
else:
|
||||
lc = len(c)
|
||||
if lc > 4:
|
||||
raise dns.exception.SyntaxError
|
||||
if lc != 4:
|
||||
c = (b"0" * (4 - lc)) + c
|
||||
canonical.append(c)
|
||||
if l < 8 and not seen_empty:
|
||||
raise dns.exception.SyntaxError
|
||||
btext = b"".join(canonical)
|
||||
|
||||
#
|
||||
# Finally we can go to binary.
|
||||
#
|
||||
try:
|
||||
return binascii.unhexlify(btext)
|
||||
except (binascii.Error, TypeError):
|
||||
raise dns.exception.SyntaxError
|
||||
|
||||
|
||||
_mapped_prefix = b"\x00" * 10 + b"\xff\xff"
|
||||
|
||||
|
||||
def is_mapped(address: bytes) -> bool:
|
||||
"""Is the specified address a mapped IPv4 address?
|
||||
|
||||
*address*, a ``bytes`` is an IPv6 address in binary form.
|
||||
|
||||
Returns a ``bool``.
|
||||
"""
|
||||
|
||||
return address.startswith(_mapped_prefix)
|
||||
|
||||
|
||||
def canonicalize(text: str | bytes) -> str:
|
||||
"""Verify that *address* is a valid text form IPv6 address and return its
|
||||
canonical text form. Addresses with scopes are rejected.
|
||||
|
||||
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
|
||||
|
||||
Raises ``dns.exception.SyntaxError`` if the text is not valid.
|
||||
"""
|
||||
return inet_ntoa(inet_aton(text))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user