diff --git a/README.md b/README.md index 49960186..e2d4b331 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,11 @@ grpcurl -protoset my-protos.bin list # Using proto sources grpcurl -import-path ../protos -proto my-stuff.proto list + +# Export proto files +grpcurl -plaintext -proto-out "out_protos" "192.168.100.1:9200" describe Api.Service + + ``` The "list" verb also lets you see all methods in a particular service: diff --git a/cmd/grpcurl/grpcurl.go b/cmd/grpcurl/grpcurl.go index 70d079bc..73e31f41 100644 --- a/cmd/grpcurl/grpcurl.go +++ b/cmd/grpcurl/grpcurl.go @@ -151,6 +151,14 @@ var ( file if this option is given. When invoking an RPC and this option is given, the method being invoked and its transitive dependencies will be included in the output file.`)) + protoOut = flags.String("proto-out", "", prettify(` + The name of a directory where the generated .proto files will be written. + With the list and describe verbs, the listed or described elements and + their transitive dependencies will be written as .proto files in the + specified directory if this option is given. When invoking an RPC and + this option is given, the method being invoked and its transitive + dependencies will be included in the generated .proto files in the + output directory.`)) msgTemplate = flags.Bool("msg-template", false, prettify(` When describing messages, show a template of input data.`)) verbose = flags.Bool("v", false, prettify(` @@ -645,6 +653,9 @@ func main() { if err := writeProtoset(descSource, svcs...); err != nil { fail(err, "Failed to write protoset to %s", *protosetOut) } + if err := writeProtos(descSource, svcs...); err != nil { + fail(err, "Failed to write protos to %s", *protoOut) + } } else { methods, err := grpcurl.ListMethods(descSource, symbol) if err != nil { @@ -660,6 +671,9 @@ func main() { if err := writeProtoset(descSource, symbol); err != nil { fail(err, "Failed to write protoset to %s", *protosetOut) } + if err := writeProtos(descSource, symbol); err != nil { + fail(err, "Failed to write protos to %s", *protoOut) + } } } else if describe { @@ -764,6 +778,9 @@ func main() { if err := writeProtoset(descSource, symbols...); err != nil { fail(err, "Failed to write protoset to %s", *protosetOut) } + if err := writeProtos(descSource, symbol); err != nil { + fail(err, "Failed to write protos to %s", *protoOut) + } } else { // Invoke an RPC @@ -923,6 +940,13 @@ func writeProtoset(descSource grpcurl.DescriptorSource, symbols ...string) error return grpcurl.WriteProtoset(f, descSource, symbols...) } +func writeProtos(descSource grpcurl.DescriptorSource, symbols ...string) error { + if *protoOut == "" { + return nil + } + return grpcurl.WriteProtoFiles(*protoOut, descSource, symbols...) +} + type optionalBoolFlag struct { set, val bool } diff --git a/desc_source.go b/desc_source.go index d12fae0d..89a7ccc8 100644 --- a/desc_source.go +++ b/desc_source.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "github.com/jhump/protoreflect/desc/protoprint" "io" "os" + "path/filepath" "sync" "github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API @@ -258,19 +260,9 @@ func reflectionSupport(err error) error { // given output. The output will include descriptors for all files in which the // symbols are defined as well as their transitive dependencies. func WriteProtoset(out io.Writer, descSource DescriptorSource, symbols ...string) error { - // compute set of file descriptors - filenames := make([]string, 0, len(symbols)) - fds := make(map[string]*desc.FileDescriptor, len(symbols)) - for _, sym := range symbols { - d, err := descSource.FindSymbol(sym) - if err != nil { - return fmt.Errorf("failed to find descriptor for %q: %v", sym, err) - } - fd := d.GetFile() - if _, ok := fds[fd.GetName()]; !ok { - fds[fd.GetName()] = fd - filenames = append(filenames, fd.GetName()) - } + filenames, fds, err := getFileDescriptors(symbols, descSource) + if err != nil { + return err } // now expand that to include transitive dependencies in topologically sorted // order (such that file always appears after its dependencies) @@ -302,3 +294,71 @@ func addFilesToSet(allFiles []*descriptorpb.FileDescriptorProto, expanded map[st } return append(allFiles, fd.AsFileDescriptorProto()) } + +// WriteProtoFiles will use the given descriptor source to resolve all the given +// symbols and write proto files with their definitions to the given output directory. +func WriteProtoFiles(outProtoDirPath string, descSource DescriptorSource, symbols ...string) error { + filenames, fds, err := getFileDescriptors(symbols, descSource) + if err != nil { + return err + } + // now expand that to include transitive dependencies in topologically sorted + // order (such that file always appears after its dependencies) + expandedFiles := make(map[string]struct{}, len(fds)) + allFilesSlice := make([]*desc.FileDescriptor, 0, len(fds)) + for _, filename := range filenames { + allFilesSlice = addFilesToFileDescriptorList(allFilesSlice, expandedFiles, fds[filename]) + } + pr := protoprint.Printer{} + // now we can serialize to files + for _, fd := range allFilesSlice { + fdFQName := fd.GetFullyQualifiedName() + dirPath := filepath.Dir(fdFQName) + outFilepath := filepath.Join(outProtoDirPath, dirPath) + if err := os.MkdirAll(outFilepath, 0755); err != nil { + return fmt.Errorf("failed to create directory %q: %v", outFilepath, err) + } + fileName := filepath.Base(fdFQName) + filePath := filepath.Join(outFilepath, fileName) + f, err := os.Create(filePath) + defer f.Close() + if err != nil { + return fmt.Errorf("failed to create file %q: %v", filePath, err) + } + if err := pr.PrintProtoFile(fd, f); err != nil { + return fmt.Errorf("failed to write file %q: %v", filePath, err) + } + } + return nil +} + +func getFileDescriptors(symbols []string, descSource DescriptorSource) ([]string, map[string]*desc.FileDescriptor, error) { + // compute set of file descriptors + filenames := make([]string, 0, len(symbols)) + fds := make(map[string]*desc.FileDescriptor, len(symbols)) + for _, sym := range symbols { + d, err := descSource.FindSymbol(sym) + if err != nil { + return nil, nil, fmt.Errorf("failed to find descriptor for %q: %v", sym, err) + } + fd := d.GetFile() + if _, ok := fds[fd.GetName()]; !ok { + fds[fd.GetName()] = fd + filenames = append(filenames, fd.GetName()) + } + } + return filenames, fds, nil +} + +func addFilesToFileDescriptorList(allFiles []*desc.FileDescriptor, expanded map[string]struct{}, fd *desc.FileDescriptor) []*desc.FileDescriptor { + if _, ok := expanded[fd.GetName()]; ok { + // already seen this one + return allFiles + } + expanded[fd.GetName()] = struct{}{} + // add all dependencies first + for _, dep := range fd.GetDependencies() { + allFiles = addFilesToFileDescriptorList(allFiles, expanded, dep) + } + return append(allFiles, fd) +}