From dcdd57df6282a6cd43a6407e8626a5cdcca60482 Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Wed, 18 Jul 2018 18:41:36 +0800
Subject: core, ethdb: two tiny fixes (#17183)

* ethdb: fix memory database

* core: fix bloombits checking

* core: minor polish
---
 ethdb/database_test.go   | 22 ++++++++++++++++++++++
 ethdb/memory_database.go | 12 ++++++++----
 2 files changed, 30 insertions(+), 4 deletions(-)

(limited to 'ethdb')

diff --git a/ethdb/database_test.go b/ethdb/database_test.go
index 2deb50988..74675cbe6 100644
--- a/ethdb/database_test.go
+++ b/ethdb/database_test.go
@@ -59,6 +59,28 @@ func TestMemoryDB_PutGet(t *testing.T) {
 func testPutGet(db ethdb.Database, t *testing.T) {
 	t.Parallel()
 
+	for _, k := range test_values {
+		err := db.Put([]byte(k), nil)
+		if err != nil {
+			t.Fatalf("put failed: %v", err)
+		}
+	}
+
+	for _, k := range test_values {
+		data, err := db.Get([]byte(k))
+		if err != nil {
+			t.Fatalf("get failed: %v", err)
+		}
+		if len(data) != 0 {
+			t.Fatalf("get returned wrong result, got %q expected nil", string(data))
+		}
+	}
+
+	_, err := db.Get([]byte("non-exist-key"))
+	if err == nil {
+		t.Fatalf("expect to return a not found error")
+	}
+
 	for _, v := range test_values {
 		err := db.Put([]byte(v), []byte(v))
 		if err != nil {
diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go
index f28ff5481..727f2f7ca 100644
--- a/ethdb/memory_database.go
+++ b/ethdb/memory_database.go
@@ -96,7 +96,10 @@ func (db *MemDatabase) NewBatch() Batch {
 
 func (db *MemDatabase) Len() int { return len(db.db) }
 
-type kv struct{ k, v []byte }
+type kv struct {
+	k, v []byte
+	del  bool
+}
 
 type memBatch struct {
 	db     *MemDatabase
@@ -105,13 +108,14 @@ type memBatch struct {
 }
 
 func (b *memBatch) Put(key, value []byte) error {
-	b.writes = append(b.writes, kv{common.CopyBytes(key), common.CopyBytes(value)})
+	b.writes = append(b.writes, kv{common.CopyBytes(key), common.CopyBytes(value), false})
 	b.size += len(value)
 	return nil
 }
 
 func (b *memBatch) Delete(key []byte) error {
-	b.writes = append(b.writes, kv{common.CopyBytes(key), nil})
+	b.writes = append(b.writes, kv{common.CopyBytes(key), nil, true})
+	b.size += 1
 	return nil
 }
 
@@ -120,7 +124,7 @@ func (b *memBatch) Write() error {
 	defer b.db.lock.Unlock()
 
 	for _, kv := range b.writes {
-		if kv.v == nil {
+		if kv.del {
 			delete(b.db.db, string(kv.k))
 			continue
 		}
-- 
cgit