refactor extAmount and fee into a single public input

This commit is contained in:
poma 2021-08-13 20:07:53 +03:00
parent 476668d250
commit f99eb4bd1e
No known key found for this signature in database
GPG Key ID: BA20CB01FE165657
4 changed files with 27 additions and 26 deletions

View File

@ -18,10 +18,10 @@ nullifier = hash(commitment, privKey, merklePath)
template Transaction(levels, nIns, nOuts, zeroLeaf) { template Transaction(levels, nIns, nOuts, zeroLeaf) {
signal input root; signal input root;
signal input newRoot; signal input newRoot;
// external amount used for deposits and withdrawals // extAmount = external amount used for deposits and withdrawals
// correct extAmount range is enforced on the smart contract // correct extAmount range is enforced on the smart contract
signal input extAmount; // publicAmount = fee - extAmount
signal input fee; signal input publicAmount;
signal input extDataHash; signal input extDataHash;
// data for transaction inputs // data for transaction inputs
@ -103,10 +103,6 @@ template Transaction(levels, nIns, nOuts, zeroLeaf) {
sumOuts += outAmount[tx]; sumOuts += outAmount[tx];
} }
// Check that fee fits into 248 bits to prevent overflow
component feeCheck = Num2Bits(248);
feeCheck.in <== fee;
// check that there are no same nullifiers among all inputs // check that there are no same nullifiers among all inputs
component sameNullifiers[nIns * (nIns - 1) / 2]; component sameNullifiers[nIns * (nIns - 1) / 2];
var index = 0; var index = 0;
@ -121,7 +117,7 @@ template Transaction(levels, nIns, nOuts, zeroLeaf) {
} }
// verify amount invariant // verify amount invariant
sumIns + extAmount === sumOuts + fee; sumIns === sumOuts + publicAmount;
// Check merkle tree update with inserted transaction outputs // Check merkle tree update with inserted transaction outputs
component treeUpdater = TreeUpdater(levels, 1 /* log2(nOuts) */, zeroLeaf); component treeUpdater = TreeUpdater(levels, 1 /* log2(nOuts) */, zeroLeaf);

View File

@ -14,9 +14,9 @@ pragma solidity ^0.7.0;
pragma experimental ABIEncoderV2; pragma experimental ABIEncoderV2;
interface IVerifier { interface IVerifier {
function verifyProof(bytes memory _proof, uint256[10] memory _input) external view returns (bool); function verifyProof(bytes memory _proof, uint256[9] memory _input) external view returns (bool);
function verifyProof(bytes memory _proof, uint256[24] memory _input) external view returns (bool); function verifyProof(bytes memory _proof, uint256[23] memory _input) external view returns (bool);
} }
contract TornadoPool { contract TornadoPool {
@ -31,7 +31,9 @@ contract TornadoPool {
struct ExtData { struct ExtData {
address payable recipient; address payable recipient;
uint256 extAmount;
address payable relayer; address payable relayer;
uint256 fee;
bytes encryptedOutput1; bytes encryptedOutput1;
bytes encryptedOutput2; bytes encryptedOutput2;
} }
@ -43,8 +45,7 @@ contract TornadoPool {
bytes32[] inputNullifiers; bytes32[] inputNullifiers;
bytes32[2] outputCommitments; bytes32[2] outputCommitments;
uint256 outPathIndices; uint256 outPathIndices;
uint256 extAmount; uint256 publicAmount;
uint256 fee;
bytes32 extDataHash; bytes32 extDataHash;
} }
@ -81,6 +82,8 @@ contract TornadoPool {
require(uint256(_args.extDataHash) == uint256(keccak256(abi.encode(_extData))) % FIELD_SIZE, "Incorrect external data hash"); require(uint256(_args.extDataHash) == uint256(keccak256(abi.encode(_extData))) % FIELD_SIZE, "Incorrect external data hash");
uint256 cachedCommitmentIndex = currentCommitmentIndex; uint256 cachedCommitmentIndex = currentCommitmentIndex;
require(_args.outPathIndices == cachedCommitmentIndex >> 1, "Invalid merkle tree insert position"); require(_args.outPathIndices == cachedCommitmentIndex >> 1, "Invalid merkle tree insert position");
require(_extData.fee < 2**248, "Invalid fee");
require((_args.publicAmount + _extData.extAmount) % FIELD_SIZE == _extData.fee % FIELD_SIZE, "Invalid public amount");
require(verifyProof(_args), "Invalid transaction proof"); require(verifyProof(_args), "Invalid transaction proof");
currentRoot = _args.newRoot; currentRoot = _args.newRoot;
@ -89,9 +92,9 @@ contract TornadoPool {
nullifierHashes[_args.inputNullifiers[i]] = true; nullifierHashes[_args.inputNullifiers[i]] = true;
} }
int256 extAmount = calculateExternalAmount(_args.extAmount); int256 extAmount = calculateExternalAmount(_extData.extAmount);
if (extAmount > 0) { if (extAmount > 0) {
require(msg.value == uint256(_args.extAmount), "Incorrect amount of ETH sent on deposit"); require(msg.value == uint256(_extData.extAmount), "Incorrect amount of ETH sent on deposit");
} else if (extAmount < 0) { } else if (extAmount < 0) {
require(msg.value == 0, "Sent ETH amount should be 0 for withdrawal"); require(msg.value == 0, "Sent ETH amount should be 0 for withdrawal");
require(_extData.recipient != address(0), "Can't withdraw to zero address"); require(_extData.recipient != address(0), "Can't withdraw to zero address");
@ -100,8 +103,8 @@ contract TornadoPool {
require(msg.value == 0, "Sent ETH amount should be 0 for transaction"); require(msg.value == 0, "Sent ETH amount should be 0 for transaction");
} }
if (_args.fee > 0) { if (_extData.fee > 0) {
_extData.relayer.transfer(_args.fee); _extData.relayer.transfer(_extData.fee);
} }
emit NewCommitment(_args.outputCommitments[0], cachedCommitmentIndex, _extData.encryptedOutput1); emit NewCommitment(_args.outputCommitments[0], cachedCommitmentIndex, _extData.encryptedOutput1);
@ -136,8 +139,7 @@ contract TornadoPool {
[ [
uint256(_args.root), uint256(_args.root),
uint256(_args.newRoot), uint256(_args.newRoot),
_args.extAmount, _args.publicAmount,
_args.fee,
uint256(_args.extDataHash), uint256(_args.extDataHash),
uint256(_args.inputNullifiers[0]), uint256(_args.inputNullifiers[0]),
uint256(_args.inputNullifiers[1]), uint256(_args.inputNullifiers[1]),
@ -153,8 +155,7 @@ contract TornadoPool {
[ [
uint256(_args.root), uint256(_args.root),
uint256(_args.newRoot), uint256(_args.newRoot),
_args.extAmount, _args.publicAmount,
_args.fee,
uint256(_args.extDataHash), uint256(_args.extDataHash),
uint256(_args.inputNullifiers[0]), uint256(_args.inputNullifiers[0]),
uint256(_args.inputNullifiers[1]), uint256(_args.inputNullifiers[1]),

View File

@ -48,7 +48,9 @@ async function getProof({ inputs, outputs, tree, extAmount, fee, recipient, rela
const extData = { const extData = {
recipient: toFixedHex(recipient, 20), recipient: toFixedHex(recipient, 20),
extAmount: toFixedHex(extAmount),
relayer: toFixedHex(relayer, 20), relayer: toFixedHex(relayer, 20),
fee: toFixedHex(fee),
encryptedOutput1: outputs[0].encrypt(), encryptedOutput1: outputs[0].encrypt(),
encryptedOutput2: outputs[1].encrypt(), encryptedOutput2: outputs[1].encrypt(),
} }
@ -59,8 +61,7 @@ async function getProof({ inputs, outputs, tree, extAmount, fee, recipient, rela
newRoot: tree.root(), newRoot: tree.root(),
inputNullifier: inputs.map((x) => x.getNullifier()), inputNullifier: inputs.map((x) => x.getNullifier()),
outputCommitment: outputs.map((x) => x.getCommitment()), outputCommitment: outputs.map((x) => x.getCommitment()),
extAmount, publicAmount: BigNumber.from(fee).sub(extAmount).add(FIELD_SIZE).mod(FIELD_SIZE).toString(),
fee,
extDataHash, extDataHash,
// data for 2 transaction inputs // data for 2 transaction inputs
@ -87,8 +88,7 @@ async function getProof({ inputs, outputs, tree, extAmount, fee, recipient, rela
inputNullifiers: inputs.map((x) => toFixedHex(x.getNullifier())), inputNullifiers: inputs.map((x) => toFixedHex(x.getNullifier())),
outputCommitments: outputs.map((x) => toFixedHex(x.getCommitment())), outputCommitments: outputs.map((x) => toFixedHex(x.getCommitment())),
outPathIndices: toFixedHex(outputIndex >> outputBatchBits), outPathIndices: toFixedHex(outputIndex >> outputBatchBits),
extAmount: toFixedHex(extAmount), publicAmount: toFixedHex(input.publicAmount),
fee: toFixedHex(fee),
extDataHash: toFixedHex(extDataHash), extDataHash: toFixedHex(extDataHash),
} }
// console.log('Solidity args', args) // console.log('Solidity args', args)

View File

@ -13,15 +13,19 @@ const FIELD_SIZE = BigNumber.from(
/** Generate random number of specified byte length */ /** Generate random number of specified byte length */
const randomBN = (nbytes = 31) => BigNumber.from(crypto.randomBytes(nbytes)) const randomBN = (nbytes = 31) => BigNumber.from(crypto.randomBytes(nbytes))
function getExtDataHash({ recipient, relayer, encryptedOutput1, encryptedOutput2 }) { function getExtDataHash({ recipient, extAmount, relayer, fee, encryptedOutput1, encryptedOutput2 }) {
const abi = new ethers.utils.AbiCoder() const abi = new ethers.utils.AbiCoder()
const encodedData = abi.encode( const encodedData = abi.encode(
['tuple(address recipient,address relayer,bytes encryptedOutput1,bytes encryptedOutput2)'], [
'tuple(address recipient,uint256 extAmount,address relayer,uint256 fee,bytes encryptedOutput1,bytes encryptedOutput2)',
],
[ [
{ {
recipient: toFixedHex(recipient, 20), recipient: toFixedHex(recipient, 20),
extAmount: toFixedHex(extAmount),
relayer: toFixedHex(relayer, 20), relayer: toFixedHex(relayer, 20),
fee: toFixedHex(fee),
encryptedOutput1: encryptedOutput1, encryptedOutput1: encryptedOutput1,
encryptedOutput2: encryptedOutput2, encryptedOutput2: encryptedOutput2,
}, },