Add functionality to export proto files

Added a new function, `WriteProtoFiles` in `desc_source.go` which is used to generate .proto files. The process involves resolving symbols from the descriptor source and writing their definitions to a designated output directory. The corresponding flag `--proto-out` has been included in `grpcurl.go` to allow users to specify the directory path.
This commit is contained in:
Eitol 2024-07-09 20:03:17 -04:00
parent a05d48d6dd
commit 876d9a9de3
3 changed files with 102 additions and 13 deletions

View File

@ -159,6 +159,11 @@ grpcurl -protoset my-protos.bin list
# Using proto sources
grpcurl -import-path ../protos -proto my-stuff.proto list
# Export proto files
grpcurl -plaintext -proto-out "out_protos" "192.168.100.1:9200" describe Api.Service
```
The "list" verb also lets you see all methods in a particular service:

View File

@ -151,6 +151,14 @@ var (
file if this option is given. When invoking an RPC and this option is
given, the method being invoked and its transitive dependencies will be
included in the output file.`))
protoOut = flags.String("proto-out", "", prettify(`
The name of a directory where the generated .proto files will be written.
With the list and describe verbs, the listed or described elements and
their transitive dependencies will be written as .proto files in the
specified directory if this option is given. When invoking an RPC and
this option is given, the method being invoked and its transitive
dependencies will be included in the generated .proto files in the
output directory.`))
msgTemplate = flags.Bool("msg-template", false, prettify(`
When describing messages, show a template of input data.`))
verbose = flags.Bool("v", false, prettify(`
@ -645,6 +653,9 @@ func main() {
if err := writeProtoset(descSource, svcs...); err != nil {
fail(err, "Failed to write protoset to %s", *protosetOut)
}
if err := writeProtos(descSource, svcs...); err != nil {
fail(err, "Failed to write protos to %s", *protoOut)
}
} else {
methods, err := grpcurl.ListMethods(descSource, symbol)
if err != nil {
@ -660,6 +671,9 @@ func main() {
if err := writeProtoset(descSource, symbol); err != nil {
fail(err, "Failed to write protoset to %s", *protosetOut)
}
if err := writeProtos(descSource, symbol); err != nil {
fail(err, "Failed to write protos to %s", *protoOut)
}
}
} else if describe {
@ -764,6 +778,9 @@ func main() {
if err := writeProtoset(descSource, symbols...); err != nil {
fail(err, "Failed to write protoset to %s", *protosetOut)
}
if err := writeProtos(descSource, symbol); err != nil {
fail(err, "Failed to write protos to %s", *protoOut)
}
} else {
// Invoke an RPC
@ -923,6 +940,13 @@ func writeProtoset(descSource grpcurl.DescriptorSource, symbols ...string) error
return grpcurl.WriteProtoset(f, descSource, symbols...)
}
func writeProtos(descSource grpcurl.DescriptorSource, symbols ...string) error {
if *protoOut == "" {
return nil
}
return grpcurl.WriteProtoFiles(*protoOut, descSource, symbols...)
}
type optionalBoolFlag struct {
set, val bool
}

View File

@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"github.com/jhump/protoreflect/desc/protoprint"
"io"
"os"
"path/filepath"
"sync"
"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
@ -258,19 +260,9 @@ func reflectionSupport(err error) error {
// given output. The output will include descriptors for all files in which the
// symbols are defined as well as their transitive dependencies.
func WriteProtoset(out io.Writer, descSource DescriptorSource, symbols ...string) error {
// compute set of file descriptors
filenames := make([]string, 0, len(symbols))
fds := make(map[string]*desc.FileDescriptor, len(symbols))
for _, sym := range symbols {
d, err := descSource.FindSymbol(sym)
if err != nil {
return fmt.Errorf("failed to find descriptor for %q: %v", sym, err)
}
fd := d.GetFile()
if _, ok := fds[fd.GetName()]; !ok {
fds[fd.GetName()] = fd
filenames = append(filenames, fd.GetName())
}
filenames, fds, err := getFileDescriptors(symbols, descSource)
if err != nil {
return err
}
// now expand that to include transitive dependencies in topologically sorted
// order (such that file always appears after its dependencies)
@ -302,3 +294,71 @@ func addFilesToSet(allFiles []*descriptorpb.FileDescriptorProto, expanded map[st
}
return append(allFiles, fd.AsFileDescriptorProto())
}
// WriteProtoFiles will use the given descriptor source to resolve all the given
// symbols and write proto files with their definitions to the given output directory.
func WriteProtoFiles(outProtoDirPath string, descSource DescriptorSource, symbols ...string) error {
filenames, fds, err := getFileDescriptors(symbols, descSource)
if err != nil {
return err
}
// now expand that to include transitive dependencies in topologically sorted
// order (such that file always appears after its dependencies)
expandedFiles := make(map[string]struct{}, len(fds))
allFilesSlice := make([]*desc.FileDescriptor, 0, len(fds))
for _, filename := range filenames {
allFilesSlice = addFilesToFileDescriptorList(allFilesSlice, expandedFiles, fds[filename])
}
pr := protoprint.Printer{}
// now we can serialize to files
for _, fd := range allFilesSlice {
fdFQName := fd.GetFullyQualifiedName()
dirPath := filepath.Dir(fdFQName)
outFilepath := filepath.Join(outProtoDirPath, dirPath)
if err := os.MkdirAll(outFilepath, 0755); err != nil {
return fmt.Errorf("failed to create directory %q: %v", outFilepath, err)
}
fileName := filepath.Base(fdFQName)
filePath := filepath.Join(outFilepath, fileName)
f, err := os.Create(filePath)
defer f.Close()
if err != nil {
return fmt.Errorf("failed to create file %q: %v", filePath, err)
}
if err := pr.PrintProtoFile(fd, f); err != nil {
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}
}
return nil
}
func getFileDescriptors(symbols []string, descSource DescriptorSource) ([]string, map[string]*desc.FileDescriptor, error) {
// compute set of file descriptors
filenames := make([]string, 0, len(symbols))
fds := make(map[string]*desc.FileDescriptor, len(symbols))
for _, sym := range symbols {
d, err := descSource.FindSymbol(sym)
if err != nil {
return nil, nil, fmt.Errorf("failed to find descriptor for %q: %v", sym, err)
}
fd := d.GetFile()
if _, ok := fds[fd.GetName()]; !ok {
fds[fd.GetName()] = fd
filenames = append(filenames, fd.GetName())
}
}
return filenames, fds, nil
}
func addFilesToFileDescriptorList(allFiles []*desc.FileDescriptor, expanded map[string]struct{}, fd *desc.FileDescriptor) []*desc.FileDescriptor {
if _, ok := expanded[fd.GetName()]; ok {
// already seen this one
return allFiles
}
expanded[fd.GetName()] = struct{}{}
// add all dependencies first
for _, dep := range fd.GetDependencies() {
allFiles = addFilesToFileDescriptorList(allFiles, expanded, dep)
}
return append(allFiles, fd)
}