From 174a6a10e95dcd6d2b0d2eb806c601efd2c46f63 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 24 Oct 2023 12:10:10 +0700 Subject: [PATCH] allow the client to disable 0-RTT when using a 0-RTT enabled session ticket --- common.go | 2 +- handshake_client.go | 9 +++++---- tls_test.go | 3 ++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/common.go b/common.go index 841c1a4..ba776d7 100644 --- a/common.go +++ b/common.go @@ -739,7 +739,7 @@ type ExtraConfig struct { // Is called when the client uses a session ticket. // Restores the application data that was saved earlier on GetAppDataForSessionTicket. - SetAppDataFromSessionState func([]byte) + SetAppDataFromSessionState func([]byte) (allowEarlyData bool) } // Clone clones. diff --git a/handshake_client.go b/handshake_client.go index a5fdd54..cc9cedd 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -418,8 +418,12 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } if c.quic != nil && maxEarlyData > 0 { + var earlyData bool + if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { + earlyData = c.extraConfig.SetAppDataFromSessionState(appData) + } // For 0-RTT, the cipher suite has to match exactly. - if mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil { + if earlyData && mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil { hello.earlyData = true } } @@ -449,9 +453,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, return "", nil, nil, nil, err } - if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { - c.extraConfig.SetAppDataFromSessionState(appData) - } return } diff --git a/tls_test.go b/tls_test.go index 87b5d8a..0958baf 100644 --- a/tls_test.go +++ b/tls_test.go @@ -867,8 +867,9 @@ func TestExtraConfigCloneFuncField(t *testing.T) { called |= 1 << 2 return nil }, - SetAppDataFromSessionState: func([]byte) { + SetAppDataFromSessionState: func([]byte) bool { called |= 1 << 3 + return true }, }