Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add functionality to export proto files #475

Merged
merged 9 commits into from
Jul 12, 2024
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ grpcurl -protoset my-protos.bin list

# Using proto sources
grpcurl -import-path ../protos -proto my-stuff.proto list

# Export proto files
grpcurl -plaintext -proto-out-dir "out_protos" "192.168.100.1:9200" describe Api.Service


```

The "list" verb also lets you see all methods in a particular service:
Expand Down
24 changes: 24 additions & 0 deletions cmd/grpcurl/grpcurl.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ var (
file if this option is given. When invoking an RPC and this option is
given, the method being invoked and its transitive dependencies will be
included in the output file.`))
protoOut = flags.String("proto-out-dir", "", prettify(`
The name of a directory where the generated .proto files will be written.
With the list and describe verbs, the listed or described elements and
their transitive dependencies will be written as .proto files in the
specified directory if this option is given. When invoking an RPC and
this option is given, the method being invoked and its transitive
dependencies will be included in the generated .proto files in the
output directory.`))
msgTemplate = flags.Bool("msg-template", false, prettify(`
When describing messages, show a template of input data.`))
verbose = flags.Bool("v", false, prettify(`
Expand Down Expand Up @@ -645,6 +653,9 @@ func main() {
if err := writeProtoset(descSource, svcs...); err != nil {
fail(err, "Failed to write protoset to %s", *protosetOut)
}
if err := writeProtos(descSource, svcs...); err != nil {
fail(err, "Failed to write protos to %s", *protoOut)
}
} else {
methods, err := grpcurl.ListMethods(descSource, symbol)
if err != nil {
Expand All @@ -660,6 +671,9 @@ func main() {
if err := writeProtoset(descSource, symbol); err != nil {
fail(err, "Failed to write protoset to %s", *protosetOut)
}
if err := writeProtos(descSource, symbol); err != nil {
fail(err, "Failed to write protos to %s", *protoOut)
}
}

} else if describe {
Expand Down Expand Up @@ -764,6 +778,9 @@ func main() {
if err := writeProtoset(descSource, symbols...); err != nil {
fail(err, "Failed to write protoset to %s", *protosetOut)
}
if err := writeProtos(descSource, symbol); err != nil {
fail(err, "Failed to write protos to %s", *protoOut)
}

} else {
// Invoke an RPC
Expand Down Expand Up @@ -923,6 +940,13 @@ func writeProtoset(descSource grpcurl.DescriptorSource, symbols ...string) error
return grpcurl.WriteProtoset(f, descSource, symbols...)
}

func writeProtos(descSource grpcurl.DescriptorSource, symbols ...string) error {
if *protoOut == "" {
return nil
}
return grpcurl.WriteProtoFiles(*protoOut, descSource, symbols...)
}

type optionalBoolFlag struct {
set, val bool
}
Expand Down
90 changes: 77 additions & 13 deletions desc_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"sync"

"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
"github.com/jhump/protoreflect/desc/protoprint"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/grpcreflect"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -258,19 +260,9 @@ func reflectionSupport(err error) error {
// given output. The output will include descriptors for all files in which the
// symbols are defined as well as their transitive dependencies.
func WriteProtoset(out io.Writer, descSource DescriptorSource, symbols ...string) error {
// compute set of file descriptors
filenames := make([]string, 0, len(symbols))
fds := make(map[string]*desc.FileDescriptor, len(symbols))
for _, sym := range symbols {
d, err := descSource.FindSymbol(sym)
if err != nil {
return fmt.Errorf("failed to find descriptor for %q: %v", sym, err)
}
fd := d.GetFile()
if _, ok := fds[fd.GetName()]; !ok {
fds[fd.GetName()] = fd
filenames = append(filenames, fd.GetName())
}
filenames, fds, err := getFileDescriptors(symbols, descSource)
if err != nil {
return err
}
// now expand that to include transitive dependencies in topologically sorted
// order (such that file always appears after its dependencies)
Expand Down Expand Up @@ -302,3 +294,75 @@ func addFilesToSet(allFiles []*descriptorpb.FileDescriptorProto, expanded map[st
}
return append(allFiles, fd.AsFileDescriptorProto())
}

// WriteProtoFiles will use the given descriptor source to resolve all the given
// symbols and write proto files with their definitions to the given output directory.
func WriteProtoFiles(outProtoDirPath string, descSource DescriptorSource, symbols ...string) error {
filenames, fds, err := getFileDescriptors(symbols, descSource)
if err != nil {
return err
}
// now expand that to include transitive dependencies in topologically sorted
// order (such that file always appears after its dependencies)
expandedFiles := make(map[string]struct{}, len(fds))
allFilesSlice := make([]*desc.FileDescriptor, 0, len(fds))
for _, filename := range filenames {
allFilesSlice = addFilesToFileDescriptorList(allFilesSlice, expandedFiles, fds[filename])
}
pr := protoprint.Printer{}
// now we can serialize to files
for _, fd := range allFilesSlice {
fdFQName := fd.GetFullyQualifiedName()
dirPath := filepath.Dir(fdFQName)
outFilepath := filepath.Join(outProtoDirPath, dirPath)
if err := os.MkdirAll(outFilepath, 0755); err != nil {
return fmt.Errorf("failed to create directory %q: %v", outFilepath, err)
}
fileName := filepath.Base(fdFQName)
filePath := filepath.Join(outFilepath, fileName)
f, err := os.Create(filePath)
if err != nil {
if f != nil {
_ = f.Close()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to omit this block, if err is non-nil, f will be nil

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved

return fmt.Errorf("failed to create file %q: %v", filePath, err)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

immediately after this block, you should probably defer f.Close()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to scope it sometimes:

for _, fd := range allFilesSlice {
	fdFQName := fd.GetFullyQualifiedName()
	dirPath := filepath.Dir(fdFQName)
	outFilepath := filepath.Join(outProtoDirPath, dirPath)
	if err := os.MkdirAll(outFilepath, 0755); err != nil {
		return fmt.Errorf("failed to create directory %q: %v", outFilepath, err)
	}
	fileName := filepath.Base(fdFQName)
	filePath := filepath.Join(outFilepath, fileName)
	err := func() error {
		f, err := os.Create(filePath)
		if err != nil {
			return fmt.Errorf("failed to create")
		}
		defer f.Close()
		if err := pr.PrintProtoFile(fd, f); err != nil {
			return fmt.Errorf("failed to write")
		}
		if err := f.Close(); err != nil {
			return fmt.Errorf("failed to close")
		}
		return nil
	}()
	if err != nil {
		return fmt.Errorf("file %q: %w", filePath, err)
	}
}
return nil

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^ yeah this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ready

if err := pr.PrintProtoFile(fd, f); err != nil {
_ = f.Close()
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}
_ = f.Close()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You actually want to error check the close as well. Something like:

err = pr.PrintProtoFile(fd, f)
if err == nil {
  err = f.Close()
}
// do the error check now

you also don't need the final close if you deferred close earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ready

}
return nil
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do one more pass on this, I have a few changes:

  • simplify filepath calculations and naming
  • improve errors with filename and error wrapping
  • 0777 so the system level umask can take effect
func writeProtoFile(outProtoDirPath string, fd *desc.FileDescriptor, pr *protoprint.Printer) error {
	outFile := filepath.Join(outProtoDirPath, fd.GetFullyQualifiedName())
	outDir := filepath.Dir(outFile)
	if err := os.MkdirAll(outDir, 0777); err != nil {
		return fmt.Errorf("failed to create directory %q: %w", outDir, err)
	}

	f, err := os.Create(outFile)
	if err != nil {
		return fmt.Errorf("failed to create proto file %q: %w", outFile, err)
	}
	defer f.Close()
	if err := pr.PrintProtoFile(fd, f); err != nil {
		return fmt.Errorf("failed to write proto file %q: %w", outFile, err)
	}
	return nil
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ready


func getFileDescriptors(symbols []string, descSource DescriptorSource) ([]string, map[string]*desc.FileDescriptor, error) {
// compute set of file descriptors
filenames := make([]string, 0, len(symbols))
fds := make(map[string]*desc.FileDescriptor, len(symbols))
for _, sym := range symbols {
d, err := descSource.FindSymbol(sym)
if err != nil {
return nil, nil, fmt.Errorf("failed to find descriptor for %q: %v", sym, err)
}
fd := d.GetFile()
if _, ok := fds[fd.GetName()]; !ok {
fds[fd.GetName()] = fd
filenames = append(filenames, fd.GetName())
}
}
return filenames, fds, nil
}

func addFilesToFileDescriptorList(allFiles []*desc.FileDescriptor, expanded map[string]struct{}, fd *desc.FileDescriptor) []*desc.FileDescriptor {
if _, ok := expanded[fd.GetName()]; ok {
// already seen this one
return allFiles
}
expanded[fd.GetName()] = struct{}{}
// add all dependencies first
for _, dep := range fd.GetDependencies() {
allFiles = addFilesToFileDescriptorList(allFiles, expanded, dep)
}
return append(allFiles, fd)
}
Loading