package gengateway import ( "errors" "fmt" "go/format" "path" "path/filepath" "strings" "github.com/golang/glog" "github.com/golang/protobuf/proto" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator" ) var ( errNoTargetService = errors.New("no target service defined in the file") ) type pathType int const ( pathTypeImport pathType = iota pathTypeSourceRelative ) type generator struct { reg *descriptor.Registry baseImports []descriptor.GoPackage useRequestContext bool registerFuncSuffix string pathType pathType allowPatchFeature bool } // New returns a new generator which generates grpc gateway files. func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString string, allowPatchFeature bool) gen.Generator { var imports []descriptor.GoPackage for _, pkgpath := range []string{ "context", "io", "net/http", "github.com/grpc-ecosystem/grpc-gateway/runtime", "github.com/grpc-ecosystem/grpc-gateway/utilities", "github.com/golang/protobuf/descriptor", "github.com/golang/protobuf/proto", "google.golang.org/grpc", "google.golang.org/grpc/codes", "google.golang.org/grpc/grpclog", "google.golang.org/grpc/status", } { pkg := descriptor.GoPackage{ Path: pkgpath, Name: path.Base(pkgpath), } if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil { for i := 0; ; i++ { alias := fmt.Sprintf("%s_%d", pkg.Name, i) if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil { continue } pkg.Alias = alias break } } imports = append(imports, pkg) } var pathType pathType switch pathTypeString { case "", "import": // paths=import is default case "source_relative": pathType = pathTypeSourceRelative default: glog.Fatalf(`Unknown path type %q: want "import" or "source_relative".`, pathTypeString) } return &generator{ reg: reg, baseImports: imports, useRequestContext: useRequestContext, registerFuncSuffix: registerFuncSuffix, pathType: pathType, allowPatchFeature: allowPatchFeature, } } func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) { var files []*plugin.CodeGeneratorResponse_File for _, file := range targets { glog.V(1).Infof("Processing %s", file.GetName()) code, err := g.generate(file) if err == errNoTargetService { glog.V(1).Infof("%s: %v", file.GetName(), err) continue } if err != nil { return nil, err } formatted, err := format.Source([]byte(code)) if err != nil { glog.Errorf("%v: %s", err, code) return nil, err } name := file.GetName() if g.pathType == pathTypeImport && file.GoPkg.Path != "" { name = fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name)) } ext := filepath.Ext(name) base := strings.TrimSuffix(name, ext) output := fmt.Sprintf("%s.pb.gw.go", base) files = append(files, &plugin.CodeGeneratorResponse_File{ Name: proto.String(output), Content: proto.String(string(formatted)), }) glog.V(1).Infof("Will emit %s", output) } return files, nil } func (g *generator) generate(file *descriptor.File) (string, error) { pkgSeen := make(map[string]bool) var imports []descriptor.GoPackage for _, pkg := range g.baseImports { pkgSeen[pkg.Path] = true imports = append(imports, pkg) } for _, svc := range file.Services { for _, m := range svc.Methods { imports = append(imports, g.addEnumPathParamImports(file, m, pkgSeen)...) pkg := m.RequestType.File.GoPkg if len(m.Bindings) == 0 || pkg == file.GoPkg || pkgSeen[pkg.Path] { continue } pkgSeen[pkg.Path] = true imports = append(imports, pkg) } } params := param{ File: file, Imports: imports, UseRequestContext: g.useRequestContext, RegisterFuncSuffix: g.registerFuncSuffix, AllowPatchFeature: g.allowPatchFeature, } return applyTemplate(params, g.reg) } // addEnumPathParamImports handles adding import of enum path parameter go packages func (g *generator) addEnumPathParamImports(file *descriptor.File, m *descriptor.Method, pkgSeen map[string]bool) []descriptor.GoPackage { var imports []descriptor.GoPackage for _, b := range m.Bindings { for _, p := range b.PathParams { e, err := g.reg.LookupEnum("", p.Target.GetTypeName()) if err != nil { continue } pkg := e.File.GoPkg if pkg == file.GoPkg || pkgSeen[pkg.Path] { continue } pkgSeen[pkg.Path] = true imports = append(imports, pkg) } } return imports }