From 6aae002ca749ae939d0dea3b2383f07f296df3e4 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sat, 30 Mar 2024 18:02:27 -0400 Subject: [PATCH] refactor: add `:` prefix to ports during config unmarshaling --- config/config.go | 7 +++++++ config/config_test.go | 4 ++-- helpertest/helper.go | 16 ++++++++++++---- server/server.go | 12 ++---------- server/server_test.go | 17 +++++++++-------- 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/config/config.go b/config/config.go index f512e17ce..8890b996e 100644 --- a/config/config.go +++ b/config/config.go @@ -171,6 +171,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error { *l = strings.Split(addresses, ",") + // Prefix all ports with : + for i, addr := range *l { + if !strings.ContainsRune(addr, ':') { + (*l)[i] = ":" + addr + } + } + return nil } diff --git a/config/config_test.go b/config/config_test.go index be75775d7..dc28c0b78 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -462,7 +462,7 @@ bootstrapDns: err := l.UnmarshalText([]byte("55,:56")) Expect(err).Should(Succeed()) Expect(*l).Should(HaveLen(2)) - Expect(*l).Should(ContainElements("55", ":56")) + Expect(*l).Should(ContainElements(":55", ":56")) }) }) }) @@ -958,7 +958,7 @@ bootstrapDns: }) func defaultTestFileConfig(config *Config) { - Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"})) + Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"})) Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError)) Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky")) Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3)) diff --git a/helpertest/helper.go b/helpertest/helper.go index cd5415bcc..8a3460875 100644 --- a/helpertest/helper.go +++ b/helpertest/helper.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "net" "net/http" "net/http/httptest" "os" @@ -31,20 +32,27 @@ const ( DS = dns.Type(dns.TypeDS) ) -// GetIntPort returns an port for the current testing +// GetIntPort returns a port for the current testing // process by adding the current ginkgo parallel process to -// the base port and returning it as int +// the base port and returning it as int. func GetIntPort(port int) int { return port + ginkgo.GinkgoParallelProcess() } -// GetStringPort returns an port for the current testing +// GetStringPort returns a port for the current testing // process by adding the current ginkgo parallel process to -// the base port and returning it as string +// the base port and returning it as string. func GetStringPort(port int) string { return fmt.Sprintf("%d", GetIntPort(port)) } +// GetHostPort returns a host:port string for the current testing +// process by adding the current ginkgo parallel process to +// the base port and returning it as string. +func GetHostPort(host string, port int) string { + return net.JoinHostPort(host, GetStringPort(port)) +} + // TempFile creates temp file with passed data func TempFile(data string) *os.File { f, err := os.CreateTemp("", "prefix") diff --git a/server/server.go b/server/server.go index b33b85433..7404541ad 100644 --- a/server/server.go +++ b/server/server.go @@ -60,14 +60,6 @@ func tlsCipherSuites() []uint16 { return tlsCipherSuites } -func getServerAddress(addr string) string { - if !strings.Contains(addr, ":") { - addr = fmt.Sprintf(":%s", addr) - } - - return addr -} - type NewServerFunc func(address string) (*dns.Server, error) func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) { @@ -195,7 +187,7 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error { for _, address := range addresses { - server, err := newServer(getServerAddress(address)) + server, err := newServer(address) if err != nil { return err } @@ -236,7 +228,7 @@ func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listene listeners := make([]net.Listener, 0, len(addresses)) for _, address := range addresses { - listener, err := net.Listen("tcp", getServerAddress(address)) + listener, err := net.Listen("tcp", address) if err != nil { return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err) } diff --git a/server/server_test.go b/server/server_test.go index 60207d0ac..753c87afd 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/base64" + "fmt" "io" "net" "net/http" @@ -43,7 +44,7 @@ var ( ) var _ = BeforeSuite(func() { - baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/" + baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort)) queryURL = baseURL + "dns-query" var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream ctx, cancelFn := context.WithCancel(context.Background()) @@ -146,10 +147,10 @@ var _ = BeforeSuite(func() { }, Ports: config.Ports{ - DNS: config.ListenConfig{GetStringPort(dnsBasePort)}, - TLS: config.ListenConfig{GetStringPort(tlsBasePort)}, - HTTP: config.ListenConfig{GetStringPort(httpBasePort)}, - HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)}, + DNS: config.ListenConfig{GetHostPort("", dnsBasePort)}, + TLS: config.ListenConfig{GetHostPort("", tlsBasePort)}, + HTTP: config.ListenConfig{GetHostPort("", httpBasePort)}, + HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)}, }, CertFile: certPem.Path, KeyFile: keyPem.Path, @@ -633,7 +634,7 @@ var _ = Describe("Running DNS server", func() { }, Blocking: config.Blocking{BlockType: "zeroIp"}, Ports: config.Ports{ - DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)}, + DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)}, }, }) @@ -677,7 +678,7 @@ var _ = Describe("Running DNS server", func() { }, Blocking: config.Blocking{BlockType: "zeroIp"}, Ports: config.Ports{ - DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)}, + DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)}, }, }) @@ -751,7 +752,7 @@ var _ = Describe("Running DNS server", func() { }) func requestServer(request *dns.Msg) *dns.Msg { - conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort)) + conn, err := net.Dial("udp", GetHostPort("", dnsBasePort)) if err != nil { Log().Fatal("could not connect to server: ", err) }