diff options
Diffstat (limited to 'p2p/handshake.go')
-rw-r--r-- | p2p/handshake.go | 27 |
1 files changed, 10 insertions, 17 deletions
diff --git a/p2p/handshake.go b/p2p/handshake.go index 17f572dea..10ef970dc 100644 --- a/p2p/handshake.go +++ b/p2p/handshake.go @@ -32,14 +32,10 @@ const ( ) type conn struct { - *frameRW + MsgReadWriter *protoHandshake } -func newConn(fd net.Conn, hs *protoHandshake) *conn { - return &conn{newFrameRW(fd, msgWriteTimeout), hs} -} - // encHandshake contains the state of the encryption handshake. type encHandshake struct { remoteID discover.NodeID @@ -115,17 +111,16 @@ func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) ( // Run the protocol handshake using authenticated messages. // TODO: move buffering setup here (out of newFrameRW) - phsrw := newRlpxFrameRW(fd, secrets) - rhs, err := readProtocolHandshake(phsrw, our) + rw := newRlpxFrameRW(fd, secrets) + rhs, err := readProtocolHandshake(rw, our) if err != nil { return nil, err } - if err := writeProtocolHandshake(phsrw, our); err != nil { + // TODO: validate that handshake node ID matches + if err := writeProtocolHandshake(rw, our); err != nil { return nil, fmt.Errorf("protocol write error: %v", err) } - - rw := newFrameRW(fd, msgWriteTimeout) - return &conn{rw, rhs}, nil + return &conn{&lockedRW{wrapped: rw}, rhs}, nil } func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) { @@ -136,20 +131,18 @@ func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, // Run the protocol handshake using authenticated messages. // TODO: move buffering setup here (out of newFrameRW) - phsrw := newRlpxFrameRW(fd, secrets) - if err := writeProtocolHandshake(phsrw, our); err != nil { + rw := newRlpxFrameRW(fd, secrets) + if err := writeProtocolHandshake(rw, our); err != nil { return nil, fmt.Errorf("protocol write error: %v", err) } - rhs, err := readProtocolHandshake(phsrw, our) + rhs, err := readProtocolHandshake(rw, our) if err != nil { return nil, fmt.Errorf("protocol handshake read error: %v", err) } if rhs.ID != dial.ID { return nil, errors.New("dialed node id mismatch") } - - rw := newFrameRW(fd, msgWriteTimeout) - return &conn{rw, rhs}, nil + return &conn{&lockedRW{wrapped: rw}, rhs}, nil } // outboundEncHandshake negotiates a session token on conn. |