Skip to content

Commit

Permalink
Change from on display data to on result
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubno committed Apr 3, 2024
1 parent bc0bec9 commit 974b1b0
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 49 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ print("world")
"""

with CodeInterpreter() as sandbox:
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_display_data=(lambda data: print(data.text)))
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_result=(lambda result: print(result.text)))

```

Expand All @@ -175,7 +175,7 @@ const sandbox = await CodeInterpreter.create()
await sandbox.notebook.execCell(code, {
onStdout: (out) => console.log(out),
onStderr: (outErr) => console.error(outErr),
onDisplayData: (outData) => console.log(outData.text)
onResult: (result) => console.log(result.text)
})

await sandbox.close()
Expand Down
2 changes: 1 addition & 1 deletion js/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ const sandbox = await CodeInterpreter.create()
await sandbox.notebook.execCell(code, {
onStdout: (out) => console.log(out),
onStderr: (outErr) => console.error(outErr),
onDisplayData: (outData) => console.log(outData.text)
onResult: (result) => console.log(result.text)
})

await sandbox.close()
Expand Down
12 changes: 6 additions & 6 deletions js/src/code-interpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class JupyterExtension {
* @param kernelID The ID of the kernel to execute the code on. If not provided, the default kernel is used.
* @param onStdout A callback function to handle standard output messages from the code execution.
* @param onStderr A callback function to handle standard error messages from the code execution.
* @param onDisplayData A callback function to handle display data messages from the code execution.
* @param onResult A callback function to handle display data messages from the code execution.
* @param timeout The maximum time to wait for the code execution to complete, in milliseconds.
* @returns A promise that resolves with the result of the code execution.
*/
Expand All @@ -73,13 +73,13 @@ export class JupyterExtension {
kernelID,
onStdout,
onStderr,
onDisplayData,
onResult,
timeout
}: {
kernelID?: string
onStdout?: (msg: ProcessMessage) => Promise<void> | void
onStderr?: (msg: ProcessMessage) => Promise<void> | void
onDisplayData?: (data: Result) => Promise<void> | void
onStdout?: (msg: ProcessMessage) => any
onStderr?: (msg: ProcessMessage) => any
onResult?: (data: Result) => any
timeout?: number
} = {}
): Promise<Execution> {
Expand All @@ -92,7 +92,7 @@ export class JupyterExtension {
code,
onStdout,
onStderr,
onDisplayData,
onResult,
timeout
)
}
Expand Down
34 changes: 19 additions & 15 deletions js/src/messaging.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,20 @@ export class Execution {
*/
class CellExecution {
execution: Execution
onStdout?: (out: ProcessMessage) => Promise<void> | void
onStderr?: (out: ProcessMessage) => Promise<void> | void
onDisplayData?: (data: Result) => Promise<void> | void
onStdout?: (out: ProcessMessage) => any
onStderr?: (out: ProcessMessage) => any
onResult?: (data: Result) => any
inputAccepted: boolean = false

constructor(
onStdout?: (out: ProcessMessage) => Promise<void> | void,
onStderr?: (out: ProcessMessage) => Promise<void> | void,
onDisplayData?: (data: Result) => Promise<void> | void
onStdout?: (out: ProcessMessage) => any,
onStderr?: (out: ProcessMessage) => any,
onResult?: (data: Result) => any
) {
this.execution = new Execution([], { stdout: [], stderr: [] })
this.onStdout = onStdout
this.onStderr = onStderr
this.onDisplayData = onDisplayData
this.onResult = onResult
}
}

Expand Down Expand Up @@ -295,11 +295,15 @@ export class JupyterKernelWebSocket {
} else if (message.msg_type == 'display_data') {
const result = new Result(message.content.data, false)
execution.results.push(result)
if (cell.onDisplayData) {
cell.onDisplayData(result)
if (cell.onResult) {
cell.onResult(result)
}
} else if (message.msg_type == 'execute_result') {
execution.results.push(new Result(message.content.data, true))
const result = new Result(message.content.data, true)
execution.results.push(result)
if (cell.onResult) {
cell.onResult(result)
}
} else if (message.msg_type == 'status') {
if (message.content.execution_state == 'idle') {
if (cell.inputAccepted) {
Expand Down Expand Up @@ -337,15 +341,15 @@ export class JupyterKernelWebSocket {
* @param code Code to be executed.
* @param onStdout Callback for stdout messages.
* @param onStderr Callback for stderr messages.
* @param onDisplayData Callback for display data messages.
* @param onResult Callback function to handle the result and display calls of the code execution.
* @param timeout Time in milliseconds to wait for response.
* @returns Promise with execution result.
*/
public sendExecutionMessage(
code: string,
onStdout?: (out: ProcessMessage) => Promise<void> | void,
onStderr?: (out: ProcessMessage) => Promise<void> | void,
onDisplayData?: (data: Result) => Promise<void> | void,
onStdout?: (out: ProcessMessage) => any,
onStderr?: (out: ProcessMessage) => any,
onResult?: (data: Result) => any,
timeout?: number
) {
return new Promise<Execution>((resolve, reject) => {
Expand All @@ -367,7 +371,7 @@ export class JupyterKernelWebSocket {
}

// expect response
this.cells[msg_id] = new CellExecution(onStdout, onStderr, onDisplayData)
this.cells[msg_id] = new CellExecution(onStdout, onStderr, onResult)
this.idAwaiter[msg_id] = (responseData: Execution) => {
// stop timeout
clearInterval(timeoutSet as number)
Expand Down
51 changes: 51 additions & 0 deletions js/tests/streaming.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { ProcessMessage } from 'e2b'
import { CodeInterpreter, Result } from '../src'

import { expect, test } from 'vitest'

test('streaming output', async () => {
const out: ProcessMessage[] = []
const sandbox = await CodeInterpreter.create()
await sandbox.notebook.execCell('print(1)', {
onStdout: (msg) => out.push(msg)
})

expect(out.length).toEqual(1)
expect(out[0].line).toEqual('1\n')
await sandbox.close()
})

test('streaming error', async () => {
const out: ProcessMessage[] = []
const sandbox = await CodeInterpreter.create()
await sandbox.notebook.execCell('import sys;print(1, file=sys.stderr)', {
onStderr: (msg) => out.push(msg)
})

expect(out.length).toEqual(1)
expect(out[0].line).toEqual('1\n')
await sandbox.close()
})

test('streaming result', async () => {
const out: Result[] = []
const sandbox = await CodeInterpreter.create()
const code = `
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 20, 100)
y = np.sin(x)
plt.plot(x, y)
plt.show()
x
`
await sandbox.notebook.execCell(code, {
onResult: (result) => out.push(result)
})

expect(out.length).toEqual(2)
await sandbox.close()
})
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ print("world")
"""

with CodeInterpreter() as sandbox:
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_display_data=(lambda data: print(data.text)))
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_result=(lambda result: print(result.text)))
```

### Pre-installed Python packages inside the sandbox
Expand Down
11 changes: 4 additions & 7 deletions python/e2b_code_interpreter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from e2b.constants import TIMEOUT

from e2b_code_interpreter.messaging import JupyterKernelWebSocket
from e2b_code_interpreter.models import KernelException, Execution

from e2b_code_interpreter.models import KernelException, Execution, Result

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,7 +65,7 @@ def exec_cell(
kernel_id: Optional[str] = None,
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
on_display_data: Optional[Callable[[Dict[str, Any]], Any]] = None,
on_result: Optional[Callable[[Result], Any]] = None,
timeout: Optional[float] = TIMEOUT,
) -> Execution:
"""
Expand All @@ -76,7 +75,7 @@ def exec_cell(
:param kernel_id: The ID of the kernel to execute the code on. If not provided, the default kernel is used.
:param on_stdout: A callback function to handle standard output messages from the code execution.
:param on_stderr: A callback function to handle standard error messages from the code execution.
:param on_display_data: A callback function to handle display data messages from the code execution.
:param on_result: A callback function to handle the result and display calls of the code execution.
:param timeout: Timeout for the call
:return: Result of the execution
Expand All @@ -93,9 +92,7 @@ def exec_cell(
logger.debug(f"Creating new websocket connection to kernel {kernel_id}")
ws = self._connect_to_kernel_ws(kernel_id, timeout=timeout)

session_id = ws.send_execution_message(
code, on_stdout, on_stderr, on_display_data
)
session_id = ws.send_execution_message(code, on_stdout, on_stderr, on_result)
logger.debug(
f"Sent execution message to kernel {kernel_id}, session_id: {session_id}"
)
Expand Down
35 changes: 18 additions & 17 deletions python/e2b_code_interpreter/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from e2b.utils.future import DeferredFuture
from pydantic import ConfigDict, PrivateAttr, BaseModel

from e2b_code_interpreter.models import Execution, Result, Error, MIMEType
from e2b_code_interpreter.models import Execution, Result, Error

logger = logging.getLogger(__name__)

Expand All @@ -26,21 +26,21 @@ class CellExecution:
"""

input_accepted: bool = False
on_stdout: Optional[Callable[[ProcessMessage], None]] = None
on_stderr: Optional[Callable[[ProcessMessage], None]] = None
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None
on_result: Optional[Callable[[Result], Any]] = None

def __init__(
self,
on_stdout: Optional[Callable[[ProcessMessage], None]] = None,
on_stderr: Optional[Callable[[ProcessMessage], None]] = None,
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None,
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
on_result: Optional[Callable[[Result], Any]] = None,
):
self.partial_result = Execution()
self.execution = Future()
self.on_stdout = on_stdout
self.on_stderr = on_stderr
self.on_display_data = on_display_data
self.on_result = on_result


class JupyterKernelWebSocket(BaseModel):
Expand Down Expand Up @@ -129,17 +129,17 @@ def _get_execute_request(msg_id: str, code: str) -> str:
def send_execution_message(
self,
code: str,
on_stdout: Optional[Callable[[ProcessMessage], None]] = None,
on_stderr: Optional[Callable[[ProcessMessage], None]] = None,
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None,
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
on_result: Optional[Callable[[Result], Any]] = None,
) -> str:
message_id = str(uuid.uuid4())
logger.debug(f"Sending execution message: {message_id}")

self._cells[message_id] = CellExecution(
on_stdout=on_stdout,
on_stderr=on_stderr,
on_display_data=on_display_data,
on_result=on_result,
)
request = self._get_execute_request(message_id, code)
self._queue_in.put(request)
Expand Down Expand Up @@ -204,12 +204,13 @@ def _receive_message(self, data: dict):
elif data["msg_type"] in "display_data":
result = Result(is_main_result=False, data=data["content"]["data"])
execution.results.append(result)
if cell.on_display_data:
cell.on_display_data(result)
if cell.on_result:
cell.on_result(result)
elif data["msg_type"] == "execute_result":
execution.results.append(
Result(is_main_result=True, data=data["content"]["data"])
)
result = Result(is_main_result=True, data=data["content"]["data"])
execution.results.append(result)
if cell.on_result:
cell.on_result(result)
elif data["msg_type"] == "status":
if data["content"]["execution_state"] == "idle":
if cell.input_accepted:
Expand Down
44 changes: 44 additions & 0 deletions python/tests/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from e2b_code_interpreter.main import CodeInterpreter


def test_streaming_output():
out = []
with CodeInterpreter() as sandbox:
def test(line) -> int:
out.append(line)
return 1
sandbox.notebook.exec_cell("print(1)", on_stdout=test)

assert len(out) == 1
assert out[0].line == "1\n"


def test_streaming_error():
out = []
with CodeInterpreter() as sandbox:
sandbox.notebook.exec_cell("import sys;print(1, file=sys.stderr)", on_stderr=out.append)

assert len(out) == 1
assert out[0].line == "1\n"


def test_streaming_result():
code = """
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 20, 100)
y = np.sin(x)
plt.plot(x, y)
plt.show()
x
"""

out = []
with CodeInterpreter() as sandbox:
sandbox.notebook.exec_cell(code, on_result=out.append)

assert len(out) == 2

0 comments on commit 974b1b0

Please sign in to comment.