From f3909f6336f75f57eb47226dd7c08f998035d042 Mon Sep 17 00:00:00 2001 From: Gerson Alexander Pardo Gamez Date: Fri, 5 Feb 2021 14:37:34 -0500 Subject: [PATCH] added client certificate support for tunnel --- client/cli.go | 14 ++++++++++++++ client/config.go | 10 ++++++++++ client/model.go | 25 +++++++++++++++++++++---- client/tls.go | 12 ++++++++++-- server/cli.go | 41 ++++++++++++++++++++++------------------- server/main.go | 14 +++++++------- server/tls.go | 2 +- 7 files changed, 85 insertions(+), 33 deletions(-) diff --git a/client/cli.go b/client/cli.go index 8c0ad15f..cc8f4a90 100644 --- a/client/cli.go +++ b/client/cli.go @@ -48,6 +48,8 @@ type Options struct { inspectaddr string inspectpublic bool tls bool + tlsClientCrt string + tlsClientKey string args []string } @@ -103,6 +105,16 @@ func ParseArgs() (opts *Options, err error) { false, "Use dial for tls port") + tlsClientCrt := flag.String( + "tlsClientCrt", + "", + "Path to a TLS Client CRT file if server requires") + + tlsClientKey := flag.String( + "tlsClientKey", + "", + "Path to a TLS Client Key file if server requires") + inspectaddr := flag.String( "inspectaddr", defaultInspectAddr, @@ -127,6 +139,8 @@ func ParseArgs() (opts *Options, err error) { inspectaddr: *inspectaddr, inspectpublic: *inspectpublic, tls: *tls, + tlsClientCrt: *tlsClientCrt, + tlsClientKey: *tlsClientKey, command: flag.Arg(0), } diff --git a/client/config.go b/client/config.go index dfd3155b..0ba1fa4a 100644 --- a/client/config.go +++ b/client/config.go @@ -24,6 +24,8 @@ type Configuration struct { AuthToken string `yaml:"auth_token,omitempty"` Tunnels map[string]*TunnelConfiguration `yaml:"tunnels,omitempty"` TLS bool `yaml:"tls,omitempty"` + TLSClientCrt string `yaml:"tls_client_crt,omitempty"` + TLSClientKey string `yaml:"tls_client_key,omitempty"` LogTo string `yaml:"-"` Path string `yaml:"-"` } @@ -82,6 +84,14 @@ func LoadConfiguration(opts *Options) (config *Configuration, err error) { config.InspectAddr = opts.inspectaddr } + if opts.tlsClientCrt != "" { + config.TLSClientCrt = opts.tlsClientCrt + } + + if opts.tlsClientKey != "" { + config.TLSClientKey = opts.tlsClientKey + } + if config.InspectAddr == "" { config.InspectAddr = defaultInspectAddr } diff --git a/client/model.go b/client/model.go index 4891f4e0..3ba1e4ff 100644 --- a/client/model.go +++ b/client/model.go @@ -54,6 +54,8 @@ type ClientModel struct { tunnelConfig map[string]*TunnelConfiguration configPath string TLS bool + TLSClientCrt string + TLSClientKey string } func newClientModel(config *Configuration, ctl mvc.Controller) *ClientModel { @@ -104,18 +106,33 @@ func newClientModel(config *Configuration, ctl mvc.Controller) *ClientModel { // TLS for dial port TLS: config.TLS, + + // TLSClientCrt for connect using a client certificate + TLSClientCrt: config.TLSClientCrt, + + // TLSClientKey for connect using a client certificate + TLSClientKey: config.TLSClientKey, } + m.tlsConfig = &tls.Config{} // configure TLS if config.TrustHostRootCerts { m.Info("Trusting host's root certificates") - m.tlsConfig = &tls.Config{} } else { - m.Info("Trusting root CAs: %v", rootCrtPaths) - var err error - if m.tlsConfig, err = LoadTLSConfig(rootCrtPaths); err != nil { + m.Info("using root CAs: %v", rootCrtPaths) + rootCAs, err := LoadTLSRootCAs(rootCrtPaths) + if err != nil { + panic(err) + } + m.tlsConfig.RootCAs = rootCAs + } + + if m.TLSClientCrt != "" { + certificates, err := LoadTLSCertificate(m.TLSClientCrt, m.TLSClientKey) + if err != nil { panic(err) } + m.tlsConfig.Certificates = certificates } // configure TLS SNI diff --git a/client/tls.go b/client/tls.go index 826deb9c..d81d370a 100644 --- a/client/tls.go +++ b/client/tls.go @@ -9,7 +9,7 @@ import ( "pgrok/client/assets" ) -func LoadTLSConfig(rootCertPaths []string) (*tls.Config, error) { +func LoadTLSRootCAs(rootCertPaths []string) (*x509.CertPool, error) { pool := x509.NewCertPool() for _, certPath := range rootCertPaths { @@ -31,5 +31,13 @@ func LoadTLSConfig(rootCertPaths []string) (*tls.Config, error) { pool.AddCert(certs[0]) } - return &tls.Config{RootCAs: pool}, nil + return pool, nil +} + +func LoadTLSCertificate(certPath, certKey string) ([]tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certPath, certKey) + if err != nil { + return nil, err + } + return []tls.Certificate{cert}, nil } diff --git a/server/cli.go b/server/cli.go index 8588c0df..75a142ae 100644 --- a/server/cli.go +++ b/server/cli.go @@ -5,38 +5,41 @@ import ( ) type Options struct { - httpAddr string - httpsAddr string - tunnelAddr string - domain string - tlsCrt string - tlsKey string - tlsClientCA string - logto string - loglevel string + httpAddr string + httpsAddr string + tunnelAddr string + tunnelTLSClientCA string + domain string + tlsCrt string + tlsKey string + tlsClientCA string + logto string + loglevel string } func parseArgs() *Options { httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable") httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable") tunnelAddr := flag.String("tunnelAddr", ":4443", "Public address listening for pgrok client") + tunnelTLSClientCA := flag.String("tunnelTLSClientCA", "", "Path to a TLS Client CA file if you want enable mutual auth for tunnel") domain := flag.String("domain", "ejemplo.me", "Domain where the tunnels are hosted") tlsCrt := flag.String("tlsCrt", "", "Path to a TLS certificate file") tlsKey := flag.String("tlsKey", "", "Path to a TLS key file") - tlsClientCA := flag.String("tlsClientCA", "", "Path to a TLS Client CA file if you want enable mutual auth") + tlsClientCA := flag.String("tlsClientCA", "", "Path to a TLS Client CA file if you want enable mutual auth for subdomains") logto := flag.String("log", "stdout", "Write log messages to this file. 'stdout' and 'none' have special meanings") loglevel := flag.String("log-level", "DEBUG", "The level of messages to log. One of: DEBUG, INFO, WARNING, ERROR") flag.Parse() return &Options{ - httpAddr: *httpAddr, - httpsAddr: *httpsAddr, - tunnelAddr: *tunnelAddr, - domain: *domain, - tlsCrt: *tlsCrt, - tlsKey: *tlsKey, - tlsClientCA: *tlsClientCA, - logto: *logto, - loglevel: *loglevel, + httpAddr: *httpAddr, + httpsAddr: *httpsAddr, + tunnelAddr: *tunnelAddr, + tunnelTLSClientCA: *tunnelTLSClientCA, + domain: *domain, + tlsCrt: *tlsCrt, + tlsKey: *tlsKey, + tlsClientCA: *tlsClientCA, + logto: *logto, + loglevel: *loglevel, } } diff --git a/server/main.go b/server/main.go index e2496e28..eb4c9660 100644 --- a/server/main.go +++ b/server/main.go @@ -120,12 +120,6 @@ func Main() { // start listeners listeners = make(map[string]*conn.Listener) - // load tls configuration - tlsConfig, err := LoadTLSConfig(opts.tlsCrt, opts.tlsKey) - if err != nil { - panic(err) - } - // listen for http if opts.httpAddr != "" { listeners["http"] = startHttpListener(opts.httpAddr, nil) @@ -133,13 +127,19 @@ func Main() { // listen for https if opts.httpsAddr != "" { - tlsConfigServer, err := LoadTLSConfigServer(opts.tlsCrt, opts.tlsKey, opts.tlsClientCA) + tlsConfigServer, err := LoadTLSConfigWithCA(opts.tlsCrt, opts.tlsKey, opts.tlsClientCA) if err != nil { panic(err) } listeners["https"] = startHttpListener(opts.httpsAddr, tlsConfigServer) } + // load tls configuration + tlsConfig, err := LoadTLSConfigWithCA(opts.tlsCrt, opts.tlsKey, opts.tunnelTLSClientCA) + if err != nil { + panic(err) + } + // pgrok clients tunnelListener(opts.tunnelAddr, tlsConfig) } diff --git a/server/tls.go b/server/tls.go index 69a03875..62b5b80a 100644 --- a/server/tls.go +++ b/server/tls.go @@ -66,7 +66,7 @@ func LoadTLSConfig(crtPath, keyPath string) (tlsConfig *tls.Config, err error) { return } -func LoadTLSConfigServer(crtPath, keyPath, clientCAPath string) (tlsConfig *tls.Config, err error) { +func LoadTLSConfigWithCA(crtPath, keyPath, clientCAPath string) (tlsConfig *tls.Config, err error) { tlsConfig, err = LoadTLSConfig(crtPath, keyPath) if err != nil {