Skip to content

Commit

Permalink
Standardise code style
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibo-Chen13 committed Oct 22, 2024
1 parent e152bc8 commit ee25983
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 73 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ per-file-ignores =
proton_driver/bufferedreader.pyx: E225, E226, E227, E999
proton_driver/bufferedwriter.pyx: E225, E226, E227, E999
proton_driver/varint.pyx: E225, E226, E227, E999
# ignore example print warning.
example/*: T201,T001
exclude = venv,.conda,build
38 changes: 24 additions & 14 deletions example/bytewax/hackernews.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ class HNSource(SimplePollingSource):
def next_item(self):
return (
"GLOBAL_ID",
requests.get("https://hacker-news.firebaseio.com/v0/maxitem.json").json(),
requests.get(
"https://hacker-news.firebaseio.com/v0/maxitem.json"
).json(),
)


def get_id_stream(old_max_id, new_max_id) -> Tuple[str,list]:
def get_id_stream(old_max_id, new_max_id) -> Tuple[str, list]:
if old_max_id is None:
# Get the last 150 items on the first run.
old_max_id = new_max_id - 150
Expand Down Expand Up @@ -51,12 +53,7 @@ def recurse_tree(metadata, og_metadata=None) -> any:
parent_metadata = download_metadata(parent_id)
return recurse_tree(parent_metadata[1], og_metadata)
except KeyError:
return (metadata["id"],
{
**og_metadata,
"root_id":metadata["id"]
}
)
return (metadata["id"], {**og_metadata, "root_id": metadata["id"]})


def key_on_parent(key__metadata) -> tuple:
Expand All @@ -68,19 +65,32 @@ def format(id__metadata):
id, metadata = id__metadata
return json.dumps(metadata)


flow = Dataflow("hn_scraper")
max_id = op.input("in", flow, HNSource(timedelta(seconds=15)))
id_stream = op.stateful_map("range", max_id, lambda: None, get_id_stream).then(
op.flat_map, "strip_key_flatten", lambda key_ids: key_ids[1]).then(
op.redistribute, "redist")
id_stream = \
op.stateful_map("range", max_id, lambda: None, get_id_stream) \
.then(op.flat_map, "strip_key_flatten", lambda key_ids: key_ids[1]) \
.then(op.redistribute, "redist")

id_stream = op.filter_map("meta_download", id_stream, download_metadata)
split_stream = op.branch("split_comments", id_stream, lambda item: item[1]["type"] == "story")
split_stream = op.branch(
"split_comments", id_stream, lambda item: item[1]["type"] == "story"
)
story_stream = split_stream.trues
story_stream = op.map("format_stories", story_stream, format)
comment_stream = split_stream.falses
comment_stream = op.map("key_on_parent", comment_stream, key_on_parent)
comment_stream = op.map("format_comments", comment_stream, format)
op.inspect("stories", story_stream)
op.inspect("comments", comment_stream)
op.output("stories-out", story_stream, ProtonSink("hn_stories_raw", os.environ.get("PROTON_HOST","127.0.0.1")))
op.output("comments-out", comment_stream, ProtonSink("hn_comments_raw", os.environ.get("PROTON_HOST","127.0.0.1")))
op.output(
"stories-out",
story_stream,
ProtonSink("hn_stories_raw", os.environ.get("PROTON_HOST", "127.0.0.1")),
)
op.output(
"comments-out",
comment_stream,
ProtonSink("hn_comments_raw", os.environ.get("PROTON_HOST", "127.0.0.1")),
)
19 changes: 11 additions & 8 deletions example/bytewax/proton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Output to Timeplus Proton."""

from bytewax.outputs import DynamicSink, StatelessSinkPartition
from proton_driver import client
import logging
Expand All @@ -9,28 +10,30 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class _ProtonSinkPartition(StatelessSinkPartition):
def __init__(self, stream: str, host: str):
self.client=client.Client(host=host, port=8463)
self.stream=stream
sql=f"CREATE STREAM IF NOT EXISTS `{stream}` (raw string)"
self.client = client.Client(host=host, port=8463)
self.stream = stream
sql = f"CREATE STREAM IF NOT EXISTS `{stream}` (raw string)"
logger.debug(sql)
self.client.execute(sql)

def write_batch(self, items):
logger.debug(f"inserting data {items}")
rows=[]
rows = []
for item in items:
rows.append([item]) # single column in each row
rows.append([item]) # single column in each row
sql = f"INSERT INTO `{self.stream}` (raw) VALUES"
logger.debug(f"inserting data {sql}")
self.client.execute(sql,rows)
self.client.execute(sql, rows)


class ProtonSink(DynamicSink):
def __init__(self, stream: str, host: str):
self.stream = stream
self.host = host if host is not None and host != "" else "127.0.0.1"

"""Write each output item to Proton on that worker.
Items consumed from the dataflow must look like a string. Use a
Expand All @@ -45,4 +48,4 @@ def __init__(self, stream: str, host: str):

def build(self, worker_index, worker_count):
"""See ABC docstring."""
return _ProtonSinkPartition(self.stream, self.host)
return _ProtonSinkPartition(self.stream, self.host)
58 changes: 40 additions & 18 deletions example/descriptive_pipeline/server/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, Request, BackgroundTasks
from fastapi import (
FastAPI,
WebSocket,
HTTPException,
WebSocketDisconnect,
Request,
BackgroundTasks,
)
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
import yaml
Expand All @@ -10,6 +17,7 @@
from proton_driver import client

from .utils.logging import getLogger

logger = getLogger()


Expand Down Expand Up @@ -58,7 +66,11 @@ def pipeline_exist(self, name):
return False

def delete_pipeline(self, name):
updated_pipelines = [pipeline for pipeline in self.config.pipelines if pipeline.name != name]
updated_pipelines = [
pipeline
for pipeline in self.config.pipelines
if pipeline.name != name
]
self.config.pipelines = updated_pipelines
self.save()

Expand All @@ -73,11 +85,13 @@ def save(self):
yaml.dump(self.config, yaml_file)

def run_pipeline(self, name):
proton_client = client.Client(host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password)
proton_client = client.Client(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
pipeline = self.get_pipeline_by_name(name)
if pipeline is not None:
for query in pipeline.sqls[:-1]:
Expand All @@ -93,7 +107,7 @@ def conf(self):
return self.config


class Query():
class Query:
def __init__(self, sql, client):
self.sql = sql
self.lock = threading.Lock()
Expand Down Expand Up @@ -198,7 +212,7 @@ async def query_stream(name, request, background_tasks):
async def check_disconnect():
while True:
await asyncio.sleep(1)
disconnected = await request.is_disconnected();
disconnected = await request.is_disconnected()
if disconnected:
query.cancel()
logger.info('Client disconnected')
Expand All @@ -215,28 +229,34 @@ async def check_disconnect():
result = {}
for index, (name, t) in enumerate(header):
if t.startswith('date'):
result[name] = str(m[index]) # convert datetime type to string
# convert datetime type to string
result[name] = str(m[index])
else:
result[name] = m[index]
result_str = json.dumps(result).encode("utf-8") + b"\n"
yield result_str
except Exception as e:
query.cancel()
logger.info(f'query cancelled due to {e}' )
logger.info(f'query cancelled due to {e}')
break

if query.is_finshed():
break

await asyncio.sleep(0.1)


@app.get("/queries/{name}")
def query_pipeline(name: str, request: Request , background_tasks: BackgroundTasks):
def query_pipeline(
name: str, request: Request, background_tasks: BackgroundTasks
):
if not config_manager.pipeline_exist(name):
raise HTTPException(status_code=404, detail="pipeline not found")

return StreamingResponse(query_stream(name, request, background_tasks), media_type="application/json")
return StreamingResponse(
query_stream(name, request, background_tasks),
media_type="application/json",
)


@app.websocket("/queries/{name}")
Expand All @@ -258,10 +278,11 @@ async def websocket_endpoint(name: str, websocket: WebSocket):
result = {}
for index, (name, t) in enumerate(header):
if t.startswith('date'):
result[name] = str(m[index]) # convert datetime type to string
# convert datetime type to string
result[name] = str(m[index])
else:
result[name] = m[index]

await websocket.send_text(f'{json.dumps(result)}')
except Exception:
hasError = True
Expand All @@ -282,6 +303,7 @@ async def websocket_endpoint(name: str, websocket: WebSocket):
except Exception as e:
logger.exception(e)
finally:
query.cancel() # Ensure query cancellation even if an exception is raised
# Ensure query cancellation even if an exception is raised
query.cancel()
await websocket.close()
logger.debug('session closed')
27 changes: 17 additions & 10 deletions example/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,31 @@

# setup the test stream
c.execute("drop stream if exists test")
c.execute("""create stream test (
c.execute(
"""create stream test (
year int16,
first_name string
)""")
)"""
)
# add some data
df = pd.DataFrame.from_records([
{'year': 1994, 'first_name': 'Vova'},
{'year': 1995, 'first_name': 'Anja'},
{'year': 1996, 'first_name': 'Vasja'},
{'year': 1997, 'first_name': 'Petja'},
])
df = pd.DataFrame.from_records(
[
{'year': 1994, 'first_name': 'Vova'},
{'year': 1995, 'first_name': 'Anja'},
{'year': 1996, 'first_name': 'Vasja'},
{'year': 1997, 'first_name': 'Petja'},
]
)
c.insert_dataframe(
'INSERT INTO "test" (year, first_name) VALUES',
df,
settings=dict(use_numpy=True),
)
# or c.execute("INSERT INTO test(year, first_name) VALUES", df.to_dict('records'))
time.sleep(3) # wait for 3 sec to make sure data available in historical store
# or c.execute(
# "INSERT INTO test(year, first_name) VALUES", df.to_dict('records')
# )
# wait for 3 sec to make sure data available in historical store
time.sleep(3)

df = c.query_dataframe('SELECT * FROM table(test)')
print(df)
Expand Down
28 changes: 20 additions & 8 deletions example/streaming_query/car.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
"""
This example uses driver DB API.
In this example, a thread writes a huge list of data of car speed into database,
and another thread reads from the database to figure out which car is speeding.
In this example, a thread writes a huge list of data of car speed into
database, and another thread reads from the database to figure out which
car is speeding.
"""

import datetime
import random
import threading
import time

from proton_driver import connect

account='default:'
account = 'default:'


def create_stream():
with connect(f"proton://{account}@localhost:8463/default") as conn:
with conn.cursor() as cursor:
cursor.execute("drop stream if exists cars")
cursor.execute("create stream if not exists car(id int64, speed float64)")
cursor.execute(
"create stream if not exists car(id int64, speed float64)"
)


def write_data(car_num: int):
car_begin_date = datetime.datetime(2022, 1, 1, 1, 0, 0)
for day in range(100):
car_begin_date += datetime.timedelta(days=1)
data = [(random.randint(0, car_num - 1), random.random() * 20 + 50,
car_begin_date
+ datetime.timedelta(milliseconds=i * 100)) for i in range(300000)]
data = [
(
random.randint(0, car_num - 1),
random.random() * 20 + 50,
car_begin_date + datetime.timedelta(milliseconds=i * 100),
)
for i in range(300000)
]
with connect(f"proton://{account}@localhost:8463/default") as conn:
with conn.cursor() as cursor:
cursor.executemany("insert into car (id, speed, _tp_time) values", data)
cursor.executemany(
"insert into car (id, speed, _tp_time) values", data
)
print(f"row count: {cursor.rowcount}")
time.sleep(10)

Expand Down
Loading

0 comments on commit ee25983

Please sign in to comment.