Skip to content

Commit

Permalink
Add search method
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitriy Zenko committed Nov 29, 2022
1 parent aac8df5 commit 876b114
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
16 changes: 16 additions & 0 deletions aiodns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ def query(self, host: str, qtype: str, qclass: str=None) -> asyncio.Future:
self._channel.query(host, qtype, cb, query_class=qclass)
return fut

def search(self, host: str, qtype: str, qclass: str=None) -> asyncio.Future:
try:
qtype = query_type_map[qtype]
except KeyError:
raise ValueError('invalid query type: {}'.format(qtype))
if qclass is not None:
try:
qclass = query_class_map[qclass]
except KeyError:
raise ValueError('invalid query class: {}'.format(qclass))

fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.search(host, qtype, cb, query_class=qclass)
return fut

def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
Expand Down
112 changes: 112 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,118 @@ async def coro(self, host, qtype, n=2):
self.assertTrue(result)
self.loop.run_until_complete(coro(self, 'gmail.com', 'MX'))

def test_search_a(self):
f = self.resolver.search('google.com', 'A')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_async_await(self):
async def f():
return await self.resolver.search('google.com', 'A')
result = self.loop.run_until_complete(f())
self.assertTrue(result)

def test_search_a_bad(self):
f = self.resolver.search('hgf8g2od29hdohid.com', 'A')
try:
self.loop.run_until_complete(f)
except aiodns.error.DNSError as e:
self.assertEqual(e.args[0], aiodns.error.ARES_ENOTFOUND)

def test_search_aaaa(self):
f = self.resolver.search('ipv6.google.com', 'AAAA')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_cname(self):
f = self.resolver.search('www.amazon.com', 'CNAME')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_mx(self):
f = self.resolver.search('google.com', 'MX')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_ns(self):
f = self.resolver.search('google.com', 'NS')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_txt(self):
f = self.resolver.search('google.com', 'TXT')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_soa(self):
f = self.resolver.search('google.com', 'SOA')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_srv(self):
f = self.resolver.search('_xmpp-server._tcp.jabber.org', 'SRV')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_naptr(self):
f = self.resolver.search('sip2sip.info', 'NAPTR')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_ptr(self):
ip = '8.8.8.8'
f = self.resolver.search(ipaddress.ip_address(ip).reverse_pointer, 'PTR')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_bad_type(self):
self.assertRaises(ValueError, self.resolver.search, 'google.com', 'XXX')

def test_search_txt_chaos(self):
self.resolver = aiodns.DNSResolver(loop=self.loop)
self.resolver.nameservers = ['1.1.1.1']
f = self.resolver.search('id.server', 'TXT', 'CHAOS')
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_search_bad_class(self):
self.assertRaises(ValueError, self.resolver.search, 'google.com', 'A', "INVALIDCLASS")

def test_search_timeout(self):
self.resolver = aiodns.DNSResolver(timeout=0.1, loop=self.loop)
self.resolver.nameservers = ['1.2.3.4']
f = self.resolver.search('google.com', 'A')
try:
self.loop.run_until_complete(f)
except aiodns.error.DNSError as e:
self.assertEqual(e.args[0], aiodns.error.ARES_ETIMEOUT)

def test_search_cancel(self):
f = self.resolver.search('google.com', 'A')
self.resolver.cancel()
try:
self.loop.run_until_complete(f)
except aiodns.error.DNSError as e:
self.assertEqual(e.args[0], aiodns.error.ARES_ECANCELLED)

def test_search_future_cancel(self):
f = self.resolver.search('google.com', 'A')
f.cancel()
async def coro():
await asyncio.sleep(0.1)
await f
try:
self.loop.run_until_complete(coro())
except asyncio.CancelledError as e:
self.assertTrue(e)

def test_search_twice(self):
async def coro(self, host, qtype, n=2):
for i in range(n):
result = await self.resolver.search(host, qtype)
self.assertTrue(result)
self.loop.run_until_complete(coro(self, 'gmail.com', 'MX'))

def test_gethostbyname(self):
f = self.resolver.gethostbyname('google.com', socket.AF_INET)
result = self.loop.run_until_complete(f)
Expand Down

0 comments on commit 876b114

Please sign in to comment.