diff --git a/aiodns/__init__.py b/aiodns/__init__.py index d79b655..c5f3ac5 100644 --- a/aiodns/__init__.py +++ b/aiodns/__init__.py @@ -90,6 +90,22 @@ def query(self, host: str, qtype: str, qclass: str=None) -> asyncio.Future: self._channel.query(host, qtype, cb, query_class=qclass) return fut + def search(self, host: str, qtype: str, qclass: str=None) -> asyncio.Future: + try: + qtype = query_type_map[qtype] + except KeyError: + raise ValueError('invalid query type: {}'.format(qtype)) + if qclass is not None: + try: + qclass = query_class_map[qclass] + except KeyError: + raise ValueError('invalid query class: {}'.format(qclass)) + + fut = asyncio.Future(loop=self.loop) # type: asyncio.Future + cb = functools.partial(self._callback, fut) + self._channel.search(host, qtype, cb, query_class=qclass) + return fut + def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future: fut = asyncio.Future(loop=self.loop) # type: asyncio.Future cb = functools.partial(self._callback, fut) diff --git a/tests.py b/tests.py index fc0e2b9..5722591 100755 --- a/tests.py +++ b/tests.py @@ -130,6 +130,118 @@ async def coro(self, host, qtype, n=2): self.assertTrue(result) self.loop.run_until_complete(coro(self, 'gmail.com', 'MX')) + def test_search_a(self): + f = self.resolver.search('google.com', 'A') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_async_await(self): + async def f(): + return await self.resolver.search('google.com', 'A') + result = self.loop.run_until_complete(f()) + self.assertTrue(result) + + def test_search_a_bad(self): + f = self.resolver.search('hgf8g2od29hdohid.com', 'A') + try: + self.loop.run_until_complete(f) + except aiodns.error.DNSError as e: + self.assertEqual(e.args[0], aiodns.error.ARES_ENOTFOUND) + + def test_search_aaaa(self): + f = self.resolver.search('ipv6.google.com', 'AAAA') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_cname(self): + f = self.resolver.search('www.amazon.com', 'CNAME') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_mx(self): + f = self.resolver.search('google.com', 'MX') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_ns(self): + f = self.resolver.search('google.com', 'NS') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_txt(self): + f = self.resolver.search('google.com', 'TXT') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_soa(self): + f = self.resolver.search('google.com', 'SOA') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_srv(self): + f = self.resolver.search('_xmpp-server._tcp.jabber.org', 'SRV') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_naptr(self): + f = self.resolver.search('sip2sip.info', 'NAPTR') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_ptr(self): + ip = '8.8.8.8' + f = self.resolver.search(ipaddress.ip_address(ip).reverse_pointer, 'PTR') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_bad_type(self): + self.assertRaises(ValueError, self.resolver.search, 'google.com', 'XXX') + + def test_search_txt_chaos(self): + self.resolver = aiodns.DNSResolver(loop=self.loop) + self.resolver.nameservers = ['1.1.1.1'] + f = self.resolver.search('id.server', 'TXT', 'CHAOS') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + + def test_search_bad_class(self): + self.assertRaises(ValueError, self.resolver.search, 'google.com', 'A', "INVALIDCLASS") + + def test_search_timeout(self): + self.resolver = aiodns.DNSResolver(timeout=0.1, loop=self.loop) + self.resolver.nameservers = ['1.2.3.4'] + f = self.resolver.search('google.com', 'A') + try: + self.loop.run_until_complete(f) + except aiodns.error.DNSError as e: + self.assertEqual(e.args[0], aiodns.error.ARES_ETIMEOUT) + + def test_search_cancel(self): + f = self.resolver.search('google.com', 'A') + self.resolver.cancel() + try: + self.loop.run_until_complete(f) + except aiodns.error.DNSError as e: + self.assertEqual(e.args[0], aiodns.error.ARES_ECANCELLED) + + def test_search_future_cancel(self): + f = self.resolver.search('google.com', 'A') + f.cancel() + async def coro(): + await asyncio.sleep(0.1) + await f + try: + self.loop.run_until_complete(coro()) + except asyncio.CancelledError as e: + self.assertTrue(e) + + def test_search_twice(self): + async def coro(self, host, qtype, n=2): + for i in range(n): + result = await self.resolver.search(host, qtype) + self.assertTrue(result) + self.loop.run_until_complete(coro(self, 'gmail.com', 'MX')) + def test_gethostbyname(self): f = self.resolver.gethostbyname('google.com', socket.AF_INET) result = self.loop.run_until_complete(f)