From 3e895ea2ea4ce4fffc22fe52857ec513d9f5a13d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Urban?= Date: Mon, 16 Oct 2023 20:39:48 +0200 Subject: [PATCH] fix(backend): improve extract domain function --- backend/internal/handlers/utils.go | 42 ++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/backend/internal/handlers/utils.go b/backend/internal/handlers/utils.go index 3121a43..dc866ef 100644 --- a/backend/internal/handlers/utils.go +++ b/backend/internal/handlers/utils.go @@ -9,23 +9,49 @@ import ( ) func extractMainDomain(u string) (string, error) { - //TODO: This function currently does not work with double tlds like .co.uk + // Prepend http:// if no scheme is provided, this ensures url.Parse succeeds + if !strings.Contains(u, "//") { + u = "http://" + u + } + + // Parse the URL and validate it parsedURL, err := url.Parse(u) if err != nil { - return "", err + return "", fmt.Errorf("error parsing URL: %v", err) } + // Split the hostname into parts parts := strings.Split(parsedURL.Hostname(), ".") - if len(parts) < 2 { - return "", fmt.Errorf("invalid domain") + partsLength := len(parts) + + // Check if the URL has at least a domain and a TLD + if partsLength < 2 { + return "", fmt.Errorf("invalid domain: domain and TLD not found in URL") } - // Extract the main domain and TLD - domain := parts[len(parts)-2] // Second to last part is the main domain - tld := parts[len(parts)-1] // Last part is the TLD + // Handle second-level domains (SLDs) like ".co.uk", ".com.au", etc. + if partsLength > 2 { + // List of common SLDs + secondLevelDomains := map[string]bool{ + "com.au": true, + "co.uk": true, + "com.br": true, + // ... add more second-level domains as needed + } - return fmt.Sprintf(".%s.%s", domain, tld), nil + // Check if the last two parts match a known second-level domain + if secondLevelDomains[parts[partsLength-2]+"."+parts[partsLength-1]] { + if partsLength < 3 { + return "", fmt.Errorf("invalid domain: missing main domain before second-level domain") + } + return fmt.Sprintf("%s.%s.%s", parts[partsLength-3], parts[partsLength-2], parts[partsLength-1]), nil + } + } + + // For non-SLDs, return the domain and TLD + return fmt.Sprintf("%s.%s", parts[partsLength-2], parts[partsLength-1]), nil } + func sendJSONError(w http.ResponseWriter, message string, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status)