Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language classifier API #66

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ expand_address('Quatre vingt douze Ave des Champs-Élysées')

from postal.parser import parse_address
parse_address('The Book Club 100-106 Leonard St, Shoreditch, London, Greater London, EC2A 4RH, United Kingdom')

from postal.lang_classifier import classify_lang_address
classify_lang_address('Quatre vingt douze Ave des Champs-Élysées')
```

Installation
Expand Down Expand Up @@ -41,7 +44,7 @@ git clone https://github.com/openvenues/libpostal
cd libpostal
./bootstrap.sh
./configure --datadir=[...some dir with a few GB of space...]
make
make -j4
sudo make install

# On Linux it's probably a good idea to run
Expand Down
15 changes: 15 additions & 0 deletions postal/lang_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Python bindings to libpostal parse_address."""
from postal import _langclassifier
from postal.utils.encoding import safe_decode

def classify_lang_address(address):
"""
Classify the language of an address.

@param address: the address as either Unicode or a UTF-8 encoded string
"""
address = safe_decode(address, 'utf-8')
try:
return _langclassifier.classify_lang_address(address)
except SystemError:
return None
163 changes: 163 additions & 0 deletions postal/pylangclassifier.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include <Python.h>
#include <libpostal/libpostal.h>
#include "pyutils.h"

#if PY_MAJOR_VERSION >= 3
#define IS_PY3K
#endif

struct module_state {
PyObject *error;
};


typedef struct language_classifier_response {
Py_ssize_t num_languages;
char **languages;
double *probs;
} language_classifier_response_t;

#ifdef IS_PY3K
#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
#else
#define GETSTATE(m) (&_state)
static struct module_state _state;
#endif


static PyObject *py_classify_lang_address(PyObject *self, PyObject *args, PyObject *keywords) {
PyObject *arg_input;

PyObject *result = NULL;

if (!PyArg_ParseTuple(args, "O:pylangclassifier", &arg_input)) {
return 0;
}

char *input = PyObject_to_string(arg_input);

if (input == NULL) {
return NULL;
}

language_classifier_response_t *response = libpostal_classify_language(input);

if (response == NULL) {
goto exit_free_input;
}

result = PyList_New((Py_ssize_t)response->num_languages);
if (!result) {
goto exit_destroy_response;
}

for (int i = 0; i < response->num_languages; i++) {
char *language = response->languages[i];
double prob = response->probs[i];
PyObject *language_unicode = PyUnicode_DecodeUTF8((const char *)language, strlen(language), "strict");
if (language_unicode == NULL) {
Py_DECREF(result);
goto exit_destroy_response;
}

PyObject *tuple = Py_BuildValue("(Od)", language_unicode, prob);
if (tuple == NULL) {
Py_DECREF(language_unicode);
goto exit_destroy_response;
}

// Note: PyList_SetItem steals a reference, so don't worry about DECREF
PyList_SetItem(result, (Py_ssize_t)i, tuple);

Py_DECREF(language_unicode);
}

exit_destroy_response:
libpostal_language_classifier_response_destroy(response);
exit_free_input:
if (input != NULL) {
free(input);
}
return result;
}

static PyMethodDef langclassifier_methods[] = {
{"classify_lang_address", (PyCFunction)py_classify_lang_address, METH_VARARGS | METH_KEYWORDS,
"classify_lang_address(text)"},
{NULL, NULL},
};

#ifdef IS_PY3K

static int langclassifier_traverse(PyObject *m, visitproc visit, void *arg) {
Py_VISIT(GETSTATE(m)->error);
return 0;
}

static int langclassifier_clear(PyObject *m) {
Py_CLEAR(GETSTATE(m)->error);
libpostal_teardown();
libpostal_teardown_language_classifier();
return 0;
}

static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"_langclassifier",
NULL,
sizeof(struct module_state),
langclassifier_methods,
NULL,
langclassifier_traverse,
langclassifier_clear,
NULL
};

#define INITERROR return NULL

PyObject *
PyInit__langclassifier(void) {
#else

#define INITERROR return

void cleanup_libpostal(void) {
libpostal_teardown();
libpostal_teardown_language_classifier();
}

void init_langclassifier(void) {
#endif

#ifdef IS_PY3K
PyObject *module = PyModule_Create(&module_def);
#else
PyObject *module = Py_InitModule("_langclassifier", langclassifier_methods);
#endif

if (module == NULL) {
INITERROR;
}
struct module_state *st = GETSTATE(module);

st->error = PyErr_NewException("_langclassifier.Error", NULL, NULL);
if (st->error == NULL) {
Py_DECREF(module);
INITERROR;
}


if (!libpostal_setup() || !libpostal_setup_language_classifier()) {
PyErr_SetString(PyExc_TypeError,
"Error loading libpostal data");
}

#ifndef IS_PY3K
Py_AtExit(&cleanup_libpostal);
#endif


#ifdef IS_PY3K
return module;
#endif
}
28 changes: 28 additions & 0 deletions postal/tests/test_lang_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""Test pypostal address parsing."""

from __future__ import unicode_literals

import unittest
from postal.lang_classifier import classify_lang_address


class TestLangClassfier(unittest.TestCase):
"""Test libpostal language classifier from Python."""
def test_parses(self):
cases = (
('Rua casemiro osorio, 123', {'pt': 1.0}),
('Street Oudenoord, 1234', {'en': 0.76, 'nl': 0.23}),
('Oudenoord, 1234', {'nl': 1.0})
)

"""Language classifier tests."""
for address, lang_expected in cases:
lang = classify_lang_address(address)
# Round probabilities
lang = {k: round(v, 2) for k, v in lang}
self.assertDictEqual(lang, lang_expected)


if __name__ == '__main__':
unittest.main()
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def main():
library_dirs=['/usr/local/lib'],
extra_compile_args=['-std=c99'],
),
Extension('postal._langclassifier',
sources=['postal/pylangclassifier.c', 'postal/pyutils.c'],
libraries=['postal'],
include_dirs=['/usr/local/include'],
library_dirs=['/usr/local/lib'],
extra_compile_args=['-std=c99'],
),
Extension('postal._token_types',
sources=['postal/pytokentypes.c'],
libraries=['postal'],
Expand Down