Skip to content

Commit

Permalink
release 0.0.64 (#56)
Browse files Browse the repository at this point in the history
* optimize import and cleanup_modules

* fix exc not showing up in run

---------

Co-authored-by: nggit <[email protected]>
  • Loading branch information
nggit and nggit authored Dec 2, 2024
1 parent a41e330 commit 2129c64
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 54 deletions.
10 changes: 9 additions & 1 deletion examples/exc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@

from httpout import run


print('Hi')

raise ValueError

async def main():
raise ValueError


run(main())
2 changes: 1 addition & 1 deletion httpout/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 nggit

__version__ = '0.0.63'
__version__ = '0.0.64'
__all__ = ('HTTPOut',)

from .httpout import HTTPOut # noqa: E402
77 changes: 26 additions & 51 deletions httpout/httpout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import sys

from traceback import TracebackException
from types import ModuleType

from awaiter import MultiThreadExecutor
Expand Down Expand Up @@ -54,10 +53,10 @@ def wait(coro, timeout=None):
return asyncio.run_coroutine_threadsafe(coro, loop).result(timeout)

def load_module(name, globals, level=0):
if globals['__name__'] == '__globals__':
modules = worker['modules']
else:
if '__server__' in globals:
modules = globals['__main__'].__server__['modules']
else:
modules = worker['modules']

if name in modules:
# already imported
Expand All @@ -68,7 +67,7 @@ def load_module(name, globals, level=0):
if module:
logger.info('%s: importing %s', globals['__name__'], name)

if globals['__name__'] != '__globals__':
if '__server__' in globals:
module.__main__ = globals['__main__']
module.__server__ = globals['__main__'].__server__
module.print = globals['__main__'].print
Expand Down Expand Up @@ -114,12 +113,12 @@ def ho_import(name, globals=None, locals=None, fromlist=(), level=0):
return module

if name == 'httpout' or name.startswith('httpout.'):
if globals['__name__'] == '__globals__':
module = worker['modules'][globals['__name__']]
else:
if '__server__' in globals:
module = globals['__main__'].__server__['modules'][
globals['__name__']
]
else:
module = worker['modules'][globals['__name__']]

# handles virtual imports,
# e.g. from httpout import request, response
Expand All @@ -128,14 +127,14 @@ def ho_import(name, globals=None, locals=None, fromlist=(), level=0):
if child in module.__dict__:
continue

if (globals['__name__'] != '__globals__' and
if ('__server__' in globals and
child in module.__server__):
module.__dict__[child] = module.__server__[
child
]
elif child in worker and (
child != 'app' or
globals['__name__'] == '__globals__'):
'__server__' not in globals):
module.__dict__[child] = worker[child]
else:
raise ImportError(
Expand Down Expand Up @@ -223,6 +222,21 @@ async def _on_request(self, **server):

server['request'] = HTTPRequest(request, server)
server['response'] = HTTPResponse(response)

if (request.protocol.options['ws'] and
b'upgrade' in request.headers and
b'connection' in request.headers and
b'sec-websocket-key' in request.headers and
request.headers[b'upgrade'].lower() == b'websocket'):
server['websocket'] = WebSocket(request, response)
else:
server['websocket'] = None

excludes = (server['response'].print,
server['response'].run_coroutine,
g.wait,
*server.values())

server['REQUEST_METHOD'] = request.method.decode('latin-1')
server['SCRIPT_NAME'] = module_path[len(document_root):].replace(
os.sep, '/'
Expand All @@ -235,15 +249,6 @@ async def _on_request(self, **server):
server['REQUEST_SCHEME'] = request.scheme.decode('latin-1')
server['DOCUMENT_ROOT'] = document_root

if (request.protocol.options['ws'] and
b'upgrade' in request.headers and
b'connection' in request.headers and
b'sec-websocket-key' in request.headers and
request.headers[b'upgrade'].lower() == b'websocket'):
server['websocket'] = WebSocket(request, response)
else:
server['websocket'] = None

module = ModuleType('__main__')
module.__file__ = module_path
module.__main__ = module
Expand Down Expand Up @@ -271,40 +276,10 @@ async def _on_request(self, **server):
ctx.module_path = module_path
except BaseException as exc:
await server['response'].join()

if not response.headers_sent():
response.set_status(500, b'Internal Server Error')
response.set_content_type(b'text/html; charset=utf-8')
request.http_keepalive = False

if isinstance(exc, Exception):
if request.protocol.options['debug']:
te = TracebackException.from_exception(exc)
await response.write(
b'<ul><li>%s</li></ul>\n' % b'</li><li>'.join(
html_escape(line)
.encode() for line in te.format()
)
)
else:
await response.write(
f'<ul><li>{exc.__class__.__name__}: '
f'{html_escape(str(exc))}</li></ul>\n'
.encode()
)
elif isinstance(exc, SystemExit):
if exc.code:
await response.write(str(exc.code).encode())
else:
request.protocol.print_exception(exc)
await server['response'].handle_exception(exc)
finally:
await g.executor.submit(
cleanup_modules, server['modules'], (module.print,
module.run,
module.wait,
g,
ctx,
server['response'])
cleanup_modules, server['modules'], excludes
)
await server['response'].join()
server['modules'].clear()
Expand Down
34 changes: 33 additions & 1 deletion httpout/lib/http_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import asyncio
import concurrent.futures

from traceback import TracebackException
from tremolo.utils import html_escape


class HTTPResponse:
def __init__(self, response):
Expand All @@ -24,6 +27,33 @@ async def join(self):
while self.tasks:
await self.tasks.pop()

async def handle_exception(self, exc):
if not self.response.headers_sent():
self.response.set_status(500, b'Internal Server Error')
self.response.set_content_type(b'text/html; charset=utf-8')
self.response.request.http_keepalive = False

if isinstance(exc, Exception):
if self.response.request.protocol.options['debug']:
te = TracebackException.from_exception(exc)
await self.response.write(
b'<ul><li>%s</li></ul>\n' % b'</li><li>'.join(
html_escape(line)
.encode() for line in te.format()
)
)
else:
await self.response.write(
f'<ul><li>{exc.__class__.__name__}: '
f'{html_escape(str(exc))}</li></ul>\n'
.encode()
)
elif isinstance(exc, SystemExit):
if exc.code:
await self.response.write(str(exc.code).encode())
else:
self.response.request.protocol.print_exception(exc)

def run_coroutine(self, coro):
fut = concurrent.futures.Future()

Expand All @@ -35,7 +65,9 @@ async def callback():
fut.set_result(result)
except BaseException as exc:
if not fut.done():
fut.set_exception(exc)
fut.set_result(None)

await self.handle_exception(exc)

self.loop.call_soon_threadsafe(self.create_task, callback())
return fut
Expand Down

0 comments on commit 2129c64

Please sign in to comment.