diff --git a/control/beaconing/connect/sender.go b/control/beaconing/connect/sender.go index f3ada995e1..4ffe8708fe 100644 --- a/control/beaconing/connect/sender.go +++ b/control/beaconing/connect/sender.go @@ -34,7 +34,7 @@ func (f *BeaconSenderFactory) NewSender( } dialer := f.Dialer(addr) return &BeaconSender{ - Addr: addr.String(), + Addr: "https://" + addr.SVC.BaseString(), Client: &HTTPClient{ RoundTripper: &http3.RoundTripper{ Dial: dialer.DialEarly, diff --git a/control/beaconing/grpc/creation_server.go b/control/beaconing/grpc/creation_server.go index 2dc3df7e28..9aa611609f 100644 --- a/control/beaconing/grpc/creation_server.go +++ b/control/beaconing/grpc/creation_server.go @@ -45,6 +45,8 @@ type SegmentCreationServer struct { func (s SegmentCreationServer) Beacon(ctx context.Context, req *cppb.BeaconRequest) (*cppb.BeaconResponse, error) { + // Need to patch https://github.com/quic-go/quic-go/blob/9414ea49100d5cf75a2044d85a6becf3985171db/http3/server.go#L578C19-L578C36 + // to get the peer address into the context. gPeer, ok := peer.FromContext(ctx) if !ok { return nil, serrors.New("peer must exist") diff --git a/control/beaconing/happy/sender.go b/control/beaconing/happy/sender.go index ff7e1d2926..254fb18839 100644 --- a/control/beaconing/happy/sender.go +++ b/control/beaconing/happy/sender.go @@ -3,6 +3,7 @@ package happy import ( "context" "net" + "sync" "time" "github.com/scionproto/scion/control/beaconing" @@ -46,36 +47,57 @@ type BeaconSender struct { func (s BeaconSender) Send(ctx context.Context, b *seg.PathSegment) error { abortCtx, cancel := context.WithCancel(ctx) - defer cancel() + var wg sync.WaitGroup + wg.Add(2) - connectCh := make(chan error, 1) - grpcCh := make(chan error, 1) + errs := [2]error{} + successCh := make(chan struct{}, 2) go func() { defer log.HandlePanic() + defer wg.Done() err := s.Connect.Send(abortCtx, b) - if abortCtx.Err() == nil { - log.Debug("Sent beacon via connect") + if err == nil { + successCh <- struct{}{} + log.Info("Sent beacon via connect") + cancel() + } else { + log.Info("Failed to send beacon via connect", "err", err) } - connectCh <- err + errs[0] = err }() go func() { defer log.HandlePanic() - time.Sleep(500 * time.Millisecond) + defer wg.Done() + select { + case <-abortCtx.Done(): + return + case <-time.After(500 * time.Millisecond): + } err := s.Grpc.Send(abortCtx, b) - if abortCtx.Err() == nil { - log.Debug("Sent beacon via gRPC") + if err == nil { + successCh <- struct{}{} + log.Info("Sent beacon via gRPC") + cancel() + } else { + log.Info("Failed to send beacon via gRPC", "err", err) } - grpcCh <- err + errs[1] = err }() - select { - case err := <-connectCh: - return err - case err := <-grpcCh: - return err + wg.Wait() + var combinedErrs serrors.List + for _, err := range errs { + if err != nil { + combinedErrs = append(combinedErrs, err) + } + } + // Only report error if both sends were unsuccessful. + if len(combinedErrs) == 2 { + return combinedErrs.ToError() } + return nil } func (s BeaconSender) Close() error { diff --git a/control/cmd/control/BUILD.bazel b/control/cmd/control/BUILD.bazel index 6d11bcd232..4b0abe0435 100644 --- a/control/cmd/control/BUILD.bazel +++ b/control/cmd/control/BUILD.bazel @@ -86,6 +86,7 @@ go_library( "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//health:go_default_library", "@org_golang_google_grpc//health/grpc_health_v1:go_default_library", + "@org_golang_google_grpc//peer:go_default_library", "@org_golang_x_sync//errgroup:go_default_library", ], ) diff --git a/control/cmd/control/main.go b/control/cmd/control/main.go index a38c58711f..20c06e5b6d 100644 --- a/control/cmd/control/main.go +++ b/control/cmd/control/main.go @@ -16,10 +16,12 @@ package main import ( "context" + "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" + "net" "net/http" _ "net/http/pprof" "path/filepath" @@ -38,6 +40,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/peer" cpconnect "github.com/scionproto/scion/bufgen/proto/control_plane/v1/control_planeconnect" cs "github.com/scionproto/scion/control" @@ -112,6 +115,12 @@ type loggingHandler struct{ next http.Handler } func (h loggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Println(r.Method) fmt.Println(r.URL) + + if addr, ok := r.Context().Value(http3.RemoteAddrContextKey).(net.Addr); ok { + log.Info("HTTP3 request", "remote", r.Context().Value(http3.RemoteAddrContextKey)) + ctx := peer.NewContext(r.Context(), &peer.Peer{Addr: addr}) + r = r.WithContext(ctx) + } h.next.ServeHTTP(w, r) } @@ -847,6 +856,12 @@ func realMain(ctx context.Context) error { Connect: &connect.BeaconSenderFactory{ Dialer: (&squic.EarlyDialerFactory{ Transport: quicStack.InsecureDialer.Transport, + TLSConfig: func() *tls.Config { + cfg := quicStack.InsecureDialer.TLSConfig.Clone() + cfg.NextProtos = []string{"h3", "SCION"} + return cfg + }(), + Rewriter: dialer.Rewriter, }).NewDialer, }, Grpc: &beaconinggrpc.BeaconSenderFactory{ diff --git a/go_deps.bzl b/go_deps.bzl index b1e6e04852..699c057988 100644 --- a/go_deps.bzl +++ b/go_deps.bzl @@ -1162,6 +1162,10 @@ def go_deps(): go_repository( name = "com_github_quic_go_quic_go", importpath = "github.com/quic-go/quic-go", + patch_args = ["-p1"], # keep + patches = [ + "@//patches/com_github_quic_go_quic_go:http3_remote_addr.patch", # keep + ], sum = "h1:GYd1iznlKm7dpHD7pOVpUvItgMPo/jrMgDWZhMCecqw=", version = "v0.40.0", ) diff --git a/patches/com_github_quic_go_quic_go/http3_remote_addr.patch b/patches/com_github_quic_go_quic_go/http3_remote_addr.patch new file mode 100644 index 0000000000..290511f152 --- /dev/null +++ b/patches/com_github_quic_go_quic_go/http3_remote_addr.patch @@ -0,0 +1,21 @@ +diff --git a/http3/server.go b/http3/server.go +index ac2e32a6..a4c4e327 100644 +--- a/http3/server.go ++++ b/http3/server.go +@@ -115,6 +115,8 @@ func (k *contextKey) String() string { return "quic-go/http3 context value " + k + // type *http3.Server. + var ServerContextKey = &contextKey{"http3-server"} + ++var RemoteAddrContextKey = &contextKey{"remote-addr"} ++ + type requestError struct { + err error + streamErr ErrCode +@@ -597,6 +599,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q + ctx := str.Context() + ctx = context.WithValue(ctx, ServerContextKey, s) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) ++ ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr()) + req = req.WithContext(ctx) + r := newResponseWriter(str, conn, s.logger) + if req.Method == http.MethodHead { diff --git a/pkg/connect/BUILD.bazel b/pkg/connect/BUILD.bazel new file mode 100644 index 0000000000..3f0c6e1a0c --- /dev/null +++ b/pkg/connect/BUILD.bazel @@ -0,0 +1,16 @@ +load("//tools/lint:go.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["dialer.go"], + importpath = "github.com/scionproto/scion/pkg/connect", + visibility = ["//visibility:public"], + deps = [ + "//pkg/grpc:go_default_library", + "//pkg/private/common:go_default_library", + "//pkg/private/serrors:go_default_library", + "//pkg/snet:go_default_library", + "//pkg/snet/squic:go_default_library", + "@com_github_quic_go_quic_go//:go_default_library", + ], +) diff --git a/pkg/connect/dialer.go b/pkg/connect/dialer.go new file mode 100644 index 0000000000..967c011b1a --- /dev/null +++ b/pkg/connect/dialer.go @@ -0,0 +1,38 @@ +package conect + +import ( + "context" + "crypto/tls" + + "github.com/quic-go/quic-go" + "github.com/scionproto/scion/pkg/grpc" + "github.com/scionproto/scion/pkg/private/common" + "github.com/scionproto/scion/pkg/private/serrors" + "github.com/scionproto/scion/pkg/snet" + "github.com/scionproto/scion/pkg/snet/squic" +) + +type QUICDialer struct { + Rewriter grpc.AddressRewriter + Transport *quic.Transport + TLSConfig *tls.Config + QUICConfig *quic.Config +} + +func (d *QUICDialer) DialEarly(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + addr, _, err := d.Rewriter.RedirectToQUIC(ctx, addr) + if err != nil { + return nil, serrors.WrapStr("resolving SVC address", err) + } + if _, ok := addr.(*snet.UDPAddr); !ok { + return nil, serrors.New("wrong address type after svc resolution", + "type", common.TypeOf(addr)) + } + dialer := squic.EarlyDialer{ + Addr: addr, + Transport: d.Transport, + TLSConfig: d.TLSConfig, + QUICConfig: d.QUICConfig, + } + return dialer.DialEarly(ctx, "", nil, nil) +} diff --git a/pkg/snet/squic/BUILD.bazel b/pkg/snet/squic/BUILD.bazel index 600d0667c5..aeef75240d 100644 --- a/pkg/snet/squic/BUILD.bazel +++ b/pkg/snet/squic/BUILD.bazel @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log:go_default_library", + "//pkg/private/common:go_default_library", "//pkg/private/serrors:go_default_library", "//pkg/snet:go_default_library", "@com_github_quic_go_quic_go//:go_default_library", diff --git a/pkg/snet/squic/early.go b/pkg/snet/squic/early.go index c34108a8d9..1151680199 100644 --- a/pkg/snet/squic/early.go +++ b/pkg/snet/squic/early.go @@ -24,44 +24,68 @@ import ( "github.com/quic-go/quic-go" + "github.com/scionproto/scion/pkg/private/common" "github.com/scionproto/scion/pkg/private/serrors" + "github.com/scionproto/scion/pkg/snet" ) +type AddressRewriter interface { + RedirectToQUIC(ctx context.Context, address net.Addr) (net.Addr, bool, error) +} + type EarlyDialerFactory struct { - Transport *quic.Transport + Transport *quic.Transport + TLSConfig *tls.Config + QUICConfig *quic.Config + Rewriter AddressRewriter } func (f *EarlyDialerFactory) NewDialer(a net.Addr) EarlyDialer { return EarlyDialer{ - Transport: f.Transport, - Addr: a, + Addr: a, + Transport: f.Transport, + TLSConfig: f.TLSConfig, + QUICConfig: f.QUICConfig, + Rewriter: f.Rewriter, } } type EarlyDialer struct { - Transport *quic.Transport - Addr net.Addr + Addr net.Addr + Transport *quic.Transport + TLSConfig *tls.Config + QUICConfig *quic.Config + Rewriter AddressRewriter } -func (d *EarlyDialer) DialEarly(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - serverName := tlsCfg.ServerName +func (d *EarlyDialer) DialEarly(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + addr, _, err := d.Rewriter.RedirectToQUIC(ctx, d.Addr) + if err != nil { + return nil, serrors.WrapStr("resolving SVC address", err) + } + if _, ok := addr.(*snet.UDPAddr); !ok { + return nil, serrors.New("wrong address type after svc resolution", + "type", common.TypeOf(addr)) + } + + serverName := d.TLSConfig.ServerName if serverName == "" { - serverName = computeServerName(d.Addr) + serverName = computeServerName(addr) } var session quic.EarlyConnection for sleep := 2 * time.Millisecond; ctx.Err() == nil; sleep = sleep * 2 { // Clone TLS config to avoid data races. - tlsConfig := tlsCfg.Clone() + tlsConfig := d.TLSConfig.Clone() tlsConfig.ServerName = serverName // Clone QUIC config to avoid data races, if it exists. var quicConfig *quic.Config - if cfg != nil { - quicConfig = cfg.Clone() + if d.QUICConfig != nil { + quicConfig = d.QUICConfig.Clone() } var err error - session, err = d.Transport.DialEarly(ctx, d.Addr, tlsConfig, quicConfig) + session, err = d.Transport.DialEarly(ctx, addr, tlsConfig, quicConfig) if err == nil { break }