From ac6cd9b8425866b31905fe5e3a4f2a4969b43e60 Mon Sep 17 00:00:00 2001 From: Ignacy Osetek Date: Wed, 15 May 2024 15:42:19 +0200 Subject: [PATCH] Add AccessControl subpackage in Azure VPC code This change refactors the code and adds a separate package in Azure package accesscontrol for better control of Networking Security Group Rules in Azure. This code additionally implements App Connectivity in Azure based on labels. --- azure/accessControl.go | 226 ++++++---- azure/accessControl/name.go | 74 ++++ azure/accessControl/priorityPool.go | 119 ++++++ azure/accessControl/rule.go | 161 +++++++ azure/accessControl/ruleSet.go | 230 ++++++++++ azure/accessControl/securityGroup.go | 69 +++ azure/azure.go | 13 + azure/helper.go | 47 +++ azure/instances.go | 6 +- azure/peering.go | 24 -- azure/securityGroup.go | 601 +-------------------------- azure/securityGroupAccessControl.go | 201 +++++++++ azure/subnetAccessControl.go | 153 +++++++ azure/subnets.go | 19 + azure/vpc.go | 158 ++----- azure/vpcAccessControl.go | 228 ++++++++++ azure/vpcConnection.go | 140 +++++++ connector/helper/set.go | 11 + 18 files changed, 1635 insertions(+), 845 deletions(-) create mode 100644 azure/accessControl/name.go create mode 100644 azure/accessControl/priorityPool.go create mode 100644 azure/accessControl/rule.go create mode 100644 azure/accessControl/ruleSet.go create mode 100644 azure/accessControl/securityGroup.go create mode 100644 azure/securityGroupAccessControl.go create mode 100644 azure/subnetAccessControl.go create mode 100644 azure/vpcAccessControl.go create mode 100644 azure/vpcConnection.go diff --git a/azure/accessControl.go b/azure/accessControl.go index 5ad3271..5b00b21 100644 --- a/azure/accessControl.go +++ b/azure/accessControl.go @@ -20,74 +20,14 @@ package azure import ( "context" "fmt" - "strings" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + accesscontrol "github.com/app-net-interface/awi-infra-guard/azure/accessControl" + "github.com/app-net-interface/awi-infra-guard/connector/helper" + "github.com/app-net-interface/awi-infra-guard/grpc/go/infrapb" "github.com/app-net-interface/awi-infra-guard/types" ) -type vpcPolicy string - -const ( - vpcPolicyAllow = "allow" - vpcPolicyDeny = "deny" -) - -// func (c *Client) refreshVnetSubnetsWithVPCPolicy( -// ctx context.Context, -// vnet armnetwork.VirtualNetwork, -// inboundVnet string, -// policy vpcPolicy, -// ) error { -// c.logger.Trace( -// "updating virtual network '%s' subnets with VPC Policy %s", -// vnet, -// ) -// if vnet.Properties == nil { -// c.logger.Warnf( -// "virtual network '%s' has no properties - skipping policy update", -// helper.StringPointerToString(vnet.ID), -// ) -// return nil -// } - -// for i := range vnet.Properties.Subnets { -// if vnet.Properties.Subnets[i] == nil { -// c.logger.Warnf( -// "virtual network '%s' has a nil subnet pointer - skipping subnet entry", -// helper.StringPointerToString(vnet.ID), -// ) -// continue -// } -// if vnet.Properties.Subnets[i].Properties == nil { -// c.logger.Warnf( -// "virtual network '%s' has a subnet %s with no properties - skipping subnet entry", -// helper.StringPointerToString(vnet.ID), -// helper.StringPointerToString(vnet.Properties.Subnets[i].ID), -// ) -// continue -// } -// } - -// return nil -// } - -func getVnetSourceIDFromAWITags(tags map[string]string) (string, error) { - tagValue, ok := tags["awi"] - if !ok { - return "", fmt.Errorf( - "expected request key tag 'awi' with source ID but found none. Got tags: %v", - tags, - ) - } - if !strings.HasPrefix(tagValue, "default-") { - return "", fmt.Errorf( - "the value of 'awi' tag from request has invalid prefix. Expected 'default-' but got: %s", - tagValue, - ) - } - return strings.TrimPrefix(tagValue, "default-"), nil -} - // AccessControl interface implementation func (c *Client) AddInboundAllowRuleInVPC( ctx context.Context, @@ -98,7 +38,6 @@ func (c *Client) AddInboundAllowRuleInVPC( ruleName string, tags map[string]string, ) error { - vnet, vnetAccount, err := c.getVPC( ctx, destinationVpcID, region, ) @@ -113,38 +52,150 @@ func (c *Client) AddInboundAllowRuleInVPC( account = vnetAccount } - sourceID, err := getVnetSourceIDFromAWITags(tags) - if err != nil { - return fmt.Errorf( - "failed to obtain the ID of Source VPC: %w", err, - ) - } + ruleset := accesscontrol.AccessControlRuleSet{} + ruleset.NewDirectedVPCRules( + accesscontrol.CustomRuleName(ruleName), + accesscontrol.AccessAllow, + cidrsToAllow, + ) - err = c.refreshSubnetSecurityGroupWithVPCInbound( + err = c.ApplyAccessRulesToVPC( ctx, account, - region, - cidrsToAllow, - vpcPolicyAllow, vnet, - sourceID, - ruleName, + ruleset, ) if err != nil { return fmt.Errorf( - "failed to refresh Security Groups for subnets from VNet %s: %w", - destinationVpcID, err, + "failed to apply rules %v to VPC %s: %w", + ruleset, destinationVpcID, err, ) } return nil } +func (c *Client) getSubnetsFromInstances( + ctx context.Context, instances []types.Instance, +) ([]armnetwork.Subnet, error) { + + type subnetInfo struct { + VNetID string + SubnetID string + } + + subnetInfos := helper.Set[subnetInfo]{} + + for _, instance := range instances { + subnetInfos.Set(subnetInfo{ + VNetID: instance.VPCID, + SubnetID: instance.SubnetID, + }) + } + + infos := subnetInfos.Keys() + subnets := make([]armnetwork.Subnet, 0, len(infos)) + + for _, info := range infos { + subnet, _, err := c.getSubnet( + ctx, + parseResourceGroupName(info.SubnetID), + parseResourceName(info.VNetID), + parseResourceName(info.SubnetID), + ) + if err != nil { + return nil, fmt.Errorf( + "failed to get subnet %s: %w", + info.SubnetID, err, + ) + } + subnets = append(subnets, subnet) + } + + return subnets, nil +} + +func (c *Client) prepareCustomAccessRules( + instances []types.Instance, + ruleName string, + cidrsToAllow []string, + protocolsAndPorts types.ProtocolsAndPorts, +) (accesscontrol.AccessControlRuleSet, error) { + ruleset := accesscontrol.AccessControlRuleSet{} + + for _, instance := range instances { + err := ruleset.NewCustomRules( + accesscontrol.CustomRuleName(ruleName), + accesscontrol.AccessAllow, + []string{instance.SubnetID}, + cidrsToAllow, + []string{instance.PrivateIP}, + protocolsAndPorts, + ) + if err != nil { + return accesscontrol.AccessControlRuleSet{}, fmt.Errorf( + "failed to create custom rule: %w", err, + ) + } + } + + return ruleset, nil +} + func (c *Client) AddInboundAllowRuleByLabelsMatch(ctx context.Context, account, region string, vpcID string, ruleName string, labels map[string]string, cidrsToAllow []string, protocolsAndPorts types.ProtocolsAndPorts) (ruleId string, instances []types.Instance, err error) { - // TBD - return "", nil, nil + + instances, err = c.ListInstances(ctx, &infrapb.ListInstancesRequest{ + VpcId: vpcID, + Zone: region, + AccountId: account, + Labels: labels, + Region: region, + }) + if err != nil { + return "", nil, fmt.Errorf( + "failed to list Instances: %w", err, + ) + } + + subnets, err := c.getSubnetsFromInstances(ctx, instances) + if err != nil { + return "", nil, fmt.Errorf( + "failed to extract subnets associated with matched instances: %w", err, + ) + } + + ruleset, err := c.prepareCustomAccessRules( + instances, + ruleName, + cidrsToAllow, + protocolsAndPorts, + ) + if err != nil { + return "", nil, fmt.Errorf( + "failed to prepare custom access rules: %w", err, + ) + } + + for _, subnet := range subnets { + err = c.ApplyAccessRulesToSubnet( + ctx, + account, + region, + subnet, + ruleset, + ) + if err != nil { + return "", nil, fmt.Errorf( + "failed to apply access rules to subnet %s: %w", + helper.StringPointerToString(subnet.ID), + err, + ) + } + } + + return ruleName, instances, nil } func (c *Client) AddInboundAllowRuleBySubnetMatch(ctx context.Context, account, region string, @@ -168,7 +219,7 @@ func (c *Client) AddInboundAllowRuleForLoadBalancerByDNS(ctx context.Context, ac } func (c *Client) RemoveInboundAllowRuleFromVPCByName(ctx context.Context, account, region string, vpcID string, ruleName string) error { - vnet, vnetAccount, err := c.getVPC( + vnet, _, err := c.getVPC( ctx, vpcID, region, ) if err != nil { @@ -177,16 +228,13 @@ func (c *Client) RemoveInboundAllowRuleFromVPCByName(ctx context.Context, accoun vpcID, err, ) } - if account == "" { - account = vnetAccount - } - err = c.deleteVPCInboundFromSubnets( + err = c.DeleteAccessRulesFromVPC( ctx, - account, - region, vnet, - ruleName, + accesscontrol.RuleNames{ + accesscontrol.CustomRuleName(ruleName), + }, ) if err != nil { return fmt.Errorf( diff --git a/azure/accessControl/name.go b/azure/accessControl/name.go new file mode 100644 index 0000000..b747423 --- /dev/null +++ b/azure/accessControl/name.go @@ -0,0 +1,74 @@ +package accesscontrol + +import ( + "crypto/sha256" + "fmt" + "slices" + "strings" +) + +// ruleName is unexported string created to +// enforce using exported functions from the +// package when setting up names for Access +// Control resources outside of this package. +type ruleName string + +// RuleNames is an exported slice of ruleName used +// mainly to specify rules that should be removed. +type RuleNames = []ruleName + +// VPCRuleName generates proper name identifier +// based on source and destination VPCs. The VPC +// rule acts bidirectional and so the order of +// VPC names will be picked by the function +// (names are sorted to keep it deterministic). +// +// TODO: Currently the name is a hash of vpc IDs, +// to keep the length of generated name fixed and +// not over accepted Azure limits, however it is +// not collision-proof. Name collision must be +// handled properly. +func VPCRuleName(vpcId1, vpcId2 string) ruleName { + ids := []string{vpcId1, vpcId2} + slices.Sort(ids) + + hasher := sha256.New() + hasher.Write([]byte(strings.Join(ids, ":"))) + hashBytes := hasher.Sum(nil) + + return ruleName(fmt.Sprintf("%x", hashBytes)) +} + +// CustomRuleName accepts a regular name provided +// by the external entity and hashes it to keep +// the length name consistent. +// +// TODO: Currently the name is a hash of a given +// string, to keep the length of generated nam +// fixed and not over accepted Azure limits, +// however it is not collision-proof. Name +// collision must be handled properly. +func CustomRuleName(name string) ruleName { + hasher := sha256.New() + hasher.Write([]byte(name)) + hashBytes := hasher.Sum(nil) + + return ruleName(fmt.Sprintf("%x", hashBytes)) +} + +// nameWithPriority combines Rule name with fixed-length priority string. +// The priority always uses ":" character and 4 digits. For priorities +// lower than 1000, the actual priority is preceeded with 0s to match +// 4 characters length. +// +// The priority acts as a name distinguisher between rules inside the +// same Network Security Group as the name prefix may be equal but +// priority ensures uniqueness. +func nameWithPriority(name ruleName, priority uint) (string, error) { + if priority >= 10000 { + return "", fmt.Errorf( + "unexpected priority value - expected 4 digits at max: %d", priority, + ) + } + return string(name) + fmt.Sprintf(":%04d", priority), nil +} diff --git a/azure/accessControl/priorityPool.go b/azure/accessControl/priorityPool.go new file mode 100644 index 0000000..6ba116a --- /dev/null +++ b/azure/accessControl/priorityPool.go @@ -0,0 +1,119 @@ +package accesscontrol + +import ( + "errors" + "fmt" + + "github.com/app-net-interface/awi-infra-guard/connector/helper" +) + +// generatePriorities generates a slice of integers that +// can be used as priorities by AccessControlRuleSet. +// +// It accepts the number of requested VPC Rules, Directed +// Rules and Custom Rules and generates priorities to fit +// already used priorities. +// +// For instance, requesting priorities for 3 VPC Rules, +// 2 Directed Rules and 4 Custom Rules with following +// priorities already in use: [3600, 3602, 100, 200] will +// generate following output: +// +// [3601, 3603, 3604, 2800, 2801, 101, 102, 103] +// +// The order of rules goes as follows: +// 1. VPC Rules +// 2. Directed Rules +// 3. Custom Rules +func generatePriorities( + prioritiesInUse helper.Set[uint], + numberOfVPCRules, + numberOfDirectedRules, + numberOfCustomRules uint, +) (priorityPool, error) { + vpcPriorities, err := generatePrioritiesForRuleGroup( + 3600, 4096, numberOfVPCRules, prioritiesInUse, + ) + if err != nil { + return priorityPool{}, fmt.Errorf( + "failed to generate priorities for VPC Rules: %w", err, + ) + } + directedPriorities, err := generatePrioritiesForRuleGroup( + 2800, 3599, numberOfDirectedRules, prioritiesInUse, + ) + if err != nil { + return priorityPool{}, fmt.Errorf( + "failed to generate priorities for Directed VPC Rules: %w", err, + ) + } + customPriorities, err := generatePrioritiesForRuleGroup( + 100, 2799, numberOfCustomRules, prioritiesInUse, + ) + if err != nil { + return priorityPool{}, fmt.Errorf( + "failed to generate priorities for Custom Rules: %w", err, + ) + } + + priorities := make([]uint, 0, len(vpcPriorities)+len(directedPriorities)+len(customPriorities)) + priorities = append(priorities, vpcPriorities...) + priorities = append(priorities, directedPriorities...) + priorities = append(priorities, customPriorities...) + + return newPriorityPool(priorities), nil +} + +func generatePrioritiesForRuleGroup( + minPriority, maxPriority, numberOfPriorities uint, + prioritiesInUse helper.Set[uint], +) ([]uint, error) { + current := minPriority + priorities := make([]uint, 0, numberOfPriorities) + + for current < maxPriority { + if len(priorities) == int(numberOfPriorities) { + break + } + + if prioritiesInUse.Has(current) { + continue + } + priorities = append(priorities, current) + current++ + } + + if len(priorities) != int(numberOfPriorities) { + return nil, fmt.Errorf( + "could not produce %d priorities as the slots are already taken", + numberOfPriorities, + ) + } + + return priorities, nil +} + +// priorityPool is a simple structure for picking a next +// available priority without using priority already in use. +type priorityPool struct { + current uint + availablePriorities []uint +} + +func newPriorityPool(availablePriorities []uint) priorityPool { + return priorityPool{ + current: 0, + availablePriorities: availablePriorities, + } +} + +func (p *priorityPool) Next() (uint, error) { + if int(p.current) >= len(p.availablePriorities) { + return 0, errors.New( + "no more priorities left", + ) + } + priority := p.availablePriorities[p.current] + p.current++ + return priority, nil +} diff --git a/azure/accessControl/rule.go b/azure/accessControl/rule.go new file mode 100644 index 0000000..3ed408b --- /dev/null +++ b/azure/accessControl/rule.go @@ -0,0 +1,161 @@ +package accesscontrol + +import ( + "fmt" + "slices" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" +) + +type rule interface { + Name() ruleName + ToNetworkSecurityGroupRule(priority uint) (armnetwork.SecurityRule, error) +} + +type access string + +const ( + AccessAllow access = "ALLOW" + AccessDeny access = "DENY" +) + +type VPCRule struct { + namePrefix ruleName + access access + sourceCIDR string +} + +func (r VPCRule) Name() ruleName { + return r.namePrefix +} + +func (r VPCRule) ToNetworkSecurityGroupRule(priority uint) (armnetwork.SecurityRule, error) { + name, err := nameWithPriority(r.namePrefix, priority) + if err != nil { + return armnetwork.SecurityRule{}, fmt.Errorf( + "failed to generate a name for Security Rule %v: %w", + r, err, + ) + } + return armnetwork.SecurityRule{ + Name: to.Ptr(name), + Properties: newSecurityRuleProperties( + r.access, + r.sourceCIDR, + "", + priority, + armnetwork.SecurityRuleProtocolAsterisk, + []string{}, + ), + }, nil +} + +type DirectedVPCRule struct { + namePrefix ruleName + access access + sourceCIDR string +} + +func (r DirectedVPCRule) Name() ruleName { + return r.namePrefix +} + +func (r DirectedVPCRule) ToNetworkSecurityGroupRule(priority uint) (armnetwork.SecurityRule, error) { + name, err := nameWithPriority(r.namePrefix, priority) + if err != nil { + return armnetwork.SecurityRule{}, fmt.Errorf( + "failed to generate a name for Security Rule %v: %w", + r, err, + ) + } + return armnetwork.SecurityRule{ + Name: to.Ptr(name), + Properties: newSecurityRuleProperties( + r.access, + r.sourceCIDR, + "", + priority, + armnetwork.SecurityRuleProtocolAsterisk, + []string{}, + ), + }, nil +} + +type CustomRule struct { + namePrefix ruleName + access access + sourceCIDR string + destinationCIDR string + protocol armnetwork.SecurityRuleProtocol + ports []string + // Subnets are not a part of actual Network Security + // Group Rule but its a helper to provide a context + // if the rules should be applied for a particular + // NSG. + subnets []string +} + +func (r CustomRule) Name() ruleName { + return r.namePrefix +} + +func (r CustomRule) MatchesSubnet(subnet string) bool { + return slices.Contains(r.subnets, subnet) +} + +func (r CustomRule) ToNetworkSecurityGroupRule(priority uint) (armnetwork.SecurityRule, error) { + name, err := nameWithPriority(r.namePrefix, priority) + if err != nil { + return armnetwork.SecurityRule{}, fmt.Errorf( + "failed to generate a name for Security Rule %v: %w", + r, err, + ) + } + return armnetwork.SecurityRule{ + Name: to.Ptr(name), + Properties: newSecurityRuleProperties( + r.access, + r.sourceCIDR, + r.destinationCIDR, + priority, + r.protocol, + r.ports, + ), + }, nil +} + +func newSecurityRuleProperties( + access access, + sourceCIDR string, + destinationCIDR string, + priority uint, + protocol armnetwork.SecurityRuleProtocol, + ports []string, +) *armnetwork.SecurityRulePropertiesFormat { + ruleAccess := armnetwork.SecurityRuleAccessAllow + if access == AccessDeny { + ruleAccess = armnetwork.SecurityRuleAccessDeny + } + + var destCIDR string + if destinationCIDR != "" { + destCIDR = destinationCIDR + } + + azurePriority := int32(priority) + + portRanges := make([]*string, len(ports)) + for i := range ports { + portRanges[i] = &ports[i] + } + + return &armnetwork.SecurityRulePropertiesFormat{ + Access: &ruleAccess, + Priority: &azurePriority, + SourceAddressPrefix: &sourceCIDR, + DestinationAddressPrefix: &destCIDR, + Protocol: &protocol, + DestinationPortRanges: portRanges, + } +} diff --git a/azure/accessControl/ruleSet.go b/azure/accessControl/ruleSet.go new file mode 100644 index 0000000..0d280b9 --- /dev/null +++ b/azure/accessControl/ruleSet.go @@ -0,0 +1,230 @@ +package accesscontrol + +import ( + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/app-net-interface/awi-infra-guard/connector/helper" + "github.com/app-net-interface/awi-infra-guard/types" +) + +// AccessControlRuleSet is a helper structure for +// preparing desired set of Network Security Group +// Rules that will be added to proper Network +// Security Groups. +type AccessControlRuleSet struct { + VPCRules []VPCRule + DirectedVPCRules []DirectedVPCRule + CustomRules []CustomRule +} + +func (a *AccessControlRuleSet) NewVPCRules( + name ruleName, + access access, + sourceCIDRs []string, +) { + for _, cidr := range sourceCIDRs { + a.VPCRules = append(a.VPCRules, VPCRule{ + namePrefix: name, + access: access, + sourceCIDR: cidr, + }) + } +} + +func (a *AccessControlRuleSet) NewDirectedVPCRules( + name ruleName, + access access, + sourceCIDRs []string, +) { + for _, cidr := range sourceCIDRs { + a.DirectedVPCRules = append(a.DirectedVPCRules, DirectedVPCRule{ + namePrefix: name, + access: access, + sourceCIDR: cidr, + }) + } +} + +func translateProtocolToAzureProtocol(protocol string) (armnetwork.SecurityRuleProtocol, error) { + p := strings.ToLower(protocol) + if p == strings.ToLower(string(armnetwork.SecurityRuleProtocolAh)) { + return armnetwork.SecurityRuleProtocolAh, nil + } + if p == strings.ToLower(string(armnetwork.SecurityRuleProtocolEsp)) { + return armnetwork.SecurityRuleProtocolEsp, nil + } + if p == strings.ToLower(string(armnetwork.SecurityRuleProtocolIcmp)) { + return armnetwork.SecurityRuleProtocolIcmp, nil + } + if p == strings.ToLower(string(armnetwork.SecurityRuleProtocolTCP)) { + return armnetwork.SecurityRuleProtocolTCP, nil + } + if p == strings.ToLower(string(armnetwork.SecurityRuleProtocolUDP)) { + return armnetwork.SecurityRuleProtocolUDP, nil + } + if p == strings.ToLower(string(armnetwork.SecurityRuleProtocolAsterisk)) || p == "" || p == "-1" { + return armnetwork.SecurityRuleProtocolAsterisk, nil + } + return armnetwork.SecurityRuleProtocol(""), fmt.Errorf( + "unsupported protocol '%s'. Azure supports following protocols: ["+ + "%s, %s, %s, %s, %s, %s]", + protocol, + string(armnetwork.SecurityRuleProtocolAh), + string(armnetwork.SecurityRuleProtocolEsp), + string(armnetwork.SecurityRuleProtocolIcmp), + string(armnetwork.SecurityRuleProtocolTCP), + string(armnetwork.SecurityRuleProtocolUDP), + string(armnetwork.SecurityRuleProtocolAsterisk), + ) +} + +func (a *AccessControlRuleSet) NewCustomRules( + name ruleName, + access access, + subnets []string, + sourceCIDRs []string, + destinationCIDRs []string, + protocolsAndPorts types.ProtocolsAndPorts, +) error { + // TODO: Verify if Azure Security Group Rule can refer to multiple + // source/destination CIDRs so that the entire slice below can be + // merged into single rule. + for _, sourceCIDR := range sourceCIDRs { + for _, destinationCIDR := range destinationCIDRs { + // TODO: Verify if these can be merged into single rule. + for protocol, ports := range protocolsAndPorts { + azProtocol, err := translateProtocolToAzureProtocol(protocol) + + if err != nil { + return fmt.Errorf( + "failed to translate given protocol %s: %w", + protocol, err, + ) + } + + a.CustomRules = append(a.CustomRules, CustomRule{ + namePrefix: name, + access: access, + sourceCIDR: sourceCIDR, + destinationCIDR: destinationCIDR, + subnets: subnets, + protocol: azProtocol, + ports: ports, + }) + } + } + } + + return nil +} + +func (a *AccessControlRuleSet) GenerateSecurityGroupRulesForVPC( + prioritiesInUse helper.Set[uint], +) ([]armnetwork.SecurityRule, error) { + return a.GenerateSecurityGroupRulesForSubnet(prioritiesInUse, nil) +} + +func (a *AccessControlRuleSet) GenerateSecurityGroupRulesForSubnet( + prioritiesInUse helper.Set[uint], + subnet *string, +) ([]armnetwork.SecurityRule, error) { + rules := a.collectAllRulesTogether(subnet) + + customRules := a.customRulesForSubnet(subnet) + + priorities, err := generatePriorities( + prioritiesInUse, + uint(len(a.VPCRules)), + uint(len(a.DirectedVPCRules)), + uint(len(customRules)), + ) + + if err != nil { + return nil, fmt.Errorf( + "failed to generate available priorities for rules: %w", err, + ) + } + + securityGroupRules := make([]armnetwork.SecurityRule, 0, len(rules)) + + for i := range rules { + priority, err := priorities.Next() + if err != nil { + return nil, fmt.Errorf( + "failed to pick a priority for a rule: %w", err, + ) + } + rule, err := rules[i].ToNetworkSecurityGroupRule(priority) + if err != nil { + return nil, fmt.Errorf( + "failed to generate rule: %w", err, + ) + } + + securityGroupRules = append(securityGroupRules, rule) + } + + return securityGroupRules, nil +} + +func (a *AccessControlRuleSet) RuleNamesForVPC() RuleNames { + names := helper.Set[ruleName]{} + rules := a.collectAllRulesTogether(nil) + + for _, rule := range rules { + names.Set(rule.Name()) + } + + return names.Keys() +} + +func (a *AccessControlRuleSet) RuleNamesForSubnet(subnet string) RuleNames { + names := helper.Set[ruleName]{} + rules := a.collectAllRulesTogether(&subnet) + + for _, rule := range rules { + names.Set(rule.Name()) + } + + return names.Keys() +} + +func (a *AccessControlRuleSet) customRulesForSubnet(subnet *string) []rule { + if subnet == nil { + return nil + } + + rules := []rule{} + + for _, rule := range a.CustomRules { + if rule.MatchesSubnet(*subnet) { + rules = append(rules, rule) + } + } + + return rules +} + +func (a *AccessControlRuleSet) collectAllRulesTogether( + subnet *string, +) []rule { + var customRulesForSubnet []rule + if subnet != nil { + customRulesForSubnet = a.customRulesForSubnet(subnet) + } + + rules := make( + []rule, 0, len(a.VPCRules)+len(a.DirectedVPCRules)+len(customRulesForSubnet), + ) + + for _, r := range a.VPCRules { + rules = append(rules, r) + } + for _, r := range a.DirectedVPCRules { + rules = append(rules, r) + } + + return append(rules, customRulesForSubnet...) +} diff --git a/azure/accessControl/securityGroup.go b/azure/accessControl/securityGroup.go new file mode 100644 index 0000000..daa14ba --- /dev/null +++ b/azure/accessControl/securityGroup.go @@ -0,0 +1,69 @@ +package accesscontrol + +import ( + "errors" + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/app-net-interface/awi-infra-guard/connector/helper" +) + +// RemoveRulesFromSecurityGroup goes through all Network Security Rules +// attached to that particular Security Group and removes these entries +// which match given Rule Names. +// +// The function does not send updating request - it prepares SecurityGroup +// object making it ready to pass to the Azure Client. +// +// The function returns a boolean informing if any rule was actually +// removed. If not, it is pointless to update the Security Group. +func RemoveRulesFromSecurityGroup( + nsg *armnetwork.SecurityGroup, rules RuleNames, +) (bool, error) { + if nsg == nil { + return false, errors.New( + "cannot remove rules from nil Network Security Group", + ) + } + if nsg.Properties == nil { + return false, fmt.Errorf( + "cannot remove rules from nil Network Security Group %s due to nil properties", + helper.StringPointerToString(nsg.ID), + ) + } + + securityRulesWithoutVPCRules := make([]*armnetwork.SecurityRule, 0, len(nsg.Properties.SecurityRules)) + + anyRuleRemoved := false + + for i := range nsg.Properties.SecurityRules { + if nsg.Properties.SecurityRules[i] == nil { + continue + } + if !securityRuleMatchesRuleGroup(nsg.Properties.SecurityRules[i], rules) { + // We want to preserve only rules that are not matching our expected name. + securityRulesWithoutVPCRules = append(securityRulesWithoutVPCRules, nsg.Properties.SecurityRules[i]) + continue + } + anyRuleRemoved = true + } + + if anyRuleRemoved { + nsg.Properties.SecurityRules = securityRulesWithoutVPCRules + } + + return anyRuleRemoved, nil +} + +func securityRuleMatchesRuleGroup(rule *armnetwork.SecurityRule, rules []ruleName) bool { + if rule == nil || rule.Name == nil { + return false + } + for _, r := range rules { + if strings.HasPrefix(*rule.Name, string(r)) { + return true + } + } + return false +} diff --git a/azure/azure.go b/azure/azure.go index cb45b4a..2042122 100644 --- a/azure/azure.go +++ b/azure/azure.go @@ -36,6 +36,7 @@ type ResourceClient struct { VNET armnetwork.VirtualNetworksClient VNETPeering armnetwork.VirtualNetworkPeeringsClient NSG armnetwork.SecurityGroupsClient + Subnet armnetwork.SubnetsClient Tag armresources.TagsClient } @@ -86,11 +87,23 @@ func NewResourceClient( "failed to create Tag Client. Got empty client", ) } + subnetClient, err := armnetwork.NewSubnetsClient(accountID, credentials, nil) + if err != nil { + return nil, fmt.Errorf( + "failed to create Subnet Client: %v", err, + ) + } + if subnetClient == nil { + return nil, errors.New( + "failed to create Subnet Client. Got empty client", + ) + } return &ResourceClient{ VNET: *vnetClient, VNETPeering: *vnetPeeringClient, NSG: *nsgClient, Tag: *tagClient, + Subnet: *subnetClient, }, nil } diff --git a/azure/helper.go b/azure/helper.go index 89e229a..f726026 100644 --- a/azure/helper.go +++ b/azure/helper.go @@ -174,6 +174,34 @@ func parseResourceGroupName(resourceID string) string { return "" } +// Returns true if the given string consists of '/' characters +// indicating the form of Azure Resource ID rather than being +// regular resource name. +// +// This method is to distinguish ResourceID and ResourceName +// whenever it is important to use any of those. +func isResourceID(id string) bool { + chunks := strings.SplitN(id, "/", 2) + return len(chunks) > 1 +} + +// Returns true if the given string consists of zero '/' +// characters indicating the form of Azure Resource name +// rather than being resource ID. +// +// This method is to distinguish ResourceID and ResourceName +// whenever it is important to use any of those. +// +// Using negation of "isResourceID" is not valid due to the +// possibility of empty string which should return false for +// both of these methods. +func isResourceName(id string) bool { + if id == "" { + return false + } + return !isResourceID(id) +} + func parseResourceName(resourceID string) string { parts := strings.Split(resourceID, "/") return parts[len(parts)-1] @@ -244,3 +272,22 @@ func extractLastSegment(resourceID string) string { segments := strings.Split(resourceID, "/") return segments[len(segments)-1] } + +// allLabelsMatch returns true if all required label keys are +// present within tags and if their values correspond to +// those from tags. +func allLabelsMatch(tags map[string]*string, requiredTags map[string]string) bool { + for k, v := range requiredTags { + actual, ok := tags[k] + if !ok { + return false + } + if actual == nil { + return false + } + if v != *actual { + return false + } + } + return true +} diff --git a/azure/instances.go b/azure/instances.go index add2800..563b89a 100644 --- a/azure/instances.go +++ b/azure/instances.go @@ -61,9 +61,13 @@ func (c *Client) ListInstances(ctx context.Context, input *infrapb.ListInstances return nil, fmt.Errorf("failed to get the next page of VMs: %w", err) } for _, vm := range vmResult.Value { - if vm.Properties.NetworkProfile == nil { + if vm.Properties == nil || vm.Properties.NetworkProfile == nil { continue } + if !allLabelsMatch(vm.Tags, input.Labels) { + continue + } + for _, nicRef := range vm.Properties.NetworkProfile.NetworkInterfaces { nic, err := nicClient.Get(ctx, parseResourceGroupName(*nicRef.ID), parseResourceName(*nicRef.ID), nil) //nic.Interface.Properties.NetworkSecurityGroup. diff --git a/azure/peering.go b/azure/peering.go index 07cf9b2..efd20d7 100644 --- a/azure/peering.go +++ b/azure/peering.go @@ -25,30 +25,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" ) -func (c *Client) getVnetPeering( - ctx context.Context, - accountID string, - resourceGroup string, - vnetName string, - vnetPeeringName string, -) (armnetwork.VirtualNetworkPeering, error) { - client, ok := c.accountClients[accountID] - if !ok { - return armnetwork.VirtualNetworkPeering{}, fmt.Errorf( - "account ID '%s' is not associated with any clients", accountID, - ) - } - peering, err := client.VNETPeering.Get( - ctx, resourceGroup, vnetName, vnetPeeringName, nil, - ) - if err != nil { - return armnetwork.VirtualNetworkPeering{}, fmt.Errorf( - "failed to get VNet Peering '%s': %w", vnetPeeringName, err, - ) - } - return peering.VirtualNetworkPeering, nil -} - func (c *Client) getVnetPeeringFromVnet(vnet armnetwork.VirtualNetwork, destVNetID string) string { if vnet.Properties == nil { return "" diff --git a/azure/securityGroup.go b/azure/securityGroup.go index 88bff48..b187018 100644 --- a/azure/securityGroup.go +++ b/azure/securityGroup.go @@ -19,11 +19,9 @@ package azure import ( "context" - "errors" "fmt" "strings" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/app-net-interface/awi-infra-guard/connector/helper" @@ -112,31 +110,6 @@ func (c *Client) ListSecurityGroups(ctx context.Context, input *infrapb.ListSecu return secGroups, nil } -func (c *Client) deleteNetworkSecurityGroup( - ctx context.Context, - accountID string, - resourceGroup string, - vnetName string, - vnetPeeringName string, -) error { - client, ok := c.accountClients[accountID] - if !ok { - return fmt.Errorf( - "account ID '%s' is not associated with any clients", accountID, - ) - } - future, err := client.VNETPeering.BeginDelete(ctx, resourceGroup, vnetName, vnetPeeringName, nil) - if err != nil { - return fmt.Errorf("cannot delete VNet peering '%s': %w", vnetPeeringName, err) - } - if _, err = future.PollUntilDone(ctx, nil); err != nil { - return fmt.Errorf( - "cannot get the VNet Peering delete future response: %w", - err) - } - return nil -} - func (c *Client) getNSG(ctx context.Context, id, region string) ( armnetwork.SecurityGroup, string, error, ) { @@ -167,451 +140,6 @@ func (c *Client) getNSG(ctx context.Context, id, region string) ( ) } -func (c *Client) deleteVPCInboundFromSubnets( - ctx context.Context, - account string, - region string, - vnet armnetwork.VirtualNetwork, - connectionTag string, -) error { - if vnet.Properties == nil { - c.logger.Warnf( - "cannot update vnet subnets as vnet '%s' has no properties", - helper.StringPointerToString(vnet.ID), - ) - return nil - } - for i := range vnet.Properties.Subnets { - if vnet.Properties.Subnets[i] == nil || vnet.Properties.Subnets[i].Properties == nil { - continue - } - subnetProps := vnet.Properties.Subnets[i].Properties - if subnetProps.NetworkSecurityGroup == nil { - continue - } - err := c.deleteVPCInboundFromNSG( - ctx, - account, - region, - helper.StringPointerToString(subnetProps.NetworkSecurityGroup.ID), - connectionTag, - ) - if err != nil { - return fmt.Errorf( - "failed to remove Inbound VPC Policy from NSG '%s' from subnet '%s': %w", - helper.StringPointerToString(subnetProps.NetworkSecurityGroup.ID), - helper.StringPointerToString(vnet.Properties.Subnets[i].ID), - err, - ) - } - } - return nil -} - -func (c *Client) removeVPCPolicyRulesFromNSG( - nsg *armnetwork.SecurityGroup, - connectionTag string, -) error { - if nsg == nil || nsg.Properties == nil { - return errors.New( - "cannot remove VPC Policies as nsg is nil or has no properties", - ) - } - - vpcRuleName, ok := nsg.Tags[connectionTag] - if !ok { - return fmt.Errorf( - "cannot remove VPC Policy Rules from NSG '%s' as there is no "+ - "tag '%s' which would tell what NSG Rules belonged to that policy", - helper.StringPointerToString(nsg.ID), connectionTag, - ) - } - if vpcRuleName == nil { - return fmt.Errorf( - "cannot remove VPC Policy Rules from NSG '%s' as the value for "+ - "tag '%s' is nil", - helper.StringPointerToString(nsg.ID), connectionTag, - ) - } - - securityRulesWithoutVPCRules := make([]*armnetwork.SecurityRule, 0, len(nsg.Properties.SecurityRules)) - - for i := range nsg.Properties.SecurityRules { - if nsg.Properties.SecurityRules[i] == nil { - continue - } - ruleName, err := extractAwiNSGRuleName( - helper.StringPointerToString(nsg.Properties.SecurityRules[i].Name), - ) - if err != nil || ruleName != *vpcRuleName { - // Regular rules will not match our expected form so an error simply - // indicates it is a different rule. - // - // We want to preserve only rules that are not matching our expected name. - securityRulesWithoutVPCRules = append(securityRulesWithoutVPCRules, nsg.Properties.SecurityRules[i]) - } - } - - nsg.Properties.SecurityRules = securityRulesWithoutVPCRules - - return nil -} - -func (c *Client) deleteVPCInboundFromNSG( - ctx context.Context, - account string, - region string, - nsgID string, - connectionTag string, -) error { - nsg, account, err := c.getNSG(ctx, nsgID, region) - if err != nil { - return fmt.Errorf( - "failed to get NSG for update %s: %w", - nsgID, err, - ) - } - err = c.removeVPCPolicyRulesFromNSG( - &nsg, - connectionTag, - ) - if err != nil { - return fmt.Errorf( - "failed to update VPC Policy Rules for NSG '%s': %w", - nsgID, err, - ) - } - - if nsg.Tags != nil { - delete(nsg.Tags, connectionTag) - } - - // TODO: Delete Network Security Group if it was created by AWI and - // there are no other policies. - return c.createNetworkSecurityGroup( - ctx, - helper.StringPointerToString(nsg.Name), - region, - account, - parseResourceGroupName(helper.StringPointerToString(nsg.ID)), - nsg, - ) -} - -// refreshSubnetSecurityGroupWithVPCInbound checks if there is a Security -// Group created for Subnet - if there is no NSG, it will create one. -// -// After ensuring a Network Security Group exists, it will either create -// or update its rules to allow/block the inbound traffic from other VNet -// depending on the policy. If the policy is already properly configured, -// nothing will happen. -func (c *Client) refreshSubnetSecurityGroupWithVPCInbound( - ctx context.Context, - account string, - region string, - inboundCIDRs []string, - policy vpcPolicy, - vnet armnetwork.VirtualNetwork, - sourceVnetID string, - connectionTag string, -) error { - if vnet.Properties == nil { - c.logger.Warnf( - "cannot update vnet subnets as vnet '%s' has no properties", - helper.StringPointerToString(vnet.ID), - ) - return nil - } - for i := range vnet.Properties.Subnets { - if vnet.Properties.Subnets[i] == nil || vnet.Properties.Subnets[i].Properties == nil { - continue - } - subnetProps := vnet.Properties.Subnets[i].Properties - if subnetProps.NetworkSecurityGroup == nil { - err := c.createNewNetworkSecurityGroup( - ctx, - account, - region, - helper.StringPointerToString( - vnet.Properties.Subnets[i].ID, - ), - sourceVnetID, - inboundCIDRs, - policy, - connectionTag, - ) - if err != nil { - return fmt.Errorf( - "failed to create a new Network Security Group for a Subnet '%s' in VNet '%s'", - helper.StringPointerToString(vnet.Properties.Subnets[i].ID), - helper.StringPointerToString(vnet.ID), - ) - } - continue - } - err := c.updateNetworkSecurityGroup( - ctx, - region, - helper.StringPointerToString(subnetProps.NetworkSecurityGroup.ID), - sourceVnetID, - inboundCIDRs, - policy, - connectionTag, - ) - if err != nil { - return fmt.Errorf( - "failed to update Network Security Group for a Subnet '%s' in VNet '%s'", - helper.StringPointerToString(vnet.Properties.Subnets[i].ID), - helper.StringPointerToString(vnet.ID), - ) - } - } - return nil -} - -const ( - awiNSGRulesTagPrefix string = "awi-nsg-" - awiNSGNamePrefix string = "awi-nsg-" - awiNSGRulesNamePrefix string = "awi-nsg-vpc-rules-" - awiVPCRuleStartingPriority int32 = 2000 -) - -// func getAwiNSGNameTag(inboundVPC string) string { -// return awiNSGRulesTagPrefix + inboundVPC -// } - -func getAwiNSGName(subnetName, inboundVPC string) string { - return fmt.Sprintf( - "%s%s-%s", - awiNSGNamePrefix, - parseResourceName(subnetName), - parseResourceName(inboundVPC), - ) -} - -func getAwiNSGRuleName(inboundVPC string) string { - return awiNSGRulesNamePrefix + parseResourceName(inboundVPC) -} - -func getAwiNSGRuleNameWithIDSuffix(inboundVPC string, ruleID int) (string, error) { - if ruleID >= 10000 { - return "", fmt.Errorf( - "cannot generate NSG Rule with a ruleID with more than 4 digits: %d", ruleID, - ) - } - return awiNSGRulesNamePrefix + parseResourceName(inboundVPC) + fmt.Sprintf(":%04d", ruleID), nil -} - -func takenPriorities(sg armnetwork.SecurityGroup) helper.Set[int32] { - priorities := helper.Set[int32]{} - - if sg.Properties == nil { - return priorities - } - - for i := range sg.Properties.SecurityRules { - if sg.Properties.SecurityRules[i] == nil || sg.Properties.SecurityRules[i].Properties == nil { - continue - } - if sg.Properties.SecurityRules[i].Properties.Priority != nil { - priorities.Set(*sg.Properties.SecurityRules[i].Properties.Priority) - } - } - - return priorities -} - -func (c *Client) deletePreviousVPCPolicyRules( - nsg *armnetwork.SecurityGroup, - sourceVnetID string, -) { - if nsg == nil || nsg.Properties == nil { - return - } - - securityRulesWithoutVPCRules := make([]*armnetwork.SecurityRule, 0, len(nsg.Properties.SecurityRules)) - - for i := range nsg.Properties.SecurityRules { - if nsg.Properties.SecurityRules[i] == nil { - continue - } - if strings.HasPrefix(helper.StringPointerToString(nsg.Properties.SecurityRules[i].Name)) { - // We want to preserve only rules that are not matching our expected name. - securityRulesWithoutVPCRules = append(securityRulesWithoutVPCRules, nsg.Properties.SecurityRules[i]) - } - } - - nsg.Properties.SecurityRules = securityRulesWithoutVPCRules -} - -func (c *Client) addVPCPolicyRulesToNSG( - nsg *armnetwork.SecurityGroup, - inboundCIDRs []string, - sourceVnetID string, - policy vpcPolicy, - connectionTag string, -) error { - if nsg == nil { - return errors.New("cannot add VNet Policy rules to nil NSG") - } - - c.deletePreviousVPCPolicyRules(nsg, sourceVnetID) - - prioritiesInUse := takenPriorities(*nsg) - - securityRules := make([]*armnetwork.SecurityRule, 0, len(inboundCIDRs)) - currentPriority := awiVPCRuleStartingPriority - - access := armnetwork.SecurityRuleAccessDeny - if policy == vpcPolicyAllow { - access = armnetwork.SecurityRuleAccessAllow - } - - ruleID := 1 - - for _, cidr := range inboundCIDRs { - // Priorities cannot overlap with already existing - // rules so find first non used priority value. - for prioritiesInUse.Has(currentPriority) { - currentPriority++ - if currentPriority > 4096 { - return fmt.Errorf( - "cannot attach VNet Policy rules to NSG '%s' as there are no more "+ - "available rule slots", - helper.StringPointerToString(nsg.ID), - ) - } - } - - ruleName, err := getAwiNSGRuleNameWithIDSuffix(sourceVnetID, ruleID) - if err != nil { - return fmt.Errorf( - "failed to prepare a rule name for NSG '%s': %v", - helper.StringPointerToString(nsg.ID), - err, - ) - } - - securityRules = append( - securityRules, - &armnetwork.SecurityRule{ - Name: to.Ptr(ruleName), - Properties: &armnetwork.SecurityRulePropertiesFormat{ - Priority: to.Ptr(currentPriority), - Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), - SourceAddressPrefix: to.Ptr(cidr), - DestinationAddressPrefix: to.Ptr("*"), - Access: to.Ptr(access), - Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), - SourcePortRange: to.Ptr("*"), - DestinationPortRange: to.Ptr("*"), - }, - }, - ) - currentPriority++ - } - - nsg.Properties.SecurityRules = securityRules - - return nil -} - -func (c *Client) updateSomethingNetworkSecurityGroup( - ctx context.Context, - region string, - nsgID string, - sourceVnetID string, - inboundCIDRs []string, - policy vpcPolicy, - connectionTag string, -) error { - nsg, account, err := c.getNSG(ctx, nsgID, region) - if err != nil { - return fmt.Errorf( - "failed to get NSG for update %s: %w", - nsgID, err, - ) - } - err = c.addVPCPolicyRulesToNSG( - &nsg, - inboundCIDRs, - sourceVnetID, - policy, - connectionTag, - ) - if err != nil { - return fmt.Errorf( - "failed to update VPC Policy Rules for NSG '%s': %w", - nsgID, err, - ) - } - - if nsg.Tags == nil { - nsg.Tags = map[string]*string{ - connectionTag: to.Ptr(getAwiNSGRuleName(sourceVnetID)), - } - } else { - nsg.Tags[connectionTag] = to.Ptr(getAwiNSGRuleName(sourceVnetID)) - } - - return c.createNetworkSecurityGroup( - ctx, - helper.StringPointerToString(nsg.Name), - region, - account, - parseResourceGroupName(helper.StringPointerToString(nsg.ID)), - nsg, - ) -} - -func (c *Client) createNewNetworkSecurityGroup( - ctx context.Context, - account string, - region string, - subnetID string, - sourceVnetID string, - inboundCIDRs []string, - policy vpcPolicy, - connectionTag string, -) error { - ngsName := getAwiNSGName(subnetID, sourceVnetID) - - nsg := armnetwork.SecurityGroup{ - Name: to.Ptr(ngsName), - Location: ®ion, - Tags: map[string]*string{ - connectionTag: to.Ptr(getAwiNSGRuleName(sourceVnetID)), - }, - Properties: &armnetwork.SecurityGroupPropertiesFormat{ - Subnets: []*armnetwork.Subnet{ - { - ID: &subnetID, - }, - }, - }, - } - - c.addVPCPolicyRulesToNSG( - &nsg, inboundCIDRs, sourceVnetID, policy, connectionTag, - ) - - resGroup := parseResourceGroupName(sourceVnetID) - if resGroup == "" { - return fmt.Errorf( - "failed to process Resource Group from Resource ID '%s'", sourceVnetID, - ) - } - - return c.createNetworkSecurityGroup( - ctx, - ngsName, - region, - account, - resGroup, - nsg, - ) -} - func (c *Client) updateNetworkSecurityGroup( ctx context.Context, account string, @@ -648,6 +176,8 @@ func (c *Client) createNetworkSecurityGroup( ) } + sg.Location = &location + future, err := client.NSG.BeginCreateOrUpdate( ctx, resourceGroup, @@ -668,130 +198,3 @@ func (c *Client) createNetworkSecurityGroup( return nil } - -func peeringBlockNSGTag(peeredVNetID string) string { - return "awi-vpc-block-" + peeredVNetID -} - -// addPeeringBlockRulesToNSG creates rules to prevent -// the traffic from VNet to the subnet associated with -// this NSG which is presumably in other VNet peered -// with that VNet. -func (c *Client) setPeeringBlockRulesToNSG( - ctx context.Context, - nsg *armnetwork.SecurityGroup, - peeredVNet armnetwork.VirtualNetwork, -) { - if nsg == nil || nsg.Properties == nil { - c.logger.Warnf( - "cannot set peering block rules for a Network Security Group. " + - "The NSG is nil or its properties are nil", - ) - return - } - peeredVNetID := helper.StringPointerToString(peeredVNet.ID) - if _, ok := nsg.Tags[peeredVNetID]; ok { - c.logger.Infof( - "Rules for blocking the traffic from VNet %s to Subnet with NSG %s "+ - "already exist. Nothing to do", - peeredVNetID, helper.StringPointerToString(nsg.ID), - ) - return - } - - vnetCIDRs := c.getVNetCIDRs(peeredVNet) - -} - -func (c *Client) addRulesWithNamePrefixToNSG( - nsg *armnetwork.SecurityGroup, - namePrefix string, - addresses []string, - allow bool, - priorityRangeStart int32, - priorityRangeEnd int32, -) error { - prioritiesInUse := takenPriorities(*nsg) - - securityRules := make([]*armnetwork.SecurityRule, 0, len(addresses)) - - currentPriority := priorityRangeStart - - access := armnetwork.SecurityRuleAccessDeny - if allow { - access = armnetwork.SecurityRuleAccessAllow - } - - ruleID := 1 - - for _, cidr := range addresses { - // Priorities cannot overlap with already existing - // rules so find first non used priority value. - for prioritiesInUse.Has(currentPriority) { - currentPriority++ - if currentPriority > priorityRangeEnd { - return fmt.Errorf( - "cannot attach VNet Policy rules to NSG '%s' as there are no more "+ - "available rule slots in the range %d:%d", - helper.StringPointerToString(nsg.ID), - priorityRangeStart, - priorityRangeEnd, - ) - } - } - - ruleName, err := getAwiNSGRuleNameWithIDSuffix(namePrefix, ruleID) - if err != nil { - return fmt.Errorf( - "failed to prepare a rule name for NSG '%s': %v", - helper.StringPointerToString(nsg.ID), - err, - ) - } - - securityRules = append( - securityRules, - &armnetwork.SecurityRule{ - Name: to.Ptr(ruleName), - Properties: &armnetwork.SecurityRulePropertiesFormat{ - Priority: to.Ptr(currentPriority), - Protocol: to.Ptr(armnetwork.SecurityRuleProtocolAsterisk), - SourceAddressPrefix: to.Ptr(cidr), - DestinationAddressPrefix: to.Ptr("*"), - Access: to.Ptr(access), - Direction: to.Ptr(armnetwork.SecurityRuleDirectionInbound), - SourcePortRange: to.Ptr("*"), - DestinationPortRange: to.Ptr("*"), - }, - }, - ) - currentPriority++ - ruleID++ - } - - nsg.Properties.SecurityRules = append(nsg.Properties.SecurityRules, securityRules...) - return nil -} - -func (c *Client) deleteRulesWithNamePrefixFromNSG( - nsg *armnetwork.SecurityGroup, - namePrefix string, -) { - if nsg == nil || nsg.Properties == nil { - return - } - - securityRulesWithoutVPCRules := make([]*armnetwork.SecurityRule, 0, len(nsg.Properties.SecurityRules)) - - for i := range nsg.Properties.SecurityRules { - if nsg.Properties.SecurityRules[i] == nil { - continue - } - if strings.HasPrefix(helper.StringPointerToString(nsg.Properties.SecurityRules[i].Name)) { - // We want to preserve only rules that are not matching our expected name. - securityRulesWithoutVPCRules = append(securityRulesWithoutVPCRules, nsg.Properties.SecurityRules[i]) - } - } - - nsg.Properties.SecurityRules = securityRulesWithoutVPCRules -} diff --git a/azure/securityGroupAccessControl.go b/azure/securityGroupAccessControl.go new file mode 100644 index 0000000..57dcf86 --- /dev/null +++ b/azure/securityGroupAccessControl.go @@ -0,0 +1,201 @@ +// Copyright (c) 2024 Cisco Systems, Inc. and its affiliates +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http:www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package azure + +import ( + "context" + "crypto/sha256" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + accesscontrol "github.com/app-net-interface/awi-infra-guard/azure/accessControl" + "github.com/app-net-interface/awi-infra-guard/connector/helper" +) + +// ApplyAccessRulesToNSG creates access rules from the provided +// AccessControlRuleSet. This operation is preceeded with the +// removal of older entries associated with the same name to +// avoid rule duplication and to refresh possibly stale rules. +func (c *Client) ApplyAccessRulesToNSG( + ctx context.Context, + account string, + region string, + nsg armnetwork.SecurityGroup, + rules accesscontrol.AccessControlRuleSet, +) error { + if nsg.Properties == nil { + return fmt.Errorf( + "cannot add Access Rules to Network Security Group %s as it has no properties", + helper.StringPointerToString(nsg.ID), + ) + } + + _, err := accesscontrol.RemoveRulesFromSecurityGroup(&nsg, rules.RuleNamesForVPC()) + if err != nil { + return fmt.Errorf( + "failed to remove rules '%v' from Network Security Group '%s': %w", + rules, helper.StringPointerToString(nsg.ID), err, + ) + } + + prioritiesInUse := takenPriorities(nsg) + + securityRules, err := rules.GenerateSecurityGroupRulesForVPC( + prioritiesInUse, + ) + if err != nil { + return fmt.Errorf( + "failed to generate Security Rules: %w", err, + ) + } + + if nsg.Properties.SecurityRules == nil { + nsg.Properties.SecurityRules = make([]*armnetwork.SecurityRule, 0, len(securityRules)) + } + + for i := range securityRules { + nsg.Properties.SecurityRules = append( + nsg.Properties.SecurityRules, + &securityRules[i], + ) + } + + err = c.updateNetworkSecurityGroup( + ctx, + account, + nsg, + ) + if err != nil { + return fmt.Errorf( + "failed to update Network Security Group '%s' with new Security Group Rules: %w", + helper.StringPointerToString(nsg.ID), err, + ) + } + + return nil +} + +func (c *Client) DeleteAccessRulesFromNSG( + ctx context.Context, + account string, + region string, + nsg armnetwork.SecurityGroup, + rules accesscontrol.RuleNames, +) error { + if nsg.Properties == nil { + return fmt.Errorf( + "cannot add Access Rules to Network Security Group %s as it has no properties", + helper.StringPointerToString(nsg.ID), + ) + } + + changed, err := accesscontrol.RemoveRulesFromSecurityGroup(&nsg, rules) + if err != nil { + return fmt.Errorf( + "failed to remove rules '%v' from Network Security Group '%s': %w", + rules, helper.StringPointerToString(nsg.ID), err, + ) + } + if !changed { + c.logger.Debugf( + "Network Security Group %s had no rules associated with names "+ + "%v. Update is redundant", + helper.StringPointerToString(nsg.ID), rules, + ) + return nil + } + + err = c.updateNetworkSecurityGroup( + ctx, + account, + nsg, + ) + if err != nil { + return fmt.Errorf( + "failed to update Network Security Group '%s' with removed Security Group Rules: %w", + helper.StringPointerToString(nsg.ID), err, + ) + } + + return nil +} + +func awiNSGName(subnetID string) string { + hasher := sha256.New() + hasher.Write([]byte(subnetID)) + hashBytes := hasher.Sum(nil) + + return fmt.Sprintf("awi-nsg-%x", hashBytes) +} + +func (c *Client) createAWINetworkSecurityGroup( + ctx context.Context, + account string, + region string, + subnetID string, +) error { + ngsName := awiNSGName(subnetID) + + nsg := armnetwork.SecurityGroup{ + Name: to.Ptr(ngsName), + Location: ®ion, + Properties: &armnetwork.SecurityGroupPropertiesFormat{ + Subnets: []*armnetwork.Subnet{ + { + ID: &subnetID, + }, + }, + }, + } + + resGroup := parseResourceGroupName(subnetID) + if resGroup == "" { + return fmt.Errorf( + "failed to process Resource Group from Resource ID '%s'", subnetID, + ) + } + + return c.createNetworkSecurityGroup( + ctx, + ngsName, + region, + account, + resGroup, + nsg, + ) +} + +func takenPriorities(sg armnetwork.SecurityGroup) helper.Set[uint] { + priorities := helper.Set[uint]{} + + if sg.Properties == nil { + return priorities + } + + for i := range sg.Properties.SecurityRules { + if sg.Properties.SecurityRules[i] == nil || sg.Properties.SecurityRules[i].Properties == nil { + continue + } + if sg.Properties.SecurityRules[i].Properties.Priority != nil { + priorities.Set(uint(*sg.Properties.SecurityRules[i].Properties.Priority)) + } + } + + return priorities +} diff --git a/azure/subnetAccessControl.go b/azure/subnetAccessControl.go new file mode 100644 index 0000000..66b97a7 --- /dev/null +++ b/azure/subnetAccessControl.go @@ -0,0 +1,153 @@ +// Copyright (c) 2024 Cisco Systems, Inc. and its affiliates +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http:www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package azure + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + accesscontrol "github.com/app-net-interface/awi-infra-guard/azure/accessControl" + "github.com/app-net-interface/awi-infra-guard/connector/helper" +) + +func (c *Client) ApplyAccessRulesToSubnet( + ctx context.Context, + account string, + region string, + subnet armnetwork.Subnet, + rules accesscontrol.AccessControlRuleSet, +) error { + if subnet.Properties == nil { + c.logger.Warnf( + "cannot update subnet '%s' as it has no properties", + helper.StringPointerToString(subnet.ID), + ) + return nil + } + + // TODO: This situation may happen if a new Subnet + // was created during processing this method. AWI ensures + // first that all existing subnets are associated with any + // NSG but this is something that can happen anyway. + // + // Solving this is a part of greater issue - updating subnets + // that were created after creating a connection. + if subnet.Properties.NetworkSecurityGroup == nil { + c.logger.Warnf( + "the subnet '%s' has no NSG network attached. "+ + "Cannot attach following rules there: %v", + helper.StringPointerToString(subnet.ID), rules, + ) + return nil + } + + nsg, nsgAcc, err := c.getNSG( + ctx, + helper.StringPointerToString(subnet.Properties.NetworkSecurityGroup.ID), + helper.StringPointerToString(®ion), + ) + if err != nil { + return fmt.Errorf( + "cannot get Network Security Group %s associated with subnet %s: %w", + helper.StringPointerToString(subnet.Properties.NetworkSecurityGroup.ID), + helper.StringPointerToString(subnet.ID), + err, + ) + } + + err = c.ApplyAccessRulesToNSG( + ctx, + nsgAcc, + region, + nsg, + rules, + ) + if err != nil { + return fmt.Errorf( + "failed to apply rules %v to associated Network Security Group %s: %w", + rules, + helper.StringPointerToString(subnet.Properties.NetworkSecurityGroup.ID), + err, + ) + } + + return nil +} + +func (c *Client) DeleteAccessRulesFromSubnet( + ctx context.Context, + region string, + subnet armnetwork.Subnet, + rules accesscontrol.RuleNames, +) error { + if subnet.Properties == nil { + c.logger.Warnf( + "cannot update subnet '%s' as it has no properties", + helper.StringPointerToString(subnet.ID), + ) + return nil + } + + // TODO: This situation may happen if a new Subnet + // was created during processing this method. AWI ensures + // first that all existing subnets are associated with any + // NSG but this is something that can happen anyway. + // + // Solving this is a part of greater issue - updating subnets + // that were created after creating a connection. + if subnet.Properties.NetworkSecurityGroup == nil { + c.logger.Warnf( + "the subnet '%s' has no NSG network attached. "+ + "Cannot attach following rules there: %v", + helper.StringPointerToString(subnet.ID), rules, + ) + return nil + } + + nsg, nsgAcc, err := c.getNSG( + ctx, + helper.StringPointerToString(subnet.Properties.NetworkSecurityGroup.ID), + helper.StringPointerToString(®ion), + ) + if err != nil { + return fmt.Errorf( + "cannot get Network Security Group %s associated with subnet %s: %w", + helper.StringPointerToString(subnet.Properties.NetworkSecurityGroup.ID), + helper.StringPointerToString(subnet.ID), + err, + ) + } + err = c.DeleteAccessRulesFromNSG( + ctx, + nsgAcc, + region, + nsg, + rules, + ) + if err != nil { + return fmt.Errorf( + "failed to delete rules %v from associated Network Security Group %s: %w", + rules, + helper.StringPointerToString(subnet.Properties.NetworkSecurityGroup.ID), + err, + ) + } + + return nil +} diff --git a/azure/subnets.go b/azure/subnets.go index 254a4f2..9493379 100644 --- a/azure/subnets.go +++ b/azure/subnets.go @@ -92,6 +92,25 @@ func (c *Client) ListSubnets(ctx context.Context, params *infrapb.ListSubnetsReq return subnets, nil } +func (c *Client) getSubnet(ctx context.Context, resGroupName, vnetName, subnetName string) ( + armnetwork.Subnet, string, error, +) { + for account, client := range c.accountClients { + subnet, err := client.Subnet.Get( + ctx, resGroupName, vnetName, subnetName, nil, + ) + if err != nil { + return armnetwork.Subnet{}, account, fmt.Errorf( + "failed to get Subnet '%s:%s': %w", vnetName, subnetName, err) + } + return subnet.Subnet, account, nil + } + + return armnetwork.Subnet{}, "", fmt.Errorf( + "subnet '%s:%s' not found", vnetName, subnetName, + ) +} + // Helper function to extract the resource group name from a resource ID func ExtractResourceGroupName(resourceID string) string { parts := strings.Split(resourceID, "/") diff --git a/azure/vpc.go b/azure/vpc.go index 1d91a42..36ce0b4 100644 --- a/azure/vpc.go +++ b/azure/vpc.go @@ -20,9 +20,9 @@ package azure import ( "context" "fmt" - "strings" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/app-net-interface/awi-infra-guard/connector/helper" "github.com/app-net-interface/awi-infra-guard/grpc/go/infrapb" "github.com/app-net-interface/awi-infra-guard/types" ) @@ -117,16 +117,6 @@ func (c *Client) getVPC(ctx context.Context, id, region string) ( ) } -// VPCConnector interface implementation -func (c *Client) ConnectVPC(ctx context.Context, input types.SingleVPCConnectionParams) (types.SingleVPCConnectionOutput, error) { - - return types.SingleVPCConnectionOutput{}, nil -} - -func getVPCSGName(connectionName string) string { - return fmt.Sprintf("awi-%s-sg", strings.Replace(connectionName, " ", "-", -1)) -} - func (c *Client) getVNetCIDRs(vnet armnetwork.VirtualNetwork) []string { if vnet.Properties == nil { c.logger.Infof( @@ -153,133 +143,37 @@ func (c *Client) getVNetCIDRs(vnet armnetwork.VirtualNetwork) []string { return prefixes } -func (c *Client) createDenyingVPCPolicy( +func (c *Client) EnsureEverySubnetInVPCHasNetworkSecurityGroup( ctx context.Context, account string, - region string, - sourceVNET armnetwork.VirtualNetwork, - destinationVNET armnetwork.VirtualNetwork, - trafficFromSourceAllowed bool, - connectionName string, + vnet armnetwork.VirtualNetwork, ) error { - destinationPrefixes := c.getVNetCIDRs(destinationVNET) - - err = c.refreshSubnetSecurityGroupWithVPCInbound( - ctx, - account, - region, - destinationPrefixes, - vpcPolicyAllow, - vnet, - sourceID, - ruleName, - ) - if err != nil { - return fmt.Errorf( - "failed to refresh Security Groups for subnets from VNet %s: %w", - destinationVpcID, err, - ) - } -} - -func (c *Client) ConnectVPCs(ctx context.Context, input types.VPCConnectionParams) (types.VPCConnectionOutput, error) { - vnet1, accountID1, err := c.getVPC(ctx, input.Vpc1ID, input.Region1) - if err != nil { - return types.VPCConnectionOutput{}, fmt.Errorf( - "failed to get VPC '%s' in region '%s'", input.Vpc1ID, input.Region1, - ) - } - vnet2, accountID2, err := c.getVPC(ctx, input.Vpc2ID, input.Region2) - if err != nil { - return types.VPCConnectionOutput{}, fmt.Errorf( - "failed to get VPC '%s' in region '%s'", input.Vpc2ID, input.Region2, + if vnet.Properties == nil { + c.logger.Warnf( + "cannot update vnet subnets as vnet '%s' has no properties", + helper.StringPointerToString(vnet.ID), ) + return nil } - - if !input.AllowAllTraffic { - err = c.createDenyingVPCPolicy(ctx, input.ConnName) - if err != nil { - return types.VPCConnectionOutput{}, fmt.Errorf( - "failed to create blocking policy across VPCs %s:%s due to %w", - input.Region1, *vnet1.ID, input.Region2, *vnet2.ID, err, + for i := range vnet.Properties.Subnets { + if vnet.Properties.Subnets[i] == nil || vnet.Properties.Subnets[i].Properties == nil { + continue + } + subnetProps := vnet.Properties.Subnets[i].Properties + if subnetProps.NetworkSecurityGroup == nil { + err := c.createAWINetworkSecurityGroup( + ctx, + account, + helper.StringPointerToString(vnet.Location), + helper.StringPointerToString(vnet.Properties.Subnets[i].ID), ) + if err != nil { + return fmt.Errorf( + "failed to create AWI Network Security Group for Subnet %s: %w", + helper.StringPointerToString(vnet.Properties.Subnets[i].ID), err, + ) + } } } - - if err = c.createVnetPeering(ctx, *vnet1.ID, *vnet2.ID, accountID1); err != nil { - return types.VPCConnectionOutput{}, fmt.Errorf( - "failed to create a VPC Peering from %s:%s to %s:%s due to %w", - input.Region1, *vnet1.ID, input.Region2, *vnet2.ID, err, - ) - } - - if err = c.createVnetPeering(ctx, *vnet2.ID, *vnet1.ID, accountID2); err != nil { - return types.VPCConnectionOutput{}, fmt.Errorf( - "failed to create a VPC Peering from %s:%s to %s:%s due to %w", - input.Region2, *vnet2.ID, input.Region1, *vnet1.ID, err, - ) - } - - return types.VPCConnectionOutput{ - Region1: input.Region1, - Region2: input.Region2, - }, nil -} - -func (c *Client) DisconnectVPC(ctx context.Context, input types.SingleVPCDisconnectionParams) (types.VPCDisconnectionOutput, error) { - // TBD - return types.VPCDisconnectionOutput{}, nil -} - -// func getNSGNameForPeeredVPCs(sourceVNET, destinationVNET string) string { - -// } - -func (c *Client) DisconnectVPCs(ctx context.Context, input types.VPCDisconnectionParams) (types.VPCDisconnectionOutput, error) { - vnet1, accountID1, err := c.getVPC(ctx, input.Vpc1ID, input.Region1) - if err != nil { - return types.VPCDisconnectionOutput{}, fmt.Errorf( - "failed to get VPC '%s' in region '%s'", input.Vpc1ID, input.Region1, - ) - } - vnet2, accountID2, err := c.getVPC(ctx, input.Vpc2ID, input.Region2) - if err != nil { - return types.VPCDisconnectionOutput{}, fmt.Errorf( - "failed to get VPC '%s' in region '%s'", input.Vpc2ID, input.Region2, - ) - } - - peering1 := c.getVnetPeeringFromVnet(vnet1, *vnet2.ID) - if peering1 != "" { - c.deleteVnetPeering( - ctx, - accountID1, - parseResourceGroupName(*vnet1.ID), - *vnet1.Name, - vnetPeeringName(*vnet1.ID, *vnet2.ID), - ) - } else { - c.logger.Infof( - "VNet Peering %s not found. Skipping it", - vnetPeeringName(*vnet1.ID, *vnet2.ID), - ) - } - - peering2 := c.getVnetPeeringFromVnet(vnet2, *vnet1.ID) - if peering2 != "" { - c.deleteVnetPeering( - ctx, - accountID2, - parseResourceGroupName(*vnet2.ID), - *vnet2.Name, - vnetPeeringName(*vnet2.ID, *vnet1.ID), - ) - } else { - c.logger.Infof( - "VNet Peering %s not found. Skipping it", - vnetPeeringName(*vnet2.ID, *vnet1.ID), - ) - } - - return types.VPCDisconnectionOutput{}, nil + return nil } diff --git a/azure/vpcAccessControl.go b/azure/vpcAccessControl.go new file mode 100644 index 0000000..ff4ca92 --- /dev/null +++ b/azure/vpcAccessControl.go @@ -0,0 +1,228 @@ +// Copyright (c) 2024 Cisco Systems, Inc. and its affiliates +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http:www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package azure + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + accesscontrol "github.com/app-net-interface/awi-infra-guard/azure/accessControl" + "github.com/app-net-interface/awi-infra-guard/connector/helper" +) + +func (c *Client) ApplyAccessRulesToVPC( + ctx context.Context, + account string, + vnet armnetwork.VirtualNetwork, + rules accesscontrol.AccessControlRuleSet, +) error { + if vnet.Properties == nil { + c.logger.Warnf( + "cannot update vnet subnets as vnet '%s' has no properties", + helper.StringPointerToString(vnet.ID), + ) + return nil + } + + err := c.EnsureEverySubnetInVPCHasNetworkSecurityGroup( + ctx, account, vnet, + ) + if err != nil { + return fmt.Errorf( + "failed to set NSG for every subnet in VNET %s: %w", + helper.StringPointerToString(vnet.ID), + err, + ) + } + + for i := range vnet.Properties.Subnets { + subnet := vnet.Properties.Subnets[i] + if subnet == nil { + continue + } + + err := c.ApplyAccessRulesToSubnet( + ctx, + account, + helper.StringPointerToString(vnet.Location), + *subnet, + rules, + ) + if err != nil { + return fmt.Errorf( + "failed to apply rules %v to subnet %s: %w", + rules, helper.StringPointerToString(subnet.ID), err, + ) + } + } + + return nil +} + +func (c *Client) DeleteAccessRulesFromVPC( + ctx context.Context, + vnet armnetwork.VirtualNetwork, + rules accesscontrol.RuleNames, +) error { + if vnet.Properties == nil { + c.logger.Warnf( + "cannot update vnet subnets as vnet '%s' has no properties", + helper.StringPointerToString(vnet.ID), + ) + return nil + } + + for i := range vnet.Properties.Subnets { + subnet := vnet.Properties.Subnets[i] + if subnet == nil { + continue + } + + err := c.DeleteAccessRulesFromSubnet( + ctx, + helper.StringPointerToString(vnet.Location), + *subnet, + rules, + ) + if err != nil { + return fmt.Errorf( + "failed to apply rules %v to subnet %s: %w", + rules, helper.StringPointerToString(subnet.ID), err, + ) + } + } + + return nil +} + +func (c *Client) BlockTrafficFromVPCToVPC( + ctx context.Context, + destinationVNET armnetwork.VirtualNetwork, + sourceVNET armnetwork.VirtualNetwork, + destinationVNETAccount string, +) error { + ruleName := accesscontrol.VPCRuleName( + helper.StringPointerToString(destinationVNET.ID), + helper.StringPointerToString(sourceVNET.ID), + ) + + cidrsToBlock := c.getVNetCIDRs(sourceVNET) + + ruleset := accesscontrol.AccessControlRuleSet{} + ruleset.NewDirectedVPCRules( + ruleName, + accesscontrol.AccessDeny, + cidrsToBlock, + ) + + err := c.ApplyAccessRulesToVPC( + ctx, + destinationVNETAccount, + destinationVNET, + ruleset, + ) + if err != nil { + return fmt.Errorf( + "failed to apply VPC Access rules %v to VNET %s: %w", + ruleset, + helper.StringPointerToString(destinationVNET.ID), + err, + ) + } + + return nil +} + +func (c *Client) UnblockTrafficFromVPCToVPC( + ctx context.Context, + destinationVNET armnetwork.VirtualNetwork, + sourceVNET armnetwork.VirtualNetwork, + destinationVNETAccount string, +) error { + ruleName := accesscontrol.VPCRuleName( + helper.StringPointerToString(destinationVNET.ID), + helper.StringPointerToString(sourceVNET.ID), + ) + + err := c.DeleteAccessRulesFromVPC( + ctx, + destinationVNET, + accesscontrol.RuleNames{ + ruleName, + }, + ) + if err != nil { + return fmt.Errorf( + "failed to delete VPC Access rule %v from VNET %s: %w", + ruleName, + helper.StringPointerToString(destinationVNET.ID), + err, + ) + } + + return nil +} + +func (c *Client) BlockTrafficBetweenVPCs( + ctx context.Context, + vnet1, vnet2 armnetwork.VirtualNetwork, + acc1, acc2 string, +) error { + if err := c.BlockTrafficFromVPCToVPC(ctx, vnet1, vnet2, acc1); err != nil { + return fmt.Errorf( + "failed to block traffic from VNET %s to VNET %s: %w", + helper.StringPointerToString(vnet2.ID), + helper.StringPointerToString(vnet1.ID), + err, + ) + } + if err := c.BlockTrafficFromVPCToVPC(ctx, vnet2, vnet1, acc2); err != nil { + return fmt.Errorf( + "failed to block traffic from VNET %s to VNET %s: %w", + helper.StringPointerToString(vnet1.ID), + helper.StringPointerToString(vnet2.ID), + err, + ) + } + return nil +} + +func (c *Client) UnblockTrafficBetweenVPCs( + ctx context.Context, + vnet1, vnet2 armnetwork.VirtualNetwork, + acc1, acc2 string, +) error { + if err := c.UnblockTrafficFromVPCToVPC(ctx, vnet1, vnet2, acc1); err != nil { + return fmt.Errorf( + "failed to unblock traffic from VNET %s to VNET %s: %w", + helper.StringPointerToString(vnet2.ID), + helper.StringPointerToString(vnet1.ID), + err, + ) + } + if err := c.UnblockTrafficFromVPCToVPC(ctx, vnet2, vnet1, acc2); err != nil { + return fmt.Errorf( + "failed to unblock traffic from VNET %s to VNET %s: %w", + helper.StringPointerToString(vnet1.ID), + helper.StringPointerToString(vnet2.ID), + err, + ) + } + return nil +} diff --git a/azure/vpcConnection.go b/azure/vpcConnection.go new file mode 100644 index 0000000..05081d6 --- /dev/null +++ b/azure/vpcConnection.go @@ -0,0 +1,140 @@ +// Copyright (c) 2024 Cisco Systems, Inc. and its affiliates +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http:www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package azure + +import ( + "context" + "fmt" + + "github.com/app-net-interface/awi-infra-guard/types" +) + +// VPCConnector interface implementation +func (c *Client) ConnectVPC(ctx context.Context, input types.SingleVPCConnectionParams) (types.SingleVPCConnectionOutput, error) { + + return types.SingleVPCConnectionOutput{}, nil +} + +func (c *Client) ConnectVPCs(ctx context.Context, input types.VPCConnectionParams) (types.VPCConnectionOutput, error) { + input.Region1 = "westus2" + input.Region2 = "westus2" + + vnet1, accountID1, err := c.getVPC(ctx, input.Vpc1ID, input.Region1) + if err != nil { + return types.VPCConnectionOutput{}, fmt.Errorf( + "failed to get VPC '%s' in region '%s'", input.Vpc1ID, input.Region1, + ) + } + vnet2, accountID2, err := c.getVPC(ctx, input.Vpc2ID, input.Region2) + if err != nil { + return types.VPCConnectionOutput{}, fmt.Errorf( + "failed to get VPC '%s' in region '%s'", input.Vpc2ID, input.Region2, + ) + } + + if err = c.BlockTrafficBetweenVPCs(ctx, vnet1, vnet2, accountID1, accountID2); err != nil { + return types.VPCConnectionOutput{}, fmt.Errorf( + "failed to create a blocking rule between %s:%s and %s:%s due to %w", + input.Region1, *vnet1.ID, input.Region2, *vnet2.ID, err, + ) + } + + if err = c.createVnetPeering(ctx, *vnet1.ID, *vnet2.ID, accountID1); err != nil { + return types.VPCConnectionOutput{}, fmt.Errorf( + "failed to create a VPC Peering from %s:%s to %s:%s due to %w", + input.Region1, *vnet1.ID, input.Region2, *vnet2.ID, err, + ) + } + + if err = c.createVnetPeering(ctx, *vnet2.ID, *vnet1.ID, accountID2); err != nil { + return types.VPCConnectionOutput{}, fmt.Errorf( + "failed to create a VPC Peering from %s:%s to %s:%s due to %w", + input.Region2, *vnet2.ID, input.Region1, *vnet1.ID, err, + ) + } + + return types.VPCConnectionOutput{ + Region1: input.Region1, + Region2: input.Region2, + }, nil +} + +func (c *Client) DisconnectVPC(ctx context.Context, input types.SingleVPCDisconnectionParams) (types.VPCDisconnectionOutput, error) { + // TBD + return types.VPCDisconnectionOutput{}, nil +} + +// func getNSGNameForPeeredVPCs(sourceVNET, destinationVNET string) string { + +// } + +func (c *Client) DisconnectVPCs(ctx context.Context, input types.VPCDisconnectionParams) (types.VPCDisconnectionOutput, error) { + vnet1, accountID1, err := c.getVPC(ctx, input.Vpc1ID, input.Region1) + if err != nil { + return types.VPCDisconnectionOutput{}, fmt.Errorf( + "failed to get VPC '%s' in region '%s'", input.Vpc1ID, input.Region1, + ) + } + vnet2, accountID2, err := c.getVPC(ctx, input.Vpc2ID, input.Region2) + if err != nil { + return types.VPCDisconnectionOutput{}, fmt.Errorf( + "failed to get VPC '%s' in region '%s'", input.Vpc2ID, input.Region2, + ) + } + + peering1 := c.getVnetPeeringFromVnet(vnet1, *vnet2.ID) + if peering1 != "" { + c.deleteVnetPeering( + ctx, + accountID1, + parseResourceGroupName(*vnet1.ID), + *vnet1.Name, + vnetPeeringName(*vnet1.ID, *vnet2.ID), + ) + } else { + c.logger.Infof( + "VNet Peering %s not found. Skipping it", + vnetPeeringName(*vnet1.ID, *vnet2.ID), + ) + } + + peering2 := c.getVnetPeeringFromVnet(vnet2, *vnet1.ID) + if peering2 != "" { + c.deleteVnetPeering( + ctx, + accountID2, + parseResourceGroupName(*vnet2.ID), + *vnet2.Name, + vnetPeeringName(*vnet2.ID, *vnet1.ID), + ) + } else { + c.logger.Infof( + "VNet Peering %s not found. Skipping it", + vnetPeeringName(*vnet2.ID, *vnet1.ID), + ) + } + + if err = c.UnblockTrafficBetweenVPCs(ctx, vnet1, vnet2, accountID1, accountID2); err != nil { + return types.VPCDisconnectionOutput{}, fmt.Errorf( + "failed to remove a blocking rule between %s:%s and %s:%s due to %w", + input.Region1, *vnet1.ID, input.Region2, *vnet2.ID, err, + ) + } + + return types.VPCDisconnectionOutput{}, nil +} diff --git a/connector/helper/set.go b/connector/helper/set.go index b427e59..6f0067e 100644 --- a/connector/helper/set.go +++ b/connector/helper/set.go @@ -30,9 +30,20 @@ func SetFromSlice[T comparable](s []T) Set[T] { } func (s *Set[T]) Set(v T) { + if s.values == nil { + s.values = map[T]struct{}{} + } s.values[v] = struct{}{} } +func (s *Set[T]) Keys() []T { + keys := make([]T, 0, len(s.values)) + for k := range s.values { + keys = append(keys, k) + } + return keys +} + func (s *Set[T]) Has(v T) bool { _, ok := s.values[v] return ok