diff --git a/README.md b/README.md index a243c8c..5607be0 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ - [Filters](#filters) - [Utilities](#utilities) - [`inspect`](#inspect) +- [Associations](#associations) - [What is coming next?](#what-is-coming-next) - [Contributing](#contributing) - [License](#license) @@ -1167,6 +1168,10 @@ Output: The `inspect` function take the following arguments. +### Associations + +Eager + ### What is coming next? 1. Associations diff --git a/dataloom/keys.py b/dataloom/keys.py index a519747..6129f81 100644 --- a/dataloom/keys.py +++ b/dataloom/keys.py @@ -1,4 +1,4 @@ -push = False +push = True class PgConfig: diff --git a/dataloom/loom/subqueries.py b/dataloom/loom/subqueries.py index 963a9b6..5c6d3c9 100644 --- a/dataloom/loom/subqueries.py +++ b/dataloom/loom/subqueries.py @@ -37,15 +37,28 @@ def get_find_by_pk_relations(self, parent: Model, pk, includes: list[Include] = _, parent_pk_name, parent_fks, _ = get_table_fields( parent, dialect=self.dialect ) - _pk = relations[key][re.sub(r'`|"', "", parent_pk_name)] - - relations[key] = { - **relations[key], - **self.get_find_by_pk_relations( - include.model, _pk, includes=include.include - ), - } + if isinstance(relations[key], dict): + _pk = relations[key][re.sub(r'`|"', "", parent_pk_name)] + relations[key] = { + **relations[key], + **self.get_find_by_pk_relations( + include.model, _pk, includes=include.include + ), + } + else: + _pk = ( + relations[key][0][re.sub(r'`|"', "", parent_pk_name)] + if len(relations[key]) != 0 + else None + ) + if _pk is not None: + relations[key] = { + **relations[key], + **self.get_find_by_pk_relations( + include.model, _pk, includes=include.include + ), + } return relations def get_one( diff --git a/dataloom/tests/mysql/test_eager_loading_mysql.py b/dataloom/tests/mysql/test_eager_loading_mysql.py new file mode 100644 index 0000000..5505b5b --- /dev/null +++ b/dataloom/tests/mysql/test_eager_loading_mysql.py @@ -0,0 +1,239 @@ +class TestEagerLoadingOnMySQL: + def test_find_by_pk(self): + from dataloom import ( + Dataloom, + Model, + Column, + PrimaryKeyColumn, + CreatedAtColumn, + TableColumn, + ForeignKeyColumn, + ColumnValue, + Include, + Order, + ) + from dataloom.keys import MySQLConfig + + mysql_loom = Dataloom( + dialect="mysql", + database=MySQLConfig.database, + password=MySQLConfig.password, + user=MySQLConfig.user, + ) + + class User(Model): + __tablename__: TableColumn = TableColumn(name="users") + id = PrimaryKeyColumn(type="int", auto_increment=True) + name = Column(type="text", nullable=False, default="Bob") + username = Column(type="varchar", unique=True, length=255) + tokenVersion = Column(type="int", default=0) + + class Profile(Model): + __tablename__: TableColumn = TableColumn(name="profiles") + id = PrimaryKeyColumn(type="int", auto_increment=True) + avatar = Column(type="text", nullable=False) + userId = ForeignKeyColumn( + User, + maps_to="1-1", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + class Post(Model): + __tablename__: TableColumn = TableColumn(name="posts") + id = PrimaryKeyColumn( + type="int", auto_increment=True, nullable=False, unique=True + ) + completed = Column(type="boolean", default=False) + title = Column(type="varchar", length=255, nullable=False) + # timestamps + createdAt = CreatedAtColumn() + # relations + userId = ForeignKeyColumn( + User, + maps_to="1-N", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + class Category(Model): + __tablename__: TableColumn = TableColumn(name="categories") + id = PrimaryKeyColumn( + type="int", auto_increment=True, nullable=False, unique=True + ) + type = Column(type="varchar", length=255, nullable=False) + + postId = ForeignKeyColumn( + Post, + maps_to="N-1", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + conn, tables = mysql_loom.connect_and_sync( + [User, Profile, Post, Category], drop=True, force=True + ) + + userId = mysql_loom.insert_one( + instance=User, + values=ColumnValue(name="username", value="@miller"), + ) + + userId2 = mysql_loom.insert_one( + instance=User, + values=ColumnValue(name="username", value="bob"), + ) + + profileId = mysql_loom.insert_one( + instance=Profile, + values=[ + ColumnValue(name="userId", value=userId), + ColumnValue(name="avatar", value="hello.jpg"), + ], + ) + for title in ["Hey", "Hello", "What are you doing", "Coding"]: + mysql_loom.insert_one( + instance=Post, + values=[ + ColumnValue(name="userId", value=userId), + ColumnValue(name="title", value=title), + ], + ) + + for cat in ["general", "education", "tech", "sport"]: + mysql_loom.insert_one( + instance=Category, + values=[ + ColumnValue(name="postId", value=1), + ColumnValue(name="type", value=cat), + ], + ) + + profile = mysql_loom.find_by_pk( + instance=Profile, + pk=profileId, + include=[ + Include( + model=User, select=["id", "username", "tokenVersion"], has="one" + ) + ], + ) + assert profile == { + "avatar": "hello.jpg", + "id": 1, + "userId": 1, + "user": {"id": 1, "username": "@miller", "tokenVersion": 0}, + } + + user = mysql_loom.find_by_pk( + instance=User, + pk=userId, + include=[Include(model=Profile, select=["id", "avatar"], has="one")], + ) + assert user == { + "id": 1, + "name": "Bob", + "tokenVersion": 0, + "username": "@miller", + "profile": {"id": 1, "avatar": "hello.jpg"}, + } + + user = mysql_loom.find_by_pk( + instance=User, + pk=userId, + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + offset=0, + limit=2, + order=[ + Order(column="createdAt", order="DESC"), + Order(column="id", order="DESC"), + ], + ), + Include(model=Profile, select=["id", "avatar"], has="one"), + ], + ) + assert user == { + "id": 1, + "name": "Bob", + "tokenVersion": 0, + "username": "@miller", + "posts": [ + {"id": 4, "title": "Coding"}, + {"id": 3, "title": "What are you doing"}, + ], + "profile": {"id": 1, "avatar": "hello.jpg"}, + } + + post = mysql_loom.find_by_pk( + instance=Post, + pk=1, + select=["title", "id"], + include=[ + Include( + model=User, + select=["id", "username"], + has="one", + include=[ + Include(model=Profile, select=["avatar", "id"], has="one") + ], + ), + Include( + model=Category, + select=["id", "type"], + has="many", + order=[Order(column="id", order="DESC")], + ), + ], + ) + + assert post == { + "title": "Hey", + "id": 1, + "user": { + "id": 1, + "username": "@miller", + "profile": {"avatar": "hello.jpg", "id": 1}, + }, + "categories": [ + {"id": 4, "type": "sport"}, + {"id": 3, "type": "tech"}, + {"id": 2, "type": "education"}, + {"id": 1, "type": "general"}, + ], + } + + user = mysql_loom.find_by_pk( + instance=User, + pk=userId2, + select=["username", "id"], + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + include=[ + Include( + model=Category, + select=["type", "id"], + has="many", + order=[Order(column="id", order="DESC")], + limit=2, + offset=0, + ) + ], + ), + ], + ) + assert user == {"username": "bob", "id": 2, "posts": []} + + conn.close() diff --git a/dataloom/tests/postgres/test_eager_loading_pg.py b/dataloom/tests/postgres/test_eager_loading_pg.py new file mode 100644 index 0000000..4ed4609 --- /dev/null +++ b/dataloom/tests/postgres/test_eager_loading_pg.py @@ -0,0 +1,239 @@ +class TestEagerLoadingOnPG: + def test_find_by_pk(self): + from dataloom import ( + Dataloom, + Model, + Column, + PrimaryKeyColumn, + CreatedAtColumn, + TableColumn, + ForeignKeyColumn, + ColumnValue, + Include, + Order, + ) + from dataloom.keys import PgConfig + + pg_loom = Dataloom( + dialect="postgres", + database=PgConfig.database, + password=PgConfig.password, + user=PgConfig.user, + ) + + class User(Model): + __tablename__: TableColumn = TableColumn(name="users") + id = PrimaryKeyColumn(type="int", auto_increment=True) + name = Column(type="text", nullable=False, default="Bob") + username = Column(type="varchar", unique=True, length=255) + tokenVersion = Column(type="int", default=0) + + class Profile(Model): + __tablename__: TableColumn = TableColumn(name="profiles") + id = PrimaryKeyColumn(type="int", auto_increment=True) + avatar = Column(type="text", nullable=False) + userId = ForeignKeyColumn( + User, + maps_to="1-1", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + class Post(Model): + __tablename__: TableColumn = TableColumn(name="posts") + id = PrimaryKeyColumn( + type="int", auto_increment=True, nullable=False, unique=True + ) + completed = Column(type="boolean", default=False) + title = Column(type="varchar", length=255, nullable=False) + # timestamps + createdAt = CreatedAtColumn() + # relations + userId = ForeignKeyColumn( + User, + maps_to="1-N", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + class Category(Model): + __tablename__: TableColumn = TableColumn(name="categories") + id = PrimaryKeyColumn( + type="int", auto_increment=True, nullable=False, unique=True + ) + type = Column(type="varchar", length=255, nullable=False) + + postId = ForeignKeyColumn( + Post, + maps_to="N-1", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + conn, tables = pg_loom.connect_and_sync( + [User, Profile, Post, Category], drop=True, force=True + ) + + userId = pg_loom.insert_one( + instance=User, + values=ColumnValue(name="username", value="@miller"), + ) + + userId2 = pg_loom.insert_one( + instance=User, + values=ColumnValue(name="username", value="bob"), + ) + + profileId = pg_loom.insert_one( + instance=Profile, + values=[ + ColumnValue(name="userId", value=userId), + ColumnValue(name="avatar", value="hello.jpg"), + ], + ) + for title in ["Hey", "Hello", "What are you doing", "Coding"]: + pg_loom.insert_one( + instance=Post, + values=[ + ColumnValue(name="userId", value=userId), + ColumnValue(name="title", value=title), + ], + ) + + for cat in ["general", "education", "tech", "sport"]: + pg_loom.insert_one( + instance=Category, + values=[ + ColumnValue(name="postId", value=1), + ColumnValue(name="type", value=cat), + ], + ) + + profile = pg_loom.find_by_pk( + instance=Profile, + pk=profileId, + include=[ + Include( + model=User, select=["id", "username", "tokenVersion"], has="one" + ) + ], + ) + assert profile == { + "avatar": "hello.jpg", + "id": 1, + "userId": 1, + "user": {"id": 1, "username": "@miller", "tokenVersion": 0}, + } + + user = pg_loom.find_by_pk( + instance=User, + pk=userId, + include=[Include(model=Profile, select=["id", "avatar"], has="one")], + ) + assert user == { + "id": 1, + "name": "Bob", + "tokenVersion": 0, + "username": "@miller", + "profile": {"id": 1, "avatar": "hello.jpg"}, + } + + user = pg_loom.find_by_pk( + instance=User, + pk=userId, + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + offset=0, + limit=2, + order=[ + Order(column="createdAt", order="DESC"), + Order(column="id", order="DESC"), + ], + ), + Include(model=Profile, select=["id", "avatar"], has="one"), + ], + ) + assert user == { + "id": 1, + "name": "Bob", + "tokenVersion": 0, + "username": "@miller", + "posts": [ + {"id": 4, "title": "Coding"}, + {"id": 3, "title": "What are you doing"}, + ], + "profile": {"id": 1, "avatar": "hello.jpg"}, + } + + post = pg_loom.find_by_pk( + instance=Post, + pk=1, + select=["title", "id"], + include=[ + Include( + model=User, + select=["id", "username"], + has="one", + include=[ + Include(model=Profile, select=["avatar", "id"], has="one") + ], + ), + Include( + model=Category, + select=["id", "type"], + has="many", + order=[Order(column="id", order="DESC")], + ), + ], + ) + + assert post == { + "title": "Hey", + "id": 1, + "user": { + "id": 1, + "username": "@miller", + "profile": {"avatar": "hello.jpg", "id": 1}, + }, + "categories": [ + {"id": 4, "type": "sport"}, + {"id": 3, "type": "tech"}, + {"id": 2, "type": "education"}, + {"id": 1, "type": "general"}, + ], + } + + user = pg_loom.find_by_pk( + instance=User, + pk=userId2, + select=["username", "id"], + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + include=[ + Include( + model=Category, + select=["type", "id"], + has="many", + order=[Order(column="id", order="DESC")], + limit=2, + offset=0, + ) + ], + ), + ], + ) + assert user == {"username": "bob", "id": 2, "posts": []} + + conn.close() diff --git a/dataloom/tests/sqlite3/test_eager_loading_sqlite.py b/dataloom/tests/sqlite3/test_eager_loading_sqlite.py new file mode 100644 index 0000000..b47ef0d --- /dev/null +++ b/dataloom/tests/sqlite3/test_eager_loading_sqlite.py @@ -0,0 +1,233 @@ +class TestEagerLoadingOnSQLite: + def test_find_by_pk(self): + from dataloom import ( + Dataloom, + Model, + Column, + PrimaryKeyColumn, + CreatedAtColumn, + TableColumn, + ForeignKeyColumn, + ColumnValue, + Include, + Order, + ) + + sqlite_loom = Dataloom(dialect="sqlite", database="hi.db") + + class User(Model): + __tablename__: TableColumn = TableColumn(name="users") + id = PrimaryKeyColumn(type="int", auto_increment=True) + name = Column(type="text", nullable=False, default="Bob") + username = Column(type="varchar", unique=True, length=255) + tokenVersion = Column(type="int", default=0) + + class Profile(Model): + __tablename__: TableColumn = TableColumn(name="profiles") + id = PrimaryKeyColumn(type="int", auto_increment=True) + avatar = Column(type="text", nullable=False) + userId = ForeignKeyColumn( + User, + maps_to="1-1", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + class Post(Model): + __tablename__: TableColumn = TableColumn(name="posts") + id = PrimaryKeyColumn( + type="int", auto_increment=True, nullable=False, unique=True + ) + completed = Column(type="boolean", default=False) + title = Column(type="varchar", length=255, nullable=False) + # timestamps + createdAt = CreatedAtColumn() + # relations + userId = ForeignKeyColumn( + User, + maps_to="1-N", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + class Category(Model): + __tablename__: TableColumn = TableColumn(name="categories") + id = PrimaryKeyColumn( + type="int", auto_increment=True, nullable=False, unique=True + ) + type = Column(type="varchar", length=255, nullable=False) + + postId = ForeignKeyColumn( + Post, + maps_to="N-1", + type="int", + required=True, + onDelete="CASCADE", + onUpdate="CASCADE", + ) + + conn, tables = sqlite_loom.connect_and_sync( + [User, Profile, Post, Category], drop=True, force=True + ) + + userId = sqlite_loom.insert_one( + instance=User, + values=ColumnValue(name="username", value="@miller"), + ) + + userId2 = sqlite_loom.insert_one( + instance=User, + values=ColumnValue(name="username", value="bob"), + ) + + profileId = sqlite_loom.insert_one( + instance=Profile, + values=[ + ColumnValue(name="userId", value=userId), + ColumnValue(name="avatar", value="hello.jpg"), + ], + ) + for title in ["Hey", "Hello", "What are you doing", "Coding"]: + sqlite_loom.insert_one( + instance=Post, + values=[ + ColumnValue(name="userId", value=userId), + ColumnValue(name="title", value=title), + ], + ) + + for cat in ["general", "education", "tech", "sport"]: + sqlite_loom.insert_one( + instance=Category, + values=[ + ColumnValue(name="postId", value=1), + ColumnValue(name="type", value=cat), + ], + ) + + profile = sqlite_loom.find_by_pk( + instance=Profile, + pk=profileId, + include=[ + Include( + model=User, select=["id", "username", "tokenVersion"], has="one" + ) + ], + ) + assert profile == { + "avatar": "hello.jpg", + "id": 1, + "userId": 1, + "user": {"id": 1, "username": "@miller", "tokenVersion": 0}, + } + + user = sqlite_loom.find_by_pk( + instance=User, + pk=userId, + include=[Include(model=Profile, select=["id", "avatar"], has="one")], + ) + assert user == { + "id": 1, + "name": "Bob", + "tokenVersion": 0, + "username": "@miller", + "profile": {"id": 1, "avatar": "hello.jpg"}, + } + + user = sqlite_loom.find_by_pk( + instance=User, + pk=userId, + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + offset=0, + limit=2, + order=[ + Order(column="createdAt", order="DESC"), + Order(column="id", order="DESC"), + ], + ), + Include(model=Profile, select=["id", "avatar"], has="one"), + ], + ) + assert user == { + "id": 1, + "name": "Bob", + "tokenVersion": 0, + "username": "@miller", + "posts": [ + {"id": 4, "title": "Coding"}, + {"id": 3, "title": "What are you doing"}, + ], + "profile": {"id": 1, "avatar": "hello.jpg"}, + } + + post = sqlite_loom.find_by_pk( + instance=Post, + pk=1, + select=["title", "id"], + include=[ + Include( + model=User, + select=["id", "username"], + has="one", + include=[ + Include(model=Profile, select=["avatar", "id"], has="one") + ], + ), + Include( + model=Category, + select=["id", "type"], + has="many", + order=[Order(column="id", order="DESC")], + ), + ], + ) + + assert post == { + "title": "Hey", + "id": 1, + "user": { + "id": 1, + "username": "@miller", + "profile": {"avatar": "hello.jpg", "id": 1}, + }, + "categories": [ + {"id": 4, "type": "sport"}, + {"id": 3, "type": "tech"}, + {"id": 2, "type": "education"}, + {"id": 1, "type": "general"}, + ], + } + + user = sqlite_loom.find_by_pk( + instance=User, + pk=userId2, + select=["username", "id"], + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + include=[ + Include( + model=Category, + select=["type", "id"], + has="many", + order=[Order(column="id", order="DESC")], + limit=2, + offset=0, + ) + ], + ), + ], + ) + assert user == {"username": "bob", "id": 2, "posts": []} + + conn.close() diff --git a/hi.db b/hi.db index 887375c..bcfadd7 100644 Binary files a/hi.db and b/hi.db differ diff --git a/playground.py b/playground.py index 79d7b55..2de66f6 100644 --- a/playground.py +++ b/playground.py @@ -194,3 +194,29 @@ class Category(Model): ) print(post) + + +user = mysql_loom.find_by_pk( + instance=User, + pk=userId2, + select=["username", "id"], + include=[ + Include( + model=Post, + select=["id", "title"], + has="many", + include=[ + Include( + model=Category, + select=["type", "id"], + has="many", + order=[Order(column="id", order="DESC")], + limit=2, + offset=0, + ) + ], + ), + ], +) + +print(user)