aboutsummaryrefslogtreecommitdiffstats
path: root/p2p/netutil
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/netutil')
-rw-r--r--p2p/netutil/net.go131
-rw-r--r--p2p/netutil/net_test.go89
2 files changed, 220 insertions, 0 deletions
diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go
index f6005afd2..656abb682 100644
--- a/p2p/netutil/net.go
+++ b/p2p/netutil/net.go
@@ -18,8 +18,11 @@
package netutil
import (
+ "bytes"
"errors"
+ "fmt"
"net"
+ "sort"
"strings"
)
@@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
}
return nil
}
+
+// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
+func SameNet(bits uint, ip, other net.IP) bool {
+ ip4, other4 := ip.To4(), other.To4()
+ switch {
+ case (ip4 == nil) != (other4 == nil):
+ return false
+ case ip4 != nil:
+ return sameNet(bits, ip4, other4)
+ default:
+ return sameNet(bits, ip.To16(), other.To16())
+ }
+}
+
+func sameNet(bits uint, ip, other net.IP) bool {
+ nb := int(bits / 8)
+ mask := ^byte(0xFF >> (bits % 8))
+ if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
+ return false
+ }
+ return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
+}
+
+// DistinctNetSet tracks IPs, ensuring that at most N of them
+// fall into the same network range.
+type DistinctNetSet struct {
+ Subnet uint // number of common prefix bits
+ Limit uint // maximum number of IPs in each subnet
+
+ members map[string]uint
+ buf net.IP
+}
+
+// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
+// number of existing IPs in the defined range exceeds the limit.
+func (s *DistinctNetSet) Add(ip net.IP) bool {
+ key := s.key(ip)
+ n := s.members[string(key)]
+ if n < s.Limit {
+ s.members[string(key)] = n + 1
+ return true
+ }
+ return false
+}
+
+// Remove removes an IP from the set.
+func (s *DistinctNetSet) Remove(ip net.IP) {
+ key := s.key(ip)
+ if n, ok := s.members[string(key)]; ok {
+ if n == 1 {
+ delete(s.members, string(key))
+ } else {
+ s.members[string(key)] = n - 1
+ }
+ }
+}
+
+// Contains whether the given IP is contained in the set.
+func (s DistinctNetSet) Contains(ip net.IP) bool {
+ key := s.key(ip)
+ _, ok := s.members[string(key)]
+ return ok
+}
+
+// Len returns the number of tracked IPs.
+func (s DistinctNetSet) Len() int {
+ n := uint(0)
+ for _, i := range s.members {
+ n += i
+ }
+ return int(n)
+}
+
+// key encodes the map key for an address into a temporary buffer.
+//
+// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
+// The remainder of the key is the IP, truncated to the number of bits.
+func (s *DistinctNetSet) key(ip net.IP) net.IP {
+ // Lazily initialize storage.
+ if s.members == nil {
+ s.members = make(map[string]uint)
+ s.buf = make(net.IP, 17)
+ }
+ // Canonicalize ip and bits.
+ typ := byte('6')
+ if ip4 := ip.To4(); ip4 != nil {
+ typ, ip = '4', ip4
+ }
+ bits := s.Subnet
+ if bits > uint(len(ip)*8) {
+ bits = uint(len(ip) * 8)
+ }
+ // Encode the prefix into s.buf.
+ nb := int(bits / 8)
+ mask := ^byte(0xFF >> (bits % 8))
+ s.buf[0] = typ
+ buf := append(s.buf[:1], ip[:nb]...)
+ if nb < len(ip) && mask != 0 {
+ buf = append(buf, ip[nb]&mask)
+ }
+ return buf
+}
+
+// String implements fmt.Stringer
+func (s DistinctNetSet) String() string {
+ var buf bytes.Buffer
+ buf.WriteString("{")
+ keys := make([]string, 0, len(s.members))
+ for k := range s.members {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for i, k := range keys {
+ var ip net.IP
+ if k[0] == '4' {
+ ip = make(net.IP, 4)
+ } else {
+ ip = make(net.IP, 16)
+ }
+ copy(ip, k[1:])
+ fmt.Fprintf(&buf, "%vĂ—%d", ip, s.members[k])
+ if i != len(keys)-1 {
+ buf.WriteString(" ")
+ }
+ }
+ buf.WriteString("}")
+ return buf.String()
+}
diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go
index 1ee1fcb4d..3a6aa081f 100644
--- a/p2p/netutil/net_test.go
+++ b/p2p/netutil/net_test.go
@@ -17,9 +17,11 @@
package netutil
import (
+ "fmt"
"net"
"reflect"
"testing"
+ "testing/quick"
"github.com/davecgh/go-spew/spew"
)
@@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
CheckRelayIP(sender, addr)
}
}
+
+func TestSameNet(t *testing.T) {
+ tests := []struct {
+ ip, other string
+ bits uint
+ want bool
+ }{
+ {"0.0.0.0", "0.0.0.0", 32, true},
+ {"0.0.0.0", "0.0.0.1", 0, true},
+ {"0.0.0.0", "0.0.0.1", 31, true},
+ {"0.0.0.0", "0.0.0.1", 32, false},
+ {"0.33.0.1", "0.34.0.2", 8, true},
+ {"0.33.0.1", "0.34.0.2", 13, true},
+ {"0.33.0.1", "0.34.0.2", 15, false},
+ }
+
+ for _, test := range tests {
+ if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
+ t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
+ }
+ }
+}
+
+func ExampleSameNet() {
+ // This returns true because the IPs are in the same /24 network:
+ fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
+ // This call returns false:
+ fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
+ // Output:
+ // true
+ // false
+}
+
+func TestDistinctNetSet(t *testing.T) {
+ ops := []struct {
+ add, remove string
+ fails bool
+ }{
+ {add: "127.0.0.1"},
+ {add: "127.0.0.2"},
+ {add: "127.0.0.3", fails: true},
+ {add: "127.32.0.1"},
+ {add: "127.32.0.2"},
+ {add: "127.32.0.3", fails: true},
+ {add: "127.33.0.1", fails: true},
+ {add: "127.34.0.1"},
+ {add: "127.34.0.2"},
+ {add: "127.34.0.3", fails: true},
+ // Make room for an address, then add again.
+ {remove: "127.0.0.1"},
+ {add: "127.0.0.3"},
+ {add: "127.0.0.3", fails: true},
+ }
+
+ set := DistinctNetSet{Subnet: 15, Limit: 2}
+ for _, op := range ops {
+ var desc string
+ if op.add != "" {
+ desc = fmt.Sprintf("Add(%s)", op.add)
+ if ok := set.Add(parseIP(op.add)); ok != !op.fails {
+ t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
+ }
+ } else {
+ desc = fmt.Sprintf("Remove(%s)", op.remove)
+ set.Remove(parseIP(op.remove))
+ }
+ t.Logf("%s: %v", desc, set)
+ }
+}
+
+func TestDistinctNetSetAddRemove(t *testing.T) {
+ cfg := &quick.Config{}
+ fn := func(ips []net.IP) bool {
+ s := DistinctNetSet{Limit: 3, Subnet: 2}
+ for _, ip := range ips {
+ s.Add(ip)
+ }
+ for _, ip := range ips {
+ s.Remove(ip)
+ }
+ return s.Len() == 0
+ }
+
+ if err := quick.Check(fn, cfg); err != nil {
+ t.Fatal(err)
+ }
+}