diff options
Diffstat (limited to 'swarm/pss/pss.go')
-rw-r--r-- | swarm/pss/pss.go | 171 |
1 files changed, 131 insertions, 40 deletions
diff --git a/swarm/pss/pss.go b/swarm/pss/pss.go index e1e24e1f5..d0986d280 100644 --- a/swarm/pss/pss.go +++ b/swarm/pss/pss.go @@ -23,11 +23,13 @@ import ( "crypto/rand" "errors" "fmt" + "hash" "sync" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" @@ -136,10 +138,10 @@ type Pss struct { symKeyDecryptCacheCapacity int // max amount of symkeys to keep. // message handling - handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle() - handlersMu sync.RWMutex - allowRaw bool - hashPool sync.Pool + handlers map[Topic]map[*handler]bool // topic and version based pss payload handlers. See pss.Handle() + handlersMu sync.RWMutex + hashPool sync.Pool + topicHandlerCaps map[Topic]*handlerCaps // caches capabilities of each topic's handlers (see handlerCap* consts in types.go) // process quitC chan struct{} @@ -180,11 +182,12 @@ func NewPss(k *network.Kademlia, params *PssParams) (*Pss, error) { symKeyDecryptCache: make([]*string, params.SymKeyCacheCapacity), symKeyDecryptCacheCapacity: params.SymKeyCacheCapacity, - handlers: make(map[Topic]map[*Handler]bool), - allowRaw: params.AllowRaw, + handlers: make(map[Topic]map[*handler]bool), + topicHandlerCaps: make(map[Topic]*handlerCaps), + hashPool: sync.Pool{ New: func() interface{} { - return storage.MakeHashFunc(storage.DefaultHash)() + return sha3.NewKeccak256() }, }, } @@ -313,30 +316,54 @@ func (p *Pss) PublicKey() *ecdsa.PublicKey { // // Returns a deregister function which needs to be called to // deregister the handler, -func (p *Pss) Register(topic *Topic, handler Handler) func() { +func (p *Pss) Register(topic *Topic, hndlr *handler) func() { p.handlersMu.Lock() defer p.handlersMu.Unlock() handlers := p.handlers[*topic] if handlers == nil { - handlers = make(map[*Handler]bool) + handlers = make(map[*handler]bool) p.handlers[*topic] = handlers + log.Debug("registered handler", "caps", hndlr.caps) + } + if hndlr.caps == nil { + hndlr.caps = &handlerCaps{} + } + handlers[hndlr] = true + if _, ok := p.topicHandlerCaps[*topic]; !ok { + p.topicHandlerCaps[*topic] = &handlerCaps{} } - handlers[&handler] = true - return func() { p.deregister(topic, &handler) } + if hndlr.caps.raw { + p.topicHandlerCaps[*topic].raw = true + } + if hndlr.caps.prox { + p.topicHandlerCaps[*topic].prox = true + } + return func() { p.deregister(topic, hndlr) } } -func (p *Pss) deregister(topic *Topic, h *Handler) { +func (p *Pss) deregister(topic *Topic, hndlr *handler) { p.handlersMu.Lock() defer p.handlersMu.Unlock() handlers := p.handlers[*topic] - if len(handlers) == 1 { + if len(handlers) > 1 { delete(p.handlers, *topic) + // topic caps might have changed now that a handler is gone + caps := &handlerCaps{} + for h := range handlers { + if h.caps.raw { + caps.raw = true + } + if h.caps.prox { + caps.prox = true + } + } + p.topicHandlerCaps[*topic] = caps return } - delete(handlers, h) + delete(handlers, hndlr) } // get all registered handlers for respective topics -func (p *Pss) getHandlers(topic Topic) map[*Handler]bool { +func (p *Pss) getHandlers(topic Topic) map[*handler]bool { p.handlersMu.RLock() defer p.handlersMu.RUnlock() return p.handlers[topic] @@ -348,12 +375,11 @@ func (p *Pss) getHandlers(topic Topic) map[*Handler]bool { // Only passes error to pss protocol handler if payload is not valid pssmsg func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error { metrics.GetOrRegisterCounter("pss.handlepssmsg", nil).Inc(1) - pssmsg, ok := msg.(*PssMsg) - if !ok { return fmt.Errorf("invalid message type. Expected *PssMsg, got %T ", msg) } + log.Trace("handler", "self", label(p.Kademlia.BaseAddr()), "topic", label(pssmsg.Payload.Topic[:])) if int64(pssmsg.Expire) < time.Now().Unix() { metrics.GetOrRegisterCounter("pss.expire", nil).Inc(1) log.Warn("pss filtered expired message", "from", common.ToHex(p.Kademlia.BaseAddr()), "to", common.ToHex(pssmsg.To)) @@ -365,13 +391,34 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error { } p.addFwdCache(pssmsg) - if !p.isSelfPossibleRecipient(pssmsg) { - log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr())) + psstopic := Topic(pssmsg.Payload.Topic) + + // raw is simplest handler contingency to check, so check that first + var isRaw bool + if pssmsg.isRaw() { + if !p.topicHandlerCaps[psstopic].raw { + log.Debug("No handler for raw message", "topic", psstopic) + return nil + } + isRaw = true + } + + // check if we can be recipient: + // - no prox handler on message and partial address matches + // - prox handler on message and we are in prox regardless of partial address match + // store this result so we don't calculate again on every handler + var isProx bool + if _, ok := p.topicHandlerCaps[psstopic]; ok { + isProx = p.topicHandlerCaps[psstopic].prox + } + isRecipient := p.isSelfPossibleRecipient(pssmsg, isProx) + if !isRecipient { + log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()), "prox", isProx) return p.enqueue(pssmsg) } - log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr())) - if err := p.process(pssmsg); err != nil { + log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()), "prox", isProx, "raw", isRaw, "topic", label(pssmsg.Payload.Topic[:])) + if err := p.process(pssmsg, isRaw, isProx); err != nil { qerr := p.enqueue(pssmsg) if qerr != nil { return fmt.Errorf("process fail: processerr %v, queueerr: %v", err, qerr) @@ -384,7 +431,7 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error { // Entry point to processing a message for which the current node can be the intended recipient. // Attempts symmetric and asymmetric decryption with stored keys. // Dispatches message to all handlers matching the message topic -func (p *Pss) process(pssmsg *PssMsg) error { +func (p *Pss) process(pssmsg *PssMsg, raw bool, prox bool) error { metrics.GetOrRegisterCounter("pss.process", nil).Inc(1) var err error @@ -397,10 +444,8 @@ func (p *Pss) process(pssmsg *PssMsg) error { envelope := pssmsg.Payload psstopic := Topic(envelope.Topic) - if pssmsg.isRaw() { - if !p.allowRaw { - return errors.New("raw message support disabled") - } + + if raw { payload = pssmsg.Payload.Data } else { if pssmsg.isSym() { @@ -422,19 +467,27 @@ func (p *Pss) process(pssmsg *PssMsg) error { return err } } - p.executeHandlers(psstopic, payload, from, asymmetric, keyid) + p.executeHandlers(psstopic, payload, from, raw, prox, asymmetric, keyid) return nil } -func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, asymmetric bool, keyid string) { +func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, raw bool, prox bool, asymmetric bool, keyid string) { handlers := p.getHandlers(topic) peer := p2p.NewPeer(enode.ID{}, fmt.Sprintf("%x", from), []p2p.Cap{}) - for f := range handlers { - err := (*f)(payload, peer, asymmetric, keyid) + for h := range handlers { + if !h.caps.raw && raw { + log.Warn("norawhandler") + continue + } + if !h.caps.prox && prox { + log.Warn("noproxhandler") + continue + } + err := (h.f)(payload, peer, asymmetric, keyid) if err != nil { - log.Warn("Pss handler %p failed: %v", f, err) + log.Warn("Pss handler failed", "err", err) } } } @@ -445,9 +498,23 @@ func (p *Pss) isSelfRecipient(msg *PssMsg) bool { } // test match of leftmost bytes in given message to node's Kademlia address -func (p *Pss) isSelfPossibleRecipient(msg *PssMsg) bool { +func (p *Pss) isSelfPossibleRecipient(msg *PssMsg, prox bool) bool { local := p.Kademlia.BaseAddr() - return bytes.Equal(msg.To, local[:len(msg.To)]) + + // if a partial address matches we are possible recipient regardless of prox + // if not and prox is not set, we are surely not + if bytes.Equal(msg.To, local[:len(msg.To)]) { + + return true + } else if !prox { + return false + } + + depth := p.Kademlia.NeighbourhoodDepth() + po, _ := p.Kademlia.Pof(p.Kademlia.BaseAddr(), msg.To, 0) + log.Trace("selfpossible", "po", po, "depth", depth) + + return depth <= po } ///////////////////////////////////////////////////////////////////// @@ -684,9 +751,6 @@ func (p *Pss) enqueue(msg *PssMsg) error { // // Will fail if raw messages are disallowed func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error { - if !p.allowRaw { - return errors.New("Raw messages not enabled") - } pssMsgParams := &msgParams{ raw: true, } @@ -699,7 +763,17 @@ func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error { pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix()) pssMsg.Payload = payload p.addFwdCache(pssMsg) - return p.enqueue(pssMsg) + err := p.enqueue(pssMsg) + if err != nil { + return err + } + + // if we have a proxhandler on this topic + // also deliver message to ourselves + if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox { + return p.process(pssMsg, true, true) + } + return nil } // Send a message using symmetric encryption @@ -800,7 +874,16 @@ func (p *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key []by pssMsg.To = to pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix()) pssMsg.Payload = envelope - return p.enqueue(pssMsg) + err = p.enqueue(pssMsg) + if err != nil { + return err + } + if _, ok := p.topicHandlerCaps[topic]; ok { + if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox { + return p.process(pssMsg, true, true) + } + } + return nil } // Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct @@ -895,6 +978,10 @@ func (p *Pss) cleanFwdCache() { } } +func label(b []byte) string { + return fmt.Sprintf("%04x", b[:2]) +} + // add a message to the cache func (p *Pss) addFwdCache(msg *PssMsg) error { metrics.GetOrRegisterCounter("pss.addfwdcache", nil).Inc(1) @@ -934,10 +1021,14 @@ func (p *Pss) checkFwdCache(msg *PssMsg) bool { // Digest of message func (p *Pss) digest(msg *PssMsg) pssDigest { - hasher := p.hashPool.Get().(storage.SwarmHash) + return p.digestBytes(msg.serialize()) +} + +func (p *Pss) digestBytes(msg []byte) pssDigest { + hasher := p.hashPool.Get().(hash.Hash) defer p.hashPool.Put(hasher) hasher.Reset() - hasher.Write(msg.serialize()) + hasher.Write(msg) digest := pssDigest{} key := hasher.Sum(nil) copy(digest[:], key[:digestLength]) |