diff --git a/api/api_controller.go b/api/api_controller.go index ae41975d891..30b025810d5 100644 --- a/api/api_controller.go +++ b/api/api_controller.go @@ -1293,7 +1293,7 @@ func (c *Controller) RevertBranchHandler() branches.RevertBranchHandler { ctx := c.Context() switch swag.StringValue(params.Revert.Type) { case models.RevertCreationTypeCommit: - err = cataloger.RollbackCommit(ctx, params.Repository, params.Revert.Commit) + err = cataloger.RollbackCommit(ctx, params.Repository, params.Branch, params.Revert.Commit) case models.RevertCreationTypeCommonPrefix: err = cataloger.ResetEntries(ctx, params.Repository, params.Branch, params.Revert.Path) case models.RevertCreationTypeReset: diff --git a/catalog/cataloger.go b/catalog/cataloger.go index deafd993c81..826009064be 100644 --- a/catalog/cataloger.go +++ b/catalog/cataloger.go @@ -122,7 +122,7 @@ type Cataloger interface { Commit(ctx context.Context, repository, branch string, message string, committer string, metadata Metadata) (*CommitLog, error) GetCommit(ctx context.Context, repository, reference string) (*CommitLog, error) ListCommits(ctx context.Context, repository, branch string, fromReference string, limit int) ([]*CommitLog, bool, error) - RollbackCommit(ctx context.Context, repository, reference string) error + RollbackCommit(ctx context.Context, repository, branch string, reference string) error Diff(ctx context.Context, repository, leftReference string, rightReference string, params DiffParams) (Differences, bool, error) DiffUncommitted(ctx context.Context, repository, branch string, limit int, after string) (Differences, bool, error) diff --git a/catalog/mvcc/cataloger_rollback_commit.go b/catalog/mvcc/cataloger_rollback_commit.go index be2bdc48b62..bb8c95050e3 100644 --- a/catalog/mvcc/cataloger_rollback_commit.go +++ b/catalog/mvcc/cataloger_rollback_commit.go @@ -8,7 +8,7 @@ import ( "github.com/treeverse/lakefs/db" ) -func (c *cataloger) RollbackCommit(ctx context.Context, repository, reference string) error { +func (c *cataloger) RollbackCommit(ctx context.Context, repository, branch, reference string) error { if err := Validate(ValidateFields{ {Name: "repository", IsValid: ValidateRepositoryName(repository)}, {Name: "reference", IsValid: ValidateReference(reference)}, @@ -23,6 +23,9 @@ func (c *cataloger) RollbackCommit(ctx context.Context, repository, reference st if ref.CommitID <= UncommittedID { return catalog.ErrInvalidReference } + if ref.Branch != branch { + return catalog.ErrInvalidReference + } _, err = c.db.Transact(func(tx db.Tx) (interface{}, error) { // extract branch id from reference diff --git a/catalog/mvcc/cataloger_rollback_commit_test.go b/catalog/mvcc/cataloger_rollback_commit_test.go index 806b3f8e107..d6287ce0075 100644 --- a/catalog/mvcc/cataloger_rollback_commit_test.go +++ b/catalog/mvcc/cataloger_rollback_commit_test.go @@ -32,7 +32,7 @@ func TestCataloger_RollbackCommit_Basic(t *testing.T) { for i := 0; i < len(refs); i++ { filesCount := len(refs) - i ref := refs[filesCount-1] - err := c.RollbackCommit(ctx, repository, ref) + err := c.RollbackCommit(ctx, repository, "master", ref) testutil.MustDo(t, "rollback", err) entries, _, err := c.ListEntries(ctx, repository, "master", "", "", "", -1) @@ -79,7 +79,7 @@ func TestCataloger_RollbackCommit_BlockedByBranch(t *testing.T) { testutil.MustDo(t, "merge master to branch1", err) // rollback to initial commit should fail - err = c.RollbackCommit(ctx, repository, masterReference) + err = c.RollbackCommit(ctx, repository, "master", masterReference) if err == nil { t.Fatal("Rollback with blocked branch should fail with error") } @@ -117,7 +117,7 @@ func TestCataloger_RollbackCommit_AfterMerge(t *testing.T) { testutil.MustDo(t, "merge branch1 to master", err) // rollback to first commit - err = c.RollbackCommit(ctx, repository, firstCommit.Reference) + err = c.RollbackCommit(ctx, repository, "master", firstCommit.Reference) testutil.MustDo(t, "rollback to first commit", err) // check we have our original files diff --git a/catalog/rocks/cataloger.go b/catalog/rocks/cataloger.go index b41f874e569..a9f3558da66 100644 --- a/catalog/rocks/cataloger.go +++ b/catalog/rocks/cataloger.go @@ -383,7 +383,7 @@ func (c *cataloger) ListCommits(ctx context.Context, repository string, branch s panic("not implemented") // TODO: Implement } -func (c *cataloger) RollbackCommit(ctx context.Context, repository string, reference string) error { +func (c *cataloger) RollbackCommit(ctx context.Context, repository string, branch string, reference string) error { panic("not implemented") // TODO: Implement }