diff --git a/dataproxy/service.go b/dataproxy/service.go index 758be60e9..948c6a25c 100644 --- a/dataproxy/service.go +++ b/dataproxy/service.go @@ -6,8 +6,15 @@ import ( "encoding/base64" "fmt" "net/url" + "reflect" "time" + "github.com/flyteorg/flyteadmin/pkg/errors" + "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "google.golang.org/protobuf/types/known/durationpb" @@ -26,9 +33,10 @@ import ( type Service struct { service.DataProxyServiceServer - cfg config.DataProxyConfig - dataStore *storage.DataStore - shardSelector ioutils.ShardSelector + cfg config.DataProxyConfig + dataStore *storage.DataStore + shardSelector ioutils.ShardSelector + nodeExecutionManager interfaces.NodeExecutionInterface } // CreateUploadLocation creates a temporary signed url to allow callers to upload content. @@ -36,20 +44,20 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp *service.CreateUploadLocationResponse, error) { if len(req.Project) == 0 || len(req.Domain) == 0 { - return nil, fmt.Errorf("prjoect and domain are required parameters") + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "project and domain are required parameters") } if len(req.ContentMd5) == 0 { - return nil, fmt.Errorf("content_md5 is a required parameter") + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "content_md5 is a required parameter") } if expiresIn := req.ExpiresIn; expiresIn != nil { if !expiresIn.IsValid() { - return nil, fmt.Errorf("expiresIn [%v] is invalid", expiresIn) + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "expiresIn [%v] is invalid", expiresIn) } if expiresIn.AsDuration() > s.cfg.Upload.MaxExpiresIn.Duration { - return nil, fmt.Errorf("expiresIn [%v] cannot exceed max allowed expiration [%v]", + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "expiresIn [%v] cannot exceed max allowed expiration [%v]", expiresIn.AsDuration().String(), s.cfg.Upload.MaxExpiresIn.String()) } } else { @@ -66,7 +74,7 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp storagePath, err := createStorageLocation(ctx, s.dataStore, s.cfg.Upload, req.Project, req.Domain, urlSafeMd5, req.Filename) if err != nil { - return nil, err + return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create shardedStorageLocation, Error: %v", err) } resp, err := s.dataStore.CreateSignedURL(ctx, storagePath, storage.SignedURLProperties{ @@ -76,7 +84,7 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp }) if err != nil { - return nil, fmt.Errorf("failed to create a signed url. Error: %w", err) + return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err) } return &service.CreateUploadLocationResponse{ @@ -86,12 +94,57 @@ func (s Service) CreateUploadLocation(ctx context.Context, req *service.CreateUp }, nil } +// CreateDownloadLink retrieves the requested artifact type for a given execution (wf, node, task) as a signed url(s). +func (s Service) CreateDownloadLink(ctx context.Context, req *service.CreateDownloadLinkRequest) ( + resp *service.CreateDownloadLinkResponse, err error) { + if req, err = s.validateCreateDownloadLinkRequest(req); err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "error while validating request. Error: %v", err) + } + + // Lookup task, node, workflow execution + var nativeURL string + if nodeExecutionIDEnvelope, casted := req.GetSource().(*service.CreateDownloadLinkRequest_NodeExecutionId); casted { + node, err := s.nodeExecutionManager.GetNodeExecution(ctx, admin.NodeExecutionGetRequest{ + Id: nodeExecutionIDEnvelope.NodeExecutionId, + }) + + if err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to find node execution [%v]. Error: %v", nodeExecutionIDEnvelope.NodeExecutionId, err) + } + + switch req.GetArtifactType() { + case service.ArtifactType_ARTIFACT_TYPE_DECK: + nativeURL = node.Closure.DeckUri + } + } else { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "unsupported source [%v]", reflect.TypeOf(req.GetSource())) + } + + if len(nativeURL) == 0 { + return nil, errors.NewFlyteAdminErrorf(codes.Internal, "no deckUrl found for request [%+v]", req) + } + + signedURLResp, err := s.dataStore.CreateSignedURL(ctx, storage.DataReference(nativeURL), storage.SignedURLProperties{ + Scope: stow.ClientMethodGet, + ExpiresIn: req.ExpiresIn.AsDuration(), + }) + + if err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err) + } + + return &service.CreateDownloadLinkResponse{ + SignedUrl: []string{signedURLResp.URL.String()}, + ExpiresAt: timestamppb.New(time.Now().Add(req.ExpiresIn.AsDuration())), + }, nil +} + // CreateDownloadLocation creates a temporary signed url to allow callers to download content. func (s Service) CreateDownloadLocation(ctx context.Context, req *service.CreateDownloadLocationRequest) ( *service.CreateDownloadLocationResponse, error) { if err := s.validateCreateDownloadLocationRequest(req); err != nil { - return nil, err + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "error while validating request: %v", err) } resp, err := s.dataStore.CreateSignedURL(ctx, storage.DataReference(req.NativeUrl), storage.SignedURLProperties{ @@ -100,7 +153,7 @@ func (s Service) CreateDownloadLocation(ctx context.Context, req *service.Create }) if err != nil { - return nil, fmt.Errorf("failed to create a signed url. Error: %w", err) + return nil, errors.NewFlyteAdminErrorf(codes.Internal, "failed to create a signed url. Error: %v", err) } return &service.CreateDownloadLocationResponse{ @@ -110,22 +163,13 @@ func (s Service) CreateDownloadLocation(ctx context.Context, req *service.Create } func (s Service) validateCreateDownloadLocationRequest(req *service.CreateDownloadLocationRequest) error { - if expiresIn := req.ExpiresIn; expiresIn != nil { - if !expiresIn.IsValid() { - return fmt.Errorf("expiresIn [%v] is invalid", expiresIn) - } - - if expiresIn.AsDuration() < 0 { - return fmt.Errorf("expiresIn [%v] should not less than 0", - expiresIn.AsDuration().String()) - } else if expiresIn.AsDuration() > s.cfg.Download.MaxExpiresIn.Duration { - return fmt.Errorf("expiresIn [%v] cannot exceed max allowed expiration [%v]", - expiresIn.AsDuration().String(), s.cfg.Download.MaxExpiresIn.String()) - } - } else { - req.ExpiresIn = durationpb.New(s.cfg.Download.MaxExpiresIn.Duration) + validatedExpiresIn, err := validateDuration(req.ExpiresIn, s.cfg.Download.MaxExpiresIn.Duration) + if err != nil { + return fmt.Errorf("expiresIn is invalid. Error: %w", err) } + req.ExpiresIn = validatedExpiresIn + if _, err := url.Parse(req.NativeUrl); err != nil { return fmt.Errorf("failed to parse native_url [%v]", req.NativeUrl) @@ -134,6 +178,45 @@ func (s Service) validateCreateDownloadLocationRequest(req *service.CreateDownlo return nil } +func validateDuration(input *durationpb.Duration, maxAllowed time.Duration) (*durationpb.Duration, error) { + if input == nil { + return durationpb.New(maxAllowed), nil + } + + if !input.IsValid() { + return nil, fmt.Errorf("input duration [%v] is invalid", input) + } + + if input.AsDuration() < 0 { + return nil, fmt.Errorf("input duration [%v] should not less than 0", + input.AsDuration().String()) + } else if input.AsDuration() > maxAllowed { + return nil, fmt.Errorf("input duration [%v] cannot exceed max allowed expiration [%v]", + input.AsDuration(), maxAllowed) + } + + return input, nil +} + +func (s Service) validateCreateDownloadLinkRequest(req *service.CreateDownloadLinkRequest) (*service.CreateDownloadLinkRequest, error) { + validatedExpiresIn, err := validateDuration(req.ExpiresIn, s.cfg.Download.MaxExpiresIn.Duration) + if err != nil { + return nil, fmt.Errorf("expiresIn is invalid. Error: %w", err) + } + + req.ExpiresIn = validatedExpiresIn + + if req.GetArtifactType() == service.ArtifactType_ARTIFACT_TYPE_UNDEFINED { + return nil, fmt.Errorf("invalid artifact type [%v]", req.GetArtifactType()) + } + + if req.GetSource() == nil { + return nil, fmt.Errorf("source is required. Provided nil") + } + + return req, nil +} + // createStorageLocation creates a location in storage destination to maximize read/write performance in most // block stores. The final location should look something like: s3:/// func createStorageLocation(ctx context.Context, store *storage.DataStore, @@ -148,7 +231,10 @@ func createStorageLocation(ctx context.Context, store *storage.DataStore, return storagePath, nil } -func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore) (Service, error) { +func NewService(cfg config.DataProxyConfig, + nodeExec interfaces.NodeExecutionInterface, + dataStore *storage.DataStore) (Service, error) { + // Context is not used in the constructor. Should ideally be removed. selector, err := ioutils.NewBase36PrefixShardSelector(context.TODO()) if err != nil { @@ -156,8 +242,9 @@ func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore) (Servi } return Service{ - cfg: cfg, - dataStore: dataStore, - shardSelector: selector, + cfg: cfg, + dataStore: dataStore, + shardSelector: selector, + nodeExecutionManager: nodeExec, }, nil } diff --git a/dataproxy/service_test.go b/dataproxy/service_test.go index 8e0d7412a..261b0f086 100644 --- a/dataproxy/service_test.go +++ b/dataproxy/service_test.go @@ -5,6 +5,12 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flyteadmin/pkg/manager/mocks" + commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" stdlibConfig "github.com/flyteorg/flytestdlib/config" @@ -24,9 +30,11 @@ import ( func TestNewService(t *testing.T) { dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) + + nodeExecutionManager := &mocks.MockNodeExecutionManager{} s, err := NewService(config.DataProxyConfig{ Upload: config.DataProxyUploadConfig{}, - }, dataStore) + }, nodeExecutionManager, dataStore) assert.NoError(t, err) assert.NotNil(t, s) } @@ -48,7 +56,8 @@ func Test_createStorageLocation(t *testing.T) { func TestCreateUploadLocation(t *testing.T) { dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) - s, err := NewService(config.DataProxyConfig{}, dataStore) + nodeExecutionManager := &mocks.MockNodeExecutionManager{} + s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore) assert.NoError(t, err) t.Run("No project/domain", func(t *testing.T) { _, err = s.CreateUploadLocation(context.Background(), &service.CreateUploadLocationRequest{}) @@ -73,9 +82,53 @@ func TestCreateUploadLocation(t *testing.T) { }) } +func TestCreateDownloadLink(t *testing.T) { + dataStore := commonMocks.GetMockStorageClient() + nodeExecutionManager := &mocks.MockNodeExecutionManager{} + nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) { + return &admin.NodeExecution{ + Closure: &admin.NodeExecutionClosure{ + DeckUri: "s3://something/something", + }, + }, nil + }) + + s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore) + assert.NoError(t, err) + + t.Run("Invalid expiry", func(t *testing.T) { + _, err = s.CreateDownloadLink(context.Background(), &service.CreateDownloadLinkRequest{ + ExpiresIn: durationpb.New(-time.Hour), + }) + assert.Error(t, err) + }) + + t.Run("valid config", func(t *testing.T) { + _, err = s.CreateDownloadLink(context.Background(), &service.CreateDownloadLinkRequest{ + ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK, + Source: &service.CreateDownloadLinkRequest_NodeExecutionId{ + NodeExecutionId: &core.NodeExecutionIdentifier{}, + }, + ExpiresIn: durationpb.New(time.Hour), + }) + assert.NoError(t, err) + }) + + t.Run("use default ExpiresIn", func(t *testing.T) { + _, err = s.CreateDownloadLink(context.Background(), &service.CreateDownloadLinkRequest{ + ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK, + Source: &service.CreateDownloadLinkRequest_NodeExecutionId{ + NodeExecutionId: &core.NodeExecutionIdentifier{}, + }, + }) + assert.NoError(t, err) + }) +} + func TestCreateDownloadLocation(t *testing.T) { dataStore := commonMocks.GetMockStorageClient() - s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, dataStore) + nodeExecutionManager := &mocks.MockNodeExecutionManager{} + s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore) assert.NoError(t, err) t.Run("Invalid expiry", func(t *testing.T) { diff --git a/pkg/server/service.go b/pkg/server/service.go index b1cbbd9b3..7e45fbb03 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -111,13 +111,14 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } configuration := runtime2.NewConfigurationProvider() - service.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, scope.NewSubScope("admin"))) + adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, scope.NewSubScope("admin")) + service.RegisterAdminServiceServer(grpcServer, adminServer) if cfg.Security.UseAuth { service.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) service.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService()) } - dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, dataStorageClient) + dataProxySvc, err := dataproxy.NewService(cfg.DataProxy, adminServer.NodeExecutionManager, dataStorageClient) if err != nil { return nil, fmt.Errorf("failed to initialize dataProxy service. Error: %w", err) }