Skip to content

Commit

Permalink
Merge pull request #3611 from greymistcube/refactor/copy-states
Browse files Browse the repository at this point in the history
♻️ Refactored `TrieStateStore.CopyStates()`
  • Loading branch information
greymistcube authored Jan 18, 2024
2 parents ccef9c5 + c61248d commit 3164228
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 95 deletions.
100 changes: 30 additions & 70 deletions Libplanet.Store/Trie/MerkleTrie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,69 +185,39 @@ internal IEnumerable<HashDigest<SHA256>> IterateHashNodes()
.Select(pair => ((HashNode)pair.Node).HashDigest);
}

internal IEnumerable<(KeyBytes Key, byte[] Value)> IterateNodeKeyValuePairs()
/// <summary>
/// Iterates over <see cref="KeyBytes"/> and <see cref="byte[]"/> pairs stored
/// necessary to fully represent this <see cref="ITrie"/>.
/// </summary>
/// <returns>An <see cref="IEnumerable"/> of all <see cref="KeyBytes"/> and
/// <see cref="byte[]"/> pairs stored necessary to fully represent
/// this <see cref="ITrie"/>.</returns>
/// <exception cref="NullReferenceException">Thrown when a <see cref="HashNode"/>
/// is encountered that can't be decoded into an <see cref="INode"/>.</exception>
internal IEnumerable<(KeyBytes Key, byte[] Value)> IterateKeyValuePairs()
{
if (Root is null)
{
yield break;
}

var queue =
new Queue<(KeyBytes Key, byte[] Value, ImmutableArray<byte> Path)>();
switch (Root)
{
case ValueNode valueNode:
var value = _codec.Encode(valueNode.ToBencodex());
var key = new KeyBytes(HashDigest<SHA256>.DeriveFrom(value).ByteArray);
yield return (key, value);
yield break;

case HashNode hashNode:
key = new KeyBytes(hashNode.HashDigest.ToByteArray());
queue.Enqueue((key, KeyValueStore.Get(key), ImmutableArray<byte>.Empty));
break;

case FullNode _:
case ShortNode _:
value = _codec.Encode(Root.ToBencodex());
key = new KeyBytes(HashDigest<SHA256>.DeriveFrom(value).ByteArray);
queue.Enqueue((key, value, ImmutableArray<byte>.Empty));
break;
}

bool GuessValueNodeByPath(in ImmutableArray<byte> path)
{
if (path.Length < 2)
{
return false;
}

bool isStartedWithUnderbar = (path[0] << 4) + path[1] == '_';

bool isStatePath = !isStartedWithUnderbar &&
path.Length == Address.Size * 2 * 2;
return isStatePath;
}
var queue = new Queue<INode>();
queue.Enqueue(Root);

while (queue.Count > 0)
{
(KeyBytes key, byte[] value, ImmutableArray<byte> path) =
queue.Dequeue();

// It assumes every length of value nodes is same with Address' hexadecimal
// string's hexadecimal string's size.
bool isValueNode = GuessValueNodeByPath(path);

yield return (key, value);

if (isValueNode)
{
continue;
}

var node = NodeDecoder.Decode(_codec.Decode(value), NodeDecoder.AnyNodeType);
if (isValueNode)
INode node = queue.Dequeue();
if (node is HashNode dequeuedHashNode)
{
var storedKey = new KeyBytes(dequeuedHashNode.HashDigest.ByteArray);
var storedValue = KeyValueStore.Get(storedKey);
var intermediateEncoding = _codec.Decode(storedValue);
queue.Enqueue(
NodeDecoder.Decode(
intermediateEncoding,
NodeDecoder.HashEmbeddedNodeType) ??
throw new NullReferenceException());
yield return (storedKey, storedValue);
continue;
}

Expand All @@ -257,41 +227,31 @@ bool GuessValueNodeByPath(in ImmutableArray<byte> path)
foreach (int index in Enumerable.Range(0, FullNode.ChildrenCount - 1))
{
INode? child = fullNode.Children[index];
if (child is HashNode hashNode)
if (child is HashNode childHashNode)
{
key = new KeyBytes(hashNode.HashDigest.ByteArray);
value = KeyValueStore.Get(key);
queue.Enqueue((key, value, path.Add((byte)index)));
queue.Enqueue(childHashNode);
}
}

switch (fullNode.Value)
if (fullNode.Value is HashNode fullNodeValueHashNode)
{
case HashNode hashNode:
key = new KeyBytes(hashNode.HashDigest.ByteArray);
value = KeyValueStore.Get(key);
queue.Enqueue((key, value, path));
break;
queue.Enqueue(fullNodeValueHashNode);
}

break;

case ShortNode shortNode:
switch (shortNode.Value)
if (shortNode.Value is HashNode shortNodeValueHashNode)
{
case HashNode hashNode:
key = new KeyBytes(hashNode.HashDigest.ByteArray);
value = KeyValueStore.Get(key);
queue.Enqueue((key, value, path.AddRange(shortNode.Key.ByteArray)));
break;
queue.Enqueue(shortNodeValueHashNode);
}

break;

case ValueNode _:
break;

default:
case HashNode _:
throw new InvalidOperationException();
}
}
Expand Down
20 changes: 15 additions & 5 deletions Libplanet.Store/TrieStateStore.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
Expand Down Expand Up @@ -93,29 +94,38 @@ public void PruneStates(IImmutableSet<HashDigest<SHA256>> survivingStateRootHash
/// </summary>
/// <param name="stateRootHashes">The state root hashes of states to copy.</param>
/// <param name="targetStateStore">The target state store to copy state root hashes.</param>
/// <exception cref="ArgumentException">Thrown when a state root cannot be found for
/// any of given <paramref name="stateRootHashes"/>.</exception>
public void CopyStates(
IImmutableSet<HashDigest<SHA256>> stateRootHashes, TrieStateStore targetStateStore)
{
IKeyValueStore targetKeyValueStore = targetStateStore.StateKeyValueStore;
var stopwatch = new Stopwatch();
long count = 0;
_logger.Verbose("Started {MethodName}()", nameof(CopyStates));
stopwatch.Start();

foreach (HashDigest<SHA256> stateRootHash in stateRootHashes)
{
var stateTrie = new MerkleTrie(
StateKeyValueStore,
new HashNode(stateRootHash));
var stateTrie = (MerkleTrie)GetStateRoot(stateRootHash);
if (!stateTrie.Recorded)
{
throw new ArgumentException(
$"Failed to find a state root for given state root hash {stateRootHash}.");
}

foreach (var (key, value) in stateTrie.IterateNodeKeyValuePairs())
foreach (var (key, value) in stateTrie.IterateKeyValuePairs())
{
targetKeyValueStore.Set(key, value);
count++;
}
}

stopwatch.Stop();
_logger.Debug(
"Finished to copy all states {ElapsedMilliseconds} ms",
"Finished copying all states with {Count} key value pairs " +
"in {ElapsedMilliseconds} ms",
count,
stopwatch.ElapsedMilliseconds);
_logger.Verbose("Finished {MethodName}()", nameof(CopyStates));
}
Expand Down
47 changes: 27 additions & 20 deletions Libplanet.Tests/Store/TrieStateStoreTest.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Security.Cryptography;
Expand Down Expand Up @@ -108,34 +110,33 @@ public void PruneStates()
[Fact]
public void CopyStates()
{
var values = ImmutableDictionary<KeyBytes, IValue>.Empty
.Add(new KeyBytes("foo"), (Binary)GetRandomBytes(4096))
.Add(
new KeyBytes("bar"),
(Text)ByteUtil.Hex(GetRandomBytes(2048)))
.Add(new KeyBytes("baz"), (Bencodex.Types.Boolean)false)
.Add(new KeyBytes("qux"), Bencodex.Types.Dictionary.Empty)
.Add(
new KeyBytes("zzz"),
Bencodex.Types.Dictionary.Empty
.Add("binary", GetRandomBytes(4096))
.Add("text", ByteUtil.Hex(GetRandomBytes(2048))));

var stateStore = new TrieStateStore(_stateKeyValueStore);

IKeyValueStore targetStateKeyValueStore = new MemoryKeyValueStore();
var targetStateStore = new TrieStateStore(targetStateKeyValueStore);
ITrie trie = stateStore.Commit(
values.Aggregate(
stateStore.GetStateRoot(null),
(prev, kv) => prev.Set(kv.Key, kv.Value)));
Random random = new Random();
List<(KeyBytes, IValue)> kvs = Enumerable.Range(0, 1_000)
.Select(_ =>
(
new KeyBytes(GetRandomBytes(random.Next(20))),
(IValue)new Binary(GetRandomBytes(20))
))
.ToList();

ITrie trie = stateStore.GetStateRoot(null);
foreach (var kv in kvs)
{
trie = trie.Set(kv.Item1, kv.Item2);
}

trie = stateStore.Commit(trie);
int prevStatesCount = _stateKeyValueStore.ListKeys().Count();

// NOTE: Avoid possible collision of KeyBytes, just in case.
_stateKeyValueStore.Set(
new KeyBytes("alpha"),
new KeyBytes(GetRandomBytes(30)),
ByteUtil.ParseHex("00"));
_stateKeyValueStore.Set(
new KeyBytes("beta"),
new KeyBytes(GetRandomBytes(40)),
ByteUtil.ParseHex("00"));

Assert.Equal(prevStatesCount + 2, _stateKeyValueStore.ListKeys().Count());
Expand All @@ -149,6 +150,12 @@ public void CopyStates()
// FIXME: Bencodex fingerprints also should be tracked.
// https://github.com/planetarium/libplanet/issues/1653
Assert.Equal(prevStatesCount, targetStateKeyValueStore.ListKeys().Count());
Assert.Equal(
trie.IterateNodes().Count(),
targetStateStore.GetStateRoot(trie.Hash).IterateNodes().Count());
Assert.Equal(
trie.IterateValues().Count(),
targetStateStore.GetStateRoot(trie.Hash).IterateValues().Count());
}

[Fact]
Expand Down

0 comments on commit 3164228

Please sign in to comment.