diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index b8ea49e65e..8bea0cc8ff 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "net" + "net/http" "strconv" "github.com/Dreamacro/clash/component/dialer" @@ -18,6 +19,7 @@ import ( type Trojan struct { *Base instance *trojan.Trojan + option *TrojanOption // for gun mux gunTLSConfig *tls.Config @@ -36,6 +38,34 @@ type TrojanOption struct { UDP bool `proxy:"udp,omitempty"` Network string `proxy:"network,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` + WSOpts WSOptions `proxy:"ws-opts,omitempty"` +} + +func (t *Trojan) plainStream(c net.Conn) (net.Conn, error) { + if t.option.Network == "ws" { + host, port, _ := net.SplitHostPort(t.addr) + wsOpts := &trojan.WebsocketOption{ + Host: host, + Port: port, + Path: t.option.WSOpts.Path, + } + + if t.option.SNI != "" { + wsOpts.Host = t.option.SNI + } + + if len(t.option.WSOpts.Headers) != 0 { + header := http.Header{} + for key, value := range t.option.WSOpts.Headers { + header.Add(key, value) + } + wsOpts.Headers = header + } + + return t.instance.StreamWebsocketConn(c, wsOpts) + } + + return t.instance.StreamConn(c) } // StreamConn implements C.ProxyAdapter @@ -44,7 +74,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) if t.transport != nil { c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig) } else { - c, err = t.instance.StreamConn(c) + c, err = t.plainStream(c) } if err != nil { @@ -106,7 +136,7 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata) } defer safeConnClose(c, err) tcpKeepAlive(c) - c, err = t.instance.StreamConn(c) + c, err = t.plainStream(c) if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) } @@ -143,6 +173,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) { udp: option.UDP, }, instance: trojan.New(tOption), + option: &option, } if option.Network == "grpc" { diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 209084ca8b..b4db00ad93 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -105,8 +105,16 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { if v.option.TLS { wsOpts.TLS = true - wsOpts.SkipCertVerify = v.option.SkipCertVerify - wsOpts.ServerName = v.option.ServerName + wsOpts.TLSConfig = &tls.Config{ + ServerName: host, + InsecureSkipVerify: v.option.SkipCertVerify, + NextProtos: []string{"http/1.1"}, + } + if v.option.ServerName != "" { + wsOpts.TLSConfig.ServerName = v.option.ServerName + } else if host := wsOpts.Headers.Get("Host"); host != "" { + wsOpts.TLSConfig.ServerName = host + } } c, err = vmess.StreamWebsocketConn(c, wsOpts) case "http": diff --git a/test/clash_test.go b/test/clash_test.go index a3303a2397..5eb9d5bda5 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -31,6 +31,7 @@ const ( ImageShadowsocksRust = "ghcr.io/shadowsocks/ssserver-rust:latest" ImageVmess = "v2fly/v2fly-core:latest" ImageTrojan = "trojangfw/trojan:latest" + ImageTrojanGo = "p4gefau1t/trojan-go:latest" ImageSnell = "icpz/snell-server:latest" ImageXray = "teddysun/xray:latest" ) @@ -99,6 +100,7 @@ func init() { ImageShadowsocksRust, ImageVmess, ImageTrojan, + ImageTrojanGo, ImageSnell, ImageXray, } diff --git a/test/config/trojan-ws.json b/test/config/trojan-ws.json new file mode 100644 index 0000000000..efc0acbd06 --- /dev/null +++ b/test/config/trojan-ws.json @@ -0,0 +1,20 @@ +{ + "run_type": "server", + "local_addr": "0.0.0.0", + "local_port": 10002, + "disable_http_check": true, + "password": [ + "example" + ], + "websocket": { + "enabled": true, + "path": "/", + "host": "example.org" + }, + "ssl": { + "verify": true, + "cert": "/fullchain.pem", + "key": "/privkey.pem", + "sni": "example.org" + } +} \ No newline at end of file diff --git a/test/trojan_test.go b/test/trojan_test.go index a57dab99fe..d1ab2a0068 100644 --- a/test/trojan_test.go +++ b/test/trojan_test.go @@ -93,6 +93,44 @@ func TestClash_TrojanGrpc(t *testing.T) { testSuit(t, proxy) } +func TestClash_TrojanWebsocket(t *testing.T) { + cfg := &container.Config{ + Image: ImageTrojanGo, + ExposedPorts: defaultExposedPorts, + } + hostCfg := &container.HostConfig{ + PortBindings: defaultPortBindings, + Binds: []string{ + fmt.Sprintf("%s:/etc/trojan-go/config.json", C.Path.Resolve("trojan-ws.json")), + fmt.Sprintf("%s:/fullchain.pem", C.Path.Resolve("example.org.pem")), + fmt.Sprintf("%s:/privkey.pem", C.Path.Resolve("example.org-key.pem")), + }, + } + + id, err := startContainer(cfg, hostCfg, "trojan-ws") + if err != nil { + assert.FailNow(t, err.Error()) + } + defer cleanContainer(id) + + proxy, err := outbound.NewTrojan(outbound.TrojanOption{ + Name: "trojan", + Server: localIP.String(), + Port: 10002, + Password: "example", + SNI: "example.org", + SkipCertVerify: true, + UDP: true, + Network: "ws", + }) + if err != nil { + assert.FailNow(t, err.Error()) + } + + time.Sleep(waitTime) + testSuit(t, proxy) +} + func Benchmark_Trojan(b *testing.B) { cfg := &container.Config{ Image: ImageTrojan, diff --git a/transport/trojan/trojan.go b/transport/trojan/trojan.go index d39cbec1ee..26e7adc7d5 100644 --- a/transport/trojan/trojan.go +++ b/transport/trojan/trojan.go @@ -8,10 +8,12 @@ import ( "errors" "io" "net" + "net/http" "sync" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/socks5" + "github.com/Dreamacro/clash/transport/vmess" ) const ( @@ -38,6 +40,13 @@ type Option struct { SkipCertVerify bool } +type WebsocketOption struct { + Host string + Port string + Path string + Headers http.Header +} + type Trojan struct { option *Option hexPassword []byte @@ -64,6 +73,29 @@ func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) { return tlsConn, nil } +func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) { + alpn := defaultALPN + if len(t.option.ALPN) != 0 { + alpn = t.option.ALPN + } + + tlsConfig := &tls.Config{ + NextProtos: alpn, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: t.option.SkipCertVerify, + ServerName: t.option.ServerName, + } + + return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{ + Host: wsOptions.Host, + Port: wsOptions.Port, + Path: wsOptions.Path, + Headers: wsOptions.Headers, + TLS: true, + TLSConfig: tlsConfig, + }) +} + func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error { buf := pool.GetBuffer() defer pool.PutBuffer(buf) diff --git a/transport/v2ray-plugin/websocket.go b/transport/v2ray-plugin/websocket.go index 317c172fc7..7591b4a865 100644 --- a/transport/v2ray-plugin/websocket.go +++ b/transport/v2ray-plugin/websocket.go @@ -1,6 +1,7 @@ package obfs import ( + "crypto/tls" "net" "net/http" @@ -26,12 +27,22 @@ func NewV2rayObfs(conn net.Conn, option *Option) (net.Conn, error) { } config := &vmess.WebsocketConfig{ - Host: option.Host, - Port: option.Port, - Path: option.Path, - TLS: option.TLS, - Headers: header, - SkipCertVerify: option.SkipCertVerify, + Host: option.Host, + Port: option.Port, + Path: option.Path, + Headers: header, + } + + if option.TLS { + config.TLS = true + config.TLSConfig = &tls.Config{ + ServerName: option.Host, + InsecureSkipVerify: option.SkipCertVerify, + NextProtos: []string{"http/1.1"}, + } + if host := config.Headers.Get("Host"); host != "" { + config.TLSConfig.ServerName = host + } } var err error diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 951fb5d0ae..f769dcce84 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -45,8 +45,7 @@ type WebsocketConfig struct { Path string Headers http.Header TLS bool - SkipCertVerify bool - ServerName string + TLSConfig *tls.Config MaxEarlyData int EarlyDataHeaderName string } @@ -254,17 +253,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf scheme := "ws" if c.TLS { scheme = "wss" - dialer.TLSClientConfig = &tls.Config{ - ServerName: c.Host, - InsecureSkipVerify: c.SkipCertVerify, - NextProtos: []string{"http/1.1"}, - } - - if c.ServerName != "" { - dialer.TLSClientConfig.ServerName = c.ServerName - } else if host := c.Headers.Get("Host"); host != "" { - dialer.TLSClientConfig.ServerName = host - } + dialer.TLSClientConfig = c.TLSConfig } uri := url.URL{