diff --git a/postgres/__init__.py b/postgres/__init__.py index d16bd7c..9e46364 100644 --- a/postgres/__init__.py +++ b/postgres/__init__.py @@ -180,6 +180,7 @@ from postgres.cursors import SimpleTupleCursor, SimpleNamedTupleCursor from postgres.cursors import SimpleDictCursor, SimpleCursorBase from postgres.orm import Model +from psycopg2 import DataError from psycopg2.extras import register_composite, CompositeCaster from psycopg2.pool import ThreadedConnectionPool as ConnectionPool @@ -824,7 +825,22 @@ def make_DelegatingCaster(postgres): """ class DelegatingCaster(CompositeCaster): + + def parse(self, s, curs, retry=True): + # Override to protect against race conditions: + # https://github.com/gratipay/postgres.py/issues/26 + + try: + return super(DelegatingCaster, self).parse(s, curs) + except (DataError, ValueError): + if not retry: + raise + # Re-fetch the type info and retry once + self._refetch_type_info(curs) + return self.parse(s, curs, False) + def make(self, values): + # Override to delegate to the model registry. if self.name not in postgres.model_registry: # This is probably a bug, not a normal user error. It means @@ -838,6 +854,12 @@ def make(self, values): instance = ModelSubclass(record) return instance + def _refetch_type_info(self, curs): + """Given a cursor, update the current object with a fresh type definition. + """ + new_self = self._from_db(self.name, curs) + self.__dict__.update(new_self.__dict__) + return DelegatingCaster diff --git a/tests.py b/tests.py index 1f628d7..fee7b4f 100644 --- a/tests.py +++ b/tests.py @@ -8,7 +8,7 @@ from postgres.cursors import TooFew, TooMany, SimpleDictCursor from postgres.orm import ReadOnly, Model from psycopg2 import InterfaceError, ProgrammingError -from pytest import raises +from pytest import mark, raises DATABASE_URL = os.environ['DATABASE_URL'] @@ -334,6 +334,32 @@ def test_unregister_unregisters_multiple(self): self.db.unregister_model(self.MyModel) assert self.db.model_registry == {} + def test_add_column_doesnt_break_anything(self): + self.db.run("ALTER TABLE foo ADD COLUMN boo text") + one = self.db.one("SELECT foo.*::foo FROM foo WHERE bar='baz'") + assert one.boo is None + + def test_replace_column_different_type(self): + self.db.run("CREATE TABLE grok (bar int)") + self.db.run("INSERT INTO grok VALUES (0)") + class EmptyModel(Model): pass + self.db.register_model(EmptyModel, 'grok') + # Add a new column then drop the original one + self.db.run("ALTER TABLE grok ADD COLUMN biz text NOT NULL DEFAULT 'x'") + self.db.run("ALTER TABLE grok DROP COLUMN bar") + # The number of columns hasn't changed but the names and types have + one = self.db.one("SELECT grok.*::grok FROM grok LIMIT 1") + assert one.biz == 'x' + assert not hasattr(one, 'bar') + + @mark.xfail(raises=AttributeError) + def test_replace_column_same_type_different_name(self): + self.db.run("ALTER TABLE foo ADD COLUMN biz text NOT NULL DEFAULT 0") + self.db.run("ALTER TABLE foo DROP COLUMN bar") + one = self.db.one("SELECT foo.*::foo FROM foo LIMIT 1") + assert one.biz == 0 + assert not hasattr(one, 'bar') + # cursor_factory # ==============