diff --git a/pkg/tui/views/psql.go b/pkg/tui/views/psql.go index 133c899..b3dfd5e 100644 --- a/pkg/tui/views/psql.go +++ b/pkg/tui/views/psql.go @@ -3,6 +3,9 @@ package views import ( "context" "fmt" + "io" + "net" + "net/http" "os/exec" tea "github.com/charmbracelet/bubbletea" @@ -76,7 +79,27 @@ func loadDataPSQL(ctx context.Context, in *PSQLInput) (*exec.Cmd, error) { return nil, err } - connectionInfo, err := postgres.NewRepo(c).GetPostgresConnectionInfo(ctx, in.PostgresID) + pgc := postgres.NewRepo(c) + + pg, err := pgc.GetPostgres(ctx, in.PostgresID) + if err != nil { + return nil, err + } + + // only check access if error is nil in case ipify is down + userIP, ok := getUserIP() + if ok { + hasAccess, err := hasAccessToPostgres(pg, userIP) + if err != nil { + return nil, err + } + + if !hasAccess { + return nil, fmt.Errorf("IP address (%s) not in allow list for %s", userIP, pg.Name) + } + } + + connectionInfo, err := pgc.GetPostgresConnectionInfo(ctx, in.PostgresID) if err != nil { return nil, err } @@ -84,6 +107,34 @@ func loadDataPSQL(ctx context.Context, in *PSQLInput) (*exec.Cmd, error) { return exec.Command(string(in.Tool), connectionInfo.ExternalConnectionString), nil } +func hasAccessToPostgres(pg *client.PostgresDetail, userIP net.IP) (bool, error) { + for _, allowedIPs := range pg.IpAllowList { + _, cidr, err := net.ParseCIDR(allowedIPs.CidrBlock) + if err != nil { + return false, err + } + + if cidr.Contains(userIP) { + return true, nil + } + } + return false, nil +} + +func getUserIP() (net.IP, bool) { + userIPRes, err := http.Get("https://api.ipify.org") + if err != nil { + return nil, false + } + + userIPBytes, err := io.ReadAll(userIPRes.Body) + if err != nil { + return nil, false + } + + return net.ParseIP(string(userIPBytes)), true +} + func (v *PSQLView) Init() tea.Cmd { if v.postgresTable != nil { return v.postgresTable.Init()