diff --git a/fed/blocklist.go b/fed/blocklist.go index 36e05b76..3301bcf5 100644 --- a/fed/blocklist.go +++ b/fed/blocklist.go @@ -24,6 +24,7 @@ import ( "math" "os" "path/filepath" + "strings" "sync" "time" ) @@ -134,10 +135,21 @@ func NewBlockList(log *slog.Logger, path string) (*BlockList, error) { // Contains determines if a domain is blocked. func (b *BlockList) Contains(domain string) bool { + domain = strings.Trim(domain, ".") + b.lock.Lock() - _, contains := b.domains[domain] - b.lock.Unlock() - return contains + defer b.lock.Unlock() + + for { + if _, contains := b.domains[domain]; contains { + return true + } + if i := strings.IndexRune(domain, '.'); i == -1 { + return false + } else { + domain = domain[i+1:] + } + } } // Close frees resources. diff --git a/fed/blocklist_test.go b/fed/blocklist_test.go new file mode 100644 index 00000000..a32f54f6 --- /dev/null +++ b/fed/blocklist_test.go @@ -0,0 +1,88 @@ +/* +Copyright 2024 Dima Krasner + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fed + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestBlockList_NotBlockedDomain(t *testing.T) { + assert := assert.New(t) + + blockList := BlockList{} + blockList.domains = map[string]struct{}{ + "0.0.0.0.com": struct{}{}, + } + + assert.False(blockList.Contains("127.0.0.1.com")) +} + +func TestBlockList_BlockedDomain(t *testing.T) { + assert := assert.New(t) + + blockList := BlockList{} + blockList.domains = map[string]struct{}{ + "0.0.0.0.com": struct{}{}, + } + + assert.True(blockList.Contains("0.0.0.0.com")) +} + +func TestBlockList_BlockedSubdomain(t *testing.T) { + assert := assert.New(t) + + blockList := BlockList{} + blockList.domains = map[string]struct{}{ + "social.0.0.0.0.com": struct{}{}, + } + + assert.True(blockList.Contains("social.0.0.0.0.com")) +} + +func TestBlockList_NotBlockedSubdomain(t *testing.T) { + assert := assert.New(t) + + blockList := BlockList{} + blockList.domains = map[string]struct{}{ + "social.0.0.0.0.com": struct{}{}, + } + + assert.False(blockList.Contains("blog.0.0.0.0.com")) +} + +func TestBlockList_BlockedSubdomainByDomain(t *testing.T) { + assert := assert.New(t) + + blockList := BlockList{} + blockList.domains = map[string]struct{}{ + "0.0.0.0.com": struct{}{}, + } + + assert.True(blockList.Contains("social.0.0.0.0.com")) +} + +func TestBlockList_BlockedSubdomainByDomainEndsWithDot(t *testing.T) { + assert := assert.New(t) + + blockList := BlockList{} + blockList.domains = map[string]struct{}{ + "0.0.0.0.com": struct{}{}, + } + + assert.True(blockList.Contains("social.0.0.0.0.com.")) +}