diff --git a/endpoints/partition.go b/endpoints/partition.go index fcafdb0e..44423776 100644 --- a/endpoints/partition.go +++ b/endpoints/partition.go @@ -4,6 +4,7 @@ package endpoints import ( + "maps" "regexp" ) @@ -36,13 +37,39 @@ func (p Partition) RegionRegex() *regexp.Regexp { return p.regionRegex } +// Regions returns a map of Regions for the partition, indexed by their ID. +func (p Partition) Regions() map[string]Region { + partitionAndRegion, ok := partitionsAndRegions[p.id] + if !ok { + return nil + } + + return maps.Clone(partitionAndRegion.regions) +} + // DefaultPartitions returns a list of the partitions. func DefaultPartitions() []Partition { - partitions := make([]Partition, len(partitionsAndRegions)) + ps := make([]Partition, len(partitionsAndRegions)) for _, v := range partitionsAndRegions { - partitions = append(partitions, v.partition) + ps = append(ps, v.partition) + } + + return ps +} + +// PartitionForRegion returns the first partition which includes the specific Region. +func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) { + for _, p := range ps { + partitionAndRegion, ok := partitionsAndRegions[p.id] + if !ok { + continue + } + + if _, ok := partitionAndRegion.regions[regionID]; ok || partitionAndRegion.partition.regionRegex.MatchString(regionID) { + return p, true + } } - return partitions + return Partition{}, false } diff --git a/endpoints/partition_test.go b/endpoints/partition_test.go index 01db7c7e..9705571c 100644 --- a/endpoints/partition_test.go +++ b/endpoints/partition_test.go @@ -17,3 +17,41 @@ func TestDefaultPartitions(t *testing.T) { t.Fatalf("expected partitions, got none") } } + +func TestPartitionForRegion(t *testing.T) { + t.Parallel() + + testcases := map[string]struct { + expectedFound bool + expectedID string + }{ + "us-east-1": { + expectedFound: true, + expectedID: "aws", + }, + "us-gov-west-1": { + expectedFound: true, + expectedID: "aws-us-gov", + }, + "not-found": { + expectedFound: false, + }, + "us-east-17": { + expectedFound: true, + expectedID: "aws", + }, + } + + ps := endpoints.DefaultPartitions() + for region, testcase := range testcases { + gotID, gotFound := endpoints.PartitionForRegion(ps, region) + + if gotFound != testcase.expectedFound { + t.Errorf("expected PartitionFound %t for Region %q, got %t", testcase.expectedFound, region, gotFound) + } + if gotID.ID() != testcase.expectedID { + t.Errorf("expected PartitionID %q for Region %q, got %q", testcase.expectedID, region, gotID.ID()) + } + } + +}