Skip to content

Commit

Permalink
Fix: gRPC wrapper does not generate the correct gRPC handlers with co…
Browse files Browse the repository at this point in the history
…ntext (#16)
  • Loading branch information
coolwednesday authored Dec 27, 2024
1 parent 5e5637e commit 6e5d42d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
51 changes: 42 additions & 9 deletions wrap/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ import (
const filePerm = 0644

var (
ErrNoProtoFile = errors.New("proto file path is required")
ErrOpeningProtoFile = errors.New("error opening the proto file")
ErrFailedToParseProto = errors.New("failed to parse proto file")
ErrGeneratingWrapper = errors.New("error generating the wrapper code from the proto file")
ErrWritingWrapperFile = errors.New("error writing the generated wrapper to the file")
ErrNoProtoFile = errors.New("proto file path is required")
ErrOpeningProtoFile = errors.New("error opening the proto file")
ErrFailedToParseProto = errors.New("failed to parse proto file")
ErrGeneratingWrapper = errors.New("error generating the wrapper code from the proto file")
ErrWritingWrapperFile = errors.New("error writing the generated wrapper to the file")
ErrGeneratingServerTemplate = errors.New("error generating the gRPC server file template")
ErrWritingServerTemplate = errors.New("error writing the generated server template to the file")
)

// ServiceMethod represents a method in a proto service.
Expand Down Expand Up @@ -71,7 +73,7 @@ func GenerateWrapper(ctx *gofr.Context) (any, error) {

var (
// Extracting package and project path from go_package option.
packageName, projectPath = getPackageAndProject(definition)
projectPath, packageName = getPackageAndProject(definition)
// Extract the services.
services = getServices(definition)
)
Expand All @@ -89,7 +91,7 @@ func GenerateWrapper(ctx *gofr.Context) (any, error) {
return nil, ErrGeneratingWrapper
}

outputFilePath := fmt.Sprintf("%s/%s.gofr.go", projectPath, strings.ToLower(service.Name))
outputFilePath := path.Join(projectPath, fmt.Sprintf("%s.gofr.go", strings.ToLower(service.Name)))

err := os.WriteFile(outputFilePath, []byte(generatedCode), filePerm)
if err != nil {
Expand All @@ -99,6 +101,22 @@ func GenerateWrapper(ctx *gofr.Context) (any, error) {
}

fmt.Printf("Generated wrapper for service %s at %s\n", service.Name, outputFilePath)

generatedgRPCCode := generategRPCCode(ctx, &wrapperData)
if generatedgRPCCode == "" {
return nil, ErrGeneratingServerTemplate
}

outputFilePath = path.Join(projectPath, fmt.Sprintf("%sServer.go", strings.ToLower(service.Name)))

err = os.WriteFile(outputFilePath, []byte(generatedgRPCCode), filePerm)
if err != nil {
ctx.Errorf("Failed to write file %s: %v", outputFilePath, err)

return nil, ErrWritingServerTemplate
}

fmt.Printf("Generated server template for service %s at %s\n", service.Name, outputFilePath)
}

return "Successfully generated all wrappers for gRPC services", nil
Expand Down Expand Up @@ -126,7 +144,7 @@ func uniqueRequestTypes(methods []ServiceMethod) []string {
func generateWrapperCode(ctx *gofr.Context, data *WrapperData) string {
var buf bytes.Buffer

tmplInstance := template.Must(template.New("wrapper").Parse(tmpl))
tmplInstance := template.Must(template.New("wrapper").Parse(wrapperTemplate))

err := tmplInstance.Execute(&buf, data)
if err != nil {
Expand All @@ -138,11 +156,26 @@ func generateWrapperCode(ctx *gofr.Context, data *WrapperData) string {
return buf.String()
}

// Generate wrapper code using the template.
func generategRPCCode(ctx *gofr.Context, data *WrapperData) string {
var buf bytes.Buffer

tmplInstance := template.Must(template.New("wrapper").Parse(serverTemplate))

err := tmplInstance.Execute(&buf, data)
if err != nil {
ctx.Errorf("Template execution failed: %v", err)
return ""
}

return buf.String()
}

func getPackageAndProject(definition *proto.Proto) (projectPath, packageName string) {
proto.Walk(definition,
proto.WithOption(func(opt *proto.Option) {
if opt.Name == "go_package" {
projectPath = opt.Constant.Source[:len(opt.Constant.Source)-1]
projectPath = opt.Constant.Source
packageName = path.Base(opt.Constant.Source)
}
}),
Expand Down
36 changes: 33 additions & 3 deletions wrap/template.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package wrap

const tmpl = `// Code generated by gofr.dev/cli/gofr. DO NOT EDIT.
const (
wrapperTemplate = `// Code generated by gofr.dev/cli/gofr. DO NOT EDIT.
package {{ .Package }}
import (
Expand Down Expand Up @@ -90,8 +91,8 @@ func (h *{{ $request }}Wrapper) Bind(p interface{}) error {
return fmt.Errorf("expected a pointer, got %T", p)
}
hValue := reflect.ValueOf(h.InfoRequest).Elem()
ptrValue := reflect.ValueOf(ptr).Elem()
hValue := reflect.ValueOf(h.{{ $request }}).Elem()
ptrValue := ptr.Elem()
// Ensure we can set exported fields (skip unexported fields)
for i := 0; i < hValue.NumField(); i++ {
Expand Down Expand Up @@ -119,3 +120,32 @@ func (h *{{ $request }}Wrapper) Params(s string) []string {
{{- end }}
`

serverTemplate = `package {{ .Package }}
import "gofr.dev/pkg/gofr"
// Register the gRPC service in your app using the following code in your main.go:
//
// grpc.Register{{ $.Service }}ServerWithGofr(app, &grpc.{{ $.Service }}GoFrServer{})
//
// {{ $.Service }}GoFrServer defines the gRPC server implementation.
// Customize the struct with required dependencies and fields as needed.
type {{ $.Service }}GoFrServer struct {
}
{{- range .Methods }}
func (s *{{ $.Service }}GoFrServer) {{ .Name }}(ctx *gofr.Context) (any, error) {
// Uncomment and use the following code if you need to bind the request payload
// request := {{ .Request }}{}
// err := ctx.Bind(&request)
// if err != nil {
// return nil, err
// }
return &{{ .Response }}{}, nil
}
{{- end }}
`
)

0 comments on commit 6e5d42d

Please sign in to comment.