diff options
Diffstat (limited to 'p2p/messenger.go')
-rw-r--r-- | p2p/messenger.go | 353 |
1 files changed, 177 insertions, 176 deletions
diff --git a/p2p/messenger.go b/p2p/messenger.go index d42ba1720..7375ecc07 100644 --- a/p2p/messenger.go +++ b/p2p/messenger.go @@ -1,220 +1,221 @@ package p2p import ( + "bufio" + "bytes" "fmt" + "io" + "io/ioutil" + "net" "sync" "time" ) -const ( - handlerTimeout = 1000 -) +type Handlers map[string]func() Protocol -type Handlers map[string](func(p *Peer) Protocol) - -type Messenger struct { - conn *Connection - peer *Peer - handlers Handlers - protocolLock sync.RWMutex - protocols []Protocol - offsets []MsgCode // offsets for adaptive message idss - protocolTable map[string]int - quit chan chan bool - err chan *PeerError - pulse chan bool -} - -func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { - baseProtocol := NewBaseProtocol(peer) - return &Messenger{ - conn: conn, - peer: peer, - offsets: []MsgCode{baseProtocol.Offset()}, - handlers: handlers, - protocols: []Protocol{baseProtocol}, - protocolTable: make(map[string]int), - err: errchan, - pulse: make(chan bool, 1), - quit: make(chan chan bool, 1), - } +type proto struct { + in chan Msg + maxcode, offset MsgCode + messenger *messenger } -func (self *Messenger) Start() { - self.conn.Open() - go self.messenger() - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - self.protocols[0].Start() +func (rw *proto) WriteMsg(msg Msg) error { + if msg.Code >= rw.maxcode { + return NewPeerError(InvalidMsgCode, "not handled") + } + return rw.messenger.writeMsg(msg) } -func (self *Messenger) Stop() { - // close pulse to stop ping pong monitoring - close(self.pulse) - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - for _, protocol := range self.protocols { - protocol.Stop() // could be parallel +func (rw *proto) ReadMsg() (Msg, error) { + msg, ok := <-rw.in + if !ok { + return msg, io.EOF } - q := make(chan bool) - self.quit <- q - <-q - self.conn.Close() + return msg, nil } -func (self *Messenger) messenger() { - in := self.conn.Read() - for { - select { - case payload, ok := <-in: - //dispatches message to the protocol asynchronously - if ok { - go self.handle(payload) - } else { - return - } - case q := <-self.quit: - q <- true - return - } - } +// eofSignal is used to 'lend' the network connection +// to a protocol. when the protocol's read loop has read the +// whole payload, the done channel is closed. +type eofSignal struct { + wrapped io.Reader + eof chan struct{} } -// handles each message by dispatching to the appropriate protocol -// using adaptive message codes -// this function is started as a separate go routine for each message -// it waits for the protocol response -// then encodes and sends outgoing messages to the connection's write channel -func (self *Messenger) handle(payload []byte) { - // send ping to heartbeat channel signalling time of last message - // select { - // case self.pulse <- true: - // default: - // } - self.pulse <- true - // initialise message from payload - msg, err := NewMsgFromBytes(payload) +func (r *eofSignal) Read(buf []byte) (int, error) { + n, err := r.wrapped.Read(buf) if err != nil { - self.err <- NewPeerError(MiscError, " %v", err) - return + close(r.eof) // tell messenger that msg has been consumed } - // retrieves protocol based on message Code - protocol, offset, peerErr := self.getProtocol(msg.Code()) - if err != nil { - self.err <- peerErr - return + return n, err +} + +// messenger represents a message-oriented peer connection. +// It keeps track of the set of protocols understood +// by the remote peer. +type messenger struct { + peer *Peer + handlers Handlers + + // the mutex protects the connection + // so only one protocol can write at a time. + writeMu sync.Mutex + conn net.Conn + bufconn *bufio.ReadWriter + + protocolLock sync.RWMutex + protocols map[string]*proto + offsets map[MsgCode]*proto + protoWG sync.WaitGroup + + err chan error + pulse chan bool +} + +func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger { + return &messenger{ + conn: conn, + bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + peer: peer, + handlers: handlers, + protocols: make(map[string]*proto), + err: errchan, + pulse: make(chan bool, 1), } - // reset message code based on adaptive offset - msg.Decode(offset) - // dispatches - response := make(chan *Msg) - go protocol.HandleIn(msg, response) - // protocol reponse timeout to prevent leaks - timer := time.After(handlerTimeout * time.Millisecond) +} + +func (m *messenger) Start() { + m.protocols[""] = m.startProto(0, "", &baseProtocol{}) + go m.readLoop() +} + +func (m *messenger) Stop() { + m.conn.Close() + m.protoWG.Wait() +} + +const ( + // maximum amount of time allowed for reading a message + msgReadTimeout = 5 * time.Second + + // messages smaller than this many bytes will be read at + // once before passing them to a protocol. + wholePayloadSize = 64 * 1024 +) + +func (m *messenger) readLoop() { + defer m.closeProtocols() for { - select { - case outgoing, ok := <-response: - // we check if response channel is not closed - if ok { - self.conn.Write() <- outgoing.Encode(offset) - } else { + m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) + msg, err := readMsg(m.bufconn) + if err != nil { + m.err <- err + return + } + // send ping to heartbeat channel signalling time of last message + m.pulse <- true + proto, err := m.getProto(msg.Code) + if err != nil { + m.err <- err + return + } + msg.Code -= proto.offset + if msg.Size <= wholePayloadSize { + // optimization: msg is small enough, read all + // of it and move on to the next message + buf, err := ioutil.ReadAll(msg.Payload) + if err != nil { + m.err <- err return } - case <-timer: - return + msg.Payload = bytes.NewReader(buf) + proto.in <- msg + } else { + pr := &eofSignal{msg.Payload, make(chan struct{})} + msg.Payload = pr + proto.in <- msg + <-pr.eof } } } -// negotiated protocols -// stores offsets needed for adaptive message id scheme - -// based on offsets set at handshake -// get the right protocol to handle the message -func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - base := MsgCode(0) - for index, offset := range self.offsets { - if code < offset { - return self.protocols[index], base, nil - } - base = offset +func (m *messenger) closeProtocols() { + m.protocolLock.RLock() + for _, p := range m.protocols { + close(p.in) } - return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) + m.protocolLock.RUnlock() } -func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { - fmt.Printf("pingpong keepalive started at %v", time.Now()) +func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto { + proto := &proto{ + in: make(chan Msg), + offset: offset, + maxcode: impl.Offset(), + messenger: m, + } + m.protoWG.Add(1) + go func() { + if err := impl.Start(m.peer, proto); err != nil && err != io.EOF { + logger.Errorf("protocol %q error: %v\n", name, err) + m.err <- err + } + m.protoWG.Done() + }() + return proto +} - timer := time.After(timeout) - pinged := false - for { - select { - case _, ok := <-self.pulse: - if ok { - pinged = false - timer = time.After(timeout) - } else { - // pulse is closed, stop monitoring - return - } - case <-timer: - if pinged { - fmt.Printf("timeout at %v", time.Now()) - timeoutCallback() - return - } else { - fmt.Printf("pinged at %v", time.Now()) - pingCallback() - timer = time.After(gracePeriod) - pinged = true - } +// getProto finds the protocol responsible for handling +// the given message code. +func (m *messenger) getProto(code MsgCode) (*proto, error) { + m.protocolLock.RLock() + defer m.protocolLock.RUnlock() + for _, proto := range m.protocols { + if code >= proto.offset && code < proto.offset+proto.maxcode { + return proto, nil } } + return nil, NewPeerError(InvalidMsgCode, "%d", code) } -func (self *Messenger) AddProtocols(protocols []string) { - self.protocolLock.Lock() - defer self.protocolLock.Unlock() - i := len(self.offsets) - offset := self.offsets[i-1] +// setProtocols starts all subprotocols shared with the +// remote peer. the protocols must be sorted alphabetically. +func (m *messenger) setRemoteProtocols(protocols []string) { + m.protocolLock.Lock() + defer m.protocolLock.Unlock() + offset := baseProtocolOffset for _, name := range protocols { - protocolFunc, ok := self.handlers[name] - if ok { - protocol := protocolFunc(self.peer) - self.protocolTable[name] = i - i++ - offset += protocol.Offset() - fmt.Println("offset ", name, offset) - - self.offsets = append(self.offsets, offset) - self.protocols = append(self.protocols, protocol) - protocol.Start() - } else { - fmt.Println("no ", name) - // protocol not handled + protocolFunc, ok := m.handlers[name] + if !ok { + continue // not handled } + inst := protocolFunc() + m.protocols[name] = m.startProto(offset, name, inst) + offset += inst.Offset() } } -func (self *Messenger) Write(protocol string, msg *Msg) error { - self.protocolLock.RLock() - defer self.protocolLock.RUnlock() - i := 0 - offset := MsgCode(0) - if len(protocol) > 0 { - var ok bool - i, ok = self.protocolTable[protocol] - if !ok { - return fmt.Errorf("protocol %v not handled by peer", protocol) - } - offset = self.offsets[i-1] +// writeProtoMsg sends the given message on behalf of the given named protocol. +func (m *messenger) writeProtoMsg(protoName string, msg Msg) error { + m.protocolLock.RLock() + proto, ok := m.protocols[protoName] + m.protocolLock.RUnlock() + if !ok { + return fmt.Errorf("protocol %s not handled by peer", protoName) } - handler := self.protocols[i] - // checking if protocol status/caps allows the message to be sent out - if handler.HandleOut(msg) { - self.conn.Write() <- msg.Encode(offset) + if msg.Code >= proto.maxcode { + return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) + } + msg.Code += proto.offset + return m.writeMsg(msg) +} + +// writeMsg writes a message to the connection. +func (m *messenger) writeMsg(msg Msg) error { + m.writeMu.Lock() + defer m.writeMu.Unlock() + if err := writeMsg(m.bufconn, msg); err != nil { + return err } - return nil + return m.bufconn.Flush() } |