| // Copyright 2018 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package internal_gengogrpc is internal to the protobuf module. |
| package internal_gengogrpc |
| |
| import ( |
| "fmt" |
| "strconv" |
| "strings" |
| |
| descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| "github.com/golang/protobuf/v2/protogen" |
| ) |
| |
| type fileInfo struct { |
| *protogen.File |
| locationMap map[string][]*descpb.SourceCodeInfo_Location |
| } |
| |
| // GenerateFile generates a _grpc.pb.go file containing gRPC service definitions. |
| func GenerateFile(gen *protogen.Plugin, f *protogen.File) { |
| if len(f.Services) == 0 { |
| return |
| } |
| filename := f.GeneratedFilenamePrefix + "_grpc.pb.go" |
| g := gen.NewGeneratedFile(filename, f.GoImportPath) |
| g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.") |
| g.P() |
| g.P("package ", f.GoPackageName) |
| g.P() |
| GenerateFileContent(gen, f, g) |
| } |
| |
| // GenerateFileContent generates the gRPC service definitions, excluding the package statement. |
| func GenerateFileContent(gen *protogen.Plugin, f *protogen.File, g *protogen.GeneratedFile) { |
| if len(f.Services) == 0 { |
| return |
| } |
| file := &fileInfo{ |
| File: f, |
| locationMap: make(map[string][]*descpb.SourceCodeInfo_Location), |
| } |
| for _, loc := range file.Proto.GetSourceCodeInfo().GetLocation() { |
| key := pathKey(loc.Path) |
| file.locationMap[key] = append(file.locationMap[key], loc) |
| } |
| |
| // TODO: Remove this. We don't need to include these references any more. |
| g.P("// Reference imports to suppress errors if they are not otherwise used.") |
| g.P("var _ ", ident("context.Context")) |
| g.P("var _ ", ident("grpc.ClientConn")) |
| g.P() |
| |
| g.P("// This is a compile-time assertion to ensure that this generated file") |
| g.P("// is compatible with the grpc package it is being compiled against.") |
| g.P("const _ = ", ident("grpc.SupportPackageIsVersion4")) |
| g.P() |
| for _, service := range file.Services { |
| genService(gen, file, g, service) |
| } |
| } |
| |
| func genService(gen *protogen.Plugin, file *fileInfo, g *protogen.GeneratedFile, service *protogen.Service) { |
| clientName := service.GoName + "Client" |
| |
| g.P("// ", clientName, " is the client API for ", service.GoName, " service.") |
| g.P("//") |
| g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.") |
| |
| // Client interface. |
| if serviceOptions(gen, service).GetDeprecated() { |
| g.P("//") |
| g.P(deprecationComment) |
| } |
| g.P("type ", clientName, " interface {") |
| for _, method := range service.Methods { |
| genComment(g, file, method.Path) |
| g.P(clientSignature(g, method)) |
| } |
| g.P("}") |
| g.P() |
| |
| // Client structure. |
| g.P("type ", unexport(clientName), " struct {") |
| g.P("cc *", ident("grpc.ClientConn")) |
| g.P("}") |
| g.P() |
| |
| // NewClient factory. |
| if serviceOptions(gen, service).GetDeprecated() { |
| g.P(deprecationComment) |
| } |
| g.P("func New", clientName, " (cc *", ident("grpc.ClientConn"), ") ", clientName, " {") |
| g.P("return &", unexport(clientName), "{cc}") |
| g.P("}") |
| g.P() |
| |
| var methodIndex, streamIndex int |
| // Client method implementations. |
| for _, method := range service.Methods { |
| if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { |
| // Unary RPC method |
| genClientMethod(gen, file, g, method, methodIndex) |
| methodIndex++ |
| } else { |
| // Streaming RPC method |
| genClientMethod(gen, file, g, method, streamIndex) |
| streamIndex++ |
| } |
| } |
| |
| // Server interface. |
| serverType := service.GoName + "Server" |
| g.P("// ", serverType, " is the server API for ", service.GoName, " service.") |
| if serviceOptions(gen, service).GetDeprecated() { |
| g.P("//") |
| g.P(deprecationComment) |
| } |
| g.P("type ", serverType, " interface {") |
| for _, method := range service.Methods { |
| genComment(g, file, method.Path) |
| g.P(serverSignature(g, method)) |
| } |
| g.P("}") |
| g.P() |
| |
| // Server registration. |
| if serviceOptions(gen, service).GetDeprecated() { |
| g.P(deprecationComment) |
| } |
| serviceDescVar := "_" + service.GoName + "_serviceDesc" |
| g.P("func Register", service.GoName, "Server(s *", ident("grpc.Server"), ", srv ", serverType, ") {") |
| g.P("s.RegisterService(&", serviceDescVar, `, srv)`) |
| g.P("}") |
| g.P() |
| |
| // Server handler implementations. |
| var handlerNames []string |
| for _, method := range service.Methods { |
| hname := genServerMethod(gen, file, g, method) |
| handlerNames = append(handlerNames, hname) |
| } |
| |
| // Service descriptor. |
| g.P("var ", serviceDescVar, " = ", ident("grpc.ServiceDesc"), " {") |
| g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",") |
| g.P("HandlerType: (*", serverType, ")(nil),") |
| g.P("Methods: []", ident("grpc.MethodDesc"), "{") |
| for i, method := range service.Methods { |
| if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { |
| continue |
| } |
| g.P("{") |
| g.P("MethodName: ", strconv.Quote(method.GoName), ",") |
| g.P("Handler: ", handlerNames[i], ",") |
| g.P("},") |
| } |
| g.P("},") |
| g.P("Streams: []", ident("grpc.StreamDesc"), "{") |
| for i, method := range service.Methods { |
| if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| continue |
| } |
| g.P("{") |
| g.P("StreamName: ", strconv.Quote(method.GoName), ",") |
| g.P("Handler: ", handlerNames[i], ",") |
| if method.Desc.IsStreamingServer() { |
| g.P("ServerStreams: true,") |
| } |
| if method.Desc.IsStreamingClient() { |
| g.P("ClientStreams: true,") |
| } |
| g.P("},") |
| } |
| g.P("},") |
| g.P("Metadata: \"", file.Desc.Path(), "\",") |
| g.P("}") |
| g.P() |
| } |
| |
| func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { |
| s := method.GoName + "(ctx " + g.QualifiedGoIdent(ident("context.Context")) |
| if !method.Desc.IsStreamingClient() { |
| s += ", in *" + g.QualifiedGoIdent(method.InputType.GoIdent) |
| } |
| s += ", opts ..." + g.QualifiedGoIdent(ident("grpc.CallOption")) + ") (" |
| if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| s += "*" + g.QualifiedGoIdent(method.OutputType.GoIdent) |
| } else { |
| s += method.ParentService.GoName + "_" + method.GoName + "Client" |
| } |
| s += ", error)" |
| return s |
| } |
| |
| func genClientMethod(gen *protogen.Plugin, file *fileInfo, g *protogen.GeneratedFile, method *protogen.Method, index int) { |
| service := method.ParentService |
| sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) |
| |
| if methodOptions(gen, method).GetDeprecated() { |
| g.P(deprecationComment) |
| } |
| g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{") |
| if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { |
| g.P("out := new(", method.OutputType.GoIdent, ")") |
| g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`) |
| g.P("if err != nil { return nil, err }") |
| g.P("return out, nil") |
| g.P("}") |
| g.P() |
| return |
| } |
| streamType := unexport(service.GoName) + method.GoName + "Client" |
| serviceDescVar := "_" + service.GoName + "_serviceDesc" |
| g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`) |
| g.P("if err != nil { return nil, err }") |
| g.P("x := &", streamType, "{stream}") |
| if !method.Desc.IsStreamingClient() { |
| g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }") |
| g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") |
| } |
| g.P("return x, nil") |
| g.P("}") |
| g.P() |
| |
| genSend := method.Desc.IsStreamingClient() |
| genRecv := method.Desc.IsStreamingServer() |
| genCloseAndRecv := !method.Desc.IsStreamingServer() |
| |
| // Stream auxiliary types and methods. |
| g.P("type ", service.GoName, "_", method.GoName, "Client interface {") |
| if genSend { |
| g.P("Send(*", method.InputType.GoIdent, ") error") |
| } |
| if genRecv { |
| g.P("Recv() (*", method.OutputType.GoIdent, ", error)") |
| } |
| if genCloseAndRecv { |
| g.P("CloseAndRecv() (*", method.OutputType.GoIdent, ", error)") |
| } |
| g.P(ident("grpc.ClientStream")) |
| g.P("}") |
| g.P() |
| |
| g.P("type ", streamType, " struct {") |
| g.P(ident("grpc.ClientStream")) |
| g.P("}") |
| g.P() |
| |
| if genSend { |
| g.P("func (x *", streamType, ") Send(m *", method.InputType.GoIdent, ") error {") |
| g.P("return x.ClientStream.SendMsg(m)") |
| g.P("}") |
| g.P() |
| } |
| if genRecv { |
| g.P("func (x *", streamType, ") Recv() (*", method.OutputType.GoIdent, ", error) {") |
| g.P("m := new(", method.OutputType.GoIdent, ")") |
| g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") |
| g.P("return m, nil") |
| g.P("}") |
| g.P() |
| } |
| if genCloseAndRecv { |
| g.P("func (x *", streamType, ") CloseAndRecv() (*", method.OutputType.GoIdent, ", error) {") |
| g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") |
| g.P("m := new(", method.OutputType.GoIdent, ")") |
| g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") |
| g.P("return m, nil") |
| g.P("}") |
| g.P() |
| } |
| } |
| |
| func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { |
| var reqArgs []string |
| ret := "error" |
| if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| reqArgs = append(reqArgs, g.QualifiedGoIdent(ident("context.Context"))) |
| ret = "(*" + g.QualifiedGoIdent(method.OutputType.GoIdent) + ", error)" |
| } |
| if !method.Desc.IsStreamingClient() { |
| reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.InputType.GoIdent)) |
| } |
| if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { |
| reqArgs = append(reqArgs, method.ParentService.GoName+"_"+method.GoName+"Server") |
| } |
| return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret |
| } |
| |
| func genServerMethod(gen *protogen.Plugin, file *fileInfo, g *protogen.GeneratedFile, method *protogen.Method) string { |
| service := method.ParentService |
| hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName) |
| |
| if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { |
| g.P("func ", hname, "(srv interface{}, ctx ", ident("context.Context"), ", dec func(interface{}) error, interceptor ", ident("grpc.UnaryServerInterceptor"), ") (interface{}, error) {") |
| g.P("in := new(", method.InputType.GoIdent, ")") |
| g.P("if err := dec(in); err != nil { return nil, err }") |
| g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }") |
| g.P("info := &", ident("grpc.UnaryServerInfo"), "{") |
| g.P("Server: srv,") |
| g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())), ",") |
| g.P("}") |
| g.P("handler := func(ctx ", ident("context.Context"), ", req interface{}) (interface{}, error) {") |
| g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.InputType.GoIdent, "))") |
| g.P("}") |
| g.P("return interceptor(ctx, in, info, handler)") |
| g.P("}") |
| g.P() |
| return hname |
| } |
| streamType := unexport(service.GoName) + method.GoName + "Server" |
| g.P("func ", hname, "(srv interface{}, stream ", ident("grpc.ServerStream"), ") error {") |
| if !method.Desc.IsStreamingClient() { |
| g.P("m := new(", method.InputType.GoIdent, ")") |
| g.P("if err := stream.RecvMsg(m); err != nil { return err }") |
| g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})") |
| } else { |
| g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})") |
| } |
| g.P("}") |
| g.P() |
| |
| genSend := method.Desc.IsStreamingServer() |
| genSendAndClose := !method.Desc.IsStreamingServer() |
| genRecv := method.Desc.IsStreamingClient() |
| |
| // Stream auxiliary types and methods. |
| g.P("type ", service.GoName, "_", method.GoName, "Server interface {") |
| if genSend { |
| g.P("Send(*", method.OutputType.GoIdent, ") error") |
| } |
| if genSendAndClose { |
| g.P("SendAndClose(*", method.OutputType.GoIdent, ") error") |
| } |
| if genRecv { |
| g.P("Recv() (*", method.InputType.GoIdent, ", error)") |
| } |
| g.P(ident("grpc.ServerStream")) |
| g.P("}") |
| g.P() |
| |
| g.P("type ", streamType, " struct {") |
| g.P(ident("grpc.ServerStream")) |
| g.P("}") |
| g.P() |
| |
| if genSend { |
| g.P("func (x *", streamType, ") Send(m *", method.OutputType.GoIdent, ") error {") |
| g.P("return x.ServerStream.SendMsg(m)") |
| g.P("}") |
| g.P() |
| } |
| if genSendAndClose { |
| g.P("func (x *", streamType, ") SendAndClose(m *", method.OutputType.GoIdent, ") error {") |
| g.P("return x.ServerStream.SendMsg(m)") |
| g.P("}") |
| g.P() |
| } |
| if genRecv { |
| g.P("func (x *", streamType, ") Recv() (*", method.InputType.GoIdent, ", error) {") |
| g.P("m := new(", method.InputType.GoIdent, ")") |
| g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }") |
| g.P("return m, nil") |
| g.P("}") |
| g.P() |
| } |
| |
| return hname |
| } |
| |
| var packages = map[string]protogen.GoImportPath{ |
| "context": "golang.org/x/net/context", |
| "grpc": "google.golang.org/grpc", |
| } |
| |
| func ident(name string) protogen.GoIdent { |
| idx := strings.LastIndex(name, ".") |
| return protogen.GoIdent{ |
| GoImportPath: packages[name[:idx]], |
| GoName: name[idx+1:], |
| } |
| } |
| |
| func genComment(g *protogen.GeneratedFile, file *fileInfo, path []int32) (hasComment bool) { |
| for _, loc := range file.locationMap[pathKey(path)] { |
| if loc.LeadingComments == nil { |
| continue |
| } |
| for _, line := range strings.Split(strings.TrimSuffix(loc.GetLeadingComments(), "\n"), "\n") { |
| hasComment = true |
| g.P("//", line) |
| } |
| break |
| } |
| return hasComment |
| } |
| |
| const deprecationComment = "// Deprecated: Do not use." |
| |
| // pathKey converts a location path to a string suitable for use as a map key. |
| func pathKey(path []int32) string { |
| var buf []byte |
| for i, x := range path { |
| if i != 0 { |
| buf = append(buf, ',') |
| } |
| buf = strconv.AppendInt(buf, int64(x), 10) |
| } |
| return string(buf) |
| } |
| |
| func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } |