Skip to content

Commit

Permalink
Refactor flags to differentiate existing_scala_rule into subtypes (bi… (
Browse files Browse the repository at this point in the history
#110)

* Refactor flags to differentiate existing_scala_rule into subtypes (binary/library/test)
* update docs
  • Loading branch information
pcj authored Dec 5, 2023
1 parent 021c30d commit 04527ef
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ To enable a provider, instantiate a "rule provider config":
You may have your own scala rule macros that look like a `scala_library` or
`scala_binary`, but have their own rule kinds and loads. To register these
rules/macros as provider implementations, use the
`-existing_scala_rule=LOAD%KIND` flag. For example:
`-existing_scala_{type}_rule=LOAD%KIND` flag (where type is one of `binary|library|test`). For example:

```bazel
gazelle(
name = "gazelle",
args = [
"-existing_scala_rule=//bazel_tools:scala.bzl%scala_app",
"-existing_scala_library_rule=//bazel_tools:scala.bzl%scala_app",
...
],
...
Expand Down
22 changes: 14 additions & 8 deletions language/scala/existing_scala_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ import (
)

func init() {
mustRegister := func(load, kind string) {
mustRegister := func(load, kind string, isBinary, isLibrary, isTest bool) {
fqn := load + "%" + kind
if err := scalarule.
GlobalProviderRegistry().
RegisterProvider(fqn, &existingScalaRuleProvider{load, kind}); err != nil {
RegisterProvider(fqn, &existingScalaRuleProvider{load, kind, isBinary, isLibrary, isTest}); err != nil {
log.Fatalf("registering scala_rule providers: %v", err)
}
}

mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_binary")
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_library")
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_macro_library")
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_test")
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_binary", true, false, false)
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_library", false, true, false)
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_macro_library", false, true, false)
mustRegister("@io_bazel_rules_scala//scala:scala.bzl", "scala_test", false, false, true)
}

// existingScalaRuleProvider implements RuleResolver for scala-like rules that
Expand All @@ -35,6 +35,9 @@ func init() {
// optionally, exports).
type existingScalaRuleProvider struct {
load, name string
isBinary bool
isLibrary bool
isTest bool
}

// Name implements part of the scalarule.Provider interface.
Expand Down Expand Up @@ -79,7 +82,7 @@ func (s *existingScalaRuleProvider) ResolveRule(cfg *scalarule.Config, pkg scala

r.SetPrivateAttr(config.GazelleImportsKey, scalaRule)

return &existingScalaRule{cfg, pkg, r, scalaRule}
return &existingScalaRule{cfg, pkg, r, scalaRule, s.isBinary, s.isLibrary, s.isTest}
}

// existingScalaRule implements scalarule.RuleProvider for existing scala rules.
Expand All @@ -88,6 +91,9 @@ type existingScalaRule struct {
pkg scalarule.Package
rule *rule.Rule
scalaRule scalarule.Rule
isBinary bool
isLibrary bool
isTest bool
}

// Kind implements part of the ruleProvider interface.
Expand Down Expand Up @@ -142,7 +148,7 @@ func (s *existingScalaRule) Resolve(rctx *scalarule.ResolveContext, importsRaw i
}

// part 1b: exports
if strings.HasSuffix(r.Kind(), "_library") {
if s.isLibrary {
newExports := exports.Deps(sc.maybeRewrite(r.Kind(), rctx.From))
exportLabels := sc.cleanExports(rctx.From, r.Attr("exports"), newExports)
mergeDeps(r.Kind(), exportLabels, newExports)
Expand Down
80 changes: 72 additions & 8 deletions language/scala/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ import (
const (
scalaSymbolProviderFlagName = "scala_symbol_provider"
scalaConflictResolverFlagName = "scala_conflict_resolver"
existingScalaRulesFlagName = "existing_scala_rule"
existingScalaBinaryRuleFlagName = "existing_scala_binary_rule"
existingScalaLibraryRuleFlagName = "existing_scala_library_rule"
existingScalaTestRuleFlagName = "existing_scala_test_rule"
scalaGazelleCacheFileFlagName = "scala_gazelle_cache_file"
scalaGazelleDebugProcessFileFlagName = "scala_gazelle_debug_process"
scalaGazelleCacheKeyFlagName = "scala_gazelle_cache_key"
Expand All @@ -38,7 +40,9 @@ func (sl *scalaLang) RegisterFlags(flags *flag.FlagSet, cmd string, c *config.Co
flags.StringVar(&sl.memprofileFlagValue, memprofileFileFlagName, "", "optional path a memory profile file (.prof)")
flags.Var(&sl.symbolProviderNamesFlagValue, scalaSymbolProviderFlagName, "name of a symbol provider implementation to enable")
flags.Var(&sl.conflictResolverNamesFlagValue, scalaConflictResolverFlagName, "name of a conflict resolver implementation to enable")
flags.Var(&sl.existingScalaRulesFlagValue, existingScalaRulesFlagName, "LOAD%NAME mapping for a custom existing_scala_rule implementation (e.g. '@io_bazel_rules_scala//scala:scala.bzl%scala_library'")
flags.Var(&sl.existingScalaBinaryRulesFlagValue, existingScalaBinaryRuleFlagName, "LOAD%NAME mapping for a custom existing scala binary rule implementation (e.g. '@io_bazel_rules_scala//scala:scala.bzl%scalabinary'")
flags.Var(&sl.existingScalaLibraryRulesFlagValue, existingScalaLibraryRuleFlagName, "LOAD%NAME mapping for a custom existing scala library rule implementation (e.g. '@io_bazel_rules_scala//scala:scala.bzl%scala_library'")
flags.Var(&sl.existingScalaTestRulesFlagValue, existingScalaTestRuleFlagName, "LOAD%NAME mapping for a custom existing scala test rule implementation (e.g. '@io_bazel_rules_scala//scala:scala.bzl%scala_test'")

sl.registerSymbolProviders(flags, cmd, c)
sl.registerConflictResolvers(flags, cmd, c)
Expand Down Expand Up @@ -78,7 +82,13 @@ func (sl *scalaLang) CheckFlags(flags *flag.FlagSet, c *config.Config) error {
if err := sl.setupConflictResolvers(flags, c, sl.conflictResolverNamesFlagValue); err != nil {
return err
}
if err := sl.setupExistingScalaRules(sl.existingScalaRulesFlagValue); err != nil {
if err := sl.setupExistingScalaBinaryRules(sl.existingScalaBinaryRulesFlagValue); err != nil {
return err
}
if err := sl.setupExistingScalaLibraryRules(sl.existingScalaLibraryRulesFlagValue); err != nil {
return err
}
if err := sl.setupExistingScalaTestRules(sl.existingScalaTestRulesFlagValue); err != nil {
return err
}
if err := sl.setupCache(); err != nil {
Expand Down Expand Up @@ -122,21 +132,75 @@ func (sl *scalaLang) setupConflictResolvers(flags *flag.FlagSet, c *config.Confi
return nil
}

func (sl *scalaLang) setupExistingScalaRules(rules []string) error {
func (sl *scalaLang) setupExistingScalaBinaryRules(rules []string) error {
for _, fqn := range rules {
parts := strings.SplitN(fqn, "%", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid -existing_scala_rule flag value: wanted '%%' separated string, got %q", fqn)
return fmt.Errorf("invalid -existing_scala_binary_rule flag value: wanted '%%' separated string, got %q", fqn)
}
if err := sl.setupExistingScalaRule(fqn, parts[0], parts[1]); err != nil {
if err := sl.setupExistingScalaBinaryRule(fqn, parts[0], parts[1]); err != nil {
return err
}
}
return nil
}

func (sl *scalaLang) setupExistingScalaRule(fqn, load, kind string) error {
provider := &existingScalaRuleProvider{load, kind}
func (sl *scalaLang) setupExistingScalaLibraryRules(rules []string) error {
for _, fqn := range rules {
parts := strings.SplitN(fqn, "%", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid -existing_scala_library_rule flag value: wanted '%%' separated string, got %q", fqn)
}
if err := sl.setupExistingScalaLibraryRule(fqn, parts[0], parts[1]); err != nil {
return err
}
}
return nil
}

func (sl *scalaLang) setupExistingScalaTestRules(rules []string) error {
for _, fqn := range rules {
parts := strings.SplitN(fqn, "%", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid -existing_scala_test_rule flag value: wanted '%%' separated string, got %q", fqn)
}
if err := sl.setupExistingScalaTestRule(fqn, parts[0], parts[1]); err != nil {
return err
}
}
return nil
}

func (sl *scalaLang) setupExistingScalaBinaryRule(fqn, load, kind string) error {
provider := &existingScalaRuleProvider{
load: load,
name: kind,
isBinary: true,
isLibrary: false,
isTest: false,
}
return sl.ruleProviderRegistry.RegisterProvider(fqn, provider)
}

func (sl *scalaLang) setupExistingScalaLibraryRule(fqn, load, kind string) error {
provider := &existingScalaRuleProvider{
load: load,
name: kind,
isBinary: false,
isLibrary: true,
isTest: false,
}
return sl.ruleProviderRegistry.RegisterProvider(fqn, provider)
}

func (sl *scalaLang) setupExistingScalaTestRule(fqn, load, kind string) error {
provider := &existingScalaRuleProvider{
load: load,
name: kind,
isBinary: false,
isLibrary: false,
isTest: true,
}
return sl.ruleProviderRegistry.RegisterProvider(fqn, provider)
}

Expand Down
4 changes: 2 additions & 2 deletions language/scala/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestParseScalaExistingRules(t *testing.T) {
"degenerate": {},
"invalid flag value": {
providerNames: []string{"@io_bazel_rules_scala//scala:scala.bzl#scala_binary"},
wantErr: fmt.Errorf(`invalid -existing_scala_rule flag value: wanted '%%' separated string, got "@io_bazel_rules_scala//scala:scala.bzl#scala_binary"`),
wantErr: fmt.Errorf(`invalid -existing_scala_binary_rule flag value: wanted '%%' separated string, got "@io_bazel_rules_scala//scala:scala.bzl#scala_binary"`),
},
"valid flag value": {
providerNames: []string{"//custom/scala:scala.bzl%scala_binary"},
Expand All @@ -175,7 +175,7 @@ func TestParseScalaExistingRules(t *testing.T) {
t.Run(name, func(t *testing.T) {
lang := NewLanguage().(*scalaLang)
lang.ruleProviderRegistry = scalarule.NewProviderRegistryMap() // don't use global one
if testutil.ExpectError(t, tc.wantErr, lang.setupExistingScalaRules(tc.providerNames)) {
if testutil.ExpectError(t, tc.wantErr, lang.setupExistingScalaBinaryRules(tc.providerNames)) {
return
}
if tc.check != nil {
Expand Down
16 changes: 11 additions & 5 deletions language/scala/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ type scalaLang struct {
// conflictResolverNamesFlagValue is a repeatable list of conflict resolver
// to enable
conflictResolverNamesFlagValue collections.StringSlice
// existingScalaRulesFlagValue is the value of the existing_scala_rule
// repeatable flag
existingScalaRulesFlagValue collections.StringSlice
cpuprofileFlagValue string
memprofileFlagValue string
// existingScalaLibraryRulesFlagValue is the value of the
// existing_scala_binary_rule repeatable flag
existingScalaBinaryRulesFlagValue collections.StringSlice
// existingScalaLibraryRulesFlagValue is the value of the
// existing_scala_library_rule repeatable flag
existingScalaLibraryRulesFlagValue collections.StringSlice
// existingScalaLibraryRulesFlagValue is the value of the
// existing_scala_test_rule repeatable flag
existingScalaTestRulesFlagValue collections.StringSlice
cpuprofileFlagValue string
memprofileFlagValue string
// cache is the loaded cache, if configured
cache scpb.Cache
// ruleProviderRegistry is the rule registry implementation. This holds the
Expand Down
8 changes: 4 additions & 4 deletions language/scala/scala_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ func TestScalaConfigParseRuleDirective(t *testing.T) {
},
"bad format": {
directives: []rule.Directive{
{Key: scalaRuleDirective, Value: "myrule existing_scala_rule"},
{Key: scalaRuleDirective, Value: "myrule existing_scala_library_rule"},
},
wantErr: fmt.Errorf(`invalid directive: "gazelle:scala_rule myrule existing_scala_rule": expected three or more fields, got 2`),
wantErr: fmt.Errorf(`invalid directive: "gazelle:scala_rule myrule existing_scala_library_rule": expected three or more fields, got 2`),
},
"example": {
directives: []rule.Directive{
{Key: scalaRuleDirective, Value: "myrule implementation existing_scala_rule"},
{Key: scalaRuleDirective, Value: "myrule implementation existing_scala_library_rule"},
{Key: scalaRuleDirective, Value: "myrule deps @maven//:a"},
{Key: scalaRuleDirective, Value: "myrule +deps @maven//:b"},
{Key: scalaRuleDirective, Value: "myrule -deps @maven//:c"},
Expand All @@ -122,7 +122,7 @@ func TestScalaConfigParseRuleDirective(t *testing.T) {
"myrule": {
Config: config.New(),
Name: "myrule",
Implementation: "existing_scala_rule",
Implementation: "existing_scala_library_rule",
Deps: map[string]bool{
"@maven//:a": true,
"@maven//:b": true,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
-scala_symbol_provider=source
-existing_scala_rule=//rules:scala.bzl%scala_helper_library
-existing_scala_library_rule=//rules:scala.bzl%scala_helper_library

0 comments on commit 04527ef

Please sign in to comment.