Skip to content

Commit

Permalink
core: add support for getaddrinfo()
Browse files Browse the repository at this point in the history
Co-authored-by: boytm <[email protected]>

Closes: #128
  • Loading branch information
saghul committed May 12, 2021
1 parent d9d64be commit dfe00f3
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 0 deletions.
18 changes: 18 additions & 0 deletions docs/channel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@

The c-ares ``Channel`` provides asynchronous DNS operations.

.. py:method:: getaddrinfo(host, port, callback, family=0, type=0, proto=0, flags=0)
:param string host: Hostname to resolve.

:param string port: Service to resolve. Can be a string, int or None.

:param callable callback: Callback to be called with the result of the query.

The ``family``, ``type`` and ``proto`` arguments can be optionally specified in order to narrow the list of
addresses returned. Passing zero as a value for each of these arguments selects the full range of results.
The ``flags`` argument can be one or several of the ``AI_*`` constants, and will influence how results are
computed and returned. For example, ``AI_NUMERICHOST`` will disable domain name resolution.

Translate the host/port argument into a sequence of 5-tuples that contain all the necessary arguments for
creating a socket connected to that service.

Callback signature: ``callback(result, errorno)``


.. py:method:: gethostbyname(name, family, callback)
Expand Down
57 changes: 57 additions & 0 deletions src/_cffi_src/build_cares.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
#define ARES_ELOADIPHLPAPI ...
#define ARES_EADDRGETNETWORKPARAMS ...
#define ARES_ECANCELLED ...
#define ARES_ESERVICE ...
#define ARES_FLAG_USEVC ...
#define ARES_FLAG_PRIMARY ...
Expand Down Expand Up @@ -207,6 +208,11 @@
int type,
void *data);
typedef void (*ares_addrinfo_callback)(void *arg,
int status,
int timeouts,
struct ares_addrinfo *res);
struct ares_channeldata;
typedef struct ares_channeldata *ares_channel;
Expand Down Expand Up @@ -292,6 +298,44 @@
unsigned int expire;
unsigned int minttl;
};
/*
* Similar to addrinfo, but with extra ttl and missing canonname.
*/
struct ares_addrinfo_node {
int ai_ttl;
int ai_flags;
int ai_family;
int ai_socktype;
int ai_protocol;
ares_socklen_t ai_addrlen;
struct sockaddr *ai_addr;
struct ares_addrinfo_node *ai_next;
};
/*
* alias - label of the resource record.
* name - value (canonical name) of the resource record.
* See RFC2181 10.1.1. CNAME terminology.
*/
struct ares_addrinfo_cname {
int ttl;
char *alias;
char *name;
struct ares_addrinfo_cname *next;
};
struct ares_addrinfo {
struct ares_addrinfo_cname *cnames;
struct ares_addrinfo_node *nodes;
};
struct ares_addrinfo_hints {
int ai_flags;
int ai_family;
int ai_socktype;
int ai_protocol;
};
struct ares_addr_node {
struct ares_addr_node *next;
Expand Down Expand Up @@ -341,6 +385,15 @@
ares_sock_create_callback callback,
void *user_data);
void ares_getaddrinfo(ares_channel channel,
const char* node,
const char* service,
const struct ares_addrinfo_hints* hints,
ares_addrinfo_callback callback,
void* arg);
void ares_freeaddrinfo(struct ares_addrinfo* ai);
void ares_send(ares_channel channel,
const unsigned char *qbuf,
int qlen,
Expand Down Expand Up @@ -512,6 +565,10 @@
int timeouts,
unsigned char *abuf,
int alen);
extern "Python" void _addrinfo_cb(void *arg,
int status,
int timeouts,
struct ares_addrinfo *res);
"""

INCLUDES = """
Expand Down
89 changes: 89 additions & 0 deletions src/pycares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ def _query_cb(arg, status, timeouts, abuf, alen):

callback(result, status)

@_ffi.def_extern()
def _addrinfo_cb(arg, status, timeouts, res):
callback = _ffi.from_handle(arg)
_global_set.discard(arg)

if status != _lib.ARES_SUCCESS:
result = None
else:
result = ares_addrinfo_result(res)
status = None

callback(result, status)

def parse_result(query_type, abuf, alen):
if query_type == _lib.T_A:
addrttls = _ffi.new("struct ares_addrttl[]", PYCARES_ADDRTTL_SIZE)
Expand Down Expand Up @@ -508,6 +521,27 @@ def gethostbyname(self, name, family, callback):
_global_set.add(userdata)
_lib.ares_gethostbyname(self._channel[0], parse_name(name), family, _lib._host_cb, userdata)

def getaddrinfo(self, host, port, callback, family=0, type=0, proto=0, flags=0):
if not callable(callback):
raise TypeError("a callable is required")

if port is None:
service = _ffi.NULL
elif isinstance(port, int):
service = str(port).encode('ascii')
else:
service = ascii_bytes(port)

userdata = _ffi.new_handle(callback)
_global_set.add(userdata)

hints = _ffi.new('struct ares_addrinfo_hints*')
hints.ai_flags = flags
hints.ai_family = family
hints.ai_socktype = type
hints.ai_protocol = proto
_lib.ares_getaddrinfo(self._channel[0], parse_name(host), service, hints, _lib._addrinfo_cb, userdata)

def query(self, name, query_type, callback, query_class=None):
self._do_query(_lib.ares_query, name, query_type, callback, query_class=query_class)

Expand Down Expand Up @@ -735,6 +769,61 @@ def __init__(self, node, service):
self.service = maybe_str(_ffi.string(service)) if service != _ffi.NULL else None


class ares_addrinfo_node_result(AresResult):
__slots__ = ('ttl', 'flags', 'family', 'socktype', 'protocol', 'addr')

def __init__(self, ares_node):
self.ttl = ares_node.ai_ttl
self.flags = ares_node.ai_flags
self.socktype = ares_node.ai_socktype
self.protocol = ares_node.ai_protocol

addr = ares_node.ai_addr
assert addr.sa_family == ares_node.ai_family
ip = _ffi.new("char []", _lib.INET6_ADDRSTRLEN)
if addr.sa_family == socket.AF_INET:
self.family = socket.AF_INET
s = _ffi.cast("struct sockaddr_in*", addr)
if _ffi.NULL != _lib.ares_inet_ntop(s.sin_family, _ffi.addressof(s.sin_addr), ip, _lib.INET6_ADDRSTRLEN):
# (address, port) 2-tuple for AF_INET
self.addr = (_ffi.string(ip, _lib.INET6_ADDRSTRLEN), socket.ntohs(s.sin_port))
elif addr.sa_family == socket.AF_INET6:
self.family = socket.AF_INET6
s = _ffi.cast("struct sockaddr_in6*", addr)
if _ffi.NULL != _lib.ares_inet_ntop(s.sin6_family, _ffi.addressof(s.sin6_addr), ip, _lib.INET6_ADDRSTRLEN):
# (address, port, flow info, scope id) 4-tuple for AF_INET6
self.addr = (_ffi.string(ip, _lib.INET6_ADDRSTRLEN), socket.ntohs(s.sin6_port), s.sin6_flowinfo, s.sin6_scope_id)
else:
raise ValueError("invalid sockaddr family")


class ares_addrinfo_cname_result(AresResult):
__slots__ = ('ttl', 'alias', 'name')

def __init__(self, ares_cname):
self.ttl = ares_cname.ttl
self.alias = maybe_str(_ffi.string(ares_cname.alias))
self.name = maybe_str(_ffi.string(ares_cname.name))


class ares_addrinfo_result(AresResult):
__slots__ = ('cnames', 'nodes')

def __init__(self, ares_addrinfo):
self.cnames = []
self.nodes = []
cname_ptr = ares_addrinfo.cnames
while cname_ptr != _ffi.NULL:
self.cnames.append(ares_addrinfo_cname_result(cname_ptr))
cname_ptr = cname_ptr.next
node_ptr = ares_addrinfo.nodes
while node_ptr != _ffi.NULL:
self.nodes.append(ares_addrinfo_node_result(node_ptr))
node_ptr = node_ptr.ai_next
_lib.ares_freeaddrinfo(ares_addrinfo)



__all__ = exported_pycares_symbols + list(exported_pycares_symbols_map.keys()) + ['AresError', 'Channel', 'errno', '__version__']

del exported_pycares_symbols, exported_pycares_symbols_map
Expand Down
1 change: 1 addition & 0 deletions src/pycares/errno.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'ARES_ELOADIPHLPAPI',
'ARES_EADDRGETNETWORKPARAMS',
'ARES_ECANCELLED',
'ARES_ESERVICE'
]

errorcode = {}
Expand Down
63 changes: 63 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,69 @@ def assertNoError(self, errorno):
raise unittest.SkipTest('timeout')
self.assertEqual(errorno, None)

@unittest.skipIf(sys.platform == 'win32', 'skipped on Windows')
def test_getaddrinfo(self):
self.result, self.errorno = None, None
def cb(result, errorno):
self.result, self.errorno = result, errorno
self.channel.getaddrinfo('localhost', 80, cb)
self.wait()
self.assertNoError(self.errorno)
self.assertEqual(type(self.result), pycares.ares_addrinfo_result)
self.assertTrue(len(self.result.nodes) > 0)
for node in self.result.nodes:
self.assertEqual(node.addr[1], 80)

@unittest.skipIf(sys.platform == 'win32', 'skipped on Windows')
def test_getaddrinfo2(self):
self.result, self.errorno = None, None
def cb(result, errorno):
self.result, self.errorno = result, errorno
self.channel.getaddrinfo('localhost', 'http', cb)
self.wait()
self.assertNoError(self.errorno)
self.assertEqual(type(self.result), pycares.ares_addrinfo_result)
self.assertTrue(len(self.result.nodes) > 0)
for node in self.result.nodes:
self.assertEqual(node.addr[1], 80)

@unittest.skipIf(sys.platform == 'win32', 'skipped on Windows')
def test_getaddrinfo3(self):
self.result, self.errorno = None, None
def cb(result, errorno):
self.result, self.errorno = result, errorno
self.channel.getaddrinfo('localhost', None, cb)
self.wait()
self.assertNoError(self.errorno)
self.assertEqual(type(self.result), pycares.ares_addrinfo_result)
self.assertTrue(len(self.result.nodes) > 0)
for node in self.result.nodes:
self.assertEqual(node.addr[1], 0)

@unittest.skipIf(sys.platform == 'win32', 'skipped on Windows')
def test_getaddrinfo4(self):
self.result, self.errorno = None, None
def cb(result, errorno):
self.result, self.errorno = result, errorno
self.channel.getaddrinfo('localhost', 'http', cb, family=socket.AF_INET)
self.wait()
self.assertNoError(self.errorno)
self.assertEqual(type(self.result), pycares.ares_addrinfo_result)
self.assertEqual(len(self.result.nodes), 1)
node = self.result.nodes[0]
self.assertEqual(node.addr[0], b'127.0.0.1')
self.assertEqual(node.addr[1], 80)

def test_getaddrinfo5(self):
self.result, self.errorno = None, None
def cb(result, errorno):
self.result, self.errorno = result, errorno
self.channel.getaddrinfo('google.com', 'http', cb)
self.wait()
self.assertNoError(self.errorno)
self.assertEqual(type(self.result), pycares.ares_addrinfo_result)
self.assertTrue(len(self.result.nodes) > 0)

@unittest.skipIf(sys.platform == 'win32', 'skipped on Windows')
def test_gethostbyaddr(self):
self.result, self.errorno = None, None
Expand Down

0 comments on commit dfe00f3

Please sign in to comment.