Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Apr 6, 2024
1 parent 45a62ca commit 250966b
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 109 deletions.
3 changes: 1 addition & 2 deletions examples/dataloader/tf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import fog_rtx

dataset = fog_rtx.dataset.Dataset(
Expand All @@ -25,4 +24,4 @@

# get samples from the dataset
for data in tf_ds:
print(data)
print(data)
2 changes: 1 addition & 1 deletion examples/rtx_example/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
split="train[:1]",
)

dataset.export("/tmp/rtx_export", format="rtx")
dataset.export("/tmp/rtx_export", format="rtx")
1 change: 1 addition & 0 deletions fog_rtx/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fog_rtx.database.db_connector import DatabaseConnector
from fog_rtx.database.db_manager import DatabaseManager
from fog_rtx.database.polars_connector import PolarsConnector

# from fog_rtx.db.postgres import Postgres

__all__ = ["DatabaseConnector", "DatabaseManager"]
23 changes: 18 additions & 5 deletions fog_rtx/database/db_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@ def insert_data(
)
return insert_result.inserted_primary_key[0]

def update_data(self, table_name: str, index: int, data: dict, create_new_column_if_not_exist: bool = False, is_partial_data: bool = False):
def update_data(
self,
table_name: str,
index: int,
data: dict,
create_new_column_if_not_exist: bool = False,
is_partial_data: bool = False,
):
metadata = MetaData()
table = Table(table_name, metadata, autoload_with=self.engine)

Expand All @@ -101,11 +108,17 @@ def update_data(self, table_name: str, index: int, data: dict, create_new_colum
# Column(key, type_py2sql(type(value)), nullable=True),
# )
column_type = type_py2sql(type(value))
self.add_column(table_name, Column(key, column_type, nullable=True))
self.add_column(
table_name, Column(key, column_type, nullable=True)
)
metadata.clear()
table = Table(table_name, metadata, autoload_with=self.engine)
logger.info(f"Successfully added column {key} to {table_name}")

table = Table(
table_name, metadata, autoload_with=self.engine
)
logger.info(
f"Successfully added column {key} to {table_name}"
)

# if is_partial_data:
# self.engine.execute(
# table.update().where(table.c.id == index).values(**data)
Expand Down
47 changes: 30 additions & 17 deletions fog_rtx/database/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def initialize_dataset(
dtype=row["Type"], shape=row["Shape"]
)
self.features[row["Feature"]] = feature_type
logger.info(f"Loaded Features: {self.features} with type {feature_type}")
logger.info(
f"Loaded Features: {self.features} with type {feature_type}"
)

else:
self.db_connector.create_table(
Expand All @@ -61,17 +63,19 @@ def initialize_episode(
raise ValueError("Dataset not initialized")
if metadata is None:
metadata = {}
logger.info(f"Initializing episode for dataset {self.dataset_name} with metadata {metadata}")
logger.info(
f"Initializing episode for dataset {self.dataset_name} with metadata {metadata}"
)

metadata["episode_id"] = self.current_episode_id
metadata["episode_id"] = self.current_episode_id
metadata["Compacted"] = False

for metadata_key in metadata.keys():
logger.info(f"Adding metadata key {metadata_key} to the database")
self.db_connector.add_column(
self.dataset_name,
metadata_key,
"str", # TODO: support more types
"str", # TODO: support more types
)

# insert episode information to the database
Expand All @@ -98,9 +102,11 @@ def add(
f"Feature {feature_name} not in the list of features"
)
if feature_type is None:

feature_type = FeatureType(data = value)
logger.warn(f"feature type not provided, inferring from data type {feature_type}")

feature_type = FeatureType(data=value)
logger.warn(
f"feature type not provided, inferring from data type {feature_type}"
)
self._initialize_feature(feature_name, feature_type)

# insert data into the table
Expand Down Expand Up @@ -150,24 +156,28 @@ def get_episode_table(self, episode_id, format: str = "pandas"):
format=format,
)

def _initialize_feature(self, feature_name: str, feature_type: FeatureType):
def _initialize_feature(
self, feature_name: str, feature_type: FeatureType
):
# create a table for the feature
# TODO: need to make the timestamp type as TIMESTAMPTZ
self.db_connector.create_table(
self._get_feature_table_name(feature_name),
)
# {"Timestamp": "int64", feature_name: feature_type.to_sql_type()}

self.db_connector.add_column(
self._get_feature_table_name(feature_name),
"Timestamp",
"int64",
)
logger.info(f"Adding feature {feature_name} to the database with type {feature_type.to_pld_storage_type()}")
logger.info(
f"Adding feature {feature_name} to the database with type {feature_type.to_pld_storage_type()}"
)
self.db_connector.add_column(
self._get_feature_table_name(feature_name),
feature_name,
feature_type.to_pld_storage_type(), #TODO: placeholder
feature_type.to_pld_storage_type(), # TODO: placeholder
)

if feature_type is None:
Expand All @@ -188,15 +198,18 @@ def _initialize_feature(self, feature_name: str, feature_type: FeatureType):
"str",
)
self.db_connector.update_data(
table_name = self.dataset_name,
index = self.current_episode_id,
data = {
f"feature_{feature_name}_type": str(self.features[feature_name].dtype),
f"feature_{feature_name}_shape": str(self.features[feature_name].shape),
table_name=self.dataset_name,
index=self.current_episode_id,
data={
f"feature_{feature_name}_type": str(
self.features[feature_name].dtype
),
f"feature_{feature_name}_shape": str(
self.features[feature_name].shape
),
},
)


def _get_feature_table_name(self, feature_name):
if self.dataset_name is None:
raise ValueError("Dataset not initialized")
Expand Down
57 changes: 43 additions & 14 deletions fog_rtx/database/polars_connector.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import logging
from typing import List

import polars as pl

from fog_rtx.database.utils import _datasets_dtype_to_pld

logger = logging.getLogger(__name__)


class PolarsConnector:
def __init__(self, path: str):
# In Polars, data is directly read into DataFrames from files, not through a database engine
self.path = path
self.tables = {} # This will store table names as keys and DataFrames as values
self.tables = (
{}
) # This will store table names as keys and DataFrames as values

def close(self):
# No connection to close in Polars, but we could clear the tables dictionary
for table_name in self.tables.keys():
self.tables[table_name].write_parquet(f"{self.path}/{table_name}.parquet")
self.tables[table_name].write_parquet(
f"{self.path}/{table_name}.parquet"
)

def list_tables(self):
# Listing available DataFrame tables
Expand All @@ -29,14 +36,20 @@ def create_table(self, table_name: str):

def add_column(self, table_name: str, column_name: str, column_type):
if column_name in self.tables[table_name].columns:
logger.warning(f"Column {column_name} already exists in table {table_name}.")
logger.warning(
f"Column {column_name} already exists in table {table_name}."
)
return
# Add a new column to an existing DataFrame
if table_name in self.tables:
arrow_type = _datasets_dtype_to_pld(column_type)
# self.tables[table_name] = self.tables[table_name].with_column(pl.lit(None).alias(column_name).cast(column_type))
# self.tables[table_name] = self.tables[table_name].with_columns(pl.lit(None).alias(column_name).cast(arrow_type))
self.tables[table_name] = self.tables[table_name].with_columns(pl.Series(column_name, [None]*len(self.tables[table_name])).cast(arrow_type))
self.tables[table_name] = self.tables[table_name].with_columns(
pl.Series(
column_name, [None] * len(self.tables[table_name])
).cast(arrow_type)
)
logger.info(f"Column {column_name} added to table {table_name}.")
else:
logger.error(f"Table {table_name} does not exist.")
Expand All @@ -45,10 +58,16 @@ def insert_data(self, table_name: str, data: dict):
# Inserting a new row into the DataFrame and return the index of the new row
if table_name in self.tables:
# use the schema of the original table
new_row = pl.DataFrame([data], schema=self.tables[table_name].schema)
index_of_new_row = len(self.tables[table_name])
logger.debug(f"Inserting data into {table_name}: {data} with {new_row} to table {self.tables[table_name]}")
self.tables[table_name] = pl.concat([self.tables[table_name], new_row], how = "align")
new_row = pl.DataFrame(
[data], schema=self.tables[table_name].schema
)
index_of_new_row = len(self.tables[table_name])
logger.debug(
f"Inserting data into {table_name}: {data} with {new_row} to table {self.tables[table_name]}"
)
self.tables[table_name] = pl.concat(
[self.tables[table_name], new_row], how="align"
)

return index_of_new_row # Return the index of the inserted row
else:
Expand All @@ -58,20 +77,30 @@ def insert_data(self, table_name: str, data: dict):
def update_data(self, table_name: str, index: int, data: dict):
# update data
for column_name, value in data.items():
logger.info(f"updating {column_name} with {value} at index {index}")
logger.info(
f"updating {column_name} with {value} at index {index}"
)
self.tables[table_name][index, column_name] = value

def merge_tables_with_timestamp(self, tables: List[str], output_table: str):
def merge_tables_with_timestamp(
self, tables: List[str], output_table: str
):
for table_name in self.tables.keys():
if table_name not in tables:
continue
self.tables[table_name] = self.tables[table_name].set_sorted("Timestamp")
self.tables[table_name] = self.tables[table_name].set_sorted(
"Timestamp"
)

# Merging tables using timestamps
if len(tables) > 1:
merged_df = self.tables[tables[0]].join_asof(self.tables[tables[1]], on="Timestamp", strategy = "nearest")
if len(tables) > 1:
merged_df = self.tables[tables[0]].join_asof(
self.tables[tables[1]], on="Timestamp", strategy="nearest"
)
for table_name in tables[2:]:
merged_df = merged_df.join_asof(self.tables[table_name], on="Timestamp", strategy = "nearest")
merged_df = merged_df.join_asof(
self.tables[table_name], on="Timestamp", strategy="nearest"
)
logger.info("Tables merged on Timestamp.")
else:
logger.error("Need at least two tables to merge.")
Expand Down
27 changes: 20 additions & 7 deletions fog_rtx/database/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import decimal
import pyarrow as pa

import numpy as np
import pyarrow as pa
import sqlalchemy # type: ignore
from polars import datatypes as pld

Expand Down Expand Up @@ -58,6 +59,7 @@ def type_np2sql(dtype=None, arr=None):
"""Return the closest sql type for a given numpy dtype"""
return type_py2sql(type_np2py(dtype=dtype, arr=arr))


def _datasets_dtype_to_arrow(datasets_dtype: str) -> pa.DataType:
"""
_datasets_dtype_to_arrow takes a datasets string dtype and converts it to a pyarrow.DataType.
Expand Down Expand Up @@ -103,14 +105,19 @@ def _datasets_dtype_to_arrow(datasets_dtype: str) -> pa.DataType:
elif datasets_dtype.startswith("timestamp"):
# Extracting unit and timezone from the string
unit, tz = None, None
if '[' in datasets_dtype and ']' in datasets_dtype:
unit = datasets_dtype[datasets_dtype.find("[")+1 : datasets_dtype.find("]")]
if "[" in datasets_dtype and "]" in datasets_dtype:
unit = datasets_dtype[
datasets_dtype.find("[") + 1 : datasets_dtype.find("]")
]
tz_start = datasets_dtype.find("tz=")
if tz_start != -1:
tz = datasets_dtype[tz_start+3:]
tz = datasets_dtype[tz_start + 3 :]
return pa.timestamp(unit, tz)
else:
raise ValueError(f"Datasets dtype {datasets_dtype} does not have a PyArrow dtype equivalent.")
raise ValueError(
f"Datasets dtype {datasets_dtype} does not have a PyArrow dtype equivalent."
)


def _datasets_dtype_to_pld(datasets_dtype: str) -> pld.DataType:
"""
Expand Down Expand Up @@ -150,7 +157,11 @@ def _datasets_dtype_to_pld(datasets_dtype: str) -> pld.DataType:
elif datasets_dtype == "large_binary":
# Polars does not differentiate between binary and large_binary
return pld.Binary
elif datasets_dtype == "string" or datasets_dtype == "str" or datasets_dtype == "large_string":
elif (
datasets_dtype == "string"
or datasets_dtype == "str"
or datasets_dtype == "large_string"
):
# Polars treats all strings as Utf8
return pld.Utf8
elif datasets_dtype == "object":
Expand All @@ -159,4 +170,6 @@ def _datasets_dtype_to_pld(datasets_dtype: str) -> pld.DataType:
# Handling timestamps in Polars is a bit different. You might need additional handling based on your use case.
return pld.Datetime
else:
raise ValueError(f"Datasets dtype {datasets_dtype} does not have a Polars dtype equivalent.")
raise ValueError(
f"Datasets dtype {datasets_dtype} does not have a Polars dtype equivalent."
)
Loading

0 comments on commit 250966b

Please sign in to comment.