From 12f95d97eeeb8500f2d7c8d38732afd10c52a191 Mon Sep 17 00:00:00 2001 From: Sergei SMART Date: Wed, 2 Mar 2022 12:02:23 +1000 Subject: [PATCH] added serialize / deserialize / proof / toString, coverage increase --- src/FixedMerkleTree.ts | 28 ++++++-- src/PartialMerkleTree.ts | 53 +++++++++++--- src/index.ts | 15 +++- test/fixedMerkleTree.spec.ts | 30 +++++++- test/partialMerkleTree.spec.ts | 125 +++++++++++++++++++++++++++++---- 5 files changed, 222 insertions(+), 29 deletions(-) diff --git a/src/FixedMerkleTree.ts b/src/FixedMerkleTree.ts index 430038d..86ddc4a 100644 --- a/src/FixedMerkleTree.ts +++ b/src/FixedMerkleTree.ts @@ -1,4 +1,15 @@ -import { defaultHash, Element, HashFunction, MerkleTreeOptions, ProofPath, SerializedTreeState, TreeEdge } from './' +import { + Element, + HashFunction, + Index, + MerkleTreeOptions, + ProofPath, + SerializedTreeState, + simpleHash, + TreeEdge, +} from './' + +const defaultHash = (left: Element, right: Element): string => simpleHash([left, right]) export default class MerkleTree { levels: number @@ -21,7 +32,7 @@ export default class MerkleTree { this._layers = [] this._layers[0] = elements.slice() this._buildZeros() - this._rebuild() + this._buildHashes() } get capacity() { @@ -47,7 +58,7 @@ export default class MerkleTree { } } - _rebuild() { + _buildHashes() { for (let level = 1; level <= this.levels; level++) { this._layers[level] = [] for (let i = 0; i < Math.ceil(this._layers[level - 1].length / 2); i++) { @@ -136,7 +147,7 @@ export default class MerkleTree { * @param {number} index Leaf index to generate path for * @returns {{pathElements: Object[], pathIndex: number[]}} An object containing adjacent elements and left-right index */ - path(index: Element): ProofPath { + path(index: Index): ProofPath { if (isNaN(Number(index)) || index < 0 || index >= this._layers[0].length) { throw new Error('Index out of bounds: ' + index) } @@ -177,6 +188,11 @@ export default class MerkleTree { } } + proof(element: Element): ProofPath { + const index = this.indexOf(element) + return this.path(index) + } + getTreeEdge(edgeElement: Element): TreeEdge { const leaves = this._layers[0] const edgeIndex = leaves.indexOf(edgeElement) @@ -208,5 +224,9 @@ export default class MerkleTree { static deserialize(data: SerializedTreeState, hashFunction?: HashFunction): MerkleTree { return new MerkleTree(data.levels, data._layers[0], { hashFunction, zeroElement: data._zeros[0] }) } + + toString() { + return JSON.stringify(this.serialize()) + } } diff --git a/src/PartialMerkleTree.ts b/src/PartialMerkleTree.ts index 9759343..a461d63 100644 --- a/src/PartialMerkleTree.ts +++ b/src/PartialMerkleTree.ts @@ -1,6 +1,15 @@ -import { defaultHash, Element, HashFunction, MerkleTreeOptions, ProofPath, TreeEdge } from './' +import { + Element, + HashFunction, + LeafWithIndex, + MerkleTreeOptions, + ProofPath, + SerializedPartialTreeState, + simpleHash, + TreeEdge, +} from './' -type LeafWithIndex = { index: number, data: Element } +export const defaultHash = (left: Element, right: Element): string => simpleHash([left, right]) export class PartialMerkleTree { levels: number @@ -19,13 +28,15 @@ export class PartialMerkleTree { edgeElement, edgeIndex, }: TreeEdge, leaves: Element[], root: Element, { hashFunction, zeroElement }: MerkleTreeOptions = {}) { + hashFunction = hashFunction || defaultHash + const hashFn = (left, right) => (left !== null && right !== null) ? hashFunction(left, right) : null this._edgeLeafProof = edgePath this.zeroElement = zeroElement ?? 0 this._edgeLeaf = { data: edgeElement, index: edgeIndex } this._leavesAfterEdge = leaves this._initialRoot = root this.levels = levels - this._hashFn = hashFunction || defaultHash + this._hashFn = hashFn this._buildTree() } @@ -45,6 +56,10 @@ export class PartialMerkleTree { return this._layers[0].slice() } + get root(): Element { + return this._layers[this.levels][0] ?? this._zeros[this.levels] + } + private _buildTree(): void { const edgeLeafIndex = this._edgeLeaf.index this._leaves = [...Array.from({ length: edgeLeafIndex }, () => null), ...this._leavesAfterEdge] @@ -148,7 +163,6 @@ export class PartialMerkleTree { if (level === this.levels) { hash = hash || this._initialRoot } - // console.log({ index, level, left, right, hash }) this._layers[level][index] = hash } } @@ -191,10 +205,31 @@ export class PartialMerkleTree { } } - /** - * Get tree root - */ - get root(): Element { - return this._layers[this.levels][0] ?? this._zeros[this.levels] + serialize(): SerializedPartialTreeState { + const leaves = this.layers[0].slice(this._edgeLeaf.index) + return { + _initialRoot: this._initialRoot, + _edgeLeafProof: this._edgeLeafProof, + _edgeLeaf: this._edgeLeaf, + levels: this.levels, + leaves, + _zeros: this._zeros, + } + } + + static deserialize(data: SerializedPartialTreeState, hashFunction?: HashFunction): PartialMerkleTree { + const edge: TreeEdge = { + edgePath: data._edgeLeafProof, + edgeElement: data._edgeLeaf.data, + edgeIndex: data._edgeLeaf.index, + } + return new PartialMerkleTree(data.levels, edge, data.leaves, data._initialRoot, { + hashFunction, + zeroElement: data._zeros[0], + }) + } + + toString() { + return JSON.stringify(this.serialize()) } } diff --git a/src/index.ts b/src/index.ts index 53af538..43867a9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,3 @@ -import { simpleHash } from './simpleHash' - export { default as MerkleTree } from './FixedMerkleTree' export { PartialMerkleTree } from './PartialMerkleTree' export { simpleHash } from './simpleHash' @@ -21,6 +19,15 @@ export type SerializedTreeState = { _layers: Array } +export type SerializedPartialTreeState = { + levels: number, + leaves: Element[] + _zeros: Array, + _edgeLeafProof: ProofPath, + _initialRoot: Element, + _edgeLeaf: LeafWithIndex +} + export type ProofPath = { pathElements: Element[], pathIndices: number[], @@ -31,4 +38,6 @@ export type TreeEdge = { edgePath: ProofPath; edgeIndex: number } -export const defaultHash = (left: Element, right: Element): string => (left !== null && right !== null) ? simpleHash([left, right]) : null +export type Index = Element +export type LeafWithIndex = { index: number, data: Element } + diff --git a/test/fixedMerkleTree.spec.ts b/test/fixedMerkleTree.spec.ts index 0b7eca4..c24424f 100644 --- a/test/fixedMerkleTree.spec.ts +++ b/test/fixedMerkleTree.spec.ts @@ -1,7 +1,10 @@ import { MerkleTree, TreeEdge } from '../src' import { assert, should } from 'chai' +import { createHash } from 'crypto' import { it } from 'mocha' +const sha256Hash = (left, right) => createHash('sha256').update(`${left}${right}`).digest('hex') + describe('MerkleTree', () => { describe('#constructor', () => { @@ -34,6 +37,11 @@ describe('MerkleTree', () => { const call = () => new MerkleTree(2, [1, 2, 3, 4, 5]) should().throw(call, 'Tree is full') }) + + it('should work with optional hash function and zero element', () => { + const tree = new MerkleTree(10, [1, 2, 3, 4, 5, 6], { hashFunction: sha256Hash, zeroElement: 'zero' }) + should().equal(tree.root, 'a377b9fa0ed41add83e56f7e1d0e2ebdb46550b9d8b26b77dece60cb67283f19') + }) }) describe('#insert', () => { @@ -110,6 +118,7 @@ describe('MerkleTree', () => { const call = () => tree.bulkInsert([3, 4, 5]) should().throw(call, 'Tree is full') }) + it('should bypass empty elements', () => { const elements = [1, 2, 3, 4] const tree = new MerkleTree(2, elements) @@ -216,7 +225,6 @@ describe('MerkleTree', () => { '4986731814143931240516913804278285467648', '1918547053077726613961101558405545328640', '5444383861051812288142814494928935059456', - ]) }) @@ -246,6 +254,12 @@ describe('MerkleTree', () => { ]) }) }) + describe('#proof', () => { + it('should return proof for leaf', () => { + const tree = new MerkleTree(10, [1, 2, 3, 4, 5]) + assert.deepEqual(tree.proof(4), tree.path(3)) + }) + }) describe('#getTreeEdge', () => { it('should return correct treeEdge', () => { @@ -348,4 +362,18 @@ describe('MerkleTree', () => { should().equal(src.root, dst.root) }) }) + describe('#toString', () => { + it('should return correct stringified representation', () => { + const src = new MerkleTree(10, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + const str = src.toString() + const dst = MerkleTree.deserialize(JSON.parse(str)) + should().equal(src.root, dst.root) + + src.insert(10) + dst.insert(10) + + should().equal(src.root, dst.root) + + }) + }) }) diff --git a/test/partialMerkleTree.spec.ts b/test/partialMerkleTree.spec.ts index 8abcb12..d33292b 100644 --- a/test/partialMerkleTree.spec.ts +++ b/test/partialMerkleTree.spec.ts @@ -1,13 +1,17 @@ -import { Element, MerkleTree, PartialMerkleTree } from '../src' +import { Element, MerkleTree, MerkleTreeOptions, PartialMerkleTree } from '../src' import { it } from 'mocha' -import { assert, should } from 'chai' +import { should } from 'chai' +import * as assert from 'assert' +import { createHash } from 'crypto' + +const sha256Hash = (left, right) => createHash('sha256').update(`${left}${right}`).digest('hex') describe('PartialMerkleTree', () => { - const getTestTrees = (levels: number, elements: Element[], edgeElement: Element) => { - const fullTree = new MerkleTree(levels, elements) + const getTestTrees = (levels: number, elements: Element[], edgeElement: Element, treeOptions: MerkleTreeOptions = {}) => { + const fullTree = new MerkleTree(levels, elements, treeOptions) const edge = fullTree.getTreeEdge(edgeElement) const leavesAfterEdge = elements.slice(edge.edgeIndex) - const partialTree = new PartialMerkleTree(levels, edge, leavesAfterEdge, fullTree.root) + const partialTree = new PartialMerkleTree(levels, edge, leavesAfterEdge, fullTree.root, treeOptions) return { fullTree, partialTree } } describe('#constructor', () => { @@ -19,6 +23,14 @@ describe('PartialMerkleTree', () => { it('should initialize merkle tree with same leaves count', () => { should().equal(fullTree.elements.length, partialTree.elements.length) }) + + it('should work with optional hash function and zero element', () => { + const { partialTree, fullTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6], 4, { + hashFunction: sha256Hash, + zeroElement: 'zero', + }) + should().equal(partialTree.root, fullTree.root) + }) }) describe('#insert', () => { @@ -77,9 +89,58 @@ describe('PartialMerkleTree', () => { const call = () => partialTree.bulkInsert([5, 6, 7]) should().throw(call, 'Tree is full') }) - + it('should bypass empty elements', () => { + const elements = [1, 2, 3, 4] + const { partialTree } = getTestTrees(2, elements, 3) + partialTree.bulkInsert([]) + should().equal(partialTree.elements.length, elements.length, 'No elements inserted') + }) }) + describe('#update', () => { + it('should update last element', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + partialTree.update(4, 42) + fullTree.update(4, 42) + should().equal(partialTree.root, fullTree.root) + }) + + it('should update odd element', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8], 3) + partialTree.update(4, 42) + fullTree.update(4, 42) + should().equal(partialTree.root, fullTree.root) + }) + + it('should update even element', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8], 3) + partialTree.update(3, 42) + fullTree.update(3, 42) + should().equal(partialTree.root, fullTree.root) + }) + + it('should update extra element', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + partialTree.update(5, 6) + fullTree.update(5, 6) + should().equal(fullTree.root, partialTree.root) + }) + + it('should fail to update incorrect index', () => { + const { partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 4) + should().throw((() => partialTree.update(-1, 42)), 'Insert index out of bounds: -1') + should().throw((() => partialTree.update(6, 42)), 'Insert index out of bounds: 6') + should().throw((() => partialTree.update(2, 42)), 'Index 2 is below the edge: 3') + // @ts-ignore + should().throw((() => partialTree.update('qwe', 42)), 'Insert index out of bounds: qwe') + }) + + it('should fail to update over capacity', () => { + const { partialTree } = getTestTrees(2, [1, 2, 3, 4], 2) + const call = () => partialTree.update(4, 42) + should().throw(call, 'Insert index out of bounds: 4') + }) + }) describe('#indexOf', () => { it('should return same result as full tree', () => { const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8], 4) @@ -112,21 +173,34 @@ describe('PartialMerkleTree', () => { }) it('should return same elements count as full tree', () => { - const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + const levels = 20 + const capacity = levels ** 2 + const elements = Array.from({ length: capacity }, (_, i) => i) + const { fullTree, partialTree } = getTestTrees(levels, elements, 200) should().equal(partialTree.elements.length, fullTree.elements.length) }) - it('should return same layers count as full tree', () => { - const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) - should().equal(partialTree.layers.length, fullTree.layers.length) + it('should return copy of layers', () => { + const { partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + const layers = partialTree.layers + should().not.equal(layers, partialTree.layers) + }) + + it('should return copy of zeros', () => { + const { partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + const zeros = partialTree.zeros + should().not.equal(zeros, partialTree.zeros) }) }) describe('#path', () => { it('should return path for known nodes', () => { - const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8, 9], 5) - assert.deepEqual(fullTree.path(4), partialTree.path(4)) + const levels = 20 + const capacity = levels ** 2 + const elements = Array.from({ length: capacity }, (_, i) => i) + const { fullTree, partialTree } = getTestTrees(levels, elements, 250) + assert.deepEqual(fullTree.path(250), partialTree.path(250)) }) it('should fail on incorrect index', () => { @@ -142,5 +216,32 @@ describe('PartialMerkleTree', () => { should().throw(call, 'Index 2 is below the edge: 4') }) }) + + describe('#serialize', () => { + it('should work', () => { + const { partialTree } = getTestTrees(5, [1, 2, 3, 4, 5, 6, 7, 8, 9], 6) + const data = partialTree.serialize() + const dst = PartialMerkleTree.deserialize(data) + should().equal(partialTree.root, dst.root) + + partialTree.insert(10) + dst.insert(10) + + should().equal(partialTree.root, dst.root) + }) + }) + describe('#toString', () => { + it('should return correct stringified representation', () => { + const { partialTree } = getTestTrees(5, [1, 2, 3, 4, 5, 6, 7, 8, 9], 6) + const str = partialTree.toString() + const dst = PartialMerkleTree.deserialize(JSON.parse(str)) + should().equal(partialTree.root, dst.root) + + partialTree.insert(10) + dst.insert(10) + + should().equal(partialTree.root, dst.root) + }) + }) })