-
Notifications
You must be signed in to change notification settings - Fork 0
/
orm.py
132 lines (111 loc) · 3.85 KB
/
orm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import functools
import sqlite3
db = None
DB_NAME = None
@functools.total_ordering
class Table(object):
_columns = ()
_table = None
def __init__(self, *args, **kwargs):
if len(args) > len(self._columns):
raise TypeError('Got %i arguments. Only %i expected'
% (len(args), len(self._columns)))
for k, v in zip(self._columns, args):
if k in kwargs:
raise TypeError("Got multiple values for keyword argument '%s'"
% k)
setattr(self, k, v)
for k, v in kwargs.iteritems():
if k not in self._columns:
raise TypeError("Unexpected keyword argument '%s'" % k)
setattr(self, k, v)
def __eq__(self, other):
if not isinstance(other, self.__class__):
raise NotImplemented
for key in self._columns:
if getattr(self, key) != getattr(other, key):
return False
return True
def __lt__(self, other):
if not isinstance(other, self.__class__):
raise NotImplemented
for key in self._columns:
if getattr(self, key) < getattr(other, key):
return True
return False
def save(self):
db = get_db()
query = 'REPLACE INTO "%s" (%s) VALUES (%s);' % (
self._table,
', '.join('"%s"' % column for column in self._columns),
', '.join('?' for column in self._columns))
c = db.cursor()
c.execute(query, [getattr(self, column) for column in self._columns])
db.commit()
@classmethod
def get_all_where(cls, where_clause, parameters=()):
db = get_db()
c = db.cursor()
c.execute(
'SELECT %s FROM "%s" WHERE %s;' % (
', '.join('"%s"' % column for column in cls._columns),
cls._table,
where_clause
), parameters)
for row in c:
yield cls(*row)
@classmethod
def get_where(cls, where_clause, parameters=()):
try:
return next(cls.get_all_where(where_clause, parameters))
except StopIteration:
return None
class Job(Table):
_columns = ('id', 'order_id', 'lang', 'source', 'translation', 'status')
_table = 'job'
@classmethod
def create_table(cls, cursor):
cursor.execute(
"""CREATE TABLE job (
id INTEGER PRIMARY KEY,
order_id INTEGER REFERENCES "order" (id),
lang TEXT,
source TEXT,
translation TEXT,
status TEXT
);""")
cursor.execute(
'CREATE INDEX job_lang_string ON job (lang, source);')
cursor.execute('CREATE INDEX job_status ON job (status);')
@classmethod
def find(cls, lang, source):
return cls.get_where('lang = ? AND source = ?', (lang, source))
@classmethod
def get_in_progress(cls):
return cls.get_all_where("status NOT IN ('approved', 'canceled')")
@classmethod
def get_reviewable(cls):
return cls.get_all_where("status = 'reviewable' ORDER BY lang, id")
class Order(Table):
_columns = ('id', 'created')
_table = 'order'
@classmethod
def create_table(cls, cursor):
cursor.execute(
"""CREATE TABLE "order"
(id INTEGER PRIMARY KEY, created INTEGER);""")
@classmethod
def get_latest(cls):
return cls.get_where('created = (SELECT MAX(created) FROM "order")')
def get_db():
global db, DB_NAME
if not db:
create_tables = not os.path.exists(DB_NAME)
db = sqlite3.connect(DB_NAME)
if create_tables:
c = db.cursor()
Order.create_table(c)
Job.create_table(c)
db.commit()
return db