From 4b6b49d00aafe79e7b4716e0244167e7287c59b3 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 6 Jan 2025 21:52:25 -0500 Subject: [PATCH] feat(catalog): Add Catalog Registry --- catalog/catalog.go | 52 ++++++++++++++ catalog/glue.go | 89 +++++++++++++----------- catalog/registry.go | 142 +++++++++++++++++++++++++++++++++++++++ catalog/registry_test.go | 74 ++++++++++++++++++++ catalog/rest.go | 53 ++++++++++++--- catalog/rest_test.go | 45 +++++++++++++ 6 files changed, 408 insertions(+), 47 deletions(-) create mode 100644 catalog/registry.go create mode 100644 catalog/registry_test.go diff --git a/catalog/catalog.go b/catalog/catalog.go index 65da7e5..b44a27f 100644 --- a/catalog/catalog.go +++ b/catalog/catalog.go @@ -21,6 +21,8 @@ import ( "context" "crypto/tls" "errors" + "fmt" + "maps" "net/url" "github.com/apache/iceberg-go" @@ -45,6 +47,9 @@ var ( ErrNoSuchTable = errors.New("table does not exist") ErrNoSuchNamespace = errors.New("namespace does not exist") ErrNamespaceAlreadyExists = errors.New("namespace already exists") + ErrTableAlreadyExists = errors.New("table already exists") + ErrCatalogNotFound = errors.New("catalog type not registered") + ErrNamespaceNotEmpty = errors.New("namespace is not empty") ) // WithAwsConfig sets the AWS configuration for the catalog. @@ -194,3 +199,50 @@ func TableNameFromIdent(ident table.Identifier) string { func NamespaceFromIdent(ident table.Identifier) table.Identifier { return ident[:len(ident)-1] } + +func checkForOverlap(removals []string, updates iceberg.Properties) error { + overlap := []string{} + for _, key := range removals { + if _, ok := updates[key]; ok { + overlap = append(overlap, key) + } + } + if len(overlap) > 0 { + return fmt.Errorf("conflict between removals and updates for keys: %v", overlap) + } + return nil +} + +func getUpdatedPropsAndUpdateSummary(currentProps iceberg.Properties, removals []string, updates iceberg.Properties) (iceberg.Properties, PropertiesUpdateSummary, error) { + if err := checkForOverlap(removals, updates); err != nil { + return nil, PropertiesUpdateSummary{}, err + } + + var ( + updatedProps = maps.Clone(currentProps) + removed = make([]string, 0, len(removals)) + updated = make([]string, 0, len(updates)) + ) + + for _, key := range removals { + if _, exists := updatedProps[key]; exists { + delete(updatedProps, key) + removed = append(removed, key) + } + } + + for key, value := range updates { + if updatedProps[key] != value { + updated = append(updated, key) + updatedProps[key] = value + } + } + + summary := PropertiesUpdateSummary{ + Removed: removed, + Updated: updated, + Missing: iceberg.Difference(removals, removed), + } + + return updatedProps, summary, nil +} diff --git a/catalog/glue.go b/catalog/glue.go index c970e5a..f8ddedc 100644 --- a/catalog/glue.go +++ b/catalog/glue.go @@ -21,11 +21,14 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/apache/iceberg-go" "github.com/apache/iceberg-go/io" "github.com/apache/iceberg-go/table" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/glue" "github.com/aws/aws-sdk-go-v2/service/glue/types" ) @@ -54,6 +57,50 @@ var ( _ Catalog = (*GlueCatalog)(nil) ) +func init() { + Register("glue", RegistrarFunc(func(_ string, props iceberg.Properties) (Catalog, error) { + awsConfig, err := toAwsConfig(props) + if err != nil { + return nil, err + } + + return NewGlueCatalog(WithAwsConfig(awsConfig), WithAwsProperties(AwsProperties(props))), nil + })) +} + +func toAwsConfig(p iceberg.Properties) (aws.Config, error) { + opts := make([]func(*config.LoadOptions) error, 0) + + for k, v := range p { + switch k { + case "glue.region": + opts = append(opts, config.WithRegion(v)) + case "glue.endpoint": + opts = append(opts, config.WithBaseEndpoint(v)) + case "glue.max-retries": + maxRetry, err := strconv.Atoi(v) + if err != nil { + return aws.Config{}, err + } + opts = append(opts, config.WithRetryMaxAttempts(maxRetry)) + case "glue.retry-mode": + m, err := aws.ParseRetryMode(v) + if err != nil { + return aws.Config{}, err + } + opts = append(opts, config.WithRetryMode(m)) + } + } + + key, secret, token := p.Get("glue.access-key", ""), p.Get("glue.secret-access-key", ""), p.Get("glue.session-token", "") + if key != "" && secret != "" && token != "" { + opts = append(opts, config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(key, secret, token))) + } + + return config.LoadDefaultConfig(context.Background(), opts...) +} + type glueAPI interface { CreateTable(ctx context.Context, params *glue.CreateTableInput, optFns ...func(*glue.Options)) (*glue.CreateTableOutput, error) GetTable(ctx context.Context, params *glue.GetTableInput, optFns ...func(*glue.Options)) (*glue.GetTableOutput, error) @@ -353,39 +400,9 @@ func (c *GlueCatalog) UpdateNamespaceProperties(ctx context.Context, namespace t return PropertiesUpdateSummary{}, err } - overlap := []string{} - for _, key := range removals { - if _, exists := updates[key]; exists { - overlap = append(overlap, key) - } - } - if len(overlap) > 0 { - return PropertiesUpdateSummary{}, fmt.Errorf("conflict between removals and updates for keys: %v", overlap) - } - - updatedProperties := make(map[string]string) - if database.Parameters != nil { - for k, v := range database.Parameters { - updatedProperties[k] = v - } - } - - // Removals. - removed := []string{} - for _, key := range removals { - if _, exists := updatedProperties[key]; exists { - delete(updatedProperties, key) - removed = append(removed, key) - } - } - - // Updates. - updated := []string{} - for key, value := range updates { - if updatedProperties[key] != value { - updatedProperties[key] = value - updated = append(updated, key) - } + updatedProperties, propertiesUpdateSummary, err := getUpdatedPropsAndUpdateSummary(database.Parameters, removals, updates) + if err != nil { + return PropertiesUpdateSummary{}, err } _, err = c.glueSvc.UpdateDatabase(ctx, &glue.UpdateDatabaseInput{CatalogId: c.catalogId, Name: aws.String(databaseName), DatabaseInput: &types.DatabaseInput{ @@ -396,12 +413,6 @@ func (c *GlueCatalog) UpdateNamespaceProperties(ctx context.Context, namespace t return PropertiesUpdateSummary{}, fmt.Errorf("failed to update namespace properties %s: %w", databaseName, err) } - propertiesUpdateSummary := PropertiesUpdateSummary{ - Removed: removed, - Updated: updated, - Missing: iceberg.Difference(removals, removed), - } - return propertiesUpdateSummary, nil } diff --git a/catalog/registry.go b/catalog/registry.go new file mode 100644 index 0000000..099852f --- /dev/null +++ b/catalog/registry.go @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog + +import ( + "fmt" + "maps" + "net/url" + "slices" + "strings" + "sync" + + "github.com/apache/iceberg-go" +) + +type registry map[string]Registrar + +func (r registry) getKeys() []string { + regMutex.Lock() + defer regMutex.Unlock() + return slices.Collect(maps.Keys(r)) +} + +func (r registry) set(catalogType string, reg Registrar) { + regMutex.Lock() + defer regMutex.Unlock() + r[catalogType] = reg +} + +func (r registry) get(catalogType string) (Registrar, bool) { + regMutex.Lock() + defer regMutex.Unlock() + reg, ok := r[catalogType] + return reg, ok +} + +func (r registry) remove(catalogType string) { + regMutex.Lock() + defer regMutex.Unlock() + delete(r, catalogType) +} + +var ( + regMutex sync.Mutex + defaultRegistry = registry{} +) + +// Registrar is a factory for creating Catalog instances, used for registering to use +// with LoadCatalog. +type Registrar interface { + GetCatalog(catalogURI string, props iceberg.Properties) (Catalog, error) +} + +type RegistrarFunc func(string, iceberg.Properties) (Catalog, error) + +func (f RegistrarFunc) GetCatalog(catalogURI string, props iceberg.Properties) (Catalog, error) { + return f(catalogURI, props) +} + +// Register adds the new catalog type to the registry. If the catalog type is already registered, it will be replaced. +func Register(catalogType string, reg Registrar) { + if reg == nil { + panic("catalog: RegisterCatalog catalog factory is nil") + } + defaultRegistry.set(catalogType, reg) +} + +// Unregister removes the requested catalog factory from the registry. +func Unregister(catalogType string) { + defaultRegistry.remove(catalogType) +} + +// GetRegsisteredCatalogs returns the list of registered catalog names that can +// be looked up via LoadCatalog. +func GetRegisteredCatalogs() []string { + return defaultRegistry.getKeys() +} + +// Load allows loading a specific catalog by URI and properties. +// +// This is utilized alongside RegisterCatalog/UnregisterCatalog to not only allow +// easier catalog loading but also to allow for custom catalog implementations to +// be registered and loaded external to this module. +// +// The URI is used to determine the catalog type by first checking if it contains +// the string "://" indicating the presence of a scheme. If so, the schema is used +// to lookup the registered catalog. i.e. "glue://..." would return the Glue catalog +// implementation, passing the URI and properties to NewGlueCatalog. If no scheme is +// present, then the URI is used as-is to lookup the catalog factory function. +// +// Currently the following catalogs are registered by default: +// +// - "glue" for AWS Glue Data Catalog, the rest of the URI is ignored, all configuration +// should be provided using the properties. "glue.region", "glue.endpoint", +// "glue.max-retries", etc. Default AWS credentials are used if found, or can be +// overridden by setting "glue.access-key", "glue.secret-access-key", and "glue.session-token". +// +// - "rest" for a REST API catalog, if the properties have a "uri" key, then that will be used +// as the REST endpoint, otherwise the URI is used as the endpoint. The REST catalog also +// registers "http" and "https" so that Load with an http/s URI will automatically +// load the REST Catalog. +// +// - "sql" for SQL catalogs. The registered generic SQL catalog loader looks for the following +// properties to create the connection: The value of "sql.driver" will be used to call `sql.Open`. +// the DSN to pass to `sql.Open` is set by the "uri" property. Finally, the "sql.dialect" property +// will be used which SQL dialect to use for queries and must be one of the supported ones. +// In addition, "catalog.name" can be set to specify the catalog name, otherwise it will just default +// to "sql". +func Load(catalogURI string, props iceberg.Properties) (Catalog, error) { + var catalogType string + if strings.Contains(catalogURI, "://") { + parsed, err := url.Parse(catalogURI) + if err != nil { + return nil, err + } + catalogType = parsed.Scheme + } else { + catalogType = catalogURI + } + + cat, ok := defaultRegistry.get(catalogType) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrCatalogNotFound, catalogType) + } + + return cat.GetCatalog(catalogURI, props) +} diff --git a/catalog/registry_test.go b/catalog/registry_test.go new file mode 100644 index 0000000..d83c47f --- /dev/null +++ b/catalog/registry_test.go @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog_test + +import ( + "testing" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/catalog" + "github.com/stretchr/testify/assert" +) + +func TestCatalogRegistry(t *testing.T) { + assert.ElementsMatch(t, []string{ + "rest", + "http", + "https", + "glue", + "sql", + }, catalog.GetRegisteredCatalogs()) + + catalog.Register("foobar", catalog.RegistrarFunc(func(s string, p iceberg.Properties) (catalog.Catalog, error) { + assert.Equal(t, "foobar", s) + assert.Equal(t, "baz", p.Get("foo", "")) + return nil, nil + })) + + assert.ElementsMatch(t, []string{ + "rest", + "http", + "foobar", + "https", + "glue", + "sql", + }, catalog.GetRegisteredCatalogs()) + + c, err := catalog.Load("foobar", iceberg.Properties{"foo": "baz"}) + assert.Nil(t, c) + assert.NoError(t, err) + + catalog.Register("foobar", catalog.RegistrarFunc(func(s string, p iceberg.Properties) (catalog.Catalog, error) { + assert.Equal(t, "foobar://helloworld", s) + assert.Equal(t, "baz", p.Get("foo", "")) + return nil, nil + })) + + c, err = catalog.Load("foobar://helloworld", iceberg.Properties{"foo": "baz"}) + assert.Nil(t, c) + assert.NoError(t, err) + + catalog.Unregister("foobar") + assert.ElementsMatch(t, []string{ + "rest", + "http", + "https", + "glue", + "sql", + }, catalog.GetRegisteredCatalogs()) +} diff --git a/catalog/rest.go b/catalog/rest.go index ef9c332..c2c8fcc 100644 --- a/catalog/rest.go +++ b/catalog/rest.go @@ -21,6 +21,7 @@ import ( "bytes" "context" "crypto/sha256" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -56,6 +57,7 @@ const ( keyRestSigV4Region = "rest.signing-region" keyRestSigV4Service = "rest.signing-name" keyAuthUrl = "rest.authorization-url" + keyTlsSkipVerify = "rest.tls.skip-verify" ) var ( @@ -71,6 +73,16 @@ var ( ErrOAuthError = fmt.Errorf("%w: oauth error", ErrRESTError) ) +func init() { + reg := RegistrarFunc(func(endpoint string, p iceberg.Properties) (Catalog, error) { + return newRestCatalogFromProps(endpoint, p.Get("uri", endpoint), p) + }) + + Register(string(REST), reg) + Register("http", reg) + Register("https", reg) +} + type errorResponse struct { Message string `json:"message"` Type string `json:"type"` @@ -328,6 +340,15 @@ func fromProps(props iceberg.Properties) *options { o.credential = v case keyPrefix: o.prefix = v + case keyTlsSkipVerify: + verify := strings.ToLower(v) == "true" + if o.tlsConfig == nil { + o.tlsConfig = &tls.Config{ + InsecureSkipVerify: verify, + } + } else { + o.tlsConfig.InsecureSkipVerify = verify + } } } return o @@ -367,29 +388,45 @@ type RestCatalog struct { props iceberg.Properties } +func newRestCatalogFromProps(name string, uri string, p iceberg.Properties) (*RestCatalog, error) { + ops := fromProps(p) + + r := &RestCatalog{name: name} + if err := r.init(ops, uri); err != nil { + return nil, err + } + + return r, nil +} + func NewRestCatalog(name, uri string, opts ...Option[RestCatalog]) (*RestCatalog, error) { ops := &options{} for _, o := range opts { o(ops) } - baseuri, err := url.Parse(uri) - if err != nil { + r := &RestCatalog{name: name} + if err := r.init(ops, uri); err != nil { return nil, err } - r := &RestCatalog{ - name: name, - baseURI: baseuri.JoinPath("v1"), + return r, nil +} + +func (r *RestCatalog) init(ops *options, uri string) error { + baseuri, err := url.Parse(uri) + if err != nil { + return err } + r.baseURI = baseuri.JoinPath("v1") if ops, err = r.fetchConfig(ops); err != nil { - return nil, err + return err } cl, err := r.createSession(ops) if err != nil { - return nil, err + return err } r.cl = cl @@ -397,7 +434,7 @@ func NewRestCatalog(name, uri string, opts ...Option[RestCatalog]) (*RestCatalog r.baseURI = r.baseURI.JoinPath(ops.prefix) } r.props = toProps(ops) - return r, nil + return nil } func (r *RestCatalog) fetchAccessToken(cl *http.Client, creds string, opts *options) (string, error) { diff --git a/catalog/rest_test.go b/catalog/rest_test.go index 618c5e0..d1afb04 100644 --- a/catalog/rest_test.go +++ b/catalog/rest_test.go @@ -113,6 +113,39 @@ func (r *RestCatalogSuite) TestToken200() { r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") } +func (r *RestCatalogSuite) TestLoadRegisteredCatalog() { + r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { + r.Equal(http.MethodPost, req.Method) + + r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + r.Require().NoError(req.ParseForm()) + values := req.PostForm + r.Equal(values.Get("grant_type"), "client_credentials") + r.Equal(values.Get("client_id"), "client") + r.Equal(values.Get("client_secret"), "secret") + r.Equal(values.Get("scope"), "catalog") + + w.WriteHeader(http.StatusOK) + + json.NewEncoder(w).Encode(map[string]any{ + "access_token": TestToken, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }) + }) + + cat, err := catalog.Load(r.srv.URL, iceberg.Properties{ + "warehouse": "s3://some-bucket", + "credential": TestCreds, + }) + r.NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") +} + func (r *RestCatalogSuite) TestToken400() { r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { r.Equal(http.MethodPost, req.Method) @@ -782,6 +815,18 @@ func (r *RestTLSCatalogSuite) TestSSLFail() { r.ErrorContains(err, "tls: failed to verify certificate") } +func (r *RestTLSCatalogSuite) TestSSLLoadRegisteredCatalog() { + cat, err := catalog.Load(r.srv.URL, iceberg.Properties{ + "warehouse": "s3://some-bucket", + "token": TestToken, + "rest.tls.skip-verify": "true", + }) + r.NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") +} + func (r *RestTLSCatalogSuite) TestSSLConfig() { cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken), catalog.WithWarehouseLocation("s3://some-bucket"),