1
#nullable enable
2
using System;
3
using System.Collections.Generic;
4
using System.Collections.Immutable;
5
using System.Linq;
6
using System.Security.Cryptography;
7
using Bencodex;
8
using Bencodex.Types;
9
using Libplanet.Store.Trie.Nodes;
10

11
namespace Libplanet.Store.Trie
12
{
13
    /// <summary>
14
    /// An <see cref="ITrie"/> implementation implemented
15
    /// <see href="https://eth.wiki/fundamentals/patricia-tree">Merkle Patricia Trie</see>.
16
    /// </summary>
17
    // TODO: implement 'logs' for debugging.
18
    [Equals]
19
    public class MerkleTrie : ITrie
20
    {
21
        public static readonly HashDigest<SHA256> EmptyRootHash;
22

23
        private static Codec _codec;
24

25
        private readonly bool _secure;
26

27
        static MerkleTrie()
28 1
        {
29 1
            _codec = new Codec();
30 1
            EmptyRootHash = Hashcash.Hash(
31 1
                _codec.Encode(default(Null)!));
32 1
        }
33

34
        /// <summary>
35
        /// An <see cref="ITrie"/> implementation.
36
        /// </summary>
37
        /// <param name="keyValueStore">The <see cref="IKeyValueStore"/> storage to store
38
        /// nodes.</param>
39
        /// <param name="rootHash">The root <see cref="ITrie.Hash"/> of
40
        /// <see cref="MerkleTrie"/>.</param>
41
        /// <param name="secure">Whether to use <see cref="MerkleTrie"/> in
42
        /// secure mode.  If it is turned on, <see cref="MerkleTrie"/> internally stores hashed keys
43
        /// instead of bare keys.  <see cref="Hashcash.Hash" /> is used to hash them.</param>
44
        public MerkleTrie(
45
            IKeyValueStore keyValueStore,
46
            HashDigest<SHA256> rootHash,
47
            bool secure = false)
48 1
            : this(keyValueStore, new HashNode(rootHash), secure)
49 1
        {
50 1
        }
51

52
        /// <summary>
53
        /// An <see cref="ITrie"/> implementation.
54
        /// </summary>
55
        /// <param name="keyValueStore">The <see cref="IKeyValueStore"/> storage to store
56
        /// nodes.</param>
57
        /// <param name="root">The root node of <see cref="MerkleTrie"/>. If it is <c>null</c>,
58
        /// it will be treated like empty trie.</param>
59
        /// <param name="secure">Whether to use <see cref="MerkleTrie"/> in secure
60
        /// mode. If it is true, <see cref="MerkleTrie"/> will stores the value with the hashed
61
        /// result from the given key as the key. It will hash with
62
        /// <see cref="Hashcash.Hash"/>.</param>
63 1
        internal MerkleTrie(IKeyValueStore keyValueStore, INode? root = null, bool secure = false)
64 1
        {
65 1
            KeyValueStore = keyValueStore;
66 1
            Root = root is HashNode hashNode && hashNode.HashDigest.Equals(EmptyRootHash)
67 1
                ? null
68 1
                : root;
69 1
            _secure = secure;
70 1
        }
71

72 1
        public HashDigest<SHA256> Hash => Root?.Hash() ?? EmptyRootHash;
73

74 1
        internal INode? Root { get; }
75

76 1
        private IKeyValueStore KeyValueStore { get; }
77

78
        public static bool operator ==(MerkleTrie left, MerkleTrie right) =>
79
            Operator.Weave(left, right);
80

81
        public static bool operator !=(MerkleTrie left, MerkleTrie right) =>
82
            Operator.Weave(left, right);
83

84
        /// <inheritdoc/>
85
        public ITrie Set(byte[] key, IValue value)
86 1
        {
87 1
            if (value is null)
88 1
            {
89 1
                throw new ArgumentNullException(nameof(value));
90
            }
91

92 1
            INode newRootNode = Insert(
93 1
                Root,
94 1
                ImmutableArray<byte>.Empty,
95 1
                ToKey(key).ToImmutableArray(),
96 1
                new ValueNode(value));
97

98 1
            return new MerkleTrie(KeyValueStore, newRootNode, _secure);
99 1
        }
100

101
        /// <inheritdoc/>
102
        public bool TryGet(byte[] key, out IValue? value)
103 1
        {
104 1
            return TryGet(
105 1
                Root,
106 1
                ImmutableArray<byte>.Empty,
107 1
                ToKey(key).ToImmutableArray(),
108 1
                out value);
109 1
        }
110

111
        /// <inheritdoc/>
112
        public ITrie Commit(bool rehearsal = false)
113 1
        {
114 1
            if (Root is null)
115 1
            {
116 1
                return new MerkleTrie(KeyValueStore, new HashNode(EmptyRootHash));
117
            }
118

119 1
            var newRoot = Commit(Root, rehearsal);
120

121
            // It assumes embedded node if it's not HashNode.
122 1
            if (!(newRoot is HashNode) && !rehearsal)
123 1
            {
124 1
                KeyValueStore.Set(newRoot.Hash().ToByteArray(), newRoot.Serialize());
125 1
            }
126

127 1
            return new MerkleTrie(KeyValueStore, newRoot);
128 1
        }
129

130
        internal IEnumerable<HashDigest<SHA256>> IterateHashNodes()
131 1
        {
132 1
            return IterateNodes().Where(pair => pair.Node is HashNode)
133 1
                .Select(pair => ((HashNode)pair.Node).HashDigest);
134 1
        }
135

136
        internal IEnumerable<(INode Node, ImmutableArray<byte> Path)> IterateNodes()
137 1
        {
138 1
            if (Root is null)
139 1
            {
140 1
                yield break;
141
            }
142

143 1
            var queue = new Queue<(INode, ImmutableArray<byte>)>();
144 1
            queue.Enqueue((Root, ImmutableArray<byte>.Empty));
145

146 1
            while (queue.Count > 0)
147 1
            {
148 1
                (INode node, ImmutableArray<byte> path) = queue.Dequeue();
149 1
                yield return (node, path);
150 1
                switch (node)
151
                {
152
                    case FullNode fullNode:
153 1
                        foreach (int index in Enumerable.Range(0, FullNode.ChildrenCount - 1))
154 1
                        {
155 1
                            INode? child = fullNode.Children[index];
156 1
                            if (!(child is null))
157 1
                            {
158 1
                                queue.Enqueue((child, path.Add((byte)index)));
159 1
                            }
160 1
                        }
161

162 1
                        if (!(fullNode.Value is null))
163 0
                        {
164 0
                            queue.Enqueue((fullNode.Value, path));
165 0
                        }
166

167 1
                        break;
168

169
                    case ShortNode shortNode:
170 1
                        if (!(shortNode.Value is null))
171 1
                        {
172 1
                            queue.Enqueue((
173 1
                                    shortNode.Value,
174 1
                                    path.Concat(shortNode.Key).ToImmutableArray()));
175 1
                        }
176

177 1
                        break;
178

179
                    case HashNode hashNode:
180 1
                        INode? nn = GetNode(hashNode.HashDigest);
181 1
                        if (!(nn is null))
182 1
                        {
183 1
                            queue.Enqueue((nn, path));
184 1
                        }
185

186 1
                        break;
187
                }
188 1
            }
189 1
        }
190

191
        private INode Commit(INode node, bool rehearsal = false)
192 1
        {
193 1
            switch (node)
194
            {
195
                case HashNode _:
196 1
                    return node;
197

198
                case FullNode fullNode:
199 1
                    return CommitFullNode(fullNode, rehearsal);
200

201
                case ShortNode shortNode:
202 1
                    return CommitShortNode(shortNode, rehearsal);
203

204
                case ValueNode valueNode:
205 1
                    return CommitValueNode(valueNode, rehearsal);
206

207
                default:
208 0
                    throw new NotSupportedException("Not supported node came.");
209
            }
210 1
        }
211

212
        private INode CommitFullNode(FullNode fullNode, bool rehearsal)
213 1
        {
214 1
            var virtualChildren = new INode?[FullNode.ChildrenCount];
215 1
            for (int i = 0; i < FullNode.ChildrenCount; ++i)
216 1
            {
217 1
                INode? child = fullNode.Children[i];
218 1
                virtualChildren[i] = child is null
219 1
                    ? null
220 1
                    : Commit(child, rehearsal);
221 1
            }
222

223 1
            fullNode = new FullNode(virtualChildren.ToImmutableArray());
224 1
            if (fullNode.Serialize().Length <= HashDigest<SHA256>.Size)
225 1
            {
226 1
                return fullNode;
227
            }
228
            else
229 1
            {
230 1
                var fullNodeHash = fullNode.Hash();
231 1
                if (!rehearsal)
232 1
                {
233 1
                    KeyValueStore.Set(
234 1
                        fullNodeHash.ToByteArray(),
235 1
                        fullNode.Serialize());
236 1
                }
237

238 1
                return new HashNode(fullNodeHash);
239
            }
240 1
        }
241

242
        private INode CommitShortNode(ShortNode shortNode, bool rehearsal)
243 1
        {
244 1
            var committedValueNode = Commit(shortNode.Value!);
245 1
            shortNode = new ShortNode(shortNode.Key, committedValueNode);
246 1
            if (shortNode.Serialize().Length <= HashDigest<SHA256>.Size)
247 1
            {
248 1
                return shortNode;
249
            }
250
            else
251 1
            {
252 1
                var shortNodeHash = shortNode.Hash();
253 1
                if (!rehearsal)
254 1
                {
255 1
                    KeyValueStore.Set(
256 1
                        shortNodeHash.ToByteArray(),
257 1
                        shortNode.Serialize());
258 1
                }
259

260 1
                return new HashNode(shortNodeHash);
261
            }
262 1
        }
263

264
        private INode CommitValueNode(ValueNode valueNode, bool rehearsal)
265 1
        {
266 1
            int nodeSize = valueNode.Serialize().Length;
267 1
            if (nodeSize <= HashDigest<SHA256>.Size)
268 1
            {
269 1
                return valueNode;
270
            }
271
            else
272 1
            {
273 1
                var valueNodeHash = valueNode.Hash();
274 1
                if (!rehearsal)
275 1
                {
276 1
                    KeyValueStore.Set(
277 1
                        valueNodeHash.ToByteArray(),
278 1
                        valueNode.Serialize());
279 1
                }
280

281 1
                return new HashNode(valueNodeHash);
282
            }
283 1
        }
284

285
        private INode Insert(
286
            INode? node,
287
            ImmutableArray<byte> prefix,
288
            ImmutableArray<byte> key,
289
            INode value)
290 1
        {
291
            // If path exists only last one
292 1
            if (key.Length == 0)
293 1
            {
294 1
                return value;
295
            }
296

297 1
            switch (node)
298
            {
299
                case ShortNode shortNode:
300 1
                    return InsertShortNode(shortNode, prefix, key, value);
301

302
                case FullNode fullNode:
303 1
                    var n = Insert(
304 1
                        fullNode.Children[key[0]],
305 1
                        prefix.Add(key[0]),
306 1
                        key.Skip(1).ToImmutableArray(),
307 1
                        value);
308 1
                    return fullNode.SetChild(key[0], n);
309

310
                case null:
311 1
                    return new ShortNode(key.ToArray(), value);
312

313
                case HashNode hashNode:
314 1
                    var hn = GetNode(hashNode.HashDigest);
315 1
                    return Insert(hn, prefix, key, value);
316

317
                default:
318 0
                    throw new InvalidTrieNodeException("Not supported node came." +
319 0
                                                       $" raw: {node.ToBencodex().Inspection}");
320
            }
321 1
        }
322

323
        private INode InsertShortNode(
324
            ShortNode shortNode,
325
            ImmutableArray<byte> prefix,
326
            ImmutableArray<byte> key,
327
            INode value)
328 1
        {
329
            int CommonPrefixLen(ImmutableArray<byte> a, ImmutableArray<byte> b)
330 1
            {
331 1
                var length = Math.Min(a.Length, b.Length);
332 1
                foreach (var i in Enumerable.Range(0, length))
333 1
                {
334 1
                    if (a[i] != b[i])
335 1
                    {
336 1
                        return i;
337
                    }
338 1
                }
339

340 1
                return length;
341 1
            }
342

343 1
            int commonPrefixLength = CommonPrefixLen(shortNode.Key, key);
344 1
            if (commonPrefixLength == shortNode.Key.Length)
345 1
            {
346 1
                var nn = Insert(
347 1
                    shortNode.Value,
348 1
                    prefix.AddRange(key.Take(commonPrefixLength)),
349 1
                    key.Skip(commonPrefixLength).ToImmutableArray(),
350 1
                    value);
351 1
                return new ShortNode(shortNode.Key, nn);
352
            }
353

354 1
            var branch = new FullNode();
355 1
            branch = branch.SetChild(
356 1
                key[commonPrefixLength],
357 1
                Insert(
358 1
                    null,
359 1
                    prefix.AddRange(key.Take(commonPrefixLength + 1)),
360 1
                    key.Skip(commonPrefixLength + 1).ToImmutableArray(),
361 1
                    value));
362 1
            branch = branch.SetChild(
363 1
                shortNode.Key[commonPrefixLength],
364 1
                Insert(
365 1
                    null,
366 1
                    prefix.AddRange(shortNode.Key.Take(commonPrefixLength + 1)),
367 1
                    shortNode.Key.Skip(commonPrefixLength + 1).ToImmutableArray(),
368 1
                    shortNode.Value!));
369

370 1
            if (commonPrefixLength == 0)
371 1
            {
372 1
                return branch;
373
            }
374

375
            // extension node
376 1
            return new ShortNode(key.Take(commonPrefixLength).ToArray(), branch);
377 1
        }
378

379
        private bool TryGet(
380
            INode? node,
381
            ImmutableArray<byte> prefix,
382
            ImmutableArray<byte> path,
383
            out IValue? value)
384 1
        {
385 1
            switch (node)
386
            {
387
                case null:
388 1
                    value = null;
389 1
                    return false;
390

391
                case ValueNode valueNode:
392 1
                    value = valueNode.Value;
393 1
                    return true;
394

395
                case ShortNode shortNode:
396 1
                    if (path.Length < shortNode.Key.Length
397 1
                        || !path.Take(shortNode.Key.Length).SequenceEqual(shortNode.Key))
398 1
                    {
399 1
                        value = null;
400 1
                        return false;
401
                    }
402

403 1
                    return TryGet(
404 1
                        shortNode.Value,
405 1
                        prefix.AddRange(path.Take(shortNode.Key.Length)),
406 1
                        path.Skip(shortNode.Key.Length).ToImmutableArray(),
407 1
                        out value);
408

409
                case FullNode fullNode:
410 1
                    INode? childNode = fullNode.Children[path[0]];
411 1
                    return TryGet(
412 1
                        childNode,
413 1
                        prefix.Add(path[0]).ToImmutableArray(),
414 1
                        path.Skip(1).ToImmutableArray(),
415 1
                        out value);
416

417
                case HashNode hashNode:
418
                    try
419 1
                    {
420 1
                        INode? resolvedNode = GetNode(hashNode.HashDigest);
421 1
                        return TryGet(resolvedNode, prefix, path, out value);
422
                    }
423 0
                    catch (KeyNotFoundException)
424 0
                    {
425 0
                        value = null;
426 0
                        return false;
427
                    }
428

429
                default:
430 0
                    throw new InvalidTrieNodeException(
431 0
                        $"Invalid node: raw: {node.ToBencodex().Inspection}");
432
            }
433 1
        }
434

435
        /// <summary>
436
        /// Gets the node corresponding to <paramref name="nodeHash"/> from storage,
437
        /// (i.e., <see cref="KeyValueStore"/>).
438
        /// </summary>
439
        /// <param name="nodeHash">The hash of node to get.</param>
440
        /// <returns>The node corresponding to <paramref name="nodeHash"/>.</returns>
441
        private INode? GetNode(HashDigest<SHA256> nodeHash)
442 1
        {
443 1
            return NodeDecoder.Decode(
444 1
                _codec.Decode(KeyValueStore.Get(nodeHash.ToByteArray())));
445 1
        }
446

447
        private byte[] ToKey(byte[] key)
448 1
        {
449 1
            if (_secure)
450 1
            {
451 1
                key = Hashcash.Hash(key).ToByteArray();
452 1
            }
453

454 1
            var res = new byte[key.Length * 2];
455
            const int lowerBytesMask = 0b00001111;
456 1
            for (var i = 0; i < key.Length; ++i)
457 1
            {
458 1
                res[i * 2] = (byte)(key[i] >> 4);
459 1
                res[i * 2 + 1] = (byte)(key[i] & lowerBytesMask);
460 1
            }
461

462 1
            return res;
463 1
        }
464
    }
465
}

Read our documentation on viewing source code .

Loading