Skip to content

Commit

Permalink
atlasaction: allow the SCM client control the comment format
Browse files Browse the repository at this point in the history
  • Loading branch information
giautm committed Dec 17, 2024
1 parent f63071c commit bdd9a91
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 96 deletions.
24 changes: 7 additions & 17 deletions atlasaction/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ type (
}
// SCMClient contains methods for interacting with SCM platforms (GitHub, Gitlab etc...).
SCMClient interface {
// UpsertComment posts or updates a pull request comment.
UpsertComment(ctx context.Context, pr *PullRequest, id, comment string) error
// CommentLint posts or updates a pull request comment.
CommentLint(context.Context, *TriggerContext, *atlasexec.SummaryReport) error
// CommentPlan posts or updates a pull request comment.
CommentPlan(context.Context, *TriggerContext, *atlasexec.SchemaPlan) error
}

// SCMSuggestions contains methods for interacting with SCM platforms (GitHub, Gitlab etc...)
Expand Down Expand Up @@ -134,6 +136,7 @@ type (

// TriggerContext holds the context of the environment the action is running in.
TriggerContext struct {
Act Action // Act is the action that is running.
SCM SCM // SCM is the source control management system.
Repo string // Repo is the repository name. e.g. "ariga/atlas-action".
RepoURL string // RepoURL is full URL of the repository. e.g. "https://github.com/ariga/atlas-action".
Expand Down Expand Up @@ -535,11 +538,7 @@ func (a *Actions) MigrateLint(ctx context.Context) error {
case err != nil:
return err
default:
comment, err := RenderTemplate("migrate-lint.tmpl", &payload)
if err != nil {
return err
}
if err = c.UpsertComment(ctx, tc.PullRequest, dirName, comment); err != nil {
if err = c.CommentLint(ctx, tc, &payload); err != nil {
a.Errorf("failed to comment on the pull request: %v", err)
}
if c, ok := c.(SCMSuggestions); ok {
Expand Down Expand Up @@ -742,16 +741,7 @@ func (a *Actions) SchemaPlan(ctx context.Context) error {
case err != nil:
return err
default:
// Report the schema plan to the user and add a comment to the PR.
comment, err := RenderTemplate("schema-plan.tmpl", map[string]any{
"Plan": plan,
"EnvName": params.Env,
"RerunCommand": tc.RerunCmd,
})
if err != nil {
return fmt.Errorf("failed to generate schema plan comment: %w", err)
}
err = c.UpsertComment(ctx, tc.PullRequest, plan.File.Name, comment)
err = c.CommentPlan(ctx, tc, plan)
if err != nil {
// Don't fail the action if the comment fails.
// It may be due to the missing permissions.
Expand Down
14 changes: 13 additions & 1 deletion atlasaction/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2500,7 +2500,19 @@ func (m *mockSCM) UpsertSuggestion(context.Context, *atlasaction.PullRequest, *a
return nil
}

func (m *mockSCM) UpsertComment(_ context.Context, _ *atlasaction.PullRequest, id string, _ string) error {
func (m *mockSCM) CommentLint(ctx context.Context, tc *atlasaction.TriggerContext, r *atlasexec.SummaryReport) error {
comment, err := atlasaction.RenderTemplate("migrate-lint.tmpl", r)
if err != nil {
return err
}
return m.comment(ctx, tc.PullRequest, "foo", comment)
}

func (m *mockSCM) CommentPlan(ctx context.Context, tc *atlasaction.TriggerContext, p *atlasexec.SchemaPlan) error {
return m.comment(ctx, tc.PullRequest, p.File.Name, "")
}

func (m *mockSCM) comment(_ context.Context, _ *atlasaction.PullRequest, id string, _ string) error {
var (
method = http.MethodPatch
urlPath = "/repos/ariga/atlas-action/issues/comments/1"
Expand Down
5 changes: 3 additions & 2 deletions atlasaction/bitbucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@ type bbPipe struct {
}

// NewBitBucketPipe returns a new Action for BitBucket.
func NewBitBucketPipe(getenv func(string) string, w io.Writer) Action {
func NewBitBucketPipe(getenv func(string) string, w io.Writer) *bbPipe {
// Disable color output for testing,
// but enable it for non-testing environments.
color.NoColor = testing.Testing()
return &bbPipe{getenv: getenv, coloredLogger: &coloredLogger{w: w}}
}

// GetType implements Action.
func (a *bbPipe) GetType() atlasexec.TriggerType {
func (*bbPipe) GetType() atlasexec.TriggerType {
return atlasexec.TriggerTypeBitbucket
}

// GetTriggerContext implements Action.
func (a *bbPipe) GetTriggerContext(context.Context) (*TriggerContext, error) {
tc := &TriggerContext{
Act: a,
Branch: a.getenv("BITBUCKET_BRANCH"),
Commit: a.getenv("BITBUCKET_COMMIT"),
Repo: a.getenv("BITBUCKET_REPO_FULL_NAME"),
Expand Down
2 changes: 1 addition & 1 deletion atlasaction/comments/schema-plan.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ the database with the desired state. Otherwise, Atlas will report a schema drift

3\. Push the updated plan to the registry using the following command:
```bash
atlas schema plan push --pending --env {{ .EnvName }} --file {{ .Plan.File.Name }}.plan.hcl
atlas schema plan push --pending --file {{ .Plan.File.Name }}.plan.hcl
```

{{- if .RerunCommand }}
Expand Down
95 changes: 57 additions & 38 deletions atlasaction/gh_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ func (a *ghAction) SchemaApply(_ context.Context, r *atlasexec.SchemaApply) {
func (a *ghAction) SchemaPlan(_ context.Context, r *atlasexec.SchemaPlan) {
summary, err := RenderTemplate("schema-plan.tmpl", map[string]any{
"Plan": r,
"EnvName": a.GetInput("env"),
"RerunCommand": fmt.Sprintf("gh run rerun %s", a.Getenv("GITHUB_RUN_ID")),
})
if err != nil {
Expand All @@ -82,7 +81,7 @@ func (a *ghAction) SchemaPlan(_ context.Context, r *atlasexec.SchemaPlan) {
}

// GetType implements the Action interface.
func (a *ghAction) GetType() atlasexec.TriggerType {
func (*ghAction) GetType() atlasexec.TriggerType {
return atlasexec.TriggerTypeGithubAction
}

Expand All @@ -97,6 +96,7 @@ func (a *ghAction) GetTriggerContext(context.Context) (*TriggerContext, error) {
return nil, err
}
tc := &TriggerContext{
Act: a,
SCM: SCM{Type: atlasexec.SCMTypeGithub, APIURL: ctx.APIURL},
Repo: ctx.Repository,
Branch: ctx.HeadRef,
Expand Down Expand Up @@ -176,8 +176,27 @@ type (
}
)

func (g *githubAPI) UpsertComment(ctx context.Context, pr *PullRequest, id, comment string) error {
comments, err := g.getIssueComments(ctx, pr)
func (c *githubAPI) CommentLint(ctx context.Context, tc *TriggerContext, r *atlasexec.SummaryReport) error {
comment, err := RenderTemplate("migrate-lint.tmpl", r)
if err != nil {
return err
}
return c.comment(ctx, tc.PullRequest, tc.Act.GetInput("dir-name"), comment)
}

func (c *githubAPI) CommentPlan(ctx context.Context, tc *TriggerContext, p *atlasexec.SchemaPlan) error {
// Report the schema plan to the user and add a comment to the PR.
comment, err := RenderTemplate("schema-plan.tmpl", map[string]any{
"Plan": p,
})
if err != nil {
return err
}
return c.comment(ctx, tc.PullRequest, p.File.Name, comment)
}

func (c *githubAPI) comment(ctx context.Context, pr *PullRequest, id, comment string) error {
comments, err := c.getIssueComments(ctx, pr)
if err != nil {
return err
}
Expand All @@ -188,43 +207,43 @@ func (g *githubAPI) UpsertComment(ctx context.Context, pr *PullRequest, id, comm
if found := slices.IndexFunc(comments, func(c githubIssueComment) bool {
return strings.Contains(c.Body, marker)
}); found != -1 {
return g.updateIssueComment(ctx, comments[found].ID, body)
return c.updateIssueComment(ctx, comments[found].ID, body)
}
return g.createIssueComment(ctx, pr, body)
return c.createIssueComment(ctx, pr, body)
}

func (g *githubAPI) getIssueComments(ctx context.Context, pr *PullRequest) ([]githubIssueComment, error) {
url := fmt.Sprintf("%v/repos/%v/issues/%v/comments", g.baseURL, g.repo, pr.Number)
func (c *githubAPI) getIssueComments(ctx context.Context, pr *PullRequest) ([]githubIssueComment, error) {
url := fmt.Sprintf("%v/repos/%v/issues/%v/comments", c.baseURL, c.repo, pr.Number)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error querying github comments with %v/%v, %w", g.repo, pr.Number, err)
return nil, fmt.Errorf("error querying github comments with %v/%v, %w", c.repo, pr.Number, err)
}
defer res.Body.Close()
buf, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("error reading PR issue comments from %v/%v, %v", g.repo, pr.Number, err)
return nil, fmt.Errorf("error reading PR issue comments from %v/%v, %v", c.repo, pr.Number, err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %v when calling GitHub API", res.StatusCode)
}
var comments []githubIssueComment
if err = json.Unmarshal(buf, &comments); err != nil {
return nil, fmt.Errorf("error parsing github comments with %v/%v from %v, %w", g.repo, pr.Number, string(buf), err)
return nil, fmt.Errorf("error parsing github comments with %v/%v from %v, %w", c.repo, pr.Number, string(buf), err)
}
return comments, nil
}

func (g *githubAPI) createIssueComment(ctx context.Context, pr *PullRequest, content io.Reader) error {
url := fmt.Sprintf("%v/repos/%v/issues/%v/comments", g.baseURL, g.repo, pr.Number)
func (c *githubAPI) createIssueComment(ctx context.Context, pr *PullRequest, content io.Reader) error {
url := fmt.Sprintf("%v/repos/%v/issues/%v/comments", c.baseURL, c.repo, pr.Number)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, content)
if err != nil {
return err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return err
}
Expand All @@ -240,13 +259,13 @@ func (g *githubAPI) createIssueComment(ctx context.Context, pr *PullRequest, con
}

// updateIssueComment updates issue comment with the given id.
func (g *githubAPI) updateIssueComment(ctx context.Context, id int, content io.Reader) error {
url := fmt.Sprintf("%v/repos/%v/issues/comments/%v", g.baseURL, g.repo, id)
func (c *githubAPI) updateIssueComment(ctx context.Context, id int, content io.Reader) error {
url := fmt.Sprintf("%v/repos/%v/issues/comments/%v", c.baseURL, c.repo, id)
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, content)
if err != nil {
return err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return err
}
Expand All @@ -262,11 +281,11 @@ func (g *githubAPI) updateIssueComment(ctx context.Context, id int, content io.R
}

// UpsertSuggestion creates or updates a suggestion review comment on trigger event pull request.
func (g *githubAPI) UpsertSuggestion(ctx context.Context, pr *PullRequest, s *Suggestion) error {
func (c *githubAPI) UpsertSuggestion(ctx context.Context, pr *PullRequest, s *Suggestion) error {
marker := commentMarker(s.ID)
body := fmt.Sprintf("%s\n%s", s.Comment, marker)
// TODO: Listing the comments only once and updating the comment in the same call.
comments, err := g.listReviewComments(ctx, pr)
comments, err := c.listReviewComments(ctx, pr)
if err != nil {
return err
}
Expand All @@ -277,7 +296,7 @@ func (g *githubAPI) UpsertSuggestion(ctx context.Context, pr *PullRequest, s *Su
return c.Path == s.Path && strings.Contains(c.Body, marker)
})
if found != -1 {
if err := g.updateReviewComment(ctx, comments[found].ID, body); err != nil {
if err := c.updateReviewComment(ctx, comments[found].ID, body); err != nil {
return err
}
return nil
Expand All @@ -292,12 +311,12 @@ func (g *githubAPI) UpsertSuggestion(ctx context.Context, pr *PullRequest, s *Su
if err != nil {
return err
}
url := fmt.Sprintf("%v/repos/%v/pulls/%v/comments", g.baseURL, g.repo, pr.Number)
url := fmt.Sprintf("%v/repos/%v/pulls/%v/comments", c.baseURL, c.repo, pr.Number)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(buf))
if err != nil {
return err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return err
}
Expand All @@ -313,13 +332,13 @@ func (g *githubAPI) UpsertSuggestion(ctx context.Context, pr *PullRequest, s *Su
}

// listReviewComments for the trigger event pull request.
func (g *githubAPI) listReviewComments(ctx context.Context, pr *PullRequest) ([]pullRequestComment, error) {
url := fmt.Sprintf("%v/repos/%v/pulls/%v/comments", g.baseURL, g.repo, pr.Number)
func (c *githubAPI) listReviewComments(ctx context.Context, pr *PullRequest) ([]pullRequestComment, error) {
url := fmt.Sprintf("%v/repos/%v/pulls/%v/comments", c.baseURL, c.repo, pr.Number)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return nil, err
}
Expand All @@ -339,20 +358,20 @@ func (g *githubAPI) listReviewComments(ctx context.Context, pr *PullRequest) ([]
}

// updateReviewComment updates the review comment with the given id.
func (g *githubAPI) updateReviewComment(ctx context.Context, id int, body string) error {
func (c *githubAPI) updateReviewComment(ctx context.Context, id int, body string) error {
type pullRequestUpdate struct {
Body string `json:"body"`
}
b, err := json.Marshal(pullRequestUpdate{Body: body})
if err != nil {
return err
}
url := fmt.Sprintf("%v/repos/%v/pulls/comments/%v", g.baseURL, g.repo, id)
url := fmt.Sprintf("%v/repos/%v/pulls/comments/%v", c.baseURL, c.repo, id)
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, bytes.NewReader(b))
if err != nil {
return err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return err
}
Expand All @@ -368,13 +387,13 @@ func (g *githubAPI) updateReviewComment(ctx context.Context, id int, body string
}

// ListPullRequestFiles return paths of the files in the trigger event pull request.
func (g *githubAPI) ListPullRequestFiles(ctx context.Context, pr *PullRequest) ([]string, error) {
url := fmt.Sprintf("%v/repos/%v/pulls/%v/files", g.baseURL, g.repo, pr.Number)
func (c *githubAPI) ListPullRequestFiles(ctx context.Context, pr *PullRequest) ([]string, error) {
url := fmt.Sprintf("%v/repos/%v/pulls/%v/files", c.baseURL, c.repo, pr.Number)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return nil, err
}
Expand All @@ -398,19 +417,19 @@ func (g *githubAPI) ListPullRequestFiles(ctx context.Context, pr *PullRequest) (
}

// OpeningPullRequest returns the latest open pull request for the given branch.
func (g *githubAPI) OpeningPullRequest(ctx context.Context, branch string) (*PullRequest, error) {
owner, _, err := g.ownerRepo()
func (c *githubAPI) OpeningPullRequest(ctx context.Context, branch string) (*PullRequest, error) {
owner, _, err := c.ownerRepo()
if err != nil {
return nil, err
}
// Get open pull requests for the branch.
url := fmt.Sprintf("%s/repos/%s/pulls?state=open&head=%s:%s&sort=created&direction=desc&per_page=1&page=1",
g.baseURL, g.repo, owner, branch)
c.baseURL, c.repo, owner, branch)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
res, err := g.client.Do(req)
res, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error calling GitHub API: %w", err)
}
Expand Down Expand Up @@ -442,8 +461,8 @@ func (g *githubAPI) OpeningPullRequest(ctx context.Context, branch string) (*Pul
}
}

func (g *githubAPI) ownerRepo() (string, string, error) {
s := strings.Split(g.repo, "/")
func (c *githubAPI) ownerRepo() (string, string, error) {
s := strings.Split(c.repo, "/")
if len(s) != 2 {
return "", "", fmt.Errorf("GITHUB_REPOSITORY must be in the format of 'owner/repo'")
}
Expand Down
Loading

0 comments on commit bdd9a91

Please sign in to comment.