From f99eb4bd1ed70f1bb40d7d16d0ca28db753b92e9 Mon Sep 17 00:00:00 2001 From: poma Date: Fri, 13 Aug 2021 20:07:53 +0300 Subject: [PATCH] refactor extAmount and fee into a single public input --- circuits/transaction.circom | 12 ++++-------- contracts/TornadoPool.sol | 25 +++++++++++++------------ src/index.js | 8 ++++---- src/utils.js | 8 ++++++-- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/circuits/transaction.circom b/circuits/transaction.circom index 45b59c6..bbd47d0 100644 --- a/circuits/transaction.circom +++ b/circuits/transaction.circom @@ -18,10 +18,10 @@ nullifier = hash(commitment, privKey, merklePath) template Transaction(levels, nIns, nOuts, zeroLeaf) { signal input root; signal input newRoot; - // external amount used for deposits and withdrawals + // extAmount = external amount used for deposits and withdrawals // correct extAmount range is enforced on the smart contract - signal input extAmount; - signal input fee; + // publicAmount = fee - extAmount + signal input publicAmount; signal input extDataHash; // data for transaction inputs @@ -103,10 +103,6 @@ template Transaction(levels, nIns, nOuts, zeroLeaf) { sumOuts += outAmount[tx]; } - // Check that fee fits into 248 bits to prevent overflow - component feeCheck = Num2Bits(248); - feeCheck.in <== fee; - // check that there are no same nullifiers among all inputs component sameNullifiers[nIns * (nIns - 1) / 2]; var index = 0; @@ -121,7 +117,7 @@ template Transaction(levels, nIns, nOuts, zeroLeaf) { } // verify amount invariant - sumIns + extAmount === sumOuts + fee; + sumIns === sumOuts + publicAmount; // Check merkle tree update with inserted transaction outputs component treeUpdater = TreeUpdater(levels, 1 /* log2(nOuts) */, zeroLeaf); diff --git a/contracts/TornadoPool.sol b/contracts/TornadoPool.sol index 1f9f164..01f589c 100644 --- a/contracts/TornadoPool.sol +++ b/contracts/TornadoPool.sol @@ -14,9 +14,9 @@ pragma solidity ^0.7.0; pragma experimental ABIEncoderV2; interface IVerifier { - function verifyProof(bytes memory _proof, uint256[10] memory _input) external view returns (bool); + function verifyProof(bytes memory _proof, uint256[9] memory _input) external view returns (bool); - function verifyProof(bytes memory _proof, uint256[24] memory _input) external view returns (bool); + function verifyProof(bytes memory _proof, uint256[23] memory _input) external view returns (bool); } contract TornadoPool { @@ -31,7 +31,9 @@ contract TornadoPool { struct ExtData { address payable recipient; + uint256 extAmount; address payable relayer; + uint256 fee; bytes encryptedOutput1; bytes encryptedOutput2; } @@ -43,8 +45,7 @@ contract TornadoPool { bytes32[] inputNullifiers; bytes32[2] outputCommitments; uint256 outPathIndices; - uint256 extAmount; - uint256 fee; + uint256 publicAmount; bytes32 extDataHash; } @@ -81,6 +82,8 @@ contract TornadoPool { require(uint256(_args.extDataHash) == uint256(keccak256(abi.encode(_extData))) % FIELD_SIZE, "Incorrect external data hash"); uint256 cachedCommitmentIndex = currentCommitmentIndex; require(_args.outPathIndices == cachedCommitmentIndex >> 1, "Invalid merkle tree insert position"); + require(_extData.fee < 2**248, "Invalid fee"); + require((_args.publicAmount + _extData.extAmount) % FIELD_SIZE == _extData.fee % FIELD_SIZE, "Invalid public amount"); require(verifyProof(_args), "Invalid transaction proof"); currentRoot = _args.newRoot; @@ -89,9 +92,9 @@ contract TornadoPool { nullifierHashes[_args.inputNullifiers[i]] = true; } - int256 extAmount = calculateExternalAmount(_args.extAmount); + int256 extAmount = calculateExternalAmount(_extData.extAmount); if (extAmount > 0) { - require(msg.value == uint256(_args.extAmount), "Incorrect amount of ETH sent on deposit"); + require(msg.value == uint256(_extData.extAmount), "Incorrect amount of ETH sent on deposit"); } else if (extAmount < 0) { require(msg.value == 0, "Sent ETH amount should be 0 for withdrawal"); require(_extData.recipient != address(0), "Can't withdraw to zero address"); @@ -100,8 +103,8 @@ contract TornadoPool { require(msg.value == 0, "Sent ETH amount should be 0 for transaction"); } - if (_args.fee > 0) { - _extData.relayer.transfer(_args.fee); + if (_extData.fee > 0) { + _extData.relayer.transfer(_extData.fee); } emit NewCommitment(_args.outputCommitments[0], cachedCommitmentIndex, _extData.encryptedOutput1); @@ -136,8 +139,7 @@ contract TornadoPool { [ uint256(_args.root), uint256(_args.newRoot), - _args.extAmount, - _args.fee, + _args.publicAmount, uint256(_args.extDataHash), uint256(_args.inputNullifiers[0]), uint256(_args.inputNullifiers[1]), @@ -153,8 +155,7 @@ contract TornadoPool { [ uint256(_args.root), uint256(_args.newRoot), - _args.extAmount, - _args.fee, + _args.publicAmount, uint256(_args.extDataHash), uint256(_args.inputNullifiers[0]), uint256(_args.inputNullifiers[1]), diff --git a/src/index.js b/src/index.js index c1d12c3..24f592b 100644 --- a/src/index.js +++ b/src/index.js @@ -48,7 +48,9 @@ async function getProof({ inputs, outputs, tree, extAmount, fee, recipient, rela const extData = { recipient: toFixedHex(recipient, 20), + extAmount: toFixedHex(extAmount), relayer: toFixedHex(relayer, 20), + fee: toFixedHex(fee), encryptedOutput1: outputs[0].encrypt(), encryptedOutput2: outputs[1].encrypt(), } @@ -59,8 +61,7 @@ async function getProof({ inputs, outputs, tree, extAmount, fee, recipient, rela newRoot: tree.root(), inputNullifier: inputs.map((x) => x.getNullifier()), outputCommitment: outputs.map((x) => x.getCommitment()), - extAmount, - fee, + publicAmount: BigNumber.from(fee).sub(extAmount).add(FIELD_SIZE).mod(FIELD_SIZE).toString(), extDataHash, // data for 2 transaction inputs @@ -87,8 +88,7 @@ async function getProof({ inputs, outputs, tree, extAmount, fee, recipient, rela inputNullifiers: inputs.map((x) => toFixedHex(x.getNullifier())), outputCommitments: outputs.map((x) => toFixedHex(x.getCommitment())), outPathIndices: toFixedHex(outputIndex >> outputBatchBits), - extAmount: toFixedHex(extAmount), - fee: toFixedHex(fee), + publicAmount: toFixedHex(input.publicAmount), extDataHash: toFixedHex(extDataHash), } // console.log('Solidity args', args) diff --git a/src/utils.js b/src/utils.js index 45fb17e..f237031 100644 --- a/src/utils.js +++ b/src/utils.js @@ -13,15 +13,19 @@ const FIELD_SIZE = BigNumber.from( /** Generate random number of specified byte length */ const randomBN = (nbytes = 31) => BigNumber.from(crypto.randomBytes(nbytes)) -function getExtDataHash({ recipient, relayer, encryptedOutput1, encryptedOutput2 }) { +function getExtDataHash({ recipient, extAmount, relayer, fee, encryptedOutput1, encryptedOutput2 }) { const abi = new ethers.utils.AbiCoder() const encodedData = abi.encode( - ['tuple(address recipient,address relayer,bytes encryptedOutput1,bytes encryptedOutput2)'], + [ + 'tuple(address recipient,uint256 extAmount,address relayer,uint256 fee,bytes encryptedOutput1,bytes encryptedOutput2)', + ], [ { recipient: toFixedHex(recipient, 20), + extAmount: toFixedHex(extAmount), relayer: toFixedHex(relayer, 20), + fee: toFixedHex(fee), encryptedOutput1: encryptedOutput1, encryptedOutput2: encryptedOutput2, },