From bc533ede2df78c567c70ec3e31740fa71ed5368c Mon Sep 17 00:00:00 2001 From: smart_ex Date: Fri, 11 Mar 2022 15:37:20 +1000 Subject: [PATCH] check elements length in constructor --- src/FixedMerkleTree.ts | 4 ++-- src/PartialMerkleTree.ts | 9 +++++---- src/index.ts | 9 +++++---- test/fixedMerkleTree.spec.ts | 34 +++++++++++++++++++++++++++++++++- 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/FixedMerkleTree.ts b/src/FixedMerkleTree.ts index 04f10ae..cd067f9 100644 --- a/src/FixedMerkleTree.ts +++ b/src/FixedMerkleTree.ts @@ -192,7 +192,7 @@ export default class MerkleTree { throw new Error('Element not found') } const edgePath = this.path(edgeIndex) - return { edgePath, edgeElement, edgeIndex } + return { edgePath, edgeElement, edgeIndex, edgeElementsCount: this._layers[0].length } } /** @@ -202,7 +202,7 @@ export default class MerkleTree { getTreeSlices(count = 4): { edge: TreeEdge, elements: Element[] }[] { const length = this._layers[0].length let size = Math.ceil(length / count) - size % 2 && size++ + if (size % 2) size++ const slices = [] for (let i = 0; i < length; i += size) { const edgeLeft = i diff --git a/src/PartialMerkleTree.ts b/src/PartialMerkleTree.ts index 5ee17c4..cc2bfe4 100644 --- a/src/PartialMerkleTree.ts +++ b/src/PartialMerkleTree.ts @@ -32,16 +32,16 @@ export class PartialMerkleTree { edgePath, edgeElement, edgeIndex, + edgeElementsCount, }: TreeEdge, leaves: Element[], { hashFunction, zeroElement }: MerkleTreeOptions = {}) { - hashFunction = hashFunction || defaultHash - const hashFn = (left, right) => (left !== undefined && right !== undefined) ? hashFunction(left, right) : undefined + if (edgeIndex + leaves.length !== edgeElementsCount) throw new Error('Invalid number of elements') this._edgeLeafProof = edgePath this._initialRoot = edgePath.pathRoot this.zeroElement = zeroElement ?? 0 this._edgeLeaf = { data: edgeElement, index: edgeIndex } this._leavesAfterEdge = leaves this.levels = levels - this._hashFn = hashFn + this._hashFn = hashFunction || defaultHash this._createProofMap() this._buildTree() } @@ -257,9 +257,9 @@ export class PartialMerkleTree { serialize(): SerializedPartialTreeState { const leaves = this.layers[0].slice(this._edgeLeaf.index) return { - _initialRoot: this._initialRoot, _edgeLeafProof: this._edgeLeafProof, _edgeLeaf: this._edgeLeaf, + _edgeElementsCount: this._layers[0].length, levels: this.levels, leaves, _zeros: this._zeros, @@ -271,6 +271,7 @@ export class PartialMerkleTree { edgePath: data._edgeLeafProof, edgeElement: data._edgeLeaf.data, edgeIndex: data._edgeLeaf.index, + edgeElementsCount: data._edgeElementsCount, } return new PartialMerkleTree(data.levels, edge, data.leaves, { hashFunction, diff --git a/src/index.ts b/src/index.ts index cd7c219..9794177 100644 --- a/src/index.ts +++ b/src/index.ts @@ -20,11 +20,11 @@ export type SerializedTreeState = { } export type SerializedPartialTreeState = { - levels: number, + levels: number leaves: Element[] - _zeros: Array, - _edgeLeafProof: ProofPath, - _initialRoot: Element, + _edgeElementsCount: number + _zeros: Array + _edgeLeafProof: ProofPath _edgeLeaf: LeafWithIndex } @@ -38,6 +38,7 @@ export type TreeEdge = { edgeElement: Element; edgePath: ProofPath; edgeIndex: number; + edgeElementsCount: number; } export type LeafWithIndex = { index: number, data: Element } diff --git a/test/fixedMerkleTree.spec.ts b/test/fixedMerkleTree.spec.ts index 4e3c96b..7dc42b8 100644 --- a/test/fixedMerkleTree.spec.ts +++ b/test/fixedMerkleTree.spec.ts @@ -1,4 +1,4 @@ -import { MerkleTree, TreeEdge } from '../src' +import { MerkleTree, PartialMerkleTree, TreeEdge } from '../src' import { assert, should } from 'chai' import { buildMimcSponge } from 'circomlibjs' import { createHash } from 'crypto' @@ -292,6 +292,7 @@ describe('MerkleTree', () => { }, edgeElement: 4, edgeIndex: 4, + edgeElementsCount: 6, } const tree = new MerkleTree(4, [0, 1, 2, 3, 4, 5]) assert.deepEqual(tree.getTreeEdge(4), expectedEdge) @@ -302,7 +303,38 @@ describe('MerkleTree', () => { should().throw(call, 'Element not found') }) }) + describe('#getTreeSlices', () => { + let fullTree: MerkleTree + before(() => { + const elements = Array.from({ length: 128 }, (_, i) => i) + fullTree = new MerkleTree(10, elements) + }) + it('should return correct slices count', () => { + const count = 5 + const slicesCount = fullTree.getTreeSlices(5).length + should().equal(count, slicesCount) + }) + it('should be able to create partial tree from last slice', () => { + const lastSlice = fullTree.getTreeSlices().pop() + const partialTree = new PartialMerkleTree(10, lastSlice.edge, lastSlice.elements) + should().equal(partialTree.root, fullTree.root) + }) + + it('should be able to build full tree from slices', () => { + const slices = fullTree.getTreeSlices() + const lastSlice = slices.pop() + const partialTree = new PartialMerkleTree(10, lastSlice.edge, lastSlice.elements) + slices.reverse().forEach(({ edge, elements }) => partialTree.shiftEdge(edge, elements)) + assert.deepEqual(partialTree.layers, fullTree.layers) + }) + + it('should throw if invalid number of elements', () => { + const [firstSlice] = fullTree.getTreeSlices() + const call = () => new PartialMerkleTree(10, firstSlice.edge, firstSlice.elements) + should().throw(call, 'Invalid number of elements') + }) + }) describe('#getters', () => { const elements = [1, 2, 3, 4, 5] const layers = [