protoveneer: the protoveneer tool
This tool was written to facilitate writing the Gemini API clients in:
- github.com/googleapis/google-cloud-go/vertexai
- github.com/google/generative-ai-go
Change-Id: Ieb1c3d8abec5f06f2fd10604daa22dcaffa13ca0
Reviewed-on: https://go-review.googlesource.com/c/exp/+/549635
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Eli Bendersky <eliben@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/protoveneer/README.md b/protoveneer/README.md
new file mode 100644
index 0000000..95283c7
--- /dev/null
+++ b/protoveneer/README.md
@@ -0,0 +1,5 @@
+# The protoveneer tool
+
+Protoveneer is an experimental tool that generates idiomatic Go types that
+correspond to protocol buffer messages and enums -- a veneer on top of the proto
+layer.
diff --git a/protoveneer/cmd/protoveneer/config.go b/protoveneer/cmd/protoveneer/config.go
new file mode 100644
index 0000000..0907df0
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/config.go
@@ -0,0 +1,98 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "errors"
+ "fmt"
+ "os"
+
+ "gopkg.in/yaml.v3"
+)
+
+// config holds the configuration for a package.
+type config struct {
+ Package string
+ ProtoImportPath string `yaml:"protoImportPath"`
+ // Import path for the support package needed by the generated code.
+ SupportImportPath string `yaml:"supportImportPath"`
+
+ // The types to process. Only these types and the types they depend
+ // on will be output.
+ // The key is the name of the proto type.
+ Types map[string]*typeConfig
+ // Omit the types in this list, even if they would normally be output.
+ // Elements can be globs.
+ OmitTypes []string `yaml:"omitTypes"`
+ // Converter functions for types not in the proto package.
+ // Each value should be "tofunc, fromfunc"
+ Converters map[string]string
+}
+
+type typeConfig struct {
+ // The name for the veneer type, if different.
+ Name string
+ // The prefix of the proto enum values. It will be removed.
+ ProtoPrefix string `yaml:"protoPrefix"`
+ // The prefix for the veneer enum values, if different from the type name.
+ VeneerPrefix string `yaml:"veneerPrefix"`
+ // Overrides for enum values.
+ ValueNames map[string]string `yaml:"valueNames"`
+ // Overrides for field types. Map key is proto field name.
+ Fields map[string]fieldConfig
+ // Custom conversion functions: "tofunc, fromfunc"
+ ConvertToFrom string `yaml:"convertToFrom"`
+ // Doc string for the type, omitting the initial type name.
+ Doc string
+ // Verb to place after type name in doc. Default: "is".
+ // Ignored if Doc is non-empty.
+ DocVerb string `yaml:"docVerb"`
+}
+
+type fieldConfig struct {
+ Name string // veneer name
+ Type string // veneer type
+ // Omit from output.
+ Omit bool
+}
+
+func (c *config) init() {
+ for protoName, tc := range c.Types {
+ if tc == nil {
+ tc = &typeConfig{Name: protoName}
+ c.Types[protoName] = tc
+ }
+ if tc.Name == "" {
+ tc.Name = protoName
+ }
+ tc.init()
+ }
+}
+
+func (tc *typeConfig) init() {
+ if tc.VeneerPrefix == "" {
+ tc.VeneerPrefix = tc.Name
+ }
+}
+
+func readConfigFile(filename string) (*config, error) {
+ if filename == "" {
+ return nil, errors.New("missing config file")
+ }
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ dec := yaml.NewDecoder(f)
+ dec.KnownFields(true)
+
+ var c config
+ if err := dec.Decode(&c); err != nil {
+ return nil, fmt.Errorf("reading %s: %w", filename, err)
+ }
+ c.init()
+ return &c, nil
+}
diff --git a/protoveneer/cmd/protoveneer/converters.go b/protoveneer/cmd/protoveneer/converters.go
new file mode 100644
index 0000000..a9afd70
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/converters.go
@@ -0,0 +1,143 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "fmt"
+
+// A converter generates code to convert between a proto type and a veneer type.
+type converter interface {
+ // genFrom returns code to convert from proto to veneer.
+ genFrom(string) string
+ // genTo returns code to convert to proto from veneer.
+ genTo(string) string
+ // These return the function argument to Transform{Slice,MapValues}, or "" if we don't need it.
+ genTransformFrom() string
+ genTransformTo() string
+}
+
+// An identityConverter does no conversion.
+type identityConverter struct{}
+
+func (identityConverter) genFrom(arg string) string { return arg }
+func (identityConverter) genTo(arg string) string { return arg }
+func (identityConverter) genTransformFrom() string { return "" }
+func (identityConverter) genTransformTo() string { return "" }
+
+// A derefConverter converts between T in the veneer and *T in the proto.
+type derefConverter struct{}
+
+func (derefConverter) genFrom(arg string) string { return fmt.Sprintf("support.DerefOrZero(%s)", arg) }
+func (derefConverter) genTo(arg string) string { return fmt.Sprintf("support.AddrOrNil(%s)", arg) }
+func (derefConverter) genTransformFrom() string { panic("can't handle deref slices") }
+func (derefConverter) genTransformTo() string { panic("can't handle deref slices") }
+
+type enumConverter struct {
+ protoName, veneerName string
+}
+
+func (c enumConverter) genFrom(arg string) string {
+ return fmt.Sprintf("%s(%s)", c.veneerName, arg)
+}
+
+func (c enumConverter) genTransformFrom() string {
+ return fmt.Sprintf("func(p pb.%s) %s { return %s }", c.protoName, c.veneerName, c.genFrom("p"))
+}
+
+func (c enumConverter) genTo(arg string) string {
+ return fmt.Sprintf("pb.%s(%s)", c.protoName, arg)
+}
+
+func (c enumConverter) genTransformTo() string {
+ return fmt.Sprintf("func(v %s) pb.%s { return %s }", c.veneerName, c.protoName, c.genTo("v"))
+}
+
+type protoConverter struct {
+ veneerName string
+}
+
+func (c protoConverter) genFrom(arg string) string {
+ return fmt.Sprintf("(%s{}).fromProto(%s)", c.veneerName, arg)
+}
+
+func (c protoConverter) genTransformFrom() string {
+ return fmt.Sprintf("(%s{}).fromProto", c.veneerName)
+}
+
+func (c protoConverter) genTo(arg string) string {
+ return fmt.Sprintf("%s.toProto()", arg)
+}
+
+func (c protoConverter) genTransformTo() string {
+ return fmt.Sprintf("(*%s).toProto", c.veneerName)
+}
+
+type customConverter struct {
+ toFunc, fromFunc string
+}
+
+func (c customConverter) genFrom(arg string) string {
+ return fmt.Sprintf("%s(%s)", c.fromFunc, arg)
+}
+
+func (c customConverter) genTransformFrom() string { return c.fromFunc }
+
+func (c customConverter) genTo(arg string) string {
+ return fmt.Sprintf("%s(%s)", c.toFunc, arg)
+}
+
+func (c customConverter) genTransformTo() string { return c.toFunc }
+
+type sliceConverter struct {
+ eltConverter converter
+}
+
+func (c sliceConverter) genFrom(arg string) string {
+ if fn := c.eltConverter.genTransformFrom(); fn != "" {
+ return fmt.Sprintf("support.TransformSlice(%s, %s)", arg, fn)
+ }
+ return c.eltConverter.genFrom(arg)
+}
+
+func (c sliceConverter) genTo(arg string) string {
+ if fn := c.eltConverter.genTransformTo(); fn != "" {
+ return fmt.Sprintf("support.TransformSlice(%s, %s)", arg, fn)
+ }
+ return c.eltConverter.genTo(arg)
+}
+
+func (c sliceConverter) genTransformTo() string {
+ panic("sliceConverter.genToSlice called")
+}
+
+func (c sliceConverter) genTransformFrom() string {
+ panic("sliceConverter.genFromSlice called")
+}
+
+// Only the values are converted.
+type mapConverter struct {
+ valueConverter converter
+}
+
+func (c mapConverter) genFrom(arg string) string {
+ if fn := c.valueConverter.genTransformFrom(); fn != "" {
+ return fmt.Sprintf("support.TransformMapValues(%s, %s)", arg, fn)
+ }
+ return c.valueConverter.genFrom(arg)
+}
+
+func (c mapConverter) genTo(arg string) string {
+ if fn := c.valueConverter.genTransformTo(); fn != "" {
+ return fmt.Sprintf("support.TransformMapValues(%s, %s)", arg, fn)
+ }
+ return c.valueConverter.genTo(arg)
+}
+
+func (c mapConverter) genTransformTo() string {
+ panic("mapConverter.genToSlice called")
+}
+
+func (c mapConverter) genTransformFrom() string {
+ panic("mapConverter.genFromSlice called")
+}
diff --git a/protoveneer/cmd/protoveneer/protoveneer.go b/protoveneer/cmd/protoveneer/protoveneer.go
new file mode 100644
index 0000000..bfe36ce
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/protoveneer.go
@@ -0,0 +1,856 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// protoveneer generates idiomatic Go types that correspond to protocol
+// buffer messages and enums -- a veneer on top of the proto layer.
+//
+// # Relationship to GAPICs
+//
+// GAPICs and this tool complement each other.
+//
+// GAPICs have client types and idiomatic methods on them that correspond to
+// RPCs. They focus on the RPC part of a service. A GAPIC relies on the
+// underlying protocol buffer types for the request and response types, and all
+// the types that these refer to.
+//
+// protoveener generates Go types that correspond to proto messages and enums,
+// including requests and responses if desired. It doesn't touch the RPC parts
+// of the proto definition.
+//
+// # Configuration
+//
+// protoveneer requires significant configuration to produce good results.
+// See the config type in config.go and the config.yaml files in the testdata
+// subdirectories to understand how to write configuration.
+//
+// # Support functions
+//
+// protoveneer generates code that relies on a few support functions. These live
+// in the support subdirectory. You should copy the contents of this directory
+// to a location of your choice, and add "supportImportPath" to your config to
+// refer to that directory's import path.
+//
+// # Unhandled features
+//
+// There is no support for oneofs. Omit the oneof type and write custom code.
+// However, the types of the individual oneof cases can be generated.
+package main
+
+// TODO:
+// - have omitFields on a TypeConfig, like omitTypes
+// - Instead of parseCustomConverter, accept a list. Users can use the inline form
+// to be compact.
+// - Check that a configured field is actually in the type.
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "io"
+ "log"
+ "os"
+ "path"
+ "path/filepath"
+ "sort"
+ "strings"
+ "text/template"
+ "time"
+ "unicode"
+)
+
+var (
+ outputDir = flag.String("outdir", "", "directory to write to, or '-' for stdout")
+ noFormat = flag.Bool("nofmt", false, "do not format output")
+)
+
+func main() {
+ log.SetPrefix("protoveneer: ")
+ log.SetFlags(0)
+ flag.Usage = func() {
+ out := flag.CommandLine.Output()
+ fmt.Fprintf(out, "usage: protoveneer CONFIG.yaml DIR_WITH_pb.go_FILES\n")
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+
+ if err := run(context.Background(), flag.Arg(0), flag.Arg(1), *outputDir); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func run(ctx context.Context, configFile, pbDir, outDir string) error {
+ config, err := readConfigFile(configFile)
+ if err != nil {
+ return err
+ }
+
+ fset := token.NewFileSet()
+ pkg, err := parseDir(fset, pbDir)
+ if err != nil {
+ return err
+ }
+
+ src, err := generate(config, pkg, fset)
+ if err != nil {
+ return err
+ }
+ if !*noFormat {
+ src, err = format.Source(src)
+ if err != nil {
+ return err
+ }
+ }
+
+ if outDir == "-" {
+ fmt.Printf("%s\n", src)
+ } else {
+ outfile := fmt.Sprintf("%s_veneer.gen.go", pkg.Name)
+ if outDir != "" {
+ outfile = filepath.Join(outDir, outfile)
+ }
+ if err := os.WriteFile(outfile, src, 0660); err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("wrote %s\n", outfile)
+ }
+ return nil
+}
+
+func parseDir(fset *token.FileSet, dir string) (*ast.Package, error) {
+ pkgs, err := parser.ParseDir(fset, dir, nil, parser.ParseComments)
+ if err != nil {
+ return nil, err
+ }
+ if len(pkgs) > 2 {
+ return nil, errors.New("too many packages")
+ }
+ var pkg *ast.Package
+ for name, apkg := range pkgs {
+ if !strings.HasSuffix(name, "_test") {
+ pkg = apkg
+ break
+ }
+ }
+ if pkg == nil {
+ return nil, errors.New("no non-test package")
+ }
+ for filename := range pkg.Files {
+ if !strings.HasSuffix(filename, ".pb.go") {
+ return nil, fmt.Errorf("%s is not a .pb.go file", filename)
+ }
+ }
+ return pkg, nil
+}
+
+func generate(conf *config, pkg *ast.Package, fset *token.FileSet) (src []byte, err error) {
+ // Get information about all the types in the proto package.
+ typeInfos, err := collectDecls(pkg)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check that every configured type is present.
+ for protoName := range conf.Types {
+ if typeInfos[protoName] == nil {
+ return nil, fmt.Errorf("configured type %s does not exist in package", protoName)
+ }
+ }
+
+ // Consult the config to determine which types to omit.
+ // If a type isn't matched by a glob in the OmitTypes list, it is output.
+ // If a type is matched, but it also has a config, it is still output.
+ var toWrite []*typeInfo
+ for name, ti := range typeInfos {
+ if !ast.IsExported(name) {
+ continue
+ }
+ omit, err := sliceAnyError(conf.OmitTypes, func(glob string) (bool, error) {
+ return path.Match(glob, name)
+ })
+ if err != nil {
+ return nil, err
+ }
+ if !omit || conf.Types[name] != nil {
+ toWrite = append(toWrite, ti)
+ }
+ }
+
+ // Fill in the configured type names, which we need to do the rest of the work.
+ for _, ti := range toWrite {
+ if tc := conf.Types[ti.protoName]; tc != nil && tc.Name != "" {
+ ti.veneerName = tc.Name
+ }
+ }
+
+ // Sort for determinism.
+ sort.Slice(toWrite, func(i, j int) bool {
+ return toWrite[i].veneerName < toWrite[j].veneerName
+ })
+
+ // Process and configure all the types we care about.
+ // Even if there is no config for a type, there is still work to do.
+ for _, ti := range toWrite {
+ if err := processType(ti, conf.Types[ti.protoName], typeInfos); err != nil {
+ return nil, err
+ }
+ }
+
+ converters, err := buildConverterMap(toWrite, conf)
+ if err != nil {
+ return nil, err
+ }
+
+ // Use the converters map to give every field a converter.
+ for _, ti := range toWrite {
+ for _, f := range ti.fields {
+ f.converter, err = makeConverter(f.af.Type, f.protoType, converters)
+ if err != nil {
+ return nil, fmt.Errorf("%s.%s: %w", ti.protoName, f.protoName, err)
+ }
+ }
+ }
+
+ // Write the generated code.
+ return write(toWrite, conf, fset)
+}
+
+// buildConverterMap builds a map from veneer name to a converter, which writes code that converts between the proto and veneer.
+// This is used for fields when generating conversion functions.
+func buildConverterMap(typeInfos []*typeInfo, conf *config) (map[string]converter, error) {
+ converters := map[string]converter{}
+ // Build a converter for each proto type.
+ for _, ti := range typeInfos {
+ var conv converter
+ // Custom converters on the type take precedence.
+ if tc := conf.Types[ti.protoName]; tc != nil && tc.ConvertToFrom != "" {
+ c, err := parseCustomConverter(ti.veneerName, tc.ConvertToFrom)
+ if err != nil {
+ return nil, err
+ }
+ conv = c
+ } else {
+ switch ti.spec.Type.(type) {
+ case *ast.StructType:
+ conv = protoConverter{veneerName: ti.veneerName}
+ case *ast.Ident:
+ conv = enumConverter{protoName: ti.protoName, veneerName: ti.veneerName}
+ default:
+ conv = identityConverter{}
+ }
+ }
+ converters[ti.veneerName] = conv
+ }
+
+ // Add converters for used external types to the map.
+ for _, et := range externalTypes {
+ if et.used {
+ converters[et.qualifiedName] = customConverter{et.convertTo, et.convertFrom}
+ }
+ }
+
+ // Add custom converters to the map.
+ // These differ from custom converters on the proto types (a few lines above here)
+ // because they are keyed by veneer type, not proto type.
+ // That can matter when the proto type is omitted but there is a corresponding
+ // veneer type.
+ for key, value := range conf.Converters {
+ c, err := parseCustomConverter(key, value)
+ if err != nil {
+ return nil, err
+ }
+ converters[key] = c
+ }
+ return converters, nil
+}
+
+func parseCustomConverter(name, value string) (converter, error) {
+ toFunc, fromFunc, ok := strings.Cut(value, ",")
+ toFunc = strings.TrimSpace(toFunc)
+ fromFunc = strings.TrimSpace(fromFunc)
+ if !ok || toFunc == "" || fromFunc == "" {
+ return nil, fmt.Errorf(`%s: ConvertToFrom = %q, want "toFunc, fromFunc"`, name, value)
+ }
+ return customConverter{toFunc, fromFunc}, nil
+}
+
+// makeConverter constructs a converter for the given type. Not every type is in the map: this
+// function puts together converters for types like pointers, slices and maps, as well as
+// named types.
+func makeConverter(veneerType, protoType ast.Expr, converters map[string]converter) (converter, error) {
+ if c, ok := converters[typeString(veneerType)]; ok {
+ return c, nil
+ }
+ // If there is no converter for this type, look for a converter for a part of the type.
+ switch t := veneerType.(type) {
+ case *ast.Ident:
+ // Handle the case where the veneer type is the dereference of the proto type.
+ if se, ok := protoType.(*ast.StarExpr); ok {
+ if identName(se.X) != t.Name {
+ return nil, fmt.Errorf("veneer type %s does not match dereferenced proto type %s", t.Name, identName(se.X))
+ }
+ return derefConverter{}, nil
+ }
+ return identityConverter{}, nil
+ case *ast.StarExpr:
+ return makeConverter(t.X, protoType.(*ast.StarExpr).X, converters)
+ case *ast.ArrayType:
+ eltc, err := makeConverter(t.Elt, protoType.(*ast.ArrayType).Elt, converters)
+ if err != nil {
+ return nil, err
+ }
+ return sliceConverter{eltc}, nil
+ case *ast.MapType:
+ // Assume the key types are the same.
+ vc, err := makeConverter(t.Value, protoType.(*ast.MapType).Value, converters)
+ if err != nil {
+ return nil, err
+ }
+ return mapConverter{vc}, nil
+ default:
+ return identityConverter{}, nil
+ }
+}
+
+// A typeInfo holds information about a named type.
+type typeInfo struct {
+ // These fields are collected from the proto package.
+ protoName string // name of type in the proto package
+ spec *ast.TypeSpec // the spec for the type, which will be modified
+ decl *ast.GenDecl // the decl holding the spec; not sure we need this
+ values *ast.GenDecl // the list of values for an enum
+
+ // These fields are added later.
+ veneerName string // may be provided by config; else same as protoName
+ fields []*fieldInfo // for structs
+ valueNames []string // to generate String functions
+}
+
+// A fieldInfo holds information about a struct field.
+type fieldInfo struct {
+ protoType ast.Expr
+ af *ast.Field
+ protoName, veneerName string
+ converter converter
+}
+
+// collectDecls collects declaration information from a package.
+// It returns information about every named type in the package in a map
+// keyed by the type's name.
+func collectDecls(pkg *ast.Package) (map[string]*typeInfo, error) {
+ typeInfos := map[string]*typeInfo{} // key is proto name
+
+ getInfo := func(name string) *typeInfo {
+ if info, ok := typeInfos[name]; ok {
+ return info
+ }
+ info := &typeInfo{protoName: name, veneerName: name}
+ typeInfos[name] = info
+ return info
+ }
+
+ for _, file := range pkg.Files {
+ for _, decl := range file.Decls {
+ if gd, ok := decl.(*ast.GenDecl); ok {
+ switch gd.Tok {
+ case token.TYPE:
+ if len(gd.Specs) != 1 {
+ return nil, errors.New("multiple TypeSpecs in a GenDecl not supported")
+ }
+ ts := gd.Specs[0].(*ast.TypeSpec)
+ info := getInfo(ts.Name.Name)
+ info.spec = ts
+ info.decl = gd
+
+ case token.CONST:
+ // Assume consts for an enum type are grouped together, and every one has a type.
+ // That's what the proto compiler generates.
+ vs0 := gd.Specs[0].(*ast.ValueSpec)
+ if len(vs0.Names) != 1 || len(vs0.Values) != 1 {
+ return nil, errors.New("multiple names/values not supported")
+ }
+
+ protoName := identName(vs0.Type)
+ if protoName == "" {
+ continue
+ }
+ for _, s := range gd.Specs {
+ vs := s.(*ast.ValueSpec)
+ if identName(vs.Type) != protoName {
+ return nil, fmt.Errorf("%s: not all same type", protoName)
+ }
+ }
+ info := getInfo(protoName)
+ info.values = gd
+ }
+ }
+ }
+ }
+ return typeInfos, nil
+}
+
+// processType processes a single type, modifying the AST.
+// If it's an enum, just change its name.
+// If it's a struct, modify its name and fields.
+func processType(ti *typeInfo, tconf *typeConfig, typeInfos map[string]*typeInfo) error {
+ ti.spec.Name.Name = ti.veneerName
+ switch t := ti.spec.Type.(type) {
+ case *ast.StructType:
+ // Check that all configured fields are present.
+ exportedFields := map[string]bool{}
+ for _, f := range t.Fields.List {
+ if len(f.Names) > 1 {
+ return fmt.Errorf("%s: multiple names in one field spec not supported: %v", ti.protoName, f.Names)
+ }
+ if f.Names[0].IsExported() {
+ exportedFields[f.Names[0].Name] = true
+ }
+ }
+ if tconf != nil {
+ for name := range tconf.Fields {
+ if !exportedFields[name] {
+ return fmt.Errorf("%s: configured field %s is not present", ti.protoName, name)
+ }
+ }
+ }
+ // Process the fields.
+ fs := t.Fields.List
+ t.Fields.List = t.Fields.List[:0]
+ for _, f := range fs {
+ fi, err := processField(f, tconf, typeInfos)
+ if err != nil {
+ return err
+ }
+ if fi != nil {
+ t.Fields.List = append(t.Fields.List, f)
+ ti.fields = append(ti.fields, fi)
+ }
+ }
+ case *ast.Ident:
+ // Enum type. Nothing else to do with the type itself; but see processEnumValues.
+ default:
+ return fmt.Errorf("unknown type: %+v: protoName=%s", ti.spec, ti.protoName)
+ }
+ processDoc(ti.decl, ti.protoName, tconf)
+ if ti.values != nil {
+ ti.valueNames = processEnumValues(ti.values, tconf)
+ }
+ return nil
+}
+
+// processField processes a struct field.
+func processField(af *ast.Field, tc *typeConfig, typeInfos map[string]*typeInfo) (*fieldInfo, error) {
+ id := af.Names[0]
+ if !id.IsExported() {
+ return nil, nil
+ }
+ fi := &fieldInfo{
+ protoType: af.Type,
+ af: af,
+ protoName: id.Name,
+ veneerName: id.Name,
+ }
+ if tc != nil {
+ if fc, ok := tc.Fields[id.Name]; ok {
+ if fc.Omit {
+ return nil, nil
+ }
+ if fc.Name != "" {
+ id.Name = fc.Name
+ fi.veneerName = fc.Name
+ }
+ if fc.Type != "" {
+ expr, err := parser.ParseExpr(fc.Type)
+ if err != nil {
+ return nil, err
+ }
+ af.Type = expr
+ }
+ }
+ }
+ af.Type = veneerType(af.Type, typeInfos)
+ af.Tag = nil
+ return fi, nil
+}
+
+// veneerType returns a type expression for the veneer type corresponding to the given proto type.
+func veneerType(protoType ast.Expr, typeInfos map[string]*typeInfo) ast.Expr {
+ var wtype func(ast.Expr) ast.Expr
+ wtype = func(protoType ast.Expr) ast.Expr {
+ if et := protoTypeToExternalType[typeString(protoType)]; et != nil {
+ et.used = true
+ return et.typeExpr
+ }
+ switch t := protoType.(type) {
+ case *ast.Ident:
+ if ti := typeInfos[t.Name]; ti != nil {
+ wt := *t
+ wt.Name = ti.veneerName
+ return &wt
+ }
+ case *ast.ParenExpr:
+ wt := *t
+ wt.X = wtype(wt.X)
+ return &wt
+
+ case *ast.StarExpr:
+ wt := *t
+ wt.X = wtype(wt.X)
+ return &wt
+
+ case *ast.ArrayType:
+ wt := *t
+ wt.Elt = wtype(wt.Elt)
+ return &wt
+ }
+ return protoType
+ }
+
+ return wtype(protoType)
+}
+
+// processEnumValues processes enum values.
+// The proto compiler puts all the values for an enum in one GenDecl,
+// and there are no other values in that GenDecl.
+func processEnumValues(d *ast.GenDecl, tc *typeConfig) []string {
+ var valueNames []string
+ for _, s := range d.Specs {
+ vs := s.(*ast.ValueSpec)
+ id := vs.Names[0]
+ protoName := id.Name
+ veneerName := veneerValueName(id.Name, tc)
+ valueNames = append(valueNames, veneerName)
+ id.Name = veneerName
+
+ if tc != nil {
+ vs.Type.(*ast.Ident).Name = tc.Name
+ }
+ modifyCommentGroup(vs.Doc, protoName, veneerName, "means", "")
+ }
+ return valueNames
+}
+
+// veneerValueName returns an idiomatic Go name for a proto enum value.
+func veneerValueName(protoValueName string, tc *typeConfig) string {
+ if tc == nil {
+ return protoValueName
+ }
+ if nn, ok := tc.ValueNames[protoValueName]; ok {
+ return nn
+ }
+ name := strings.TrimPrefix(protoValueName, tc.ProtoPrefix)
+ // Some values have the type name in upper snake case after the prefix.
+ // Example:
+ // proto type: FinishReason
+ // prefix: Candidate_
+ // value: Candidate_FINISH_REASON_UNSPECIFIED
+ prefix := camelToUpperSnakeCase(tc.Name) + "_"
+ name = strings.TrimPrefix(name, prefix)
+ return tc.VeneerPrefix + snakeToCamelCase(name)
+}
+
+func processDoc(gd *ast.GenDecl, protoName string, tc *typeConfig) {
+ doc := ""
+ verb := ""
+ if tc != nil {
+ doc = tc.Doc
+ verb = tc.DocVerb
+ }
+
+ spec := gd.Specs[0]
+ var name string
+ switch spec := spec.(type) {
+ case *ast.TypeSpec:
+ name = spec.Name.Name
+ case *ast.ValueSpec:
+ name = spec.Names[0].Name
+ default:
+ panic("bad spec")
+ }
+ if tc != nil && name != tc.Name {
+ panic(fmt.Errorf("GenDecl name is %q, config name is %q", name, tc.Name))
+ }
+ modifyCommentGroup(gd.Doc, protoName, name, verb, doc)
+}
+
+func modifyCommentGroup(cg *ast.CommentGroup, protoName, veneerName, verb, doc string) {
+ if cg == nil {
+ return
+ }
+ if len(cg.List) == 0 {
+ return
+ }
+ c := cg.List[0]
+ c.Text = "// " + adjustDoc(strings.TrimPrefix(c.Text, "// "), protoName, veneerName, verb, doc)
+}
+
+// adjustDoc takes a doc string with initial comment characters and whitespace removed, and returns
+// a replacement that uses the given veneer name, verb and new doc string.
+func adjustDoc(origDoc, protoName, veneerName, verb, newDoc string) string {
+ // if newDoc is non-empty, completely replace the existing doc.
+ if newDoc != "" {
+ return veneerName + " " + newDoc
+ }
+ // If the doc string starts with the proto name, just replace it with the
+ // veneer name. We can't do anything about the verb because we don't know
+ // where it is in the original doc string. (I guess we could assume it's the
+ // next word, but that might not always work.)
+ if strings.HasPrefix(origDoc, protoName+" ") {
+ return veneerName + origDoc[len(protoName):]
+ }
+
+ // Lowercase the first letter of the given doc if it's not part of an acronym.
+ runes := []rune(origDoc)
+ // It shouldn't be possible for the original doc string to be empty,
+ // but check just in case to avoid panics.
+ if len(runes) == 0 {
+ return origDoc
+ }
+ // Heuristic: an acronym begins with two consecutive uppercase letters.
+ if unicode.IsUpper(runes[0]) && (len(runes) == 1 || !unicode.IsUpper(runes[1])) {
+ runes[0] = unicode.ToLower(runes[0])
+ origDoc = string(runes)
+ }
+
+ if verb == "" {
+ verb = "is"
+ }
+ return fmt.Sprintf("%s %s %s", veneerName, verb, origDoc)
+}
+
+////////////////////////////////////////////////////////////////
+
+const licenseFormat = `// Copyright %d Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+`
+
+func write(typeInfos []*typeInfo, conf *config, fset *token.FileSet) ([]byte, error) {
+ var buf bytes.Buffer
+ pr := func(format string, args ...any) { fmt.Fprintf(&buf, format, args...) }
+ prn := func(format string, args ...any) {
+ pr(format, args...)
+ pr("\n")
+ }
+ // Top of file.
+ pr(licenseFormat, time.Now().Year())
+ prn("")
+ pr("// This file was generated by protoveneer. DO NOT EDIT.\n\n")
+ pr("package %s\n\n", conf.Package)
+ prn("import (")
+ prn(` "fmt"`)
+ pr("\n")
+ prn(` pb "%s"`, conf.ProtoImportPath)
+ if conf.SupportImportPath == "" {
+ return nil, errors.New("missing supportImportPath in config")
+ }
+ prn(` "%s"`, conf.SupportImportPath)
+ for _, et := range externalTypes {
+ if et.used && et.importPath != "" {
+ prn(` "%s"`, et.importPath)
+ }
+ }
+ pr(")\n\n")
+
+ // Types.
+ for _, ti := range typeInfos {
+ for _, decl := range []*ast.GenDecl{ti.decl, ti.values} {
+ if decl != nil {
+ data, err := formatDecl(fset, decl)
+ if err != nil {
+ return nil, err
+ }
+ buf.Write(data)
+ prn("")
+ }
+ }
+ if ti.valueNames != nil {
+ if err := generateEnumStringMethod(&buf, ti.veneerName, ti.valueNames); err != nil {
+ return nil, err
+ }
+ }
+ if _, ok := ti.spec.Type.(*ast.StructType); ok {
+ ti.generateConversionMethods(pr)
+ }
+ }
+
+ return buf.Bytes(), nil
+}
+
+func formatDecl(fset *token.FileSet, gd *ast.GenDecl) ([]byte, error) {
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, gd); err != nil {
+ return nil, err
+ }
+ // Remove blank lines that result from deleting unexported struct fields.
+ return bytes.ReplaceAll(buf.Bytes(), []byte("\n\n"), []byte("\n")), nil
+}
+
+////////////////////////////////////////////////////////////////
+
+var stringMethodTemplate = template.Must(template.New("").Parse(`
+ var namesFor{{.Type}} = map[{{.Type}}]string {
+ {{- range .Values}}
+ {{.}}: "{{.}}",
+ {{- end}}
+ }
+
+ func (v {{.Type}}) String() string {
+ if n, ok := namesFor{{.Type}}[v]; ok {
+ return n
+ }
+ return fmt.Sprintf("{{.Type}}(%d)", v)
+ }
+`))
+
+func generateEnumStringMethod(w io.Writer, typeName string, valueNames []string) error {
+ return stringMethodTemplate.Execute(w, struct {
+ Type string
+ Values []string
+ }{typeName, valueNames})
+}
+
+func (ti *typeInfo) generateConversionMethods(pr func(string, ...any)) {
+ ti.generateToProto(pr)
+ pr("\n")
+ ti.generateFromProto(pr)
+}
+
+func (ti *typeInfo) generateToProto(pr func(string, ...any)) {
+ pr("func (v *%s) toProto() *pb.%s {\n", ti.veneerName, ti.protoName)
+ pr(" if v == nil { return nil }\n")
+ pr(" return &pb.%s{\n", ti.protoName)
+ for _, f := range ti.fields {
+ pr(" %s: %s,\n", f.protoName, f.converter.genTo("v."+f.veneerName))
+ }
+ pr(" }\n")
+ pr("}\n")
+}
+
+func (ti *typeInfo) generateFromProto(pr func(string, ...any)) {
+ pr("func (%s) fromProto(p *pb.%s) *%[1]s {\n", ti.veneerName, ti.protoName)
+ pr(" if p == nil { return nil }\n")
+ pr(" return &%s{\n", ti.veneerName)
+ for _, f := range ti.fields {
+ pr(" %s: %s,\n", f.veneerName, f.converter.genFrom("p."+f.protoName))
+ }
+ pr(" }\n")
+ pr("}\n")
+}
+
+////////////////////////////////////////////////////////////////
+
+// externalType holds information about a type that is not part of the proto package.
+type externalType struct {
+ qualifiedName string
+ replaces string
+ importPath string
+ convertTo string
+ convertFrom string
+
+ typeExpr ast.Expr
+ used bool
+}
+
+var externalTypes = []*externalType{
+ {
+ qualifiedName: "civil.Date",
+ replaces: "*date.Date",
+ importPath: "cloud.google.com/go/civil",
+ convertTo: "support.CivilDateToProto",
+ convertFrom: "support.CivilDateFromProto",
+ },
+ {
+ qualifiedName: "map[string]any",
+ replaces: "*structpb.Struct",
+ convertTo: "support.MapToStructPB",
+ convertFrom: "support.MapFromStructPB",
+ },
+}
+
+var protoTypeToExternalType = map[string]*externalType{}
+
+func init() {
+ var err error
+ for _, et := range externalTypes {
+ et.typeExpr, err = parser.ParseExpr(et.qualifiedName)
+ if err != nil {
+ panic(err)
+ }
+ protoTypeToExternalType[et.replaces] = et
+ }
+}
+
+////////////////////////////////////////////////////////////////
+
+var emptyFileSet = token.NewFileSet()
+
+// typeString produces a string for a type expression.
+func typeString(t ast.Expr) string {
+ var buf bytes.Buffer
+ err := format.Node(&buf, emptyFileSet, t)
+ if err != nil {
+ panic(err)
+ }
+ return buf.String()
+}
+
+func identName(x any) string {
+ id, ok := x.(*ast.Ident)
+ if !ok {
+ return ""
+ }
+ return id.Name
+}
+
+func snakeToCamelCase(s string) string {
+ words := strings.Split(s, "_")
+ for i, w := range words {
+ if len(w) == 0 {
+ words[i] = w
+ } else {
+ words[i] = fmt.Sprintf("%c%s", unicode.ToUpper(rune(w[0])), strings.ToLower(w[1:]))
+ }
+ }
+ return strings.Join(words, "")
+}
+
+func camelToUpperSnakeCase(s string) string {
+ var res []rune
+ for i, r := range s {
+ if unicode.IsUpper(r) && i > 0 {
+ res = append(res, '_')
+ }
+ res = append(res, unicode.ToUpper(r))
+ }
+ return string(res)
+}
+
+func sliceAnyError[T any](s []T, f func(T) (bool, error)) (bool, error) {
+ for _, e := range s {
+ b, err := f(e)
+ if err != nil {
+ return false, err
+ }
+ if b {
+ return true, nil
+ }
+ }
+ return false, nil
+}
diff --git a/protoveneer/cmd/protoveneer/protoveneer_test.go b/protoveneer/cmd/protoveneer/protoveneer_test.go
new file mode 100644
index 0000000..75c7d9b
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/protoveneer_test.go
@@ -0,0 +1,136 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "context"
+ "flag"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+var (
+ update = flag.Bool("update", false, "update test goldens")
+ keep = flag.Bool("keep", false, "do not remove generated files")
+)
+
+func TestGeneration(t *testing.T) {
+ ctx := context.Background()
+ entries, err := os.ReadDir("testdata")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, e := range entries {
+ if e.IsDir() {
+ t.Run(e.Name(), func(t *testing.T) {
+ dir := filepath.Join("testdata", e.Name())
+ configFile := filepath.Join(dir, "config.yaml")
+ goldenFile := filepath.Join(dir, "golden")
+ outFile := filepath.Join(dir, e.Name()+"_veneer.gen.go")
+ if *keep {
+ t.Logf("keeping %s", outFile)
+ } else {
+ defer os.Remove(outFile)
+ }
+ if err := run(ctx, configFile, dir, dir); err != nil {
+ t.Fatal(err)
+ }
+ if *update {
+ if err := os.Remove(goldenFile); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.Rename(outFile, goldenFile); err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("updated golden")
+ } else {
+ if diff := diffFiles(goldenFile, outFile); diff != "" {
+ t.Errorf("diff (-want, +got):\n%s", diff)
+ }
+ }
+ })
+ }
+ }
+}
+
+func diffFiles(wantFile, gotFile string) string {
+ want, err := os.ReadFile(wantFile)
+ if err != nil {
+ return err.Error()
+ }
+ got, err := os.ReadFile(gotFile)
+ if err != nil {
+ return err.Error()
+ }
+ return cmp.Diff(string(want), string(got))
+}
+
+func TestCamelToUpperSnakeCase(t *testing.T) {
+ for _, test := range []struct {
+ in, want string
+ }{
+ {"foo", "FOO"},
+ {"fooBar", "FOO_BAR"},
+ {"aBC", "A_B_C"},
+ {"ABC", "A_B_C"},
+ } {
+ got := camelToUpperSnakeCase(test.in)
+ if got != test.want {
+ t.Errorf("%q: got %q, want %q", test.in, got, test.want)
+ }
+ }
+}
+
+func TestAdjustDoc(t *testing.T) {
+ const protoName = "PName"
+ const veneerName = "VName"
+ for i, test := range []struct {
+ origDoc string
+ verb string
+ newDoc string
+ want string
+ }{
+ {
+ origDoc: "",
+ verb: "foo",
+ newDoc: "",
+ want: "",
+ },
+ {
+ origDoc: "",
+ verb: "",
+ newDoc: "is new doc.",
+ want: "VName is new doc.",
+ },
+ {
+ origDoc: "The harm category is dangerous content.",
+ verb: "means",
+ want: "VName means the harm category is dangerous content.",
+ },
+ {
+ origDoc: "URI for the file.",
+ verb: "is the",
+ want: "VName is the URI for the file.",
+ },
+ {
+ origDoc: "PName is a thing.",
+ newDoc: "contains something else.",
+ want: "VName contains something else.",
+ },
+ {
+ origDoc: "PName is a thing.",
+ verb: "ignored",
+ want: "VName is a thing.",
+ },
+ } {
+ got := adjustDoc(test.origDoc, protoName, veneerName, test.verb, test.newDoc)
+ if got != test.want {
+ t.Errorf("#%d: got %q, want %q", i, got, test.want)
+ }
+ }
+}
diff --git a/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go b/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go
new file mode 100755
index 0000000..25f0c75
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/testdata/basic/basic.pb.go
@@ -0,0 +1,122 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package basic
+
+import (
+ _ "google.golang.org/genproto/googleapis/api/annotations"
+ date "google.golang.org/genproto/googleapis/type/date"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ "google.golang.org/protobuf/types/known/structpb"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+// Harm categories that will block the content.
+type HarmCategory int32
+
+const (
+ // The harm category is unspecified.
+ HarmCategory_HARM_CATEGORY_UNSPECIFIED HarmCategory = 0
+ // The harm category is hate speech.
+ HarmCategory_HARM_CATEGORY_HATE_SPEECH HarmCategory = 1
+ // The harm category is dangerous content.
+ HarmCategory_HARM_CATEGORY_DANGEROUS_CONTENT HarmCategory = 2
+ // The harm category is harassment.
+ HarmCategory_HARM_CATEGORY_HARASSMENT HarmCategory = 3
+ // The harm category is sexually explicit content.
+ HarmCategory_HARM_CATEGORY_SEXUALLY_EXPLICIT HarmCategory = 4
+)
+
+// The reason why the model stopped generating tokens.
+// If empty, the model has not stopped generating the tokens.
+type Candidate_FinishReason int32
+
+const (
+ // The finish reason is unspecified.
+ Candidate_FINISH_REASON_UNSPECIFIED Candidate_FinishReason = 0
+ // Natural stop point of the model or provided stop sequence.
+ Candidate_STOP Candidate_FinishReason = 1
+ // The maximum number of tokens as specified in the request was reached.
+ Candidate_MAX_TOKENS Candidate_FinishReason = 2
+ // The token generation was stopped as the response was flagged for safety
+ // reasons. NOTE: When streaming the Candidate.content will be empty if
+ // content filters blocked the output.
+ Candidate_SAFETY Candidate_FinishReason = 3
+ // The token generation was stopped as the response was flagged for
+ // unauthorized citations.
+ Candidate_RECITATION Candidate_FinishReason = 4
+ // All other reasons that stopped the token generation
+ Candidate_OTHER Candidate_FinishReason = 5
+)
+
+// Raw media bytes.
+//
+// Text should not be sent as raw bytes, use the 'text' field.
+type Blob struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // Required. The IANA standard MIME type of the source data.
+ MimeType string `protobuf:"bytes,1,opt,name=mime_type,json=mimeType,proto3" json:"mime_type,omitempty"`
+ // Required. Raw bytes for media formats.
+ Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
+}
+
+// Generation config.
+type GenerationConfig struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // Optional. Controls the randomness of predictions.
+ Temperature *float32 `protobuf:"fixed32,1,opt,name=temperature,proto3,oneof" json:"temperature,omitempty"`
+ // Optional. Number of candidates to generate.
+ CandidateCount *int32 `protobuf:"varint,4,opt,name=candidate_count,json=candidateCount,proto3,oneof" json:"candidate_count,omitempty"`
+ // Optional. Stop sequences.
+ StopSequences []string `protobuf:"bytes,6,rep,name=stop_sequences,json=stopSequences,proto3" json:"stop_sequences,omitempty"`
+ HarmCat HarmCategory
+ FinishReason Candidate_FinishReason
+ CitMet *CitationMetadata
+}
+
+// A collection of source attributions for a piece of content.
+type CitationMetadata struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // Output only. List of citations.
+ Citations []*Citation `protobuf:"bytes,1,rep,name=citations,proto3" json:"citations,omitempty"`
+ CitMap map[string]*Citation
+}
+
+// Source attributions for content.
+type Citation struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ // Output only. Url reference of the attribution.
+ Uri string `protobuf:"bytes,3,opt,name=uri,proto3" json:"uri,omitempty"`
+ // Output only. Publication date of the attribution.
+ PublicationDate *date.Date `protobuf:"bytes,6,opt,name=publication_date,json=publicationDate,proto3" json:"publication_date,omitempty"`
+ Struct *structpb.Struct
+}
diff --git a/protoveneer/cmd/protoveneer/testdata/basic/config.yaml b/protoveneer/cmd/protoveneer/testdata/basic/config.yaml
new file mode 100644
index 0000000..9a54cbe
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/testdata/basic/config.yaml
@@ -0,0 +1,45 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package: basic
+protoImportPath: example.com/basic
+supportImportPath: example.com/protoveneer/support
+
+types:
+ HarmCategory:
+ protoPrefix: HarmCategory_HARM_CATEGORY_
+ docVerb: specifies
+
+ Candidate_FinishReason:
+ name: FinishReason
+ protoPrefix: Candidate_
+
+ Blob:
+ fields:
+ MimeType:
+ name: MIMEType
+ docVerb: contains
+
+ GenerationConfig:
+ fields:
+ Temperature:
+ type: float32
+ CandidateCount:
+ type: int32
+
+ Citation:
+ docVerb: contains
+ fields:
+ Uri:
+ name: URI
diff --git a/protoveneer/cmd/protoveneer/testdata/basic/golden b/protoveneer/cmd/protoveneer/testdata/basic/golden
new file mode 100644
index 0000000..ec4b158
--- /dev/null
+++ b/protoveneer/cmd/protoveneer/testdata/basic/golden
@@ -0,0 +1,223 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file was generated by protoveneer. DO NOT EDIT.
+
+package basic
+
+import (
+ "fmt"
+
+ "cloud.google.com/go/civil"
+ pb "example.com/basic"
+ "example.com/protoveneer/support"
+)
+
+// Blob contains raw media bytes.
+//
+// Text should not be sent as raw bytes, use the 'text' field.
+type Blob struct {
+ // Required. The IANA standard MIME type of the source data.
+ MIMEType string
+ // Required. Raw bytes for media formats.
+ Data []byte
+}
+
+func (v *Blob) toProto() *pb.Blob {
+ if v == nil {
+ return nil
+ }
+ return &pb.Blob{
+ MimeType: v.MIMEType,
+ Data: v.Data,
+ }
+}
+
+func (Blob) fromProto(p *pb.Blob) *Blob {
+ if p == nil {
+ return nil
+ }
+ return &Blob{
+ MIMEType: p.MimeType,
+ Data: p.Data,
+ }
+}
+
+// Citation contains source attributions for content.
+type Citation struct {
+ // Output only. Url reference of the attribution.
+ URI string
+ // Output only. Publication date of the attribution.
+ PublicationDate civil.Date
+ Struct map[string]any
+}
+
+func (v *Citation) toProto() *pb.Citation {
+ if v == nil {
+ return nil
+ }
+ return &pb.Citation{
+ Uri: v.URI,
+ PublicationDate: support.CivilDateToProto(v.PublicationDate),
+ Struct: support.MapToStructPB(v.Struct),
+ }
+}
+
+func (Citation) fromProto(p *pb.Citation) *Citation {
+ if p == nil {
+ return nil
+ }
+ return &Citation{
+ URI: p.Uri,
+ PublicationDate: support.CivilDateFromProto(p.PublicationDate),
+ Struct: support.MapFromStructPB(p.Struct),
+ }
+}
+
+// CitationMetadata is a collection of source attributions for a piece of content.
+type CitationMetadata struct {
+ // Output only. List of citations.
+ Citations []*Citation
+ CitMap map[string]*Citation
+}
+
+func (v *CitationMetadata) toProto() *pb.CitationMetadata {
+ if v == nil {
+ return nil
+ }
+ return &pb.CitationMetadata{
+ Citations: support.TransformSlice(v.Citations, (*Citation).toProto),
+ CitMap: support.TransformMapValues(v.CitMap, (*Citation).toProto),
+ }
+}
+
+func (CitationMetadata) fromProto(p *pb.CitationMetadata) *CitationMetadata {
+ if p == nil {
+ return nil
+ }
+ return &CitationMetadata{
+ Citations: support.TransformSlice(p.Citations, (Citation{}).fromProto),
+ CitMap: support.TransformMapValues(p.CitMap, (Citation{}).fromProto),
+ }
+}
+
+// FinishReason is the reason why the model stopped generating tokens.
+// If empty, the model has not stopped generating the tokens.
+type FinishReason int32
+
+const (
+ // FinishReasonUnspecified means the finish reason is unspecified.
+ FinishReasonUnspecified FinishReason = 0
+ // FinishReasonStop means natural stop point of the model or provided stop sequence.
+ FinishReasonStop FinishReason = 1
+ // FinishReasonMaxTokens means the maximum number of tokens as specified in the request was reached.
+ FinishReasonMaxTokens FinishReason = 2
+ // FinishReasonSafety means the token generation was stopped as the response was flagged for safety
+ // reasons. NOTE: When streaming the Candidate.content will be empty if
+ // content filters blocked the output.
+ FinishReasonSafety FinishReason = 3
+ // FinishReasonRecitation means the token generation was stopped as the response was flagged for
+ // unauthorized citations.
+ FinishReasonRecitation FinishReason = 4
+ // FinishReasonOther means all other reasons that stopped the token generation
+ FinishReasonOther FinishReason = 5
+)
+
+var namesForFinishReason = map[FinishReason]string{
+ FinishReasonUnspecified: "FinishReasonUnspecified",
+ FinishReasonStop: "FinishReasonStop",
+ FinishReasonMaxTokens: "FinishReasonMaxTokens",
+ FinishReasonSafety: "FinishReasonSafety",
+ FinishReasonRecitation: "FinishReasonRecitation",
+ FinishReasonOther: "FinishReasonOther",
+}
+
+func (v FinishReason) String() string {
+ if n, ok := namesForFinishReason[v]; ok {
+ return n
+ }
+ return fmt.Sprintf("FinishReason(%d)", v)
+}
+
+// GenerationConfig is generation config.
+type GenerationConfig struct {
+ // Optional. Controls the randomness of predictions.
+ Temperature float32
+ // Optional. Number of candidates to generate.
+ CandidateCount int32
+ // Optional. Stop sequences.
+ StopSequences []string
+ HarmCat HarmCategory
+ FinishReason FinishReason
+ CitMet *CitationMetadata
+}
+
+func (v *GenerationConfig) toProto() *pb.GenerationConfig {
+ if v == nil {
+ return nil
+ }
+ return &pb.GenerationConfig{
+ Temperature: support.AddrOrNil(v.Temperature),
+ CandidateCount: support.AddrOrNil(v.CandidateCount),
+ StopSequences: v.StopSequences,
+ HarmCat: pb.HarmCategory(v.HarmCat),
+ FinishReason: pb.Candidate_FinishReason(v.FinishReason),
+ CitMet: v.CitMet.toProto(),
+ }
+}
+
+func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig {
+ if p == nil {
+ return nil
+ }
+ return &GenerationConfig{
+ Temperature: support.DerefOrZero(p.Temperature),
+ CandidateCount: support.DerefOrZero(p.CandidateCount),
+ StopSequences: p.StopSequences,
+ HarmCat: HarmCategory(p.HarmCat),
+ FinishReason: FinishReason(p.FinishReason),
+ CitMet: (CitationMetadata{}).fromProto(p.CitMet),
+ }
+}
+
+// HarmCategory specifies harm categories that will block the content.
+type HarmCategory int32
+
+const (
+ // HarmCategoryUnspecified means the harm category is unspecified.
+ HarmCategoryUnspecified HarmCategory = 0
+ // HarmCategoryHateSpeech means the harm category is hate speech.
+ HarmCategoryHateSpeech HarmCategory = 1
+ // HarmCategoryDangerousContent means the harm category is dangerous content.
+ HarmCategoryDangerousContent HarmCategory = 2
+ // HarmCategoryHarassment means the harm category is harassment.
+ HarmCategoryHarassment HarmCategory = 3
+ // HarmCategorySexuallyExplicit means the harm category is sexually explicit content.
+ HarmCategorySexuallyExplicit HarmCategory = 4
+)
+
+var namesForHarmCategory = map[HarmCategory]string{
+ HarmCategoryUnspecified: "HarmCategoryUnspecified",
+ HarmCategoryHateSpeech: "HarmCategoryHateSpeech",
+ HarmCategoryDangerousContent: "HarmCategoryDangerousContent",
+ HarmCategoryHarassment: "HarmCategoryHarassment",
+ HarmCategorySexuallyExplicit: "HarmCategorySexuallyExplicit",
+}
+
+func (v HarmCategory) String() string {
+ if n, ok := namesForHarmCategory[v]; ok {
+ return n
+ }
+ return fmt.Sprintf("HarmCategory(%d)", v)
+}
diff --git a/protoveneer/go.mod b/protoveneer/go.mod
new file mode 100644
index 0000000..03579fb
--- /dev/null
+++ b/protoveneer/go.mod
@@ -0,0 +1,11 @@
+module golang.org/x/exp/protoveneer
+
+go 1.21
+
+require (
+ cloud.google.com/go v0.111.0
+ github.com/google/go-cmp v0.6.0
+ google.golang.org/genproto v0.0.0-20231212172506-995d672761c0
+ google.golang.org/protobuf v1.31.0
+ gopkg.in/yaml.v3 v3.0.1
+)
diff --git a/protoveneer/go.sum b/protoveneer/go.sum
new file mode 100644
index 0000000..b6f598c
--- /dev/null
+++ b/protoveneer/go.sum
@@ -0,0 +1,16 @@
+cloud.google.com/go v0.111.0 h1:YHLKNupSD1KqjDbQ3+LVdQ81h/UJbJyZG203cEfnQgM=
+cloud.google.com/go v0.111.0/go.mod h1:0mibmpKP1TyOOFYQY5izo0LnT+ecvOQ0Sg3OdmMiNRU=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/genproto v0.0.0-20231212172506-995d672761c0 h1:YJ5pD9rF8o9Qtta0Cmy9rdBwkSjrTCT6XTiUQVOtIos=
+google.golang.org/genproto v0.0.0-20231212172506-995d672761c0/go.mod h1:l/k7rMz0vFTBPy+tFSGvXEd3z+BcoG1k7EHbqm+YBsY=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
+google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/protoveneer/support/support.go b/protoveneer/support/support.go
new file mode 100644
index 0000000..5e74f3b
--- /dev/null
+++ b/protoveneer/support/support.go
@@ -0,0 +1,112 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package support provides support functions for protoveneer.
+package support
+
+import (
+ "fmt"
+ "time"
+
+ "cloud.google.com/go/civil"
+ "google.golang.org/genproto/googleapis/type/date"
+ "google.golang.org/protobuf/types/known/structpb"
+)
+
+// TransformSlice applies f to each element of from and returns
+// a new slice with the results.
+func TransformSlice[From, To any](from []From, f func(From) To) []To {
+ if from == nil {
+ return nil
+ }
+ to := make([]To, len(from))
+ for i, e := range from {
+ to[i] = f(e)
+ }
+ return to
+}
+
+// TransformMapValues applies f to each value of from, returning a new map.
+// It does not change the keys.
+func TransformMapValues[K comparable, VFrom, VTo any](from map[K]VFrom, f func(VFrom) VTo) map[K]VTo {
+ if from == nil {
+ return nil
+ }
+ to := map[K]VTo{}
+ for k, v := range from {
+ to[k] = f(v)
+ }
+ return to
+}
+
+// AddrOrNil returns nil if x is the zero value for T,
+// or &x otherwise.
+func AddrOrNil[T comparable](x T) *T {
+ var z T
+ if x == z {
+ return nil
+ }
+ return &x
+}
+
+// DerefOrZero returns the zero value for T if x is nil,
+// or *x otherwise.
+func DerefOrZero[T any](x *T) T {
+ if x == nil {
+ var z T
+ return z
+ }
+ return *x
+}
+
+// CivilDateToProto converts a civil.Date to a date.Date.
+func CivilDateToProto(d civil.Date) *date.Date {
+ return &date.Date{
+ Year: int32(d.Year),
+ Month: int32(d.Month),
+ Day: int32(d.Day),
+ }
+}
+
+// CivilDateFromProto converts a date.Date to a civil.Date.
+func CivilDateFromProto(p *date.Date) civil.Date {
+ if p == nil {
+ return civil.Date{}
+ }
+ return civil.Date{
+ Year: int(p.Year),
+ Month: time.Month(p.Month),
+ Day: int(p.Day),
+ }
+}
+
+// MapToStructPB converts a map into a structpb.Struct.
+func MapToStructPB(m map[string]any) *structpb.Struct {
+ if m == nil {
+ return nil
+ }
+ s, err := structpb.NewStruct(m)
+ if err != nil {
+ panic(fmt.Errorf("support.MapToProto: %w", err))
+ }
+ return s
+}
+
+// MapFromStructPB converts a structpb.Struct to a map.
+func MapFromStructPB(p *structpb.Struct) map[string]any {
+ if p == nil {
+ return nil
+ }
+ return p.AsMap()
+}
diff --git a/protoveneer/support/support_test.go b/protoveneer/support/support_test.go
new file mode 100644
index 0000000..c3def96
--- /dev/null
+++ b/protoveneer/support/support_test.go
@@ -0,0 +1,35 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package support
+
+import (
+ "reflect"
+ "strconv"
+ "testing"
+)
+
+func TestTransformMapValues(t *testing.T) {
+ var from map[string]int
+ got := TransformMapValues(from, strconv.Itoa)
+ if got != nil {
+ t.Fatalf("got %v, want nil", got)
+ }
+ from = map[string]int{"one": 1, "two": 2}
+ got = TransformMapValues(from, strconv.Itoa)
+ want := map[string]string{"one": "1", "two": "2"}
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+}