blob: 4a9acdf5a618e54ae7a32b12147cd39a365db8bc [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package setapi implements the setapi open2opaque subcommand, which sets the
// go_api_flag file option in .proto files.
package setapi
import (
"bytes"
"cmp"
"context"
"errors"
"fmt"
"io"
"math"
"os"
"os/exec"
"slices"
"strings"
"flag"
"github.com/google/subcommands"
"golang.org/x/sync/errgroup"
pb "google.golang.org/open2opaque/internal/apiflagdata"
"google.golang.org/open2opaque/internal/o2o/args"
"google.golang.org/open2opaque/internal/protodetect"
"google.golang.org/open2opaque/internal/protoparse"
descpb "google.golang.org/protobuf/types/descriptorpb"
gofeaturespb "google.golang.org/protobuf/types/gofeaturespb"
)
// Cmd implements the setapi subcommand of the open2opaque tool.
type Cmd struct {
apiFlag string
inputFile string
skipCleanup bool
maxProcs uint
protoFmt string
kind string
}
// Name implements subcommand.Command.
func (*Cmd) Name() string { return "setapi" }
// Synopsis implements subcommand.Command.
func (*Cmd) Synopsis() string {
return "Set the Go API level proto option."
}
// Usage implements subcommand.Command.
func (*Cmd) Usage() string {
return `Usage: open2opaque setapi [-api=<` + strings.Join(validApis, "|") + `>] [-input_file=<input-list>] [<path/file.proto>] [<protopkg1>] [<protopkg2.MessageName>]
The setapi subcommand adjusts the Go API level option on the specified proto file(s) / package(s) / message(s).
The setapi subcommand can either read the proto file name(s) / package(s) / message(s)
from a text file (-input_file) or from the command line arguments, or both.
Command-line flag documentation follows:
`
}
// SetFlags implements subcommand.Command.
func (cmd *Cmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&cmd.apiFlag, "api", "OPAQUE", "set Go API level value to this, valid values: OPEN, HYBRID, and OPAQUE")
f.StringVar(&cmd.inputFile, "input_file", "", "file containing list of proto source files / proto packages / proto messages to update (one per line)")
f.BoolVar(&cmd.skipCleanup, "skip_cleanup", false, "skip the cleanup step, which removes the file-level flag if it equals the default and removes message flags if they equal the file-level API")
f.UintVar(&cmd.maxProcs, "max_procs", 32, "max number of files concurrently processed")
protofmtDefault := ""
f.StringVar(&cmd.protoFmt, "protofmt", protofmtDefault, "if non-empty, a formatter program for .proto files")
}
// Execute implements subcommand.Command.
func (cmd *Cmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...any) subcommands.ExitStatus {
if err := cmd.setapi(ctx, f); err != nil {
// Use fmt.Fprintf instead of log.Exit to generate a shorter error
// message: users do not care about the current date/time and the fact
// that our code lives in setapi.go.
fmt.Fprintf(os.Stderr, "%v\n", err)
return subcommands.ExitFailure
}
return subcommands.ExitSuccess
}
// Command returns an initialized Cmd for registration with the subcommands
// package.
func Command() *Cmd {
return &Cmd{}
}
// Task defines the modification of the Go API level of a proto file or a
// particular message.
type Task struct {
// Path of the proto file.
Path string
// The content of the proto file.
Content []byte
// If not empty, set the API level only for the message with this
// package-local or fully-qualified name, e.g. Msg.NestedMsg or
// pkgname.Msg.NestedMsg.
// If empty, set the file-level API level.
Symbol string
TargetAPI gofeaturespb.GoFeatures_APILevel
// If true, skip cleanup steps. If false, perform the following steps:
// - if all messages in a file are on the same API level, set the whole file
// to that level.
// - remove API flags from all messages that don't need it because they are
// on the same API level as the file (or of its parent message for nested
// messages in edition protos that use the new edition-feature API flag).
SkipCleanup bool
// A leading comment before the Go API flag prevents setapi from
// modifying it. If a modification was prevent by this mechanism and
// ErrorOnExempt is true, Process returns an error. Otherwise, the original
// content is returned.
ErrorOnExempt bool
}
func (cmd *Cmd) setapi(ctx context.Context, f *flag.FlagSet) error {
api, err := parseAPIFlag(cmd.apiFlag)
if err != nil {
return err
}
var inputs []string
if cmd.inputFile != "" {
var err error
if inputs, err = readInputFile(cmd.inputFile); err != nil {
return fmt.Errorf("error reading file: %v", err)
}
}
inputs = append(inputs, f.Args()...)
var tasks []Task
kind := cmd.kind
if kind == "" {
kind = "proto_filename"
}
for _, input := range inputs {
_, filename, symbol, err := args.ToProtoFilename(ctx, input, kind)
if err != nil {
return err
}
content, err := os.ReadFile(filename)
if err != nil {
return err
}
tasks = append(tasks, Task{
Path: filename,
Symbol: symbol,
Content: content,
TargetAPI: api,
SkipCleanup: cmd.skipCleanup,
ErrorOnExempt: true,
})
}
if len(tasks) == 0 {
return fmt.Errorf("missing inputs, use either -list (one input per line) and / or pass input file name(s) / package(s) / message(s) as non-flag arguments")
}
// Error out if there are non-unique paths to avoid conflicting inputs like
// editing the file-level flag of a file and a message flag in the same file.
if conflicts := conflictingTasks(tasks); len(conflicts) > 0 {
var l []string
for _, c := range conflicts {
l = append(l, c.Path)
}
return fmt.Errorf("conflicting, non-unique proto files in inputs: %s", strings.Join(l, ", "))
}
outputs := make([][]byte, len(tasks))
protofmt := cmd.protoFmt
if protofmt == "" {
protofmt = "cat"
}
eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(int(cmd.maxProcs))
for itask, task := range tasks {
itask, task := itask, task
eg.Go(func() error {
var err error
outputs[itask], err = Process(ctx, task, protofmt)
return err
})
}
if err := eg.Wait(); err != nil {
return fmt.Errorf("aborting without writing any proto files: %v", err)
}
// Write the outputs back to the input files.
eg, ctx = errgroup.WithContext(ctx)
eg.SetLimit(int(cmd.maxProcs))
for itask, task := range tasks {
itask, task := itask, task
eg.Go(func() error {
return os.WriteFile(task.Path, outputs[itask], 0644)
})
}
if err := eg.Wait(); err != nil {
return fmt.Errorf("error while writing proto files: %v", err)
}
return nil
}
var (
apiMap = map[string]gofeaturespb.GoFeatures_APILevel{
"OPEN": gofeaturespb.GoFeatures_API_OPEN,
"HYBRID": gofeaturespb.GoFeatures_API_HYBRID,
"OPAQUE": gofeaturespb.GoFeatures_API_OPAQUE,
}
// Don't use the keys of apiMap to report valid values to the user, we want
// in a specific order.
validApis = []string{"OPEN", "HYBRID", "OPAQUE"}
)
func parseAPIFlag(flag string) (gofeaturespb.GoFeatures_APILevel, error) {
if flag == "" {
return gofeaturespb.GoFeatures_API_LEVEL_UNSPECIFIED, fmt.Errorf("missing --api flag value")
}
if api, ok := apiMap[flag]; ok {
return api, nil
}
return gofeaturespb.GoFeatures_API_LEVEL_UNSPECIFIED, fmt.Errorf("invalid --api flag value: %v, valid values: %s", flag, strings.Join(validApis, ", "))
}
func readInputFile(name string) ([]string, error) {
b, err := os.ReadFile(name)
if err != nil {
return nil, err
}
var result []string
for _, line := range strings.Split(string(b), "\n") {
if line := strings.TrimSpace(line); line != "" {
result = append(result, line)
}
}
return result, nil
}
func conflictingTasks(tasks []Task) []Task {
var conflicts []Task
present := make(map[string]bool)
for _, task := range tasks {
if present[task.Path] {
conflicts = append(conflicts, task)
}
present[task.Path] = true
}
return conflicts
}
var logf = func(format string, a ...any) { fmt.Fprintf(os.Stderr, "[setapi] "+format+"\n", a...) }
func parse(path string, content []byte, skipMessages bool) (*protoparse.FileOpt, error) {
parser := protoparse.NewParserWithAccessor(func(string) (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(content)), nil
})
fopt, err := parser.ParseFile(path, skipMessages)
if err != nil {
return nil, fmt.Errorf("protoparse.ParseFile: %v", err)
}
return fopt, nil
}
func parentAPI(path string, content []byte, msgName string) (gofeaturespb.GoFeatures_APILevel, error) {
fopt, err := parse(path, content, false)
if err != nil {
return gofeaturespb.GoFeatures_API_LEVEL_UNSPECIFIED, err
}
msgName = strings.TrimPrefix(msgName, fopt.Package+".")
var mopt *protoparse.MessageOpt
for _, moptLoop := range fopt.MessageOpts {
if o := findMsg(moptLoop, msgName); o != nil {
mopt = o
break
}
}
if mopt == nil {
return gofeaturespb.GoFeatures_API_LEVEL_UNSPECIFIED, fmt.Errorf("cannot find message %q", msgName)
}
result := fopt.GoAPI
if fopt.Syntax == "editions" && mopt.Parent != nil {
result = mopt.Parent.GoAPI
}
return result, nil
}
func traverseMsgTree(opt *protoparse.MessageOpt, f func(*protoparse.MessageOpt) error) error {
if err := f(opt); err != nil {
return err
}
for _, c := range opt.Children {
if err := traverseMsgTree(c, f); err != nil {
return err
}
}
return nil
}
// Process modifies the API level of a proto file or of a particular message in
// a proto file, see the doc comment of the type Task for more details. Before
// returning the modified file content, the file is formatted by executing
// formatter (use "cat" if you don't have a formatter handy). This function
// doesn't modify the []byte task.Content.
func Process(ctx context.Context, task Task, formatter string) ([]byte, error) {
if task.Path == "" {
return nil, fmt.Errorf("path is empty")
}
if len(task.Content) == 0 {
return nil, fmt.Errorf("content is empty")
}
// Clone the input file content in case the caller will use the slice after
// the call.
content := slices.Clone(task.Content)
var err error
if task.Symbol == "" {
content, err = setFileAPI(task.Path, content, task.TargetAPI, task.SkipCleanup, task.ErrorOnExempt)
if err != nil {
return nil, fmt.Errorf("setFileAPI: %v", err)
}
} else {
parentAPI, err := parentAPI(task.Path, content, task.Symbol)
if err != nil {
return nil, fmt.Errorf("parentAPI: %v", err)
}
content, err = setMsgAPI(task.Path, content, task.Symbol, parentAPI, task.TargetAPI, task.SkipCleanup)
if err != nil {
if errors.Is(err, ErrLeadingCommentPreventsEdit) && !task.ErrorOnExempt {
// Don't error out if leading comment exempted edit but ErrorOnExempt
// is false. Could write this a lot more compact, but this makes it easy
// to understand.
} else {
return nil, err
}
}
}
if !task.SkipCleanup {
content, err = cleanup(task.Path, content)
if err != nil {
return nil, fmt.Errorf("cleanup: %v", err)
}
}
content, err = FormatFile(ctx, content, formatter)
if err != nil {
return nil, fmt.Errorf("FormatFile: %v", err)
}
return content, nil
}
func fromFeatureToOld(apiLevel gofeaturespb.GoFeatures_APILevel) pb.GoAPI {
switch apiLevel {
case gofeaturespb.GoFeatures_API_OPEN:
return pb.GoAPI_OPEN_V1
case gofeaturespb.GoFeatures_API_HYBRID:
return pb.GoAPI_OPEN_TO_OPAQUE_HYBRID
case gofeaturespb.GoFeatures_API_OPAQUE:
return pb.GoAPI_OPAQUE_V0
default:
panic(fmt.Sprintf("unknown apilevel %v", apiLevel))
}
}
func byteRangeWithEOLComment(in []byte, tr protoparse.TextRange) (beginByte, endByte int, err error) {
beginByte, endByte, err = tr.ToByteRange(in)
if err != nil {
return -1, -1, fmt.Errorf("tr.ToByteRange: %v", err)
}
if tr.BeginLine == tr.EndLine {
if nextNewlineByte := bytes.IndexByte(in[endByte:], '\n'); nextNewlineByte != -1 {
if lineSuffix := in[endByte : endByte+nextNewlineByte]; bytes.HasPrefix(bytes.TrimSpace(lineSuffix), []byte("//")) {
endByte += nextNewlineByte
}
}
}
return beginByte, endByte, nil
}
func replaceTextRange(in []byte, tr protoparse.TextRange, insert []byte) ([]byte, error) {
beginByte, endByte, err := tr.ToByteRange(in)
if err != nil {
return nil, fmt.Errorf("tr.ToByteRange: %v", err)
}
in = slices.Delete(in, beginByte, endByte)
return slices.Insert(in, beginByte, insert...), nil
}
func levelLineFor(fopt *protoparse.FileOpt, targetAPI gofeaturespb.GoFeatures_APILevel) (string, error) {
if fopt.Syntax == "editions" {
return fmt.Sprintf("option features.(pb.go).api_level = %s;", targetAPI), nil
}
return "", fmt.Errorf("file %s is not using Protobuf Editions (https://protobuf.dev/editions/overview/) yet", fopt.File)
}
func setFileAPI(path string, content []byte, targetAPI gofeaturespb.GoFeatures_APILevel, skipCleanup, errorOnExempt bool) ([]byte, error) {
fopt, err := parse(path, content, true)
if err != nil {
return nil, err
}
if fopt.IsExplicit && fopt.APIInfo == nil {
return nil, fmt.Errorf("BUG: fopt.APIInfo is nil")
}
if defaultAPI := protodetect.DefaultFileLevel(path); defaultAPI == targetAPI {
if !fopt.IsExplicit {
logf("File %s is already on the target API level by default, doing nothing", path)
return content, nil
}
if fopt.APIInfo.HasLeadingComment {
if errorOnExempt {
return nil, fmt.Errorf("API flag of file %s has a leading comment that prevents removing it", path)
}
logf("API flag of file %s has a leading comment that prevents removing it", path)
return content, nil
}
if fopt.GoAPI == targetAPI && skipCleanup {
logf("skipping cleanup: not removing the API flag of file %s although it's at the default API level", path)
return content, nil
}
from, to, err := byteRangeWithEOLComment(content, fopt.APIInfo.TextRange)
if err != nil {
return nil, fmt.Errorf("byteRangeWithEOLComment: %v", err)
}
logf("Removing the API flag of file %s to use the default API level", path)
return slices.Delete(content, from, to), nil
}
levelLine, err := levelLineFor(fopt, targetAPI)
if err != nil {
return nil, err
}
if !fopt.IsExplicit {
logf("Inserting API flag %q into file %s", targetAPI, path)
ln, err := fileOptionLineNumber(fopt.SourceCodeInfo)
if err != nil {
return nil, fmt.Errorf("fileOptionLineNumber: %v", err)
}
return insertLine(content, levelLine, ln), nil
}
// File is currently explicitly set to API value and target API isn't the
// default API.
if fopt.GoAPI == targetAPI {
logf("File %s is already on the target API level, do nothing", path)
return content, nil
}
if fopt.APIInfo.HasLeadingComment {
if errorOnExempt {
return nil, fmt.Errorf("API flag of file %s has a leading comment that prevents replacing it", path)
}
logf("API flag of file %s has a leading comment that prevents replacing it", path)
return content, nil
}
logf("Replacing the API flag of file %s", path)
content, err = replaceTextRange(content, fopt.APIInfo.TextRange, []byte(levelLine))
if err != nil {
return nil, fmt.Errorf("replaceTextRange: %v", err)
}
return content, nil
}
func findMsg(opt *protoparse.MessageOpt, name string) *protoparse.MessageOpt {
if opt.Message == name {
return opt
}
for _, c := range opt.Children {
if childOpt := findMsg(c, name); childOpt != nil {
return childOpt
}
}
return nil
}
// ErrLeadingCommentPreventsEdit signals that no modifications were made because
// a leading comment exempted an API flag from modification. This error is
// ignored if Task.ErrorOnExempt is false.
var ErrLeadingCommentPreventsEdit = errors.New("leading comment prevents edit, check the logs for more details")
func setMsgAPI(path string, content []byte, msgName string, parentAPI, targetAPI gofeaturespb.GoFeatures_APILevel, skipCleanup bool) ([]byte, error) {
// This function is called recursively. Parse content every time because the
// position information changes when parent messages are manipulated.
fopt, err := parse(path, content, false)
if err != nil {
return nil, err
}
msgName = strings.TrimPrefix(msgName, fopt.Package+".")
var mopt *protoparse.MessageOpt
for _, moptLoop := range fopt.MessageOpts {
if o := findMsg(moptLoop, msgName); o != nil {
mopt = o
break
}
}
if mopt == nil {
return nil, fmt.Errorf("cannot find massage %q", msgName)
}
if mopt.IsExplicit && mopt.APIInfo == nil {
return nil, fmt.Errorf("BUG: mopt.APIInfo is nil")
}
if parentAPI == targetAPI {
if !mopt.IsExplicit {
logf("Message %q is already on the target API level, doing nothing", msgName)
return content, nil
}
if mopt.APIInfo.HasLeadingComment {
logf("Changing API level of message %q was prevented by a leading comment of the API flag", msgName)
if mopt.GoAPI != targetAPI {
// Report an error to abort the whole operation if we would have
// changed the API flag instead of just cleaning it up because it's the
// same as the parent. If we didn't, we might change nested messages to
// the wrong API level.
return nil, ErrLeadingCommentPreventsEdit
}
return content, nil
}
if mopt.GoAPI != targetAPI && fopt.Syntax == "editions" {
logf("Before changing API flag of message %q, descending into children to prevent recursive change of API level", msgName)
for _, child := range mopt.Children {
content, err = setMsgAPI(path, content, child.Message, targetAPI, child.GoAPI, skipCleanup)
if err != nil {
return nil, err
}
}
}
if mopt.GoAPI == targetAPI && skipCleanup {
logf("skipping cleanup: not removing the API flag of message %q although it's at the parent API level", msgName)
return content, nil
}
from, to, err := byteRangeWithEOLComment(content, mopt.APIInfo.TextRange)
if err != nil {
return nil, fmt.Errorf("byteRangeWithEOLComment: %v", err)
}
logf("Removing the API flag of message %q to use the parent API level", msgName)
return slices.Delete(content, from, to), nil
}
if !mopt.IsExplicit {
if fopt.Syntax == "editions" {
logf("Before changing API flag of message %q, descending into children to prevent recursive change of API level", msgName)
for _, child := range mopt.Children {
content, err = setMsgAPI(path, content, child.Message, targetAPI, child.GoAPI, skipCleanup)
if err != nil {
return nil, err
}
}
}
logf("Inserting API flag %q into message %q", targetAPI, msgName)
idx, err := msgOptionInsertionByteIdx(content, fopt.SourceCodeInfo, mopt.LocPath)
if err != nil {
return nil, fmt.Errorf("msgOptionInsertionByteIdx: %v", err)
}
insertion := fmt.Sprintf("\noption go_api_flag = %q;", fromFeatureToOld(targetAPI))
if fopt.Syntax == "editions" {
insertion = fmt.Sprintf("\noption features.(pb.go).api_level = %s;", targetAPI)
}
return slices.Concat(content[:idx], []byte(insertion), content[idx:]), nil
}
// Message is currently explicitly set to API value and target API isn't the
// parent API.
if mopt.GoAPI == targetAPI {
logf("Message %q is already on the target API level, do nothing", msgName)
return content, nil
}
if mopt.APIInfo.HasLeadingComment {
logf("Changing API level of message %q was prevented by a leading comment of the API flag", msgName)
return content, ErrLeadingCommentPreventsEdit
}
if fopt.Syntax == "editions" {
logf("Before changing API flag of message %q, descending into children to prevent recursive change of API level", msgName)
for _, child := range mopt.Children {
content, err = setMsgAPI(path, content, child.Message, targetAPI, child.GoAPI, skipCleanup)
if err != nil {
return nil, err
}
}
}
logf("Replacing the API flag of message %q", msgName)
insert := fmt.Sprintf("option go_api_flag = %q;", fromFeatureToOld(targetAPI))
if fopt.Syntax == "editions" {
insert = fmt.Sprintf("option features.(pb.go).api_level = %s;", targetAPI)
}
content, err = replaceTextRange(content, mopt.APIInfo.TextRange, []byte(insert))
if err != nil {
return nil, fmt.Errorf("replaceTextRange: %v", err)
}
return content, nil
}
func cleanup(path string, content []byte) ([]byte, error) {
// Cleanup 1: If all messages are on the same API level, flip the file API
// level to that value. If said level is the default for the path, don't use
// an explicit flag.
var err error
content, err = cleanupAllMsgsToFile(path, content)
if err != nil {
return nil, err
}
// Cleanup 2: Remove the API flag from all messages that have the same API
// level as their parents. The parent is usually the file API level, but in
// case of editions proto that use the new edition feature, the parent of
// a nested message is its parent message.
//
// Parse again after the previous cleanup might have modified content.
content, err = cleanupSameAsParent(path, content)
if err != nil {
return nil, err
}
// Cleanup 3: ensure the go_features.proto import is present if and only if
// any parts of the file set the features.(pb.go).api_level option.
return cleanupGoFeaturesImport(path, content)
}
func cleanupAllMsgsToFile(path string, content []byte) ([]byte, error) {
fopt, err := parse(path, content, false)
if err != nil {
return nil, err
}
apiMap := map[gofeaturespb.GoFeatures_APILevel]struct{}{}
for _, moptLoop := range fopt.MessageOpts {
traverseMsgTree(moptLoop, func(mopt *protoparse.MessageOpt) error {
apiMap[mopt.GoAPI] = struct{}{}
return nil
})
}
if len(apiMap) != 1 {
// Not all messages are on the same API level.
return content, nil
}
targetAPI := fopt.MessageOpts[0].GoAPI
if fopt.GoAPI == targetAPI {
// File is already on the target API level, nothing to do.
return content, nil
}
// All messages have same API level.
// Consider cleanup as a nice to have, don't error out on leading-comment
// exemptions.
const errorOnExempt = false
const skipCleanup = false
content, err = setFileAPI(path, content, targetAPI, skipCleanup, errorOnExempt)
if err != nil {
return nil, fmt.Errorf("setFileAPI: %v", err)
}
return content, nil
}
func cleanupSameAsParent(path string, content []byte) ([]byte, error) {
fopt, err := parse(path, content, false)
if err != nil {
return nil, err
}
// Collect all byte ranges of API flags that should be removed.
type byteRange struct {
from, to int
}
var removeByteRanges []byteRange
for _, moptLoop := range fopt.MessageOpts {
if err := traverseMsgTree(moptLoop, func(mopt *protoparse.MessageOpt) error {
if !mopt.IsExplicit {
return nil
}
if mopt.APIInfo == nil {
return fmt.Errorf("BUG: mopt.APIInfo is nil")
}
if mopt.APIInfo.HasLeadingComment {
return nil
}
parentAPI := fopt.GoAPI
if fopt.Syntax == "editions" && mopt.Parent != nil {
parentAPI = mopt.Parent.GoAPI
}
if mopt.GoAPI != parentAPI {
return nil
}
// Message has same API level as parent (or file) and API flag is
// explicit.
from, to, err := byteRangeWithEOLComment(content, mopt.APIInfo.TextRange)
if err != nil {
return fmt.Errorf("TextRange.ToByteRange: %v", err)
}
removeByteRanges = append(removeByteRanges, byteRange{from: from, to: to})
return nil
}); err != nil {
return nil, err
}
}
if len(removeByteRanges) == 0 {
return content, nil
}
slices.SortFunc(removeByteRanges, func(a, b byteRange) int {
return cmp.Compare(a.from, b.from)
})
for i := 1; i < len(removeByteRanges); i++ {
if removeByteRanges[i].from < removeByteRanges[i-1].to {
return nil, fmt.Errorf("text ranges overlap")
}
}
for _, br := range slices.Backward(removeByteRanges) {
content = slices.Delete(content, br.from, br.to)
}
return content, nil
}
func insertLine(content []byte, insertLine string, ln int32) []byte {
lines := bytes.Split(content, []byte{'\n'})
result := slices.Clone(lines[:ln])
result = append(result, []byte(insertLine))
result = append(result, lines[ln:]...)
return bytes.Join(result, []byte{'\n'})
}
func messageUsesFeature(mopt *protoparse.MessageOpt) bool {
if mopt.IsExplicit {
return true // editions feature set at message level
}
for _, mo := range mopt.Children {
if messageUsesFeature(mo) {
return true
}
}
return false
}
func usesFeature(fopt *protoparse.FileOpt) bool {
if fopt.Syntax != "editions" {
return false // only files on editions can use editions features
}
if fopt.IsExplicit {
return true // editions feature set at file level
}
for _, mo := range fopt.MessageOpts {
if messageUsesFeature(mo) {
return true
}
}
return false
}
func cleanupGoFeaturesImport(path string, content []byte) ([]byte, error) {
const featuresProto = "google/protobuf/go_features.proto"
parser := protoparse.NewParserWithAccessor(func(string) (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(content)), nil
})
fopt, err := parser.ParseFile(path, false)
if err != nil {
return nil, fmt.Errorf("error reading file %s: %w", path, err)
}
desc := fopt.Desc
goFeaturesImported := slices.ContainsFunc(
desc.GetDependency(),
func(path string) bool {
return path == featuresProto
})
usesFeature := usesFeature(fopt)
if usesFeature && !goFeaturesImported {
ln, err := featuresImportLineNumber(desc.GetSourceCodeInfo())
if err != nil {
return nil, err
}
importLine := fmt.Sprintf("import %q;", featuresProto)
return insertLine(content, importLine, ln), nil
}
if !usesFeature && goFeaturesImported {
importNum := 0
for _, loc := range desc.GetSourceCodeInfo().GetLocation() {
path := loc.GetPath()
span := loc.GetSpan()
if len(path) < 1 {
continue
}
if path[0] != importFieldNum {
continue
}
// Found an import. Is it the go_features.proto import?
if max := len(desc.GetDependency()); importNum >= max {
return nil, fmt.Errorf("BUG: too many imports in SourceCodeInfo: got index %d, want at most %d", importNum, max)
}
imported := desc.GetDependency()[importNum]
importNum++
if imported != featuresProto {
continue
}
ln := spanEndLine(span)
lines := bytes.Split(content, []byte{'\n'})
result := slices.Clone(lines[:ln])
// lines[ln] deleted by omitting it.
result = append(result, lines[ln+1:]...)
return bytes.Join(result, []byte{'\n'}), nil
}
return nil, fmt.Errorf("BUG: could not locate import line in SourceCodeInfo")
}
return content, nil
}
func spanEndLine(span []int32) int32 {
// https://github.com/protocolbuffers/protobuf/blob/v29.1/src/google/protobuf/descriptor.proto#L1209
if len(span) == 4 {
return span[2]
}
return span[0]
}
const (
// https://github.com/protocolbuffers/protobuf/blob/v29.1/src/google/protobuf/descriptor.proto#L105-L137
syntaxFieldNum = 12
editionFieldNum = 14
editionDeprFieldNum = 13
packageFieldNum = 2
importFieldNum = 3
optionFieldNum = 8
)
func fieldNumToLine(info *descpb.SourceCodeInfo) map[int32]int32 {
// pos contains a mapping of proto field numbers to line numbers. For syntax
// and package constructs, the line numbers represent where they are declared.
// For import and options, the line number represents the last import/option
// if it exists.
pos := map[int32]int32{}
for _, loc := range info.GetLocation() {
path := loc.GetPath()
span := loc.GetSpan()
if len(path) < 1 {
continue
}
switch fieldNum := path[0]; fieldNum {
case syntaxFieldNum, editionFieldNum, editionDeprFieldNum, packageFieldNum, importFieldNum, optionFieldNum:
endLine := spanEndLine(span)
// If not set, map returns 0.
if endLine > pos[fieldNum] {
pos[fieldNum] = endLine
}
}
}
return pos
}
// featuresImportLineNumber returns the line number to insert the file-level
// go_features.proto import. It uses the following heuristics to determine the
// line number:
// If there are any import lines, return the line number after the last one.
// If there is a package statement line, return the line number after it.
// If there is a syntax statement line, return the line number after it.
// Note that this func assumes that an earlier protoparser.Parser.ParseFile
// invocation will return error already if there is no syntax statement.
func featuresImportLineNumber(info *descpb.SourceCodeInfo) (int32, error) {
pos := fieldNumToLine(info)
if lnum, ok := pos[importFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[packageFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[syntaxFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[editionFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[editionDeprFieldNum]; ok {
return lnum + 1, nil
}
return 0, fmt.Errorf("cannot determine line number for go_features.proto import")
}
// fileOptionLineNumber returns the line number to insert the file-level
// go_api_flag option. It uses the following heuristics to determine the
// line number:
// If there are any file option settings, return the line number after the last
// one.
// If there are any import lines, return the line number after the last one.
// If there is a package statement line, return the line number after it.
// If there is a syntax statement line, return the line number after it.
// Note that this func assumes that an earlier protoparser.Parser.ParseFile
// invocation will return error already if there is no syntax statement.
func fileOptionLineNumber(info *descpb.SourceCodeInfo) (int32, error) {
pos := fieldNumToLine(info)
if lnum, ok := pos[optionFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[importFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[packageFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[syntaxFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[editionFieldNum]; ok {
return lnum + 1, nil
}
if lnum, ok := pos[editionDeprFieldNum]; ok {
return lnum + 1, nil
}
return 0, fmt.Errorf("cannot determine line number for file-level API flag")
}
func msgOptionInsertionByteIdx(content []byte, info *descpb.SourceCodeInfo, msgPath []int32) (int, error) {
const (
// https://github.com/protocolbuffers/protobuf/blob/v29.1/src/google/protobuf/descriptor.proto#L105
nameFieldNum = 1
)
namePath := append(slices.Clone(msgPath), nameFieldNum)
nameEndIdx := math.MaxInt
for _, loc := range info.GetLocation() {
path := loc.GetPath()
if len(path) <= len(msgPath) {
// Paths in the message (name, fields, option, ...) have to be longer than
// the message path.
continue
}
if slices.Equal(path, namePath) {
var err error
_, nameEndIdx, err = protoparse.SpanToTextRange(loc.GetSpan()).ToByteRange(content)
if err != nil {
return -1, fmt.Errorf("TextRange.ToByteRange: %v", err)
}
break
}
}
if nameEndIdx == math.MaxInt {
return -1, fmt.Errorf("BUG: cannot find message name")
}
// Look for the next opening curly after the name. If this happens to be in a
// comment, this function will produce wrong results.
offset := bytes.IndexByte(content[nameEndIdx:], '{')
if offset == -1 {
return -1, fmt.Errorf(`cannot find "{"`)
}
return nameEndIdx + offset + 1, nil
}
// FormatFile runs formatter on input and returns the formatted result.
func FormatFile(ctx context.Context, input []byte, formatter string) ([]byte, error) {
cmd := exec.CommandContext(ctx, formatter)
cmd.Stdin = bytes.NewReader(input)
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
cmd.Stdout = stdout
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return nil, fmt.Errorf("formatter error: %v", err)
}
if stderr.Len() > 0 {
return nil, fmt.Errorf("formatter stderr: %s", stderr.String())
}
return stdout.Bytes(), nil
}