Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket Authentication on Client and Server #255

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions software/source/clients/base_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def record_audio(self):
global RECORDING

# Create a temporary WAV file to store the audio data
temp_dir = tempfile.gettempdir()
temp_dir = tempfile.mkdtemp(prefix="audio_")
wav_path = os.path.join(
temp_dir, f"audio_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
)
Expand Down Expand Up @@ -237,9 +237,8 @@ def record_audio(self):
"end": True,
}
)

if os.path.exists(wav_path):
os.remove(wav_path)
# Remove the temporary directory and its contents
shutil.rmtree(temp_dir)

def toggle_recording(self, state):
"""Toggle the recording state."""
Expand Down Expand Up @@ -286,11 +285,49 @@ async def message_sender(self, websocket):
await websocket.send(json.dumps(message))
send_queue.task_done()
await asyncio.sleep(0.01)


async def authenticate(self, websocket):
while True:
# Receive authentication request from the server
auth_request = await websocket.recv()
auth_data = json.loads(auth_request)

if auth_data["type"] == "auth_request":
# Send authentication response with the token
token = os.getenv("WS_TOKEN")
if token:
auth_response = {"token": token}
await websocket.send(json.dumps(auth_response))

# Receive authentication result from the server
auth_result = await websocket.recv()
result_data = json.loads(auth_result)

if result_data["type"] == "auth_success":
# Authentication successful
return True
else:
# Authentication failed
logger.error("Authentication failed. Closing the connection.")
await websocket.close()
return False
else:
logger.error("WS_TOKEN not found in environment variables.")
await websocket.close()
return False
else:
# Unexpected message from the server
logger.warning(f"Unexpected message from the server: {auth_data}")

async def websocket_communication(self, WS_URL):
show_connection_log = True

async def exec_ws_communication(websocket):

# Perform authentication
if not await self.authenticate(websocket):
return # Authentication successful, continue with the rest of the communication

if CAMERA_ENABLED:
print(
"\nHold the spacebar to start recording. Press 'c' to capture an image from the camera. Press CTRL-C to exit."
Expand Down Expand Up @@ -348,8 +385,8 @@ async def exec_ws_communication(websocket):
# Workaround for Windows 10 not latching to the websocket server.
# See https://github.com/OpenInterpreter/01/issues/197
try:
ws = websockets.connect(WS_URL)
await exec_ws_communication(ws)
async with websockets.connect(WS_URL) as websocket:
await exec_ws_communication(websocket)
except Exception as e:
logger.error(f"Error while attempting to connect: {e}")
else:
Expand Down
27 changes: 27 additions & 0 deletions software/source/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
accumulator = Accumulator()

app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

app_dir = user_data_dir("01")
conversation_history_path = os.path.join(app_dir, "conversations", "user.json")
Expand Down Expand Up @@ -134,10 +135,36 @@ def terminate(self):
async def ping():
return PlainTextResponse("pong")

async def authenticate(websocket: WebSocket):
# Send authentication request to the client
await websocket.send_json({"type": "auth_request"})

# Receive authentication response from the client
try:
auth_response = await websocket.receive_json()
except WebSocketDisconnect:
return False

# Verify the provided token
token = auth_response.get("token")
expected_token = os.getenv("WS_TOKEN")
if token != expected_token:
await websocket.send_json({"type": "auth_failure"})
await websocket.close()
return False

# Authentication successful
await websocket.send_json({"type": "auth_success"})
return True

@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

# Perform authentication
if not await authenticate(websocket):
return

receive_task = asyncio.create_task(receive_messages(websocket))
send_task = asyncio.create_task(send_messages(websocket))
try:
Expand Down