Skip to content

Commit

Permalink
Add support for using SQLite as a DB (#271)
Browse files Browse the repository at this point in the history
* Implement support for Sqlite DB. Remove unused query method.

* Document how to create SQLite DB

* DB tests using SQLite

* [MegaLinter] Apply linters fixes

* Make sqlite-detection (by file-extension) a little more generic

* [MegaLinter] Apply linters fixes

* Script to convert docker/postgresql DB to sqlite

* Create a release with address_principals.sqllite

* Fix sqlite output filename

* Ignore the sqlite DB from git

* Fix a couple of documentation references to SQLite
  • Loading branch information
lyricnz authored Oct 1, 2023
1 parent d86081d commit d16cb18
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 64 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/publish-db-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,13 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
build-args: |
GNAF_LOADER_TAG=${{ steps.version.outputs.GNAF_LOADER_TAG }}
- name: Convert the Postgres DB to SQLite
run: ./extra/db/docker2sqlite.sh

- name: Release
uses: softprops/action-gh-release@v1
with:
tag_name: sqlite-db-${{ steps.version.outputs.GNAF_LOADER_TAG }}
body: SQLite DB for the cutdown version of the GNAF address database
files: address_principals.sqlite
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
cache
code/__pycache__
megalinter-reports/
address_principals.sqlite
147 changes: 83 additions & 64 deletions code/db.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
import itertools
import logging
import sqlite3
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace

import data
import psycopg2
from psycopg2.extras import NamedTupleCursor

SQLITE_FILE_EXTENSIONS = {"db", "sqlite", "sqlite3", "db3", "s3db", "sl3"}

class AddressDB:
"""Connect to the GNAF Postgres database and query for addresses. See https://github.com/minus34/gnaf-loader"""

def __init__(self, database: str, host: str, port: str, user: str, password: str, create_index: bool = True):
"""Connect to the database"""
conn = psycopg2.connect(
database=database, host=host, port=port, user=user, password=password, cursor_factory=NamedTupleCursor
)
class DbDriver(ABC):
"""Abstract class for DB connections."""

self.cur = conn.cursor()
@abstractmethod
def execute(self, query, vars=None):
"""Return a list of Namespace objects for the provided query."""
pass

# detect the schema used by the DB
self.cur.execute("SELECT schema_name FROM information_schema.schemata where schema_name like 'gnaf_%'")
db_schema = self.cur.fetchone().schema_name
self.cur.execute(f"SET search_path TO {db_schema}")
conn.commit()

# optionally create a DB index
if create_index:
logging.info("Creating DB index...")
self.cur.execute(
"CREATE INDEX IF NOT EXISTS address_name_state ON address_principals (locality_name, state)"
)
conn.commit()
class AddressDB:
"""Connect to our cut-down version of the GNAF Postgres database and query for addresses."""

def __init__(self, db: DbDriver):
self.db = db

def get_addresses(self, target_suburb: str, target_state: str) -> data.AddressList:
"""Return a list of Address for the provided suburb+state from the database."""
Expand All @@ -40,56 +33,27 @@ def get_addresses(self, target_suburb: str, target_state: str) -> data.AddressLi
WHERE locality_name = %s AND state = %s
LIMIT 100000"""

self.cur.execute(query, (target_suburb, target_state))

return [
data.Address(
name=f"{row.address} {target_suburb} {row.postcode}",
gnaf_pid=row.gnaf_pid,
longitude=float(row.longitude),
latitude=float(row.latitude),
)
for row in self.cur.fetchall()
for row in self.db.execute(query, (target_suburb, target_state))
]

def get_list_vs_total(self, suburbs_states: dict) -> dict:
"""Calculate which fraction of the entire dataset is represented by the given list of state+suburb."""
self.cur.execute("SELECT state, COUNT(*) FROM address_principals GROUP BY state")
states = {row.state: {"total": row.count} for row in self.cur.fetchall()}

query_parts = ["(state = %s AND locality_name IN %s)\n"] * len(suburbs_states)
values = [[state, tuple(suburbs)] for state, suburbs in suburbs_states.items()]
all_values = tuple(itertools.chain.from_iterable(values))

query = f"""
SELECT state, COUNT(*)
FROM address_principals
WHERE\n{" OR ".join(query_parts)}
GROUP BY state
"""
self.cur.execute(query, all_values) # takes ~2 minutes
for row in self.cur.fetchall():
states[row.state]["completed"] = row.count

# add a totals row
total_completed = sum(sp.get("completed", 0) for sp in states.values())
total = sum(sp.get("total", 0) for sp in states.values())
states["total"] = {"completed": total_completed, "total": total}

return states

def get_counts_by_suburb(self) -> dict[str, dict[str, int]]:
"""return a tally of addresses by state and suburb"""
query = """
SELECT locality_name, state, COUNT(*)
SELECT locality_name, state, COUNT(*) as count
FROM address_principals
GROUP BY locality_name, state
ORDER BY state, locality_name
"""
self.cur.execute(query)

results = {}
for record in self.cur.fetchall():
for record in self.db.execute(query):
if record.state not in results:
results[record.state] = {}
results[record.state][record.locality_name] = record.count
Expand All @@ -108,10 +72,9 @@ def get_extents_by_suburb(self) -> dict:
GROUP BY locality_name, state
ORDER BY state, locality_name
"""
self.cur.execute(query)

results = {}
for record in self.cur.fetchall():
for record in self.db.execute(query):
if record.state not in results:
results[record.state] = {}
results[record.state][record.locality_name] = (
Expand All @@ -131,7 +94,9 @@ def add_db_arguments(parser: ArgumentParser):
help="The password for the database user",
default="password",
)
parser.add_argument("-H", "--dbhost", help="The hostname for the database", default="localhost")
parser.add_argument(
"-H", "--dbhost", help="The hostname for the database (or file-path for Sqlite)", default="localhost"
)
parser.add_argument("-P", "--dbport", help="The port number for the database", default="5433")
parser.add_argument(
"-i",
Expand All @@ -141,13 +106,67 @@ def add_db_arguments(parser: ArgumentParser):
)


class PostgresDb(DbDriver):
"""Class that implements Postgresql DB connection."""

def __init__(self, database: str, host: str, port: str, user: str, password: str, create_index: bool = True):
"""Connect to the database"""
conn = psycopg2.connect(
database=database, host=host, port=port, user=user, password=password, cursor_factory=NamedTupleCursor
)

self.cur = conn.cursor()

# detect the schema used by the DB
self.cur.execute("SELECT schema_name FROM information_schema.schemata where schema_name like 'gnaf_%'")
db_schema = self.cur.fetchone().schema_name
self.cur.execute(f"SET search_path TO {db_schema}")
conn.commit()

# optionally create a DB index
if create_index:
logging.info("Creating DB index...")
self.cur.execute(
"CREATE INDEX IF NOT EXISTS address_name_state ON address_principals (locality_name, state)"
)
conn.commit()

def execute(self, query, vars=None):
"""Return a list of Namespace objects for the provided query."""
self.cur.execute(query, vars)
return self.cur.fetchall()


class SqliteDb(DbDriver):
"""Class that implements Sqlite DB connection (to a file). Pass the filename as the dbhost."""

def __init__(self, database_file: str):
"""Connect to the database"""
conn = sqlite3.connect(database_file)
conn.row_factory = sqlite3.Row
self.cur = conn.cursor()

def execute(self, query, vars=None):
"""Return a list of Namespace objects for the provided query."""
query = query.replace("%s", "?")
if vars is None:
vars = {}
self.cur.execute(query, vars)
# sqlite doesn't support NamedTupleCursor, so we need to manually add the column names
return [Namespace(**dict(zip(x.keys(), x))) for x in self.cur.fetchall()]


def connect_to_db(args: Namespace) -> AddressDB:
"""return a DB connection based on the provided args"""
return AddressDB(
"postgres",
args.dbhost,
args.dbport,
args.dbuser,
args.dbpassword,
args.create_index,
)
if args.dbhost.split(".")[-1] in SQLITE_FILE_EXTENSIONS:
db = SqliteDb(args.dbhost)
else:
db = PostgresDb(
"postgres",
args.dbhost,
args.dbport,
args.dbuser,
args.dbpassword,
args.create_index,
)
return AddressDB(db)
26 changes: 26 additions & 0 deletions extra/db/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ REPOSITORY TAG IMAGE ID CREATED SIZE
mydb latest 84af660a3493 39 seconds ago 3.73GB
minus34/gnafloader latest d2c552c72a0a 10 days ago 32GB
```
# Sqlite Version

To create a SQLite DB from the full CSV file (as used in the Dockerfile) use:

```
sqlite3 address_principals.db
CREATE TABLE address_principals
(
gnaf_pid text NOT NULL,
address text NOT NULL,
locality_name text NOT NULL,
postcode INTEGER NULL,
state text NOT NULL,
latitude numeric(10,8) NOT NULL,
longitude numeric(11,8) NOT NULL
);
CREATE INDEX address_name_state ON address_principals(locality_name, state);
.mode csv
.import address_principals.csv address_principals
.exit
```

This will create 1.5GB file (about 400MB is the index).

## References

Expand Down
42 changes: 42 additions & 0 deletions extra/db/docker2sqlite.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

set -ex

# Extract CSV from the DB if we don't have it already.
# It's also available as part of the docker-build process, but this is a bit more flexible.
CSV_FILENAME=address_principals.csv
if [ -f $CSV_FILENAME ]; then
echo "CSV file already exists, skipping extract..."
else
docker run -d --name db --publish=5433:5432 lukeprior/nbn-upgrade-map-db:latest
sleep 5 # it takes a few seconds to be ready
psql -h localhost -p 5433 -U postgres -c 'COPY gnaf_cutdown.address_principals TO stdout WITH CSV HEADER' > $CSV_FILENAME
docker rm -f db
fi

# Create a new sqlite DB with the contents of the CSV
DB_FILENAME=address_principals.sqlite
if [ -f $DB_FILENAME ]; then
echo "SQLite file $DB_FILENAME already exists, skipping creation..."
else
sqlite3 $DB_FILENAME <<EOF
CREATE TABLE address_principals
(
gnaf_pid text NOT NULL,
address text NOT NULL,
locality_name text NOT NULL,
postcode INTEGER NULL,
state text NOT NULL,
latitude numeric(10,8) NOT NULL,
longitude numeric(11,8) NOT NULL
);
CREATE INDEX address_name_state ON address_principals(locality_name, state);
.mode csv
.import $CSV_FILENAME address_principals
.exit
EOF

fi
26 changes: 26 additions & 0 deletions tests/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
To create sample data in SQLite use the following process:

- create empty DB per process described in DB:

```
sqlite3 tests/data/sample-addresses.db
-- create table and index per process described in DB
CREATE TABLE address_principals
(
gnaf_pid text NOT NULL,
address text NOT NULL,
locality_name text NOT NULL,
postcode INTEGER NULL,
state text NOT NULL,
latitude numeric(10,8) NOT NULL,
longitude numeric(11,8) NOT NULL
);
CREATE INDEX address_name_state ON address_principals(locality_name, state);
-- attach and import a subset of the data
attach database './extra/db/address_principals.db' as full_db;
INSERT INTO main.address_principals SELECT * FROM full_db.address_principals WHERE locality_name like '%SOMER%' ORDER BY RANDOM() LIMIT 100;
```

Binary file added tests/data/sample-addresses.sqlite
Binary file not shown.
49 changes: 49 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from argparse import ArgumentParser, Namespace

import db

SAMPLE_ADDRESSES_DB_FILE = f"{os.path.dirname(os.path.realpath(__file__))}/data/sample-addresses.sqlite"


def test_get_address():
address_db = db.connect_to_db(Namespace(dbhost=SAMPLE_ADDRESSES_DB_FILE))
addresses = address_db.get_addresses("SOMERVILLE", "VIC")
assert len(addresses) == 30
assert addresses[0].name == "83 GUELPH STREET SOMERVILLE 3912"
assert addresses[0].gnaf_pid == "GAVIC421048228"


def test_get_counts_by_suburb():
address_db = db.connect_to_db(Namespace(dbhost=SAMPLE_ADDRESSES_DB_FILE))
counts = address_db.get_counts_by_suburb()
assert counts["VIC"]["SOMERVILLE"] == 30
assert counts["VIC"]["SOMERS"] == 10
assert counts["VIC"]["SOMERTON"] == 1
assert len(counts["NSW"]) == 2
assert len(counts["SA"]) == 1
assert len(counts["TAS"]) == 1
assert len(counts["WA"]) == 1


def test_get_extents_by_suburb():
address_db = db.connect_to_db(Namespace(dbhost=SAMPLE_ADDRESSES_DB_FILE))
extents = address_db.get_extents_by_suburb()
assert extents["VIC"]["SOMERVILLE"] == (
(-38.23846838, 145.162399),
(-38.21306546, 145.22678832),
)


def test_add_db_arguments():
parser = ArgumentParser()
db.add_db_arguments(parser)
args = parser.parse_args([])
assert args.dbuser == "postgres"
assert args.dbpassword == "password"
assert args.dbhost == "localhost"
assert args.dbport == "5433"
assert args.create_index


# TODO: test postgres with mocks

0 comments on commit d16cb18

Please sign in to comment.