diff --git a/applicationset/services/repo_service.go b/applicationset/services/repo_service.go index 93bc765a455e11..ebace7ec3bb3de 100644 --- a/applicationset/services/repo_service.go +++ b/applicationset/services/repo_service.go @@ -19,8 +19,9 @@ type argoCDService struct { listRepositories func(ctx context.Context) ([]*v1alpha1.Repository, error) storecreds git.CredsStore submoduleEnabled bool - repoServerClientSet apiclient.Clientset newFileGlobbingEnabled bool + getGitFiles func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) + getGitDirectories func(ctx context.Context, req *apiclient.GitDirectoriesRequest) (*apiclient.GitDirectoriesResponse, error) } type Repos interface { @@ -35,8 +36,23 @@ func NewArgoCDService(listRepositories func(ctx context.Context) ([]*v1alpha1.Re return &argoCDService{ listRepositories: listRepositories, submoduleEnabled: submoduleEnabled, - repoServerClientSet: repoClientset, newFileGlobbingEnabled: newFileGlobbingEnabled, + getGitFiles: func(ctx context.Context, fileRequest *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + closer, client, err := repoClientset.NewRepoServerClient() + if err != nil { + return nil, fmt.Errorf("error initializing new repo server client: %w", err) + } + defer io.Close(closer) + return client.GetGitFiles(ctx, fileRequest) + }, + getGitDirectories: func(ctx context.Context, dirRequest *apiclient.GitDirectoriesRequest) (*apiclient.GitDirectoriesResponse, error) { + closer, client, err := repoClientset.NewRepoServerClient() + if err != nil { + return nil, fmt.Errorf("error initialising new repo server client: %w", err) + } + defer io.Close(closer) + return client.GetGitDirectories(ctx, dirRequest) + }, }, nil } @@ -60,13 +76,7 @@ func (a *argoCDService) GetFiles(ctx context.Context, repoURL, revision, project NoRevisionCache: noRevisionCache, VerifyCommit: verifyCommit, } - closer, client, err := a.repoServerClientSet.NewRepoServerClient() - if err != nil { - return nil, fmt.Errorf("error initialising new repo server client: %w", err) - } - defer io.Close(closer) - - fileResponse, err := client.GetGitFiles(ctx, fileRequest) + fileResponse, err := a.getGitFiles(ctx, fileRequest) if err != nil { return nil, fmt.Errorf("error retrieving Git files: %w", err) } @@ -92,13 +102,7 @@ func (a *argoCDService) GetDirectories(ctx context.Context, repoURL, revision, p VerifyCommit: verifyCommit, } - closer, client, err := a.repoServerClientSet.NewRepoServerClient() - if err != nil { - return nil, fmt.Errorf("error initialising new repo server client: %w", err) - } - defer io.Close(closer) - - dirResponse, err := client.GetGitDirectories(ctx, dirRequest) + dirResponse, err := a.getGitDirectories(ctx, dirRequest) if err != nil { return nil, fmt.Errorf("error retrieving Git Directories: %w", err) } @@ -110,17 +114,29 @@ func getRepo(repos []*v1alpha1.Repository, repoURL string, project string) (*v1a if err != nil { if errors.Is(err, status.Error(codes.PermissionDenied, "permission denied")) { // No repo found with a matching URL - attempt fallback without any actual credentials - repo = &v1alpha1.Repository{Repo: repoURL} - } else { - // This is the final fallback - ensure that at least one repo cred with an unset project is present. - for _, repo = range repos { - if git.SameURL(repo.Repo, repoURL) && repo.Project == "" { - return repo, nil + return &v1alpha1.Repository{Repo: repoURL}, nil + } else if project == "" { + for _, r := range repos { + if git.SameURL(r.Repo, repoURL) { + // Prioritize using a repository with an unset project. + if r.Project == "" { + return r, nil + } + + if repo == nil { + repo = r + } } } - return nil, fmt.Errorf("no matching repository found for url %s, ensure that you have a repo credential with an unset project", repoURL) + // Try any repo matching the same repoURL + if repo != nil { + return repo, nil + } + + // No repo found with a matching URL - attempt fallback without any actual credentials + return &v1alpha1.Repository{Repo: repoURL}, nil } } - return repo, nil + return repo, err } diff --git a/applicationset/services/repo_service_test.go b/applicationset/services/repo_service_test.go index 149668748b715a..35d9af6ac655d6 100644 --- a/applicationset/services/repo_service_test.go +++ b/applicationset/services/repo_service_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/argoproj/argo-cd/v2/reposerver/apiclient" @@ -18,16 +17,15 @@ import ( func TestGetDirectories(t *testing.T) { type fields struct { - storecreds git.CredsStore - submoduleEnabled bool - listRepositories func(ctx context.Context) ([]*v1alpha1.Repository, error) - repoServerClientFuncs []func(*repo_mocks.RepoServerServiceClient) + storecreds git.CredsStore + submoduleEnabled bool + listRepositories func(ctx context.Context) ([]*v1alpha1.Repository, error) + getGitDirectories func(ctx context.Context, req *apiclient.GitDirectoriesRequest) (*apiclient.GitDirectoriesResponse, error) } type args struct { ctx context.Context repoURL string revision string - project string noRevisionCache bool verifyCommit bool } @@ -47,10 +45,8 @@ func TestGetDirectories(t *testing.T) { listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { return []*v1alpha1.Repository{{}}, nil }, - repoServerClientFuncs: []func(*repo_mocks.RepoServerServiceClient){ - func(client *repo_mocks.RepoServerServiceClient) { - client.On("GetGitDirectories", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("unable to get dirs")) - }, + getGitDirectories: func(ctx context.Context, req *apiclient.GitDirectoriesRequest) (*apiclient.GitDirectoriesResponse, error) { + return nil, fmt.Errorf("unable to get dirs") }, }, args: args{}, want: nil, wantErr: assert.Error}, {name: "HappyCase", fields: fields{ @@ -59,12 +55,10 @@ func TestGetDirectories(t *testing.T) { Repo: "foo", }}, nil }, - repoServerClientFuncs: []func(*repo_mocks.RepoServerServiceClient){ - func(client *repo_mocks.RepoServerServiceClient) { - client.On("GetGitDirectories", mock.Anything, mock.Anything).Return(&apiclient.GitDirectoriesResponse{ - Paths: []string{"foo", "foo/bar", "bar/foo"}, - }, nil) - }, + getGitDirectories: func(ctx context.Context, req *apiclient.GitDirectoriesRequest) (*apiclient.GitDirectoriesResponse, error) { + return &apiclient.GitDirectoriesResponse{ + Paths: []string{"foo", "foo/bar", "bar/foo"}, + }, nil }, }, args: args{ repoURL: "foo", @@ -73,28 +67,20 @@ func TestGetDirectories(t *testing.T) { listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { return []*v1alpha1.Repository{{}}, nil }, - repoServerClientFuncs: []func(*repo_mocks.RepoServerServiceClient){ - func(client *repo_mocks.RepoServerServiceClient) { - client.On("GetGitDirectories", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("revision HEAD is not signed")) - }, + getGitDirectories: func(ctx context.Context, req *apiclient.GitDirectoriesRequest) (*apiclient.GitDirectoriesResponse, error) { + return nil, fmt.Errorf("revision HEAD is not signed") }, }, args: args{}, want: nil, wantErr: assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockRepoClient := &repo_mocks.RepoServerServiceClient{} - // decorate the mocks - for i := range tt.fields.repoServerClientFuncs { - tt.fields.repoServerClientFuncs[i](mockRepoClient) - } - a := &argoCDService{ - listRepositories: tt.fields.listRepositories, - storecreds: tt.fields.storecreds, - submoduleEnabled: tt.fields.submoduleEnabled, - repoServerClientSet: &repo_mocks.Clientset{RepoServerServiceClient: mockRepoClient}, + listRepositories: tt.fields.listRepositories, + storecreds: tt.fields.storecreds, + submoduleEnabled: tt.fields.submoduleEnabled, + getGitDirectories: tt.fields.getGitDirectories, } - got, err := a.GetDirectories(tt.args.ctx, tt.args.repoURL, tt.args.revision, tt.args.project, tt.args.noRevisionCache, tt.args.verifyCommit) + got, err := a.GetDirectories(tt.args.ctx, tt.args.repoURL, tt.args.revision, "", tt.args.noRevisionCache, tt.args.verifyCommit) if !tt.wantErr(t, err, fmt.Sprintf("GetDirectories(%v, %v, %v, %v)", tt.args.ctx, tt.args.repoURL, tt.args.revision, tt.args.noRevisionCache)) { return } @@ -105,10 +91,10 @@ func TestGetDirectories(t *testing.T) { func TestGetFiles(t *testing.T) { type fields struct { - storecreds git.CredsStore - submoduleEnabled bool - repoServerClientFuncs []func(*repo_mocks.RepoServerServiceClient) - listRepositories func(ctx context.Context) ([]*v1alpha1.Repository, error) + storecreds git.CredsStore + submoduleEnabled bool + listRepositories func(ctx context.Context) ([]*v1alpha1.Repository, error) + getGitFiles func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) } type args struct { ctx context.Context @@ -134,10 +120,8 @@ func TestGetFiles(t *testing.T) { listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { return []*v1alpha1.Repository{{}}, nil }, - repoServerClientFuncs: []func(*repo_mocks.RepoServerServiceClient){ - func(client *repo_mocks.RepoServerServiceClient) { - client.On("GetGitFiles", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("unable to get files")) - }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + return nil, fmt.Errorf("unable to get files") }, }, args: args{}, want: nil, wantErr: assert.Error}, {name: "HappyCase", fields: fields{ @@ -146,15 +130,13 @@ func TestGetFiles(t *testing.T) { Repo: "foo", }}, nil }, - repoServerClientFuncs: []func(*repo_mocks.RepoServerServiceClient){ - func(client *repo_mocks.RepoServerServiceClient) { - client.On("GetGitFiles", mock.Anything, mock.Anything).Return(&apiclient.GitFilesResponse{ - Map: map[string][]byte{ - "foo.json": []byte("hello: world!"), - "bar.yaml": []byte("yay: appsets"), - }, - }, nil) - }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + return &apiclient.GitFilesResponse{ + Map: map[string][]byte{ + "foo.json": []byte("hello: world!"), + "bar.yaml": []byte("yay: appsets"), + }, + }, nil }, }, args: args{ repoURL: "foo", @@ -162,30 +144,74 @@ func TestGetFiles(t *testing.T) { "foo.json": []byte("hello: world!"), "bar.yaml": []byte("yay: appsets"), }, wantErr: assert.NoError}, + {name: "NoRepoFoundFallback", fields: fields{ + listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { + return []*v1alpha1.Repository{}, nil + }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + require.Equal(t, &v1alpha1.Repository{Repo: "foo"}, req.Repo) + return &apiclient.GitFilesResponse{ + Map: map[string][]byte{}, + }, nil + }, + }, args: args{ + repoURL: "foo", + }, want: map[string][]byte{}, wantErr: assert.NoError}, + {name: "RepoWithProjectFoundFallback", fields: fields{ + listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { + return []*v1alpha1.Repository{{Repo: "foo", Project: "default"}}, nil + }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + require.Equal(t, &v1alpha1.Repository{Repo: "foo", Project: "default"}, req.Repo) + return &apiclient.GitFilesResponse{ + Map: map[string][]byte{}, + }, nil + }, + }, args: args{ + repoURL: "foo", + }, want: map[string][]byte{}, wantErr: assert.NoError}, + {name: "MultipleReposWithEmptyProjectFoundFallback", fields: fields{ + listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { + return []*v1alpha1.Repository{{Repo: "foo", Project: "default"}, {Repo: "foo", Project: ""}}, nil + }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + require.Equal(t, &v1alpha1.Repository{Repo: "foo", Project: ""}, req.Repo) + return &apiclient.GitFilesResponse{ + Map: map[string][]byte{}, + }, nil + }, + }, args: args{ + repoURL: "foo", + }, want: map[string][]byte{}, wantErr: assert.NoError}, + {name: "MultipleReposFoundFallback", fields: fields{ + listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { + return []*v1alpha1.Repository{{Repo: "foo", Project: "default"}, {Repo: "foo", Project: "bar"}}, nil + }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + require.Equal(t, &v1alpha1.Repository{Repo: "foo", Project: "default"}, req.Repo) + return &apiclient.GitFilesResponse{ + Map: map[string][]byte{}, + }, nil + }, + }, args: args{ + repoURL: "foo", + }, want: map[string][]byte{}, wantErr: assert.NoError}, {name: "ErrorVerifyingCommit", fields: fields{ listRepositories: func(ctx context.Context) ([]*v1alpha1.Repository, error) { return []*v1alpha1.Repository{{}}, nil }, - repoServerClientFuncs: []func(*repo_mocks.RepoServerServiceClient){ - func(client *repo_mocks.RepoServerServiceClient) { - client.On("GetGitFiles", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("revision HEAD is not signed")) - }, + getGitFiles: func(ctx context.Context, req *apiclient.GitFilesRequest) (*apiclient.GitFilesResponse, error) { + return nil, fmt.Errorf("revision HEAD is not signed") }, }, args: args{}, want: nil, wantErr: assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockRepoClient := &repo_mocks.RepoServerServiceClient{} - // decorate the mocks - for i := range tt.fields.repoServerClientFuncs { - tt.fields.repoServerClientFuncs[i](mockRepoClient) - } - a := &argoCDService{ - listRepositories: tt.fields.listRepositories, - storecreds: tt.fields.storecreds, - submoduleEnabled: tt.fields.submoduleEnabled, - repoServerClientSet: &repo_mocks.Clientset{RepoServerServiceClient: mockRepoClient}, + listRepositories: tt.fields.listRepositories, + storecreds: tt.fields.storecreds, + submoduleEnabled: tt.fields.submoduleEnabled, + getGitFiles: tt.fields.getGitFiles, } got, err := a.GetFiles(tt.args.ctx, tt.args.repoURL, tt.args.revision, tt.args.pattern, "", tt.args.noRevisionCache, tt.args.verifyCommit) if !tt.wantErr(t, err, fmt.Sprintf("GetFiles(%v, %v, %v, %v, %v)", tt.args.ctx, tt.args.repoURL, tt.args.revision, tt.args.pattern, tt.args.noRevisionCache)) {