| // Copyright 2012 Google Inc. All rights reserved. |
| // Use of this source code is governed by the Apache 2.0 |
| // license that can be found in the LICENSE file. |
| |
| // +build appengine |
| |
| package socket |
| |
| import ( |
| "fmt" |
| "io" |
| "net" |
| "strconv" |
| "time" |
| |
| "github.com/golang/protobuf/proto" |
| "golang.org/x/net/context" |
| "google.golang.org/appengine/internal" |
| |
| pb "google.golang.org/appengine/internal/socket" |
| ) |
| |
| // Dial connects to the address addr on the network protocol. |
| // The address format is host:port, where host may be a hostname or an IP address. |
| // Known protocols are "tcp" and "udp". |
| // The returned connection satisfies net.Conn, and is valid while ctx is valid; |
| // if the connection is to be used after ctx becomes invalid, invoke SetContext |
| // with the new context. |
| func Dial(ctx context.Context, protocol, addr string) (*Conn, error) { |
| return DialTimeout(ctx, protocol, addr, 0) |
| } |
| |
| var ipFamilies = []pb.CreateSocketRequest_SocketFamily{ |
| pb.CreateSocketRequest_IPv4, |
| pb.CreateSocketRequest_IPv6, |
| } |
| |
| // DialTimeout is like Dial but takes a timeout. |
| // The timeout includes name resolution, if required. |
| func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) { |
| dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn. |
| if timeout > 0 { |
| var cancel context.CancelFunc |
| dialCtx, cancel = context.WithTimeout(ctx, timeout) |
| defer cancel() |
| } |
| |
| host, portStr, err := net.SplitHostPort(addr) |
| if err != nil { |
| return nil, err |
| } |
| port, err := strconv.Atoi(portStr) |
| if err != nil { |
| return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err) |
| } |
| |
| var prot pb.CreateSocketRequest_SocketProtocol |
| switch protocol { |
| case "tcp": |
| prot = pb.CreateSocketRequest_TCP |
| case "udp": |
| prot = pb.CreateSocketRequest_UDP |
| default: |
| return nil, fmt.Errorf("socket: unknown protocol %q", protocol) |
| } |
| |
| packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host) |
| if err != nil { |
| return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err) |
| } |
| if len(packedAddrs) == 0 { |
| return nil, fmt.Errorf("no addresses for %q", host) |
| } |
| |
| packedAddr := packedAddrs[0] // use first address |
| fam := pb.CreateSocketRequest_IPv4 |
| if len(packedAddr) == net.IPv6len { |
| fam = pb.CreateSocketRequest_IPv6 |
| } |
| |
| req := &pb.CreateSocketRequest{ |
| Family: fam.Enum(), |
| Protocol: prot.Enum(), |
| RemoteIp: &pb.AddressPort{ |
| Port: proto.Int32(int32(port)), |
| PackedAddress: packedAddr, |
| }, |
| } |
| if resolved { |
| req.RemoteIp.HostnameHint = &host |
| } |
| res := &pb.CreateSocketReply{} |
| if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil { |
| return nil, err |
| } |
| |
| return &Conn{ |
| ctx: ctx, |
| desc: res.GetSocketDescriptor(), |
| prot: prot, |
| local: res.ProxyExternalIp, |
| remote: req.RemoteIp, |
| }, nil |
| } |
| |
| // LookupIP returns the given host's IP addresses. |
| func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) { |
| packedAddrs, _, err := resolve(ctx, ipFamilies, host) |
| if err != nil { |
| return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err) |
| } |
| addrs = make([]net.IP, len(packedAddrs)) |
| for i, pa := range packedAddrs { |
| addrs[i] = net.IP(pa) |
| } |
| return addrs, nil |
| } |
| |
| func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) { |
| // Check if it's an IP address. |
| if ip := net.ParseIP(host); ip != nil { |
| if ip := ip.To4(); ip != nil { |
| return [][]byte{ip}, false, nil |
| } |
| return [][]byte{ip}, false, nil |
| } |
| |
| req := &pb.ResolveRequest{ |
| Name: &host, |
| AddressFamilies: fams, |
| } |
| res := &pb.ResolveReply{} |
| if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil { |
| // XXX: need to map to pb.ResolveReply_ErrorCode? |
| return nil, false, err |
| } |
| return res.PackedAddress, true, nil |
| } |
| |
| // withDeadline is like context.WithDeadline, except it ignores the zero deadline. |
| func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { |
| if deadline.IsZero() { |
| return parent, func() {} |
| } |
| return context.WithDeadline(parent, deadline) |
| } |
| |
| // Conn represents a socket connection. |
| // It implements net.Conn. |
| type Conn struct { |
| ctx context.Context |
| desc string |
| offset int64 |
| |
| prot pb.CreateSocketRequest_SocketProtocol |
| local, remote *pb.AddressPort |
| |
| readDeadline, writeDeadline time.Time // optional |
| } |
| |
| // SetContext sets the context that is used by this Conn. |
| // It is usually used only when using a Conn that was created in a different context, |
| // such as when a connection is created during a warmup request but used while |
| // servicing a user request. |
| func (cn *Conn) SetContext(ctx context.Context) { |
| cn.ctx = ctx |
| } |
| |
| func (cn *Conn) Read(b []byte) (n int, err error) { |
| const maxRead = 1 << 20 |
| if len(b) > maxRead { |
| b = b[:maxRead] |
| } |
| |
| req := &pb.ReceiveRequest{ |
| SocketDescriptor: &cn.desc, |
| DataSize: proto.Int32(int32(len(b))), |
| } |
| res := &pb.ReceiveReply{} |
| if !cn.readDeadline.IsZero() { |
| req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds()) |
| } |
| ctx, cancel := withDeadline(cn.ctx, cn.readDeadline) |
| defer cancel() |
| if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil { |
| return 0, err |
| } |
| if len(res.Data) == 0 { |
| return 0, io.EOF |
| } |
| if len(res.Data) > len(b) { |
| return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b)) |
| } |
| return copy(b, res.Data), nil |
| } |
| |
| func (cn *Conn) Write(b []byte) (n int, err error) { |
| const lim = 1 << 20 // max per chunk |
| |
| for n < len(b) { |
| chunk := b[n:] |
| if len(chunk) > lim { |
| chunk = chunk[:lim] |
| } |
| |
| req := &pb.SendRequest{ |
| SocketDescriptor: &cn.desc, |
| Data: chunk, |
| StreamOffset: &cn.offset, |
| } |
| res := &pb.SendReply{} |
| if !cn.writeDeadline.IsZero() { |
| req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds()) |
| } |
| ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline) |
| defer cancel() |
| if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil { |
| // assume zero bytes were sent in this RPC |
| break |
| } |
| n += int(res.GetDataSent()) |
| cn.offset += int64(res.GetDataSent()) |
| } |
| |
| return |
| } |
| |
| func (cn *Conn) Close() error { |
| req := &pb.CloseRequest{ |
| SocketDescriptor: &cn.desc, |
| } |
| res := &pb.CloseReply{} |
| if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil { |
| return err |
| } |
| cn.desc = "CLOSED" |
| return nil |
| } |
| |
| func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr { |
| if ap == nil { |
| return nil |
| } |
| switch prot { |
| case pb.CreateSocketRequest_TCP: |
| return &net.TCPAddr{ |
| IP: net.IP(ap.PackedAddress), |
| Port: int(*ap.Port), |
| } |
| case pb.CreateSocketRequest_UDP: |
| return &net.UDPAddr{ |
| IP: net.IP(ap.PackedAddress), |
| Port: int(*ap.Port), |
| } |
| } |
| panic("unknown protocol " + prot.String()) |
| } |
| |
| func (cn *Conn) LocalAddr() net.Addr { return addr(cn.prot, cn.local) } |
| func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) } |
| |
| func (cn *Conn) SetDeadline(t time.Time) error { |
| cn.readDeadline = t |
| cn.writeDeadline = t |
| return nil |
| } |
| |
| func (cn *Conn) SetReadDeadline(t time.Time) error { |
| cn.readDeadline = t |
| return nil |
| } |
| |
| func (cn *Conn) SetWriteDeadline(t time.Time) error { |
| cn.writeDeadline = t |
| return nil |
| } |
| |
| // KeepAlive signals that the connection is still in use. |
| // It may be called to prevent the socket being closed due to inactivity. |
| func (cn *Conn) KeepAlive() error { |
| req := &pb.GetSocketNameRequest{ |
| SocketDescriptor: &cn.desc, |
| } |
| res := &pb.GetSocketNameReply{} |
| return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res) |
| } |
| |
| func init() { |
| internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name) |
| } |