blob: cf8d16a978eb45c8d8f74419029a6e74bcdbc997 [file] [log] [blame]
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"go/ast"
"go/token"
)
var httpserverFix = fix{
"httpserver",
httpserver,
`Adapt http server methods and functions to changes
made to the http ResponseWriter interface.
http://codereview.appspot.com/4245064 Hijacker
http://codereview.appspot.com/4239076 Header
http://codereview.appspot.com/4239077 Flusher
http://codereview.appspot.com/4248075 RemoteAddr, UsingTLS
`,
}
func httpserver(f *ast.File) bool {
if !imports(f, "http") {
return false
}
fixed := false
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
w, req, ok := isServeHTTP(fn)
if !ok {
continue
}
walk(fn.Body, func(n interface{}) {
// Want to replace expression sometimes,
// so record pointer to it for updating below.
ptr, ok := n.(*ast.Expr)
if ok {
n = *ptr
}
// Look for w.UsingTLS() and w.Remoteaddr().
call, ok := n.(*ast.CallExpr)
if !ok || (len(call.Args) != 0 && len(call.Args) != 2) {
return
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return
}
if !refersTo(sel.X, w) {
return
}
switch sel.Sel.String() {
case "Hijack":
// replace w with w.(http.Hijacker)
sel.X = &ast.TypeAssertExpr{
X: sel.X,
Type: ast.NewIdent("http.Hijacker"),
}
fixed = true
case "Flush":
// replace w with w.(http.Flusher)
sel.X = &ast.TypeAssertExpr{
X: sel.X,
Type: ast.NewIdent("http.Flusher"),
}
fixed = true
case "UsingTLS":
if ptr == nil {
// can only replace expression if we have pointer to it
break
}
// replace with req.TLS != nil
*ptr = &ast.BinaryExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent(req.String()),
Sel: ast.NewIdent("TLS"),
},
Op: token.NEQ,
Y: ast.NewIdent("nil"),
}
fixed = true
case "RemoteAddr":
if ptr == nil {
// can only replace expression if we have pointer to it
break
}
// replace with req.RemoteAddr
*ptr = &ast.SelectorExpr{
X: ast.NewIdent(req.String()),
Sel: ast.NewIdent("RemoteAddr"),
}
fixed = true
case "SetHeader":
// replace w.SetHeader with w.Header().Set
// or w.Header().Del if second argument is ""
sel.X = &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(w.String()),
Sel: ast.NewIdent("Header"),
},
}
sel.Sel = ast.NewIdent("Set")
if len(call.Args) == 2 && isEmptyString(call.Args[1]) {
sel.Sel = ast.NewIdent("Del")
call.Args = call.Args[:1]
}
fixed = true
}
})
}
return fixed
}
func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) {
for _, field := range fn.Type.Params.List {
if isPkgDot(field.Type, "http", "ResponseWriter") {
w = field.Names[0]
continue
}
if isPtrPkgDot(field.Type, "http", "Request") {
req = field.Names[0]
continue
}
}
ok = w != nil && req != nil
return
}