Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axum integration #36

Open
SirCipher opened this issue Jul 25, 2024 · 10 comments
Open

Axum integration #36

SirCipher opened this issue Jul 25, 2024 · 10 comments

Comments

@SirCipher
Copy link
Member

No description provided.

@nakedible-p
Copy link

FWIW, we implemented a custom axum integration, since we need manual control over headers during negotiation. It was quite straightforward. Deflate extension negotiation was a bit annoying at the lower level, but all in all not a problem.

@SirCipher
Copy link
Member Author

@nakedible-p, thank you for the integration. I have got an open ticket for providing user defined headers that needs implementing #16. Regarding your last comment, what was annoying about the deflate extension negotiation? Was it the API or just integrating it into the Axum integration?

@nakedible-p
Copy link

Well, I don't know if I'm using the API wrong, but this took a bit of trial and error:

            let deflate_config = ratchet_rs::deflate::DeflateConfig::default(); // XXX: tweak config
            let deflate_ext_provider = ratchet_rs::deflate::DeflateExtProvider::with_config(deflate_config);
            let sec_websocket_extensions: Vec<ratchet_rs::Header> = if ENV_CONFIG.disable_ws_compression {
                // XXX: we disable all extensions
                vec![]
            } else {
                parts
                    .headers
                    .get_all(axum::http::header::SEC_WEBSOCKET_EXTENSIONS)
                    .into_iter()
                    .map(|v| ratchet_rs::Header {
                        name: axum::http::header::SEC_WEBSOCKET_EXTENSIONS.as_str(),
                        value: v.as_bytes(),
                    })
                    .collect()
            };
            let (extension, sec_websocket_extensions) = match deflate_ext_provider.negotiate_server(sec_websocket_extensions.as_slice()) {
                Ok(Some((extension, sec_websocket_extensions))) => (
                    ratchet_rs::NegotiatedExtension::from(Some(extension)),
                    Some(sec_websocket_extensions),
                ),
                Ok(None) => (ratchet_rs::NegotiatedExtension::from(None), None),
                Err(e) => {
                    return RpcError::new(RpcError::BAD_REQUEST, anyhow!("failed to negotiate deflate extension: {}", e)).into_response();
                }
            };

In general, for any websocket implementation, it would be nice if there was a more convenient wrapper for when none of the HTTP stuff is in ratchet's control. This means API should take in a read-only HeaderMap (or iterator of headers or something) and produce out statuscode with headers, but not write them out to anything. Passing "user defined headers" should not be needed, as ratchet should only output the websocket relevant headers and the caller is responsible for putting them in the response that the caller produces along with everything else it wants to do with the response.

@huntc
Copy link

huntc commented Nov 5, 2024

Thanks for the code snippet. Before I dive in further, are there any complete examples of integrating ratchet with axum?

@nakedible-p
Copy link

The API has since been improved significantly for this use case, so the snippet is outdated.

@huntc
Copy link

huntc commented Nov 6, 2024

The API indeed appears to be improved. Here's the rough equivalent of @nakedible-p 's code snippet (for convenience, I just pass in all of the headers, and I don't have the environment variable check):

    let deflate_config = ratchet_rs::deflate::DeflateConfig::default();
    let deflate_ext_provider = ratchet_rs::deflate::DeflateExtProvider::with_config(deflate_config);

    if let Ok(Some((extension, sec_websocket_extensions))) =
        deflate_ext_provider.negotiate_server(&request_headers)
    {
        ...
    }

I'm unsure what to do next though. It is wonderful to have established the dflate extension and its response headers, however I'm unsure what I can call on next. The next step is of course to upgrade the connection, but I don't have a stream representing the socket connection to be able to call accept_with. Any clues? Thanks.

@nakedible-p
Copy link

Not full code, but this should get you far enough

            let ratchet_core::server::UpgradeResponseParts {
                response,
                subprotocol,
                extension,
                ..
            } = ratchet_core::server::response_from_headers(&parts.headers, provider, &protocols)
                .map_err(...)?;
            let Some(on_upgrade) = parts.extensions.remove::<hyper::upgrade::OnUpgrade>() else {
                return Err(...);
            };
            tokio::spawn(async move {
                match on_upgrade.await {
                    Ok(upgraded) => {
                        let upgraded = hyper_util::rt::tokio::TokioIo::new(upgraded);
                        let ws = ratchet_core::WebSocket::from_upgraded(
                            ratchet_core::WebSocketConfig::default(),
                            upgraded,
                            extension,
                            bytes::BytesMut::new(),
                            ratchet_core::Role::Server,
                        );
                    }
                    Err(e) => {
                           ...
                    }
                }
            });

            Ok(response)

@huntc
Copy link

huntc commented Nov 6, 2024

Thanks for the fabulous guidance here. I've been able to get things integrated into Axum nicely. Here's the code I've landed on - I hope that it helps someone else:

async fn my_handler(
    headers: HeaderMap,
    Extension(on_upgrade): Extension<OnUpgrade>,
) -> Result<impl IntoResponse, StatusCode> {
    let deflate_config = ratchet_deflate::DeflateConfig::default();
    let deflate_ext_provider = ratchet_deflate::DeflateExtProvider::with_config(deflate_config);

    let ratchet_core::server::UpgradeResponseParts {
        response,
        extension,
        ..
    } = ratchet_core::server::response_from_headers(
        &headers,
        deflate_ext_provider,
        &ratchet_core::SubprotocolRegistry::default(),
    )
    .map_err(|e| {
        log::error!("Problem with headers: {e}");
        StatusCode::BAD_REQUEST
    })?;

    tokio::spawn(async move {
        match on_upgrade.await {
            Ok(upgraded) => {
                let upgraded = hyper_util::rt::tokio::TokioIo::new(upgraded);
                let mut ws = ratchet_core::WebSocket::from_upgraded(
                    ratchet_core::WebSocketConfig::default(),
                    upgraded,
                    extension,
                    bytes::BytesMut::new(),
                    ratchet_core::Role::Server,
                );
                let mut buf = BytesMut::new();

                while let Ok(msg) = ws.read(&mut buf).await {
                    match msg {
                        ratchet_core::Message::Text => {
                            if ws
                                .write(&mut buf, ratchet_core::PayloadType::Text)
                                .await
                                .is_err()
                            {
                                break;
                            }
                            buf.clear();
                        }
                        ratchet_core::Message::Binary => {
                            if ws
                                .write(&mut buf, ratchet_core::PayloadType::Binary)
                                .await
                                .is_err()
                            {
                                break;
                            }
                            buf.clear();
                        }
                        ratchet_core::Message::Ping(_bytes) => (),
                        ratchet_core::Message::Pong(_bytes) => (),
                        ratchet_core::Message::Close(_) => break,
                    }
                }
            }
            Err(e) => {
                log::error!("Problem upgrading: {e}");
            }
        }
    });

    Ok(Response::from_parts(response.into_parts().0, Body::empty()))
}

BTW that last line has to re-create the response as the one that ratchet provides is one with a body of (). If response_from_headers used Body::empty() then they'd be no need for a coercian.

@nakedible-p
Copy link

nakedible-p commented Nov 6, 2024

Response in ratchet needs to use () because it isn't linked with any certain framework. To translate it to axum, just say:

response.map(|_| axum::body::Body::empty());

@huntc
Copy link

huntc commented Nov 6, 2024

Response in ratchet needs to use () because it isn't linked with any certain framework. To translate it to axum, just say:

response.map(|_| axum::body::Body::empty());

That's tricky - I didn't realise that map just maps over body. Intuitively, I'd expect it to map over its parts and the body. Anyhow, what you state does indeed work. Thanks.

So, if I were to attempt to generalise this axum integration, would there be interest in a PR? @SirCipher ?

I'm thinking that an extractor similar to axum's existing WebSocketUpgrade would be the go, but along with some methods to add in the extensions and protocol adapters prior to calling on_upgrade within the axum handler. Thoughts? Perhaps an additional crate to this project?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants