// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssh

import (
	"errors"
	"io"
	"sync"
)

// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
// SSH connection.
type Channel interface {
	// Accept accepts the channel creation request.
	Accept() error
	// Reject rejects the channel creation request. After calling this, no
	// other methods on the Channel may be called. If they are then the
	// peer is likely to signal a protocol error and drop the connection.
	Reject(reason RejectionReason, message string) error

	// Read may return a ChannelRequest as an error.
	Read(data []byte) (int, error)
	Write(data []byte) (int, error)
	Close() error

	// AckRequest either sends an ack or nack to the channel request.
	AckRequest(ok bool) error

	// ChannelType returns the type of the channel, as supplied by the
	// client.
	ChannelType() string
	// ExtraData returns the arbitary payload for this channel, as supplied
	// by the client. This data is specific to the channel type.
	ExtraData() []byte
}

// ChannelRequest represents a request sent on a channel, outside of the normal
// stream of bytes. It may result from calling Read on a Channel.
type ChannelRequest struct {
	Request   string
	WantReply bool
	Payload   []byte
}

func (c ChannelRequest) Error() string {
	return "channel request received"
}

// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason int

const (
	Prohibited RejectionReason = iota + 1
	ConnectionFailed
	UnknownChannelType
	ResourceShortage
)

type channel struct {
	// immutable once created
	chanType  string
	extraData []byte

	theyClosed  bool
	theySentEOF bool
	weClosed    bool
	dead        bool

	serverConn            *ServerConn
	myId, theirId         uint32
	myWindow, theirWindow uint32
	maxPacketSize         uint32
	err                   error

	pendingRequests []ChannelRequest
	pendingData     []byte
	head, length    int

	// This lock is inferior to serverConn.lock
	lock sync.Mutex
	cond *sync.Cond
}

func (c *channel) Accept() error {
	c.serverConn.lock.Lock()
	defer c.serverConn.lock.Unlock()

	if c.serverConn.err != nil {
		return c.serverConn.err
	}

	confirm := channelOpenConfirmMsg{
		PeersId:       c.theirId,
		MyId:          c.myId,
		MyWindow:      c.myWindow,
		MaxPacketSize: c.maxPacketSize,
	}
	return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm))
}

func (c *channel) Reject(reason RejectionReason, message string) error {
	c.serverConn.lock.Lock()
	defer c.serverConn.lock.Unlock()

	if c.serverConn.err != nil {
		return c.serverConn.err
	}

	reject := channelOpenFailureMsg{
		PeersId:  c.theirId,
		Reason:   uint32(reason),
		Message:  message,
		Language: "en",
	}
	return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject))
}

func (c *channel) handlePacket(packet interface{}) {
	c.lock.Lock()
	defer c.lock.Unlock()

	switch packet := packet.(type) {
	case *channelRequestMsg:
		req := ChannelRequest{
			Request:   packet.Request,
			WantReply: packet.WantReply,
			Payload:   packet.RequestSpecificData,
		}

		c.pendingRequests = append(c.pendingRequests, req)
		c.cond.Signal()
	case *channelCloseMsg:
		c.theyClosed = true
		c.cond.Signal()
	case *channelEOFMsg:
		c.theySentEOF = true
		c.cond.Signal()
	default:
		panic("unknown packet type")
	}
}

func (c *channel) handleData(data []byte) {
	c.lock.Lock()
	defer c.lock.Unlock()

	// The other side should never send us more than our window.
	if len(data)+c.length > len(c.pendingData) {
		// TODO(agl): we should tear down the channel with a protocol
		// error.
		return
	}

	c.myWindow -= uint32(len(data))
	for i := 0; i < 2; i++ {
		tail := c.head + c.length
		if tail > len(c.pendingData) {
			tail -= len(c.pendingData)
		}
		n := copy(c.pendingData[tail:], data)
		data = data[n:]
		c.length += n
	}

	c.cond.Signal()
}

func (c *channel) Read(data []byte) (n int, err error) {
	c.lock.Lock()
	defer c.lock.Unlock()

	if c.err != nil {
		return 0, c.err
	}

	if c.myWindow <= uint32(len(c.pendingData))/2 {
		packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
			PeersId:         c.theirId,
			AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
		})
		if err := c.serverConn.writePacket(packet); err != nil {
			return 0, err
		}
	}

	for {
		if c.theySentEOF || c.theyClosed || c.dead {
			return 0, io.EOF
		}

		if len(c.pendingRequests) > 0 {
			req := c.pendingRequests[0]
			if len(c.pendingRequests) == 1 {
				c.pendingRequests = nil
			} else {
				oldPendingRequests := c.pendingRequests
				c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
				copy(c.pendingRequests, oldPendingRequests[1:])
			}

			return 0, req
		}

		if c.length > 0 {
			tail := c.head + c.length
			if tail > len(c.pendingData) {
				tail -= len(c.pendingData)
			}
			n = copy(data, c.pendingData[c.head:tail])
			c.head += n
			c.length -= n
			if c.head == len(c.pendingData) {
				c.head = 0
			}
			return
		}

		c.cond.Wait()
	}

	panic("unreachable")
}

func (c *channel) Write(data []byte) (n int, err error) {
	for len(data) > 0 {
		c.lock.Lock()
		if c.dead || c.weClosed {
			return 0, io.EOF
		}

		if c.theirWindow == 0 {
			c.cond.Wait()
			continue
		}
		c.lock.Unlock()

		todo := data
		if uint32(len(todo)) > c.theirWindow {
			todo = todo[:c.theirWindow]
		}

		packet := make([]byte, 1+4+4+len(todo))
		packet[0] = msgChannelData
		packet[1] = byte(c.theirId) >> 24
		packet[2] = byte(c.theirId) >> 16
		packet[3] = byte(c.theirId) >> 8
		packet[4] = byte(c.theirId)
		packet[5] = byte(len(todo)) >> 24
		packet[6] = byte(len(todo)) >> 16
		packet[7] = byte(len(todo)) >> 8
		packet[8] = byte(len(todo))
		copy(packet[9:], todo)

		c.serverConn.lock.Lock()
		if err = c.serverConn.writePacket(packet); err != nil {
			c.serverConn.lock.Unlock()
			return
		}
		c.serverConn.lock.Unlock()

		n += len(todo)
		data = data[len(todo):]
	}

	return
}

func (c *channel) Close() error {
	c.serverConn.lock.Lock()
	defer c.serverConn.lock.Unlock()

	if c.serverConn.err != nil {
		return c.serverConn.err
	}

	if c.weClosed {
		return errors.New("ssh: channel already closed")
	}
	c.weClosed = true

	closeMsg := channelCloseMsg{
		PeersId: c.theirId,
	}
	return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg))
}

func (c *channel) AckRequest(ok bool) error {
	c.serverConn.lock.Lock()
	defer c.serverConn.lock.Unlock()

	if c.serverConn.err != nil {
		return c.serverConn.err
	}

	if ok {
		ack := channelRequestSuccessMsg{
			PeersId: c.theirId,
		}
		return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
	} else {
		ack := channelRequestFailureMsg{
			PeersId: c.theirId,
		}
		return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
	}
	panic("unreachable")
}

func (c *channel) ChannelType() string {
	return c.chanType
}

func (c *channel) ExtraData() []byte {
	return c.extraData
}
