Skip to content

Commit

Permalink
github now sets reviewers/team-reviewers/assignees/labels insteads of
Browse files Browse the repository at this point in the history
adding them
  • Loading branch information
lindell committed Feb 1, 2024
1 parent 2c3e1fb commit 630f091
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 29 deletions.
8 changes: 4 additions & 4 deletions cmd/cmd-run.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ func run(cmd *cobra.Command, _ []string) error {
prTitle, _ := flag.GetString("pr-title")
prBody, _ := flag.GetString("pr-body")
commitMessage, _ := flag.GetString("commit-message")
reviewers, _ := flag.GetStringSlice("reviewers")
teamReviewers, _ := flag.GetStringSlice("team-reviewers")
reviewers, _ := stringSlice(flag, "reviewers")
teamReviewers, _ := stringSlice(flag, "team-reviewers")
maxReviewers, _ := flag.GetInt("max-reviewers")
maxTeamReviewers, _ := flag.GetInt("max-team-reviewers")
concurrent, _ := flag.GetInt("concurrent")
Expand All @@ -98,9 +98,9 @@ func run(cmd *cobra.Command, _ []string) error {
authorName, _ := flag.GetString("author-name")
authorEmail, _ := flag.GetString("author-email")
strOutput, _ := flag.GetString("output")
assignees, _ := flag.GetStringSlice("assignees")
assignees, _ := stringSlice(flag, "assignees")
draft, _ := flag.GetBool("draft")
labels, _ := flag.GetStringSlice("labels")
labels, _ := stringSlice(flag, "labels")
repoInclude, _ := flag.GetString("repo-include")
repoExclude, _ := flag.GetString("repo-exclude")

Expand Down
11 changes: 11 additions & 0 deletions cmd/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package cmd

import "github.com/spf13/pflag"

// stringSlice is a wrapped around *pflag.FlagSet.GetStringSlice to allow nil when the flag is not set
func stringSlice(set *pflag.FlagSet, name string) ([]string, error) {
if !set.Changed(name) {
return nil, nil
}
return set.GetStringSlice(name)
}
122 changes: 97 additions & 25 deletions internal/scm/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,15 +435,15 @@ func (g *Github) CreatePullRequest(ctx context.Context, repo scm.Repository, prR
return nil, err
}

if err := g.addReviewers(ctx, r, newPR, pr); err != nil {
if err := g.setReviewers(ctx, r, newPR, pr); err != nil {
return nil, err
}

if err := g.addAssignees(ctx, r, newPR, pr); err != nil {
if err := g.setAssignees(ctx, r, newPR, pr); err != nil {
return nil, err
}

if err := g.addLabels(ctx, r, newPR, pr); err != nil {
if err := g.setLabels(ctx, r, newPR, pr); err != nil {
return nil, err
}

Expand All @@ -465,38 +465,110 @@ func (g *Github) createPullRequest(ctx context.Context, repo repository, prRepo
return pr, err
}

func (g *Github) addReviewers(ctx context.Context, repo repository, newPR scm.NewPullRequest, createdPR *github.PullRequest) error {
if len(newPR.Reviewers) == 0 && len(newPR.TeamReviewers) == 0 {
return nil
func (g *Github) setReviewers(ctx context.Context, repo repository, newPR scm.NewPullRequest, createdPR *github.PullRequest) error {
var addedReviewers, removedReviewers []string
if newPR.Reviewers != nil {
existingReviewers := scm.Map(createdPR.RequestedReviewers, func(user *github.User) string {
return user.GetLogin()
})
addedReviewers, removedReviewers = scm.Diff(existingReviewers, newPR.Reviewers)
}

_, _, err := retry(ctx, func() (*github.PullRequest, *github.Response, error) {
return g.ghClient.PullRequests.RequestReviewers(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), github.ReviewersRequest{
Reviewers: newPR.Reviewers,
TeamReviewers: newPR.TeamReviewers,
var addedTeamReviewers, removedTeamReviewers []string
if newPR.Reviewers != nil {
existingTeamReviewers := scm.Map(createdPR.RequestedTeams, func(team *github.Team) string {
return team.GetSlug()
})
})
return err
addedTeamReviewers, removedTeamReviewers = scm.Diff(existingTeamReviewers, newPR.TeamReviewers)
}

if len(addedReviewers) > 0 || len(addedTeamReviewers) > 0 {
_, _, err := retry(ctx, func() (*github.PullRequest, *github.Response, error) {
return g.ghClient.PullRequests.RequestReviewers(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), github.ReviewersRequest{
Reviewers: addedReviewers,
TeamReviewers: addedTeamReviewers,
})
})
if err != nil {
return err
}
}

if len(removedReviewers) > 0 || len(removedTeamReviewers) > 0 {
_, err := retryWithoutReturn(ctx, func() (*github.Response, error) {
return g.ghClient.PullRequests.RemoveReviewers(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), github.ReviewersRequest{
Reviewers: removedReviewers,
TeamReviewers: removedTeamReviewers,
})
})
if err != nil {
return err
}
}

return nil
}

func (g *Github) addAssignees(ctx context.Context, repo repository, newPR scm.NewPullRequest, createdPR *github.PullRequest) error {
if len(newPR.Assignees) == 0 {
func (g *Github) setAssignees(ctx context.Context, repo repository, newPR scm.NewPullRequest, createdPR *github.PullRequest) error {
if newPR.Assignees == nil {
return nil
}
_, _, err := retry(ctx, func() (*github.Issue, *github.Response, error) {
return g.ghClient.Issues.AddAssignees(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), newPR.Assignees)

existingAssignees := scm.Map(createdPR.Assignees, func(user *github.User) string {
return user.GetLogin()
})
return err
addedAssignees, removedAssignees := scm.Diff(existingAssignees, newPR.Assignees)

if len(addedAssignees) > 0 {
_, _, err := retry(ctx, func() (*github.Issue, *github.Response, error) {
return g.ghClient.Issues.AddAssignees(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), addedAssignees)
})
if err != nil {
return err
}
}

if len(removedAssignees) > 0 {
_, _, err := retry(ctx, func() (*github.Issue, *github.Response, error) {
return g.ghClient.Issues.RemoveAssignees(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), removedAssignees)
})
if err != nil {
return err
}
}

return nil
}

func (g *Github) addLabels(ctx context.Context, repo repository, newPR scm.NewPullRequest, createdPR *github.PullRequest) error {
if len(newPR.Labels) == 0 {
func (g *Github) setLabels(ctx context.Context, repo repository, newPR scm.NewPullRequest, createdPR *github.PullRequest) error {
if newPR.Labels == nil {
return nil
}
_, _, err := retry(ctx, func() ([]*github.Label, *github.Response, error) {
return g.ghClient.Issues.AddLabelsToIssue(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), newPR.Labels)

existingLabels := scm.Map(createdPR.Labels, func(label *github.Label) string {
return label.GetName()
})
return err
addedLabels, removedLabels := scm.Diff(existingLabels, newPR.Labels)

if len(addedLabels) > 0 {
_, _, err := retry(ctx, func() ([]*github.Label, *github.Response, error) {
return g.ghClient.Issues.AddLabelsToIssue(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), addedLabels)
})
if err != nil {
return err
}
}

for _, label := range removedLabels {
_, err := retryWithoutReturn(ctx, func() (*github.Response, error) {
return g.ghClient.Issues.RemoveLabelForIssue(ctx, repo.ownerName, repo.name, createdPR.GetNumber(), label)
})
if err != nil {
return err
}
}

return nil
}

// UpdatePullRequest updates an existing pull request
Expand All @@ -517,15 +589,15 @@ func (g *Github) UpdatePullRequest(ctx context.Context, repo scm.Repository, pul
return nil, err
}

if err := g.addReviewers(ctx, r, updatedPR, ghPR); err != nil {
if err := g.setReviewers(ctx, r, updatedPR, ghPR); err != nil {
return nil, err
}

if err := g.addAssignees(ctx, r, updatedPR, ghPR); err != nil {
if err := g.setAssignees(ctx, r, updatedPR, ghPR); err != nil {
return nil, err
}

if err := g.addLabels(ctx, r, updatedPR, ghPR); err != nil {
if err := g.setLabels(ctx, r, updatedPR, ghPR); err != nil {
return nil, err
}

Expand Down
35 changes: 35 additions & 0 deletions internal/scm/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package scm

// Diff two slices and get the added and removed items compared to s1
func Diff[T comparable](s1, s2 []T) (added, removed []T) {
s1Lookup := map[T]struct{}{}
for _, v := range s1 {
s1Lookup[v] = struct{}{}
}
s2Lookup := map[T]struct{}{}
for _, v := range s2 {
s2Lookup[v] = struct{}{}
}

for _, v := range s2 {
if _, ok := s1Lookup[v]; !ok {
added = append(added, v)
}
}
for _, v := range s1 {
if _, ok := s2Lookup[v]; !ok {
removed = append(removed, v)
}
}

return added, removed
}

// Map runs a function for each value in a slice and returns a slice of all function returns
func Map[T any, K any](vals []T, mapping func(T) K) []K {
newVals := make([]K, len(vals))
for i, v := range vals {
newVals[i] = mapping(v)
}
return newVals
}
56 changes: 56 additions & 0 deletions internal/scm/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package scm

import (
"reflect"
"testing"
)

func TestDiff(t *testing.T) {
tests := []struct {
name string
s1 []int
s2 []int
wantAdded []int
wantRemoved []int
}{
{
name: "same",
s1: []int{1, 2, 3},
s2: []int{1, 2, 3},
wantAdded: nil,
wantRemoved: nil,
},
{
name: "empty s2",
s1: []int{1, 2, 3},
s2: []int{},
wantAdded: nil,
wantRemoved: []int{1, 2, 3},
},
{
name: "empty s1",
s1: []int{},
s2: []int{1, 2, 3},
wantAdded: []int{1, 2, 3},
wantRemoved: nil,
},
{
name: "some overlap",
s1: []int{1, 2, 3},
s2: []int{3, 4, 5},
wantAdded: []int{4, 5},
wantRemoved: []int{1, 2},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotAdded, gotRemoved := Diff(tt.s1, tt.s2)
if !reflect.DeepEqual(gotAdded, tt.wantAdded) {
t.Errorf("Diff() gotAdded = %v, want %v", gotAdded, tt.wantAdded)
}
if !reflect.DeepEqual(gotRemoved, tt.wantRemoved) {
t.Errorf("Diff() gotRemoved = %v, want %v", gotRemoved, tt.wantRemoved)
}
})
}
}

0 comments on commit 630f091

Please sign in to comment.