diff --git a/src/PartialMerkleTree.ts b/src/PartialMerkleTree.ts index fdf5b14..4f6eccf 100644 --- a/src/PartialMerkleTree.ts +++ b/src/PartialMerkleTree.ts @@ -60,6 +60,14 @@ export class PartialMerkleTree { return this._layers[this.levels][0] ?? this._zeros[this.levels] } + get edgeIndex(): number { + return this._edgeLeaf.index + } + + get edgeElement(): Element { + return this._edgeLeaf.data + } + private _buildTree(): void { const edgeLeafIndex = this._edgeLeaf.index this._leaves = [...Array.from({ length: edgeLeafIndex }, () => null), ...this._leavesAfterEdge] @@ -69,7 +77,6 @@ export class PartialMerkleTree { this._layers = [this._leaves] this._buildZeros() this._buildHashes() - } private _buildZeros() { @@ -210,6 +217,25 @@ export class PartialMerkleTree { return this.path(index) } + /** + * Shifts edge of tree to left + * @param edge new TreeEdge below current edge + * @param elements leaves between old and new edge + */ + + shiftEdge(edge: TreeEdge, elements: Element[]) { + if (this._edgeLeaf.index <= edge.edgeIndex) { + throw new Error(`New edgeIndex should be smaller then ${this._edgeLeaf.index}`) + } + if (elements.length !== (this._edgeLeaf.index - edge.edgeIndex)) { + throw new Error(`Elements length should be ${elements.length}`) + } + this._edgeLeafProof = edge.edgePath + this._edgeLeaf = { index: edge.edgeIndex, data: edge.edgeElement } + this._leavesAfterEdge = [...elements, ...this._leavesAfterEdge] + this._buildTree() + } + serialize(): SerializedPartialTreeState { const leaves = this.layers[0].slice(this._edgeLeaf.index) return { diff --git a/test/partialMerkleTree.spec.ts b/test/partialMerkleTree.spec.ts index cd4ea57..83f61f7 100644 --- a/test/partialMerkleTree.spec.ts +++ b/test/partialMerkleTree.spec.ts @@ -224,7 +224,31 @@ describe('PartialMerkleTree', () => { should().throw(call, 'Index 2 is below the edge: 4') }) }) - + describe('#shiftEdge', () => { + it('should work', () => { + const levels = 20 + const elements: Element[] = Array.from({ length: levels ** 2 }, (_, i) => i) + const tree = new MerkleTree(levels, elements) + const edge1 = tree.getTreeEdge(200) + const edge2 = tree.getTreeEdge(100) + const partialTree1 = new PartialMerkleTree(levels, edge1, elements.slice(edge1.edgeIndex), tree.root) + const partialTree2 = new PartialMerkleTree(levels, edge2, elements.slice(edge2.edgeIndex), tree.root) + partialTree1.shiftEdge(edge2, elements.slice(edge2.edgeIndex, partialTree1.edgeIndex)) + assert.deepEqual(partialTree1.path(105), partialTree2.path(105)) + }) + it('should fail if new edge index is over current edge', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8, 9], 5) + const newEdge = fullTree.getTreeEdge(6) + const call = () => partialTree.shiftEdge(newEdge, [1, 2]) + should().throw(call, 'New edgeIndex should be smaller then 4') + }) + it('should fail if elements length are incorrect', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8, 9], 5) + const newEdge = fullTree.getTreeEdge(4) + const call = () => partialTree.shiftEdge(newEdge, [1, 2]) + should().throw(call, 'Elements length should be 2') + }) + }) describe('#serialize', () => { it('should work', () => { const { partialTree } = getTestTrees(5, [1, 2, 3, 4, 5, 6, 7, 8, 9], 6)