mirror of
https://github.com/sqlmapproject/sqlmap.git
synced 2026-06-28 20:40:58 +00:00
291 lines
11 KiB
Python
291 lines
11 KiB
Python
#!/usr/bin/env python
|
|
|
|
"""
|
|
Copyright (c) 2006-2026 sqlmap developers (https://sqlmap.org)
|
|
See the file 'LICENSE' for copying permission
|
|
|
|
The DNS server used for DNS-exfiltration (lib/request/dns.py): raw packet parsing
|
|
(DNSQuery), fake A-record response crafting, the pop(prefix, suffix) accounting, and
|
|
- importantly - resilience: a single malformed packet or a transient send error must
|
|
NOT kill the server thread (which would silently lose all further exfiltration).
|
|
"""
|
|
|
|
import collections
|
|
import os
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import threading
|
|
import time
|
|
import unittest
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
|
|
|
|
from lib.core.settings import MAX_DNS_REQUESTS
|
|
from lib.request.dns import DNSQuery, DNSServer
|
|
|
|
|
|
def build_query(name, tid=b"\x12\x34", qtype=1):
|
|
"""Minimal standard (opcode 0) DNS query packet for L{name} (qtype 1=A, 28=AAAA, ...)"""
|
|
pkt = tid + b"\x01\x00" + b"\x00\x01" + b"\x00\x00" + b"\x00\x00" + b"\x00\x00"
|
|
for label in name.split("."):
|
|
if label:
|
|
pkt += struct.pack("B", len(label)) + label.encode()
|
|
return pkt + b"\x00" + struct.pack(">H", qtype) + b"\x00\x01"
|
|
|
|
|
|
class _HighPortDNSServer(DNSServer):
|
|
"""Real DNSServer logic, bound on a high port (no root, no :53 probe)"""
|
|
def __init__(self, port, sock=None, maxlen=MAX_DNS_REQUESTS):
|
|
self._requests = collections.deque(maxlen=maxlen)
|
|
self._lock = threading.Lock()
|
|
if sock is None:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(("127.0.0.1", port))
|
|
self._socket = sock
|
|
self._running = False
|
|
self._initialized = False
|
|
|
|
|
|
class _SendFailOnceSocket(object):
|
|
"""Wraps a real UDP socket; first sendto() raises (simulated transient failure)"""
|
|
def __init__(self, real):
|
|
self._real = real
|
|
self._sends = 0
|
|
|
|
def recvfrom(self, *a, **k):
|
|
return self._real.recvfrom(*a, **k)
|
|
|
|
def sendto(self, *a, **k):
|
|
self._sends += 1
|
|
if self._sends == 1:
|
|
raise RuntimeError("simulated transient sendto failure")
|
|
return self._real.sendto(*a, **k)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._real, name)
|
|
|
|
|
|
class TestDNSQuery(unittest.TestCase):
|
|
def test_parses_data_bearing_name(self):
|
|
q = DNSQuery(build_query("pre.deadbeef.suf.exfil.test"))
|
|
self.assertEqual(q._query, b"pre.deadbeef.suf.exfil.test.")
|
|
|
|
def test_empty_and_short_packets_do_not_raise(self):
|
|
for raw in (b"", b"\x00", b"\x12", b"\x12\x34", b"\x12\x34\x01\x20"):
|
|
self.assertEqual(DNSQuery(raw)._query, b"") # no exception, empty query
|
|
|
|
def test_unterminated_name_does_not_raise(self):
|
|
# a length byte that runs past the buffer, with no null terminator
|
|
pkt = b"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00" + b"\x20" + b"abc"
|
|
DNSQuery(pkt) # must not raise (slicing past end yields b"", ord guards)
|
|
|
|
def test_response_is_valid_A_record(self):
|
|
q = DNSQuery(build_query("x.y.z", tid=b"\xab\xcd"))
|
|
resp = q.response("127.0.0.1")
|
|
self.assertEqual(resp[:2], b"\xab\xcd") # transaction id echoed
|
|
self.assertEqual(resp[2:4], b"\x85\x80") # standard response, no error
|
|
ip = ".".join(str(b if isinstance(b, int) else ord(b)) for b in resp[-4:])
|
|
self.assertEqual(ip, "127.0.0.1")
|
|
|
|
def test_empty_query_yields_empty_response(self):
|
|
self.assertEqual(DNSQuery(b"\x00").response("127.0.0.1"), b"")
|
|
|
|
|
|
class TestDNSServerRoundTrip(unittest.TestCase):
|
|
PORT = 5471
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.srv = _HighPortDNSServer(cls.PORT)
|
|
cls.srv.run()
|
|
while not cls.srv._initialized:
|
|
time.sleep(0.02)
|
|
|
|
def _send(self, name):
|
|
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
c.settimeout(3)
|
|
c.sendto(build_query(name), ("127.0.0.1", self.PORT))
|
|
try:
|
|
c.recvfrom(512)
|
|
except socket.timeout:
|
|
pass
|
|
finally:
|
|
c.close()
|
|
for _ in range(100):
|
|
with self.srv._lock:
|
|
if any(name.encode() in r for r in self.srv._requests):
|
|
return True
|
|
time.sleep(0.01)
|
|
return False
|
|
|
|
def test_roundtrip_and_pop(self):
|
|
self.assertTrue(self._send("aaa.cafe.bbb.exfil.test"))
|
|
self.assertIsNone(self.srv.pop("zzz", "yyy")) # wrong boundaries
|
|
self.assertIsNotNone(self.srv.pop("aaa", "bbb")) # correct boundaries
|
|
self.assertIsNone(self.srv.pop("aaa", "bbb")) # consumed only once
|
|
|
|
def test_non_a_query_type_still_recorded(self):
|
|
# a DBMS resolver may emit AAAA (28) / TXT (16) lookups - the exfiltrated name is in the
|
|
# labels regardless of qtype, and the server records before crafting the (A) response
|
|
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
c.settimeout(2)
|
|
c.sendto(build_query("ggg.beef.hhh.exfil.test", qtype=28), ("127.0.0.1", self.PORT))
|
|
try:
|
|
c.recvfrom(512)
|
|
except socket.timeout:
|
|
pass
|
|
finally:
|
|
c.close()
|
|
for _ in range(200):
|
|
if self.srv.pop("ggg", "hhh"):
|
|
return
|
|
time.sleep(0.01)
|
|
self.fail("AAAA-type query was not recorded (exfil would be lost for AAAA-resolving DBMSes)")
|
|
|
|
|
|
class TestDNSServerMemoryBound(unittest.TestCase):
|
|
"""The server records every received query (it listens on :53); only matching ones are
|
|
popped. Unrelated/stray traffic and resolver retries must not grow memory without bound."""
|
|
PORT = 5475
|
|
|
|
def test_requests_are_bounded_and_recent_kept(self):
|
|
srv = _HighPortDNSServer(self.PORT, maxlen=50)
|
|
srv.run()
|
|
while not srv._initialized:
|
|
time.sleep(0.02)
|
|
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
for i in range(200): # flood well past the bound
|
|
c.sendto(build_query("noise%d.unrelated.test" % i), ("127.0.0.1", self.PORT))
|
|
c.close()
|
|
# a legit exfil query right after the flood must still be capturable
|
|
c2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM); c2.settimeout(2)
|
|
c2.sendto(build_query("ppp.d00d.qqq.exfil.test"), ("127.0.0.1", self.PORT))
|
|
try:
|
|
c2.recvfrom(512)
|
|
except socket.timeout:
|
|
pass
|
|
finally:
|
|
c2.close()
|
|
popped = None
|
|
for _ in range(200):
|
|
popped = srv.pop("ppp", "qqq")
|
|
if popped:
|
|
break
|
|
time.sleep(0.01)
|
|
with srv._lock:
|
|
n = len(srv._requests)
|
|
self.assertLessEqual(n, 50, "request buffer exceeded its bound (%d)" % n)
|
|
self.assertIsNotNone(popped, "a fresh exfil query was lost after a flood of stray traffic")
|
|
|
|
|
|
class TestDNSServerResilience(unittest.TestCase):
|
|
def _make(self, port, sock=None):
|
|
srv = _HighPortDNSServer(port, sock=sock)
|
|
srv.run()
|
|
while not srv._initialized:
|
|
time.sleep(0.02)
|
|
return srv
|
|
|
|
def _query(self, port, name):
|
|
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
c.settimeout(1)
|
|
c.sendto(build_query(name), ("127.0.0.1", port))
|
|
try:
|
|
c.recvfrom(512)
|
|
except socket.timeout:
|
|
pass
|
|
finally:
|
|
c.close()
|
|
|
|
def _recorded(self, srv, token, tries=120):
|
|
for _ in range(tries):
|
|
with srv._lock:
|
|
if any(token.encode() in r for r in srv._requests):
|
|
return True
|
|
time.sleep(0.01)
|
|
return False
|
|
|
|
def test_survives_transient_send_error(self):
|
|
port = 5472
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind(("127.0.0.1", port))
|
|
srv = self._make(port, sock=_SendFailOnceSocket(s))
|
|
self._query(port, "aaa.11.bbb.exfil.test") # first sendto raises
|
|
self._query(port, "ccc.22.ddd.exfil.test") # must still be served
|
|
self.assertTrue(self._recorded(srv, "ccc.22.ddd"),
|
|
"DNS server died after one failing sendto (lost subsequent exfil)")
|
|
self.assertTrue(srv._running)
|
|
|
|
def test_survives_malformed_packets(self):
|
|
port = 5473
|
|
srv = self._make(port)
|
|
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
for junk in (b"", b"\x00", b"\xff" * 7, b"\x12\x34\x01\x00\x00\x01" + b"\x20abc"):
|
|
c.sendto(junk, ("127.0.0.1", port))
|
|
c.close()
|
|
self._query(port, "ok.33.fine.exfil.test")
|
|
self.assertTrue(self._recorded(srv, "ok.33.fine"),
|
|
"DNS server died on a malformed packet")
|
|
|
|
|
|
class TestDNSServerConcurrency(unittest.TestCase):
|
|
"""Under --threads, many workers fire DNS queries and call pop() while the server thread
|
|
appends - all guarded by one lock. Each worker must get back exactly its own data."""
|
|
PORT = 5474
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.srv = _HighPortDNSServer(cls.PORT)
|
|
cls.srv.run()
|
|
while not cls.srv._initialized:
|
|
time.sleep(0.02)
|
|
|
|
def test_concurrent_send_and_pop_no_crosstalk(self):
|
|
import binascii, re
|
|
N = 12
|
|
errors = []
|
|
|
|
def worker(i):
|
|
# distinct boundary labels per worker (DNS boundary alphabet = letters, no a-f/digits)
|
|
prefix = "gg" + chr(ord("g") + i)
|
|
suffix = "mm" + chr(ord("g") + i)
|
|
secret = ("worker-%02d-secret" % i).encode()
|
|
host = "%s.%s.%s.exfil.test" % (prefix, binascii.hexlify(secret).decode(), suffix)
|
|
c = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
c.settimeout(2)
|
|
try:
|
|
c.sendto(build_query(host), ("127.0.0.1", self.PORT))
|
|
try:
|
|
c.recvfrom(512)
|
|
except socket.timeout:
|
|
pass
|
|
finally:
|
|
c.close()
|
|
got = None
|
|
for _ in range(200):
|
|
got = self.srv.pop(prefix, suffix)
|
|
if got:
|
|
break
|
|
time.sleep(0.01)
|
|
if not got:
|
|
errors.append("worker %d: never popped its query" % i); return
|
|
m = re.search(r"%s\.(?P<r>.+?)\.%s" % (prefix, suffix), got, re.I)
|
|
if not m or binascii.unhexlify(m.group("r")) != secret:
|
|
errors.append("worker %d: cross-talk/corruption got=%r" % (i, got))
|
|
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
self.assertEqual(errors, [], "concurrency failures: %s" % errors)
|
|
# every queued request consumed exactly once -> nothing left behind
|
|
self.assertEqual(self.srv.pop("gg" + chr(ord("g")), "mm" + chr(ord("g"))), None)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|