| // Copyright 2014 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package agent |
| |
| import ( |
| "errors" |
| "io" |
| "net" |
| "sync" |
| |
| "golang.org/x/crypto/ssh" |
| ) |
| |
| // RequestAgentForwarding sets up agent forwarding for the session. |
| // ForwardToAgent or ForwardToRemote should be called to route |
| // the authentication requests. |
| func RequestAgentForwarding(session *ssh.Session) error { |
| ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) |
| if err != nil { |
| return err |
| } |
| if !ok { |
| return errors.New("forwarding request denied") |
| } |
| return nil |
| } |
| |
| // ForwardToAgent routes authentication requests to the given keyring. |
| func ForwardToAgent(client *ssh.Client, keyring Agent) error { |
| channels := client.HandleChannelOpen(channelType) |
| if channels == nil { |
| return errors.New("agent: already have handler for " + channelType) |
| } |
| |
| go func() { |
| for ch := range channels { |
| channel, reqs, err := ch.Accept() |
| if err != nil { |
| continue |
| } |
| go ssh.DiscardRequests(reqs) |
| go func() { |
| ServeAgent(keyring, channel) |
| channel.Close() |
| }() |
| } |
| }() |
| return nil |
| } |
| |
| const channelType = "auth-agent@openssh.com" |
| |
| // ForwardToRemote routes authentication requests to the ssh-agent |
| // process serving on the given unix socket. |
| func ForwardToRemote(client *ssh.Client, addr string) error { |
| channels := client.HandleChannelOpen(channelType) |
| if channels == nil { |
| return errors.New("agent: already have handler for " + channelType) |
| } |
| conn, err := net.Dial("unix", addr) |
| if err != nil { |
| return err |
| } |
| conn.Close() |
| |
| go func() { |
| for ch := range channels { |
| channel, reqs, err := ch.Accept() |
| if err != nil { |
| continue |
| } |
| go ssh.DiscardRequests(reqs) |
| go forwardUnixSocket(channel, addr) |
| } |
| }() |
| return nil |
| } |
| |
| func forwardUnixSocket(channel ssh.Channel, addr string) { |
| conn, err := net.Dial("unix", addr) |
| if err != nil { |
| return |
| } |
| |
| var wg sync.WaitGroup |
| wg.Add(2) |
| go func() { |
| io.Copy(conn, channel) |
| conn.(*net.UnixConn).CloseWrite() |
| wg.Done() |
| }() |
| go func() { |
| io.Copy(channel, conn) |
| channel.CloseWrite() |
| wg.Done() |
| }() |
| |
| wg.Wait() |
| conn.Close() |
| channel.Close() |
| } |