-
Notifications
You must be signed in to change notification settings - Fork 1
/
client.py
155 lines (131 loc) · 4.9 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import annotations
import asyncio
import io
import json
import struct
import time
import uuid
from enum import Enum
from typing import NamedTuple, Optional
from urllib import request
import numpy as np
from PIL import Image
from websockets import client as websockets_client
from websockets import exceptions as websockets_exceptions
from eventloop import AsyncApp
from util import client_logger as log
class ClientEvent(Enum):
progress = 0
finished = 1
interrupted = 2
error = 3
connected = 4
disconnected = 5
class ClientMessage(NamedTuple):
event: ClientEvent
prompt_id: Optional[str] = ""
progress: float = 0
images: list[Image.Image] = []
result: Optional[dict] = None
error: Optional[str] = None
class Client:
def __init__(self, ip="http://127.0.0.1", port="8188"):
self.url = f"{ip}:{port}"
self._id = str(uuid.uuid4())
self._connected = False
self._async_app = AsyncApp()
def enqueue(self, workflow):
if not self._connected:
self._wait_connection()
data = {"prompt": workflow, "client_id": self._id}
data = json.dumps(data).encode("utf-8")
req = request.Request(f"{self.url}/prompt", data)
req = json.loads(request.urlopen(req).read())
return req["prompt_id"]
def _wait_connection(self, time_out=60):
wait_time = 0
self._connected = self.health_check()
while not self._connected:
self._connected = self.health_check()
if wait_time > time_out:
raise Exception("Connection timeout")
else:
print("waiting for connection with ComfyUI...")
time.sleep(1)
wait_time += 1
def health_check(self):
req = request.Request(f"{self.url}/system_stats")
try:
status = request.urlopen(req).status
if status == 200:
return True
return False
except Exception:
return False
async def _listen(self, prompt_id):
url = self.url.replace("http", "ws", 1)
async for websocket in websockets_client.connect(
f"{url}/ws?clientId={self._id}",
max_size=2**30,
read_limit=2**30,
ping_timeout=60,
):
try:
async for msg in self._listen_main(websocket, prompt_id):
yield msg
except websockets_exceptions.ConnectionClosedError as e:
log.warning(f"Websocket connection closed: {str(e)}")
except OSError as e:
msg = "Could not connect to websocket server " + f"at {url}: {str(e)}"
except asyncio.CancelledError:
await websocket.close()
break
except Exception as e:
log.exception(f"Unhandled exception in websocket listener, {e}")
async def _listen_main(
self,
websocket: websockets_client.WebSocketClientProtocol,
prompt_id: str,
):
images = []
result = None
async for msg in websocket:
if isinstance(msg, bytes):
image = _extract_message_png_image(memoryview(msg))
if image is not None:
images.append(image)
elif isinstance(msg, str):
msg = json.loads(msg)
if msg["type"] == "status":
yield ClientMessage(ClientEvent.connected)
if (
msg["type"] == "executing"
and msg["data"]["node"] is None
and msg["data"]["prompt_id"] == prompt_id
):
yield ClientMessage(ClientEvent.finished, prompt_id, 1, images)
if msg["type"] == "executed" and msg["data"]["prompt_id"] == prompt_id:
yield ClientMessage(
ClientEvent.finished, prompt_id, 1, images, result
)
def polling(self, prompt_id):
return self._async_app.run(self._polling(prompt_id))
async def _polling(self, prompt_id):
results = await self._receive_images(prompt_id)
return np.array(results[0])
async def _receive_images(self, prompt_id):
async for msg in self._listen(prompt_id):
if msg.event is ClientEvent.finished and msg.prompt_id == prompt_id:
assert msg.images is not None
return msg.images
if msg.event is ClientEvent.error and msg.prompt_id == prompt_id:
raise Exception(msg.error)
assert False, "Connection closed without receiving images"
def _extract_message_png_image(data: memoryview) -> Optional[Image.Image]:
s = struct.calcsize(">II")
if len(data) > s:
byte_data = data[s:].tobytes()
byte_stream = io.BytesIO(byte_data)
image = Image.open(byte_stream)
return image
return None