diff --git a/src/fixedMerkleTree.ts b/src/FixedMerkleTree.ts similarity index 94% rename from src/fixedMerkleTree.ts rename to src/FixedMerkleTree.ts index def3a92..2405e8f 100644 --- a/src/fixedMerkleTree.ts +++ b/src/FixedMerkleTree.ts @@ -1,21 +1,8 @@ -import { - defaultHash, - Element, - HashFunction, - MerkleTreeOptions, - ProofPath, - SerializedTreeState, - TreeEdge, -} from './' +import { defaultHash, Element, HashFunction, MerkleTreeOptions, ProofPath, SerializedTreeState, TreeEdge } from './' export default class MerkleTree { - get layers(): Array { - return this._layers.slice() - } - levels: number - capacity: number private _hashFn: HashFunction private zeroElement: Element private _zeros: Element[] @@ -26,7 +13,6 @@ export default class MerkleTree { zeroElement = 0, }: MerkleTreeOptions = {}) { this.levels = levels - this.capacity = 2 ** levels if (elements.length > this.capacity) { throw new Error('Tree is full') } @@ -39,6 +25,22 @@ export default class MerkleTree { this._rebuild() } + get capacity() { + return this.levels ** 2 + } + + get layers(): Array { + return this._layers.slice() + } + + get zeros(): Element[] { + return this._zeros.slice() + } + + get elements(): Element[] { + return this._layers[0].slice() + } + private _buildZeros() { this._zeros = [this.zeroElement] for (let i = 1; i <= this.levels; i++) { @@ -63,7 +65,7 @@ export default class MerkleTree { /** * Get tree root */ - root(): Element { + get root(): Element { return this._layers[this.levels][0] ?? this._zeros[this.levels] } @@ -180,26 +182,12 @@ export default class MerkleTree { const leaves = this._layers[0] const edgeIndex = leaves.indexOf(edgeElement) if (edgeIndex <= -1) { - return null + throw new Error('Element not found') } const edgePath = this.path(edgeIndex) return { edgePath, edgeElement, edgeIndex } } - /** - * Returns a copy of non-zero tree elements. - */ - get elements() { - return this._layers[0].slice() - } - - /** - * Returns a copy of n-th zero elements array - */ - get zeros() { - return this._zeros.slice() - } - /** * Serialize entire tree state including intermediate layers into a plain object * Deserializing it back will not require to recompute any hashes diff --git a/src/PartialMerkleTree.ts b/src/PartialMerkleTree.ts new file mode 100644 index 0000000..6422e85 --- /dev/null +++ b/src/PartialMerkleTree.ts @@ -0,0 +1,194 @@ +import { defaultHash, Element, HashFunction, MerkleTreeOptions, ProofPath, TreeEdge } from './' + +type LeafWithIndex = { index: number, data: Element } + +export class PartialMerkleTree { + levels: number + private zeroElement: Element + private _zeros: Element[] + private _layers: Array + private _leaves: Element[] + private _leavesAfterEdge: Element[] + private _edgeLeaf: LeafWithIndex + private _initialRoot: Element + private _hashFn: HashFunction + private _edgeLeafProof: ProofPath + + constructor(levels: number, { + edgePath, + edgeElement, + edgeIndex, + }: TreeEdge, leaves: Element[], root: Element, { hashFunction, zeroElement }: MerkleTreeOptions = {}) { + this._edgeLeafProof = edgePath + this.zeroElement = zeroElement ?? 0 + this._edgeLeaf = { data: edgeElement, index: edgeIndex } + this._leavesAfterEdge = leaves + this._initialRoot = root + this.levels = levels + this._hashFn = hashFunction || defaultHash + this._buildTree() + } + + get capacity() { + return this.levels ** 2 + } + + get layers(): Array { + return this._layers.slice() + } + + get zeros(): Element[] { + return this._zeros.slice() + } + + get elements(): Element[] { + return this._layers[0].slice() + } + + private _buildTree(): void { + const edgeLeafIndex = this._edgeLeaf.index + this._leaves = [...Array.from({ length: edgeLeafIndex }, () => null), ...this._leavesAfterEdge] + if (this._edgeLeafProof.pathIndices[0] === 1) { + this._leaves[this._edgeLeafProof.pathPositions[0]] = this._edgeLeafProof.pathElements[0] + } + this._layers = [this._leaves] + this._buildZeros() + this._buildHashes() + + } + + private _buildZeros() { + this._zeros = [this.zeroElement] + for (let i = 1; i <= this.levels; i++) { + this._zeros[i] = this._hashFn(this._zeros[i - 1], this._zeros[i - 1]) + } + } + + _buildHashes() { + for (let level = 1; level <= this.levels; level++) { + this._layers[level] = [] + for (let i = 0; i < Math.ceil(this._layers[level - 1].length / 2); i++) { + const left = this._layers[level - 1][i * 2] + const right = i * 2 + 1 < this._layers[level - 1].length + ? this._layers[level - 1][i * 2 + 1] + : this._zeros[level - 1] + let hash: Element = this._hashFn(left, right) + if (!hash && this._edgeLeafProof.pathPositions[level] === i) hash = this._edgeLeafProof.pathElements[level] + if (level === this.levels) hash = hash || this._initialRoot + this._layers[level][i] = hash + } + } + } + + /** + * Insert new element into the tree + * @param element Element to insert + */ + insert(element: Element) { + if (this._layers[0].length >= this.capacity) { + throw new Error('Tree is full') + } + this.update(this._layers[0].length, element) + } + + /** + * Insert multiple elements into the tree. + * @param {Array} elements Elements to insert + */ + bulkInsert(elements: Element[]): void { + if (!elements.length) { + return + } + + if (this._layers[0].length + elements.length > this.capacity) { + throw new Error('Tree is full') + } + // First we insert all elements except the last one + // updating only full subtree hashes (all layers where inserted element has odd index) + // the last element will update the full path to the root making the tree consistent again + for (let i = 0; i < elements.length - 1; i++) { + this._layers[0].push(elements[i]) + let level = 0 + let index = this._layers[0].length - 1 + while (index % 2 === 1) { + level++ + index >>= 1 + this._layers[level][index] = this._hashFn( + this._layers[level - 1][index * 2], + this._layers[level - 1][index * 2 + 1], + ) + } + } + this.insert(elements[elements.length - 1]) + } + + /** + * Change an element in the tree + * @param {number} index Index of element to change + * @param element Updated element value + */ + update(index: number, element: Element) { + if (isNaN(Number(index)) || index < 0 || index > this._layers[0].length || index >= this.capacity) { + throw new Error('Insert index out of bounds: ' + index) + } + this._layers[0][index] = element + for (let level = 1; level <= this.levels; level++) { + index >>= 1 + const left = this._layers[level - 1][index * 2] + const right = index * 2 + 1 < this._layers[level - 1].length + ? this._layers[level - 1][index * 2 + 1] + : this._zeros[level - 1] + let hash: Element = this._hashFn(left, right) + if (!hash && this._edgeLeafProof.pathPositions[level] === index) { + hash = this._edgeLeafProof.pathElements[level] + } + if (level === this.levels) { + hash = hash || this._initialRoot + } + // console.log({ index, level, left, right, hash }) + this._layers[level][index] = hash + } + } + + path(index: Element): ProofPath { + if (isNaN(Number(index)) || index < 0 || index >= this._layers[0].length) { + throw new Error('Index out of bounds: ' + index) + } + let elIndex = +index + const pathElements: Element[] = [] + const pathIndices: number[] = [] + const pathPositions: number [] = [] + for (let level = 0; level < this.levels; level++) { + pathIndices[level] = elIndex % 2 + const leafIndex = elIndex ^ 1 + if (leafIndex < this._layers[level].length) { + pathElements[level] = this._layers[level][leafIndex] + pathPositions[level] = leafIndex + } else { + pathElements[level] = this._zeros[level] + pathPositions[level] = 0 + } + elIndex >>= 1 + } + return { + pathElements, + pathIndices, + pathPositions, + } + } + + indexOf(element: Element, comparator?: (arg0: T, arg1: T) => boolean): number { + if (comparator) { + return this._layers[0].findIndex((el) => comparator(element, el)) + } else { + return this._layers[0].indexOf(element) + } + } + + /** + * Get tree root + */ + get root(): Element { + return this._layers[this.levels][0] ?? this._zeros[this.levels] + } +} diff --git a/src/index.ts b/src/index.ts index b331226..53af538 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,7 +1,7 @@ import { simpleHash } from './simpleHash' -export { default as MerkleTree } from './fixedMerkleTree' -export { PartialMerkleTree } from './partialMerkleTree' +export { default as MerkleTree } from './FixedMerkleTree' +export { PartialMerkleTree } from './PartialMerkleTree' export { simpleHash } from './simpleHash' export type HashFunction = { @@ -31,4 +31,4 @@ export type TreeEdge = { edgePath: ProofPath; edgeIndex: number } -export const defaultHash = (left: Element, right: Element): string => simpleHash([left, right]) +export const defaultHash = (left: Element, right: Element): string => (left !== null && right !== null) ? simpleHash([left, right]) : null diff --git a/src/partialMerkleTree.ts b/src/partialMerkleTree.ts deleted file mode 100644 index d5e4092..0000000 --- a/src/partialMerkleTree.ts +++ /dev/null @@ -1,64 +0,0 @@ -import { defaultHash, Element, HashFunction, MerkleTreeOptions, ProofPath, TreeEdge } from './' - -type LeafWithIndex = { index: number, data: Element } - -export class PartialMerkleTree { - levels: number - private zeroElement: Element - private _zeros: Element[] - private _layers: Array - private _leaves: Element[] - private _leavesAfterEdge: Element[] - private _edgeLeaf: LeafWithIndex - private _root: Element - private _hashFn: HashFunction - private _edgeLeafProof: ProofPath - - constructor({ - edgePath, - edgeElement, - edgeIndex, - }: TreeEdge, leaves: Element[], root: Element, { hashFunction, zeroElement }: MerkleTreeOptions = {}) { - this._edgeLeafProof = edgePath - this.zeroElement = zeroElement ?? 0 - this._edgeLeaf = { data: edgeElement, index: edgeIndex } - this._leavesAfterEdge = leaves - this._root = root - this._hashFn = hashFunction || defaultHash - this._buildTree() - } - - get capacity() { - return this.levels ** 2 - } - - private _buildTree(): void { - const edgeLeafIndex = this._edgeLeaf.index - this._leaves = [...Array.from({ length: edgeLeafIndex - 1 }, () => null), ...this._leavesAfterEdge] - this._layers = [this._leaves] - this._buildZeros() - this._rebuild() - - } - - private _buildZeros() { - this._zeros = [this.zeroElement] - for (let i = 1; i <= this.levels; i++) { - this._zeros[i] = this._hashFn(this._zeros[i - 1], this._zeros[i - 1]) - } - } - - _rebuild() { - for (let level = 1; level <= this.levels; level++) { - this._layers[level] = [] - for (let i = 0; i < Math.ceil(this._layers[level - 1].length / 2); i++) { - this._layers[level][i] = this._hashFn( - this._layers[level - 1][i * 2], - i * 2 + 1 < this._layers[level - 1].length - ? this._layers[level - 1][i * 2 + 1] - : this._zeros[level - 1], - ) - } - } - } -} diff --git a/test/fixedMerkleTree.spec.ts b/test/fixedMerkleTree.spec.ts index 23bc6c3..16fdb0d 100644 --- a/test/fixedMerkleTree.spec.ts +++ b/test/fixedMerkleTree.spec.ts @@ -1,6 +1,7 @@ import { MerkleTree } from '../src' import { assert, should } from 'chai' import { it } from 'mocha' +import { TreeEdge } from '../lib' describe('MerkleTree', () => { @@ -8,22 +9,22 @@ describe('MerkleTree', () => { it('should have correct zero root', () => { const tree = new MerkleTree(10, []) - return should().equal(tree.root(), '3060353338620102847451617558650138132480') + return should().equal(tree.root, '3060353338620102847451617558650138132480') }) it('should have correct 1 element root', () => { const tree = new MerkleTree(10, [1]) - should().equal(tree.root(), '4059654748770657324723044385589999697920') + should().equal(tree.root, '4059654748770657324723044385589999697920') }) it('should have correct even elements root', () => { const tree = new MerkleTree(10, [1, 2]) - should().equal(tree.root(), '3715471817149864798706576217905179918336') + should().equal(tree.root, '3715471817149864798706576217905179918336') }) it('should have correct odd elements root', () => { const tree = new MerkleTree(10, [1, 2, 3]) - should().equal(tree.root(), '5199180210167621115778229238102210117632') + should().equal(tree.root, '5199180210167621115778229238102210117632') }) it('should be able to create a full tree', () => { @@ -40,19 +41,19 @@ describe('MerkleTree', () => { it('should insert into empty tree', () => { const tree = new MerkleTree(10) tree.insert(42) - should().equal(tree.root(), '750572848877730275626358141391262973952') + should().equal(tree.root, '750572848877730275626358141391262973952') }) it('should insert into odd tree', () => { const tree = new MerkleTree(10, [1]) tree.insert(42) - should().equal(tree.root(), '5008383558940708447763798816817296703488') + should().equal(tree.root, '5008383558940708447763798816817296703488') }) it('should insert into even tree', () => { const tree = new MerkleTree(10, [1, 2]) tree.insert(42) - should().equal(tree.root(), '5005864318873356880627322373636156817408') + should().equal(tree.root, '5005864318873356880627322373636156817408') }) it('should insert last element', () => { @@ -71,7 +72,7 @@ describe('MerkleTree', () => { it('should work', () => { const tree = new MerkleTree(10, [1, 2, 3]) tree.bulkInsert([4, 5, 6]) - should().equal(tree.root(), '4066635800770511602067209448381558554624') + should().equal(tree.root, '4066635800770511602067209448381558554624') }) it('should give the same result as sequential inserts', () => { @@ -95,7 +96,7 @@ describe('MerkleTree', () => { for (const item of inserted) { tree2.insert(item) } - should().equal(tree1.root(), tree2.root()) + should().equal(tree1.root, tree2.root) } } }).timeout(10000) @@ -122,31 +123,31 @@ describe('MerkleTree', () => { it('should update first element', () => { const tree = new MerkleTree(10, [1, 2, 3, 4, 5]) tree.update(0, 42) - should().equal(tree.root(), '3884161948856565981263417078389340635136') + should().equal(tree.root, '3884161948856565981263417078389340635136') }) it('should update last element', () => { const tree = new MerkleTree(10, [1, 2, 3, 4, 5]) tree.update(4, 42) - should().equal(tree.root(), '3564959811529894228734180300843252711424') + should().equal(tree.root, '3564959811529894228734180300843252711424') }) it('should update odd element', () => { const tree = new MerkleTree(10, [1, 2, 3, 4, 5]) tree.update(1, 42) - should().equal(tree.root(), '4576704573778433422699674477203122290688') + should().equal(tree.root, '4576704573778433422699674477203122290688') }) it('should update even element', () => { const tree = new MerkleTree(10, [1, 2, 3, 4, 5]) tree.update(2, 42) - should().equal(tree.root(), '1807994110952186123819489133812038762496') + should().equal(tree.root, '1807994110952186123819489133812038762496') }) it('should update extra element', () => { const tree = new MerkleTree(10, [1, 2, 3, 4]) tree.update(4, 5) - should().equal(tree.root(), '1099080610107164849381389194938128793600') + should().equal(tree.root, '1099080610107164849381389194938128793600') }) it('should fail to update incorrect index', () => { @@ -248,8 +249,22 @@ describe('MerkleTree', () => { }) describe('#getTreeEdge', () => { it('should return correct treeEdge', () => { + const expectedEdge: TreeEdge = { + edgePath: { + pathElements: [ + 5, + '1390935134112885103361924701261056180224', + '1952916572242076545231119328171167580160', + '938972308169430750202858820582946897920' + ], + pathIndices: [ 0, 0, 1, 0 ], + pathPositions: [ 5, 0, 0, 0 ] + }, + edgeElement: 4, + edgeIndex: 4 + } const tree = new MerkleTree(4, [0, 1, 2, 3, 4, 5]) - console.log(tree.getTreeEdge(4)) + assert.deepEqual(tree.getTreeEdge(4), expectedEdge) }) }) describe('#getters', () => { @@ -319,12 +334,12 @@ describe('MerkleTree', () => { const src = new MerkleTree(10, [1, 2, 3, 4, 5, 6, 7, 8, 9]) const data = src.serialize() const dst = MerkleTree.deserialize(data) - should().equal(src.root(), dst.root()) + should().equal(src.root, dst.root) src.insert(10) dst.insert(10) - should().equal(src.root(), dst.root()) + should().equal(src.root, dst.root) }) }) }) diff --git a/test/partialMerkleTree.spec.ts b/test/partialMerkleTree.spec.ts index 794644b..5dd284a 100644 --- a/test/partialMerkleTree.spec.ts +++ b/test/partialMerkleTree.spec.ts @@ -1,17 +1,105 @@ -import { MerkleTree, PartialMerkleTree } from '../src' +import { Element, MerkleTree, PartialMerkleTree } from '../src' import { it } from 'mocha' +import { should } from 'chai' describe('PartialMerkleTree', () => { - + const getTestTrees = (levels: number, elements: Element[], edgeElement: Element) => { + const fullTree = new MerkleTree(levels, elements) + const edge = fullTree.getTreeEdge(edgeElement) + const leavesAfterEdge = elements.slice(edge.edgeIndex) + const partialTree = new PartialMerkleTree(levels, edge, leavesAfterEdge, fullTree.root) + return { fullTree, partialTree } + } describe('#constructor', () => { - const leaves = [1, 2, 3, 4, 5] - const fullTree = new MerkleTree(4, leaves) - const root = fullTree.root() - const edge = fullTree.getTreeEdge(3) - const leavesAfterEdge = leaves.splice(edge.edgeIndex) - it('should initialize merkle tree', () => { - const partialTree = new PartialMerkleTree(edge, leavesAfterEdge, root) - return true + const { fullTree, partialTree } = getTestTrees(20, ['0', '1', '2', '3', '4', '5'], '2') + it('should initialize merkle tree with same root', () => { + should().equal(fullTree.root, partialTree.root) + }) + + it('should initialize merkle tree with same leaves count', () => { + should().equal(fullTree.elements.length, partialTree.elements.length) + }) + }) + + describe('#insert', () => { + + it('should have equal root to full tree after insertion ', () => { + const { fullTree, partialTree } = getTestTrees(10, ['0', '1', '2', '3', '4', '5', '6', '7'], '5') + fullTree.insert('9') + partialTree.insert('9') + should().equal(fullTree.root, partialTree.root) + }) + + it('should fail to insert when tree is full', () => { + const { partialTree } = getTestTrees(3, ['0', '1', '2', '3', '4', '5', '6', '7', '8'], '5') + const call = () => partialTree.insert('9') + should().throw(call, 'Tree is full') + }) + }) + + describe('#bulkInsert', () => { + + it('should work like full tree', () => { + const { fullTree, partialTree } = getTestTrees(20, [1, 2, 3, 4, 5], 3) + partialTree.bulkInsert([6, 7, 8]) + fullTree.bulkInsert([6, 7, 8]) + should().equal(fullTree.root, partialTree.root) + }) + + it('should give the same result as sequential inserts', () => { + const initialArray = [ + [1], + [1, 2], + [1, 2, 3], + [1, 2, 3, 4], + ] + const insertedArray = [ + [11], + [11, 12], + [11, 12, 13], + [11, 12, 13, 14], + ] + for (const initial of initialArray) { + for (const inserted of insertedArray) { + const { partialTree: tree1 } = getTestTrees(10, initial, initial.length > 1 ? initial.length - 1 : initial.length) + const { partialTree: tree2 } = getTestTrees(10, initial, initial.length > 1 ? initial.length - 1 : initial.length) + tree1.bulkInsert(inserted) + for (const item of inserted) { + tree2.insert(item) + } + should().equal(tree1.root, tree2.root) + } + } + }).timeout(10000) + + it('should fail to insert too many elements', () => { + const { fullTree, partialTree } = getTestTrees(2, [1, 2, 3, 4], 3) + const call = () => partialTree.bulkInsert([5, 6, 7]) + should().throw(call, 'Tree is full') + }) + + }) + + describe('#indexOf', () => { + it('should return same result as full tree', () => { + const { fullTree, partialTree } = getTestTrees(10, [1, 2, 3, 4, 5, 6, 7, 8], 4) + should().equal(partialTree.indexOf(5), fullTree.indexOf(5)) + }) + + it('should find index', () => { + const { partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + should().equal(partialTree.indexOf(3), 2) + }) + + it('should work with comparator', () => { + const { partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + should().equal(partialTree.indexOf(4, (arg0, arg1) => arg0 === arg1), 3) + }) + + it('should return -1 for non existent element', () => { + const { partialTree } = getTestTrees(10, [1, 2, 3, 4, 5], 3) + should().equal(partialTree.indexOf(42), -1) }) }) }) +