TC-106 tornadoPool refactoring

This commit is contained in:
Drygin 2022-07-21 19:18:49 +03:00
parent 2ad99530f3
commit 57a1db3b48
2 changed files with 55 additions and 42 deletions

View File

@ -25,8 +25,7 @@ import "./MerkleTreeWithHistory.sol";
* and withdrawal from the pool. Project utilizes UTXO model to handle users' funds. * and withdrawal from the pool. Project utilizes UTXO model to handle users' funds.
*/ */
contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard, CrossChainGuard { contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard, CrossChainGuard {
int256 public constant MAX_EXT_AMOUNT = 2**248; uint256 public constant MAX_FIELD_UINT = 2**248;
uint256 public constant MAX_FEE = 2**248;
IVerifier public immutable verifier2; IVerifier public immutable verifier2;
IVerifier public immutable verifier16; IVerifier public immutable verifier16;
@ -76,6 +75,11 @@ contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard,
_; _;
} }
modifier updateLastBalance() {
_;
lastBalance = token.balanceOf(address(this));
}
/** /**
@dev The constructor @dev The constructor
@param _verifier2 the address of SNARK verifier for 2 inputs @param _verifier2 the address of SNARK verifier for 2 inputs
@ -134,10 +138,10 @@ contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard,
/** @dev Function that allows public deposits without proof verification. /** @dev Function that allows public deposits without proof verification.
*/ */
function publicDeposit(bytes32 pubkey, uint256 depositAmount) public payable { function publicDeposit(bytes32 pubkey, uint256 depositAmount) public payable updateLastBalance {
require(depositAmount <= maximumDepositAmount, "amount is larger than maximumDepositAmount"); require(depositAmount <= maximumDepositAmount, "amount is larger than maximumDepositAmount");
// make sure that that limit the same as in transaction.circom output check // make sure that that limit the same as in transaction.circom output check
require(depositAmount < 2**248, "depositAmount should be inside the field"); require(depositAmount < MAX_FIELD_UINT, "depositAmount should be inside the field");
require(uint256(pubkey) < FIELD_SIZE, "pubkey should be inside the field"); require(uint256(pubkey) < FIELD_SIZE, "pubkey should be inside the field");
token.transferFrom(msg.sender, address(this), depositAmount); token.transferFrom(msg.sender, address(this), depositAmount);
@ -148,10 +152,9 @@ contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard,
input[2] = bytes32(0); input[2] = bytes32(0);
bytes32 commitment = hasher3.poseidon(input); bytes32 commitment = hasher3.poseidon(input);
bytes memory packedOutput = abi.encodePacked("abi", depositAmount, pubkey);
lastBalance = token.balanceOf(address(this));
_insert(commitment, bytes32(ZERO_VALUE)); // use second empty commitment _insert(commitment, bytes32(ZERO_VALUE)); // use second empty commitment
bytes memory packedOutput = abi.encodePacked("abi", depositAmount, pubkey);
emit NewCommitment(commitment, nextIndex - 2, packedOutput); emit NewCommitment(commitment, nextIndex - 2, packedOutput);
emit NewCommitment(bytes32(ZERO_VALUE), nextIndex - 1, new bytes(0)); emit NewCommitment(bytes32(ZERO_VALUE), nextIndex - 1, new bytes(0));
} }
@ -223,8 +226,8 @@ contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard,
} }
function calculatePublicAmount(int256 _extAmount, uint256 _fee) public pure returns (uint256) { function calculatePublicAmount(int256 _extAmount, uint256 _fee) public pure returns (uint256) {
require(_fee < MAX_FEE, "Invalid fee"); require(_fee < MAX_FIELD_UINT, "Invalid fee");
require(_extAmount > -MAX_EXT_AMOUNT && _extAmount < MAX_EXT_AMOUNT, "Invalid ext amount"); require(_extAmount > -int256(MAX_FIELD_UINT) && _extAmount < int256(MAX_FIELD_UINT), "Invalid ext amount");
int256 publicAmount = _extAmount - int256(_fee); int256 publicAmount = _extAmount - int256(_fee);
return (publicAmount >= 0) ? uint256(publicAmount) : FIELD_SIZE - uint256(-publicAmount); return (publicAmount >= 0) ? uint256(publicAmount) : FIELD_SIZE - uint256(-publicAmount);
} }
@ -286,7 +289,7 @@ contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard,
emit PublicKey(_account.owner, _account.publicKey); emit PublicKey(_account.owner, _account.publicKey);
} }
function _transact(Proof memory _args, ExtData memory _extData) internal nonReentrant { function _transact(Proof memory _args, ExtData memory _extData) internal nonReentrant updateLastBalance {
require(isKnownRoot(_args.root), "Invalid merkle root"); require(isKnownRoot(_args.root), "Invalid merkle root");
for (uint256 i = 0; i < _args.inputNullifiers.length; i++) { for (uint256 i = 0; i < _args.inputNullifiers.length; i++) {
require(!isSpent(_args.inputNullifiers[i]), "Input is already spent"); require(!isSpent(_args.inputNullifiers[i]), "Input is already spent");
@ -298,40 +301,51 @@ contract TornadoPool is MerkleTreeWithHistory, IERC20Receiver, ReentrancyGuard,
for (uint256 i = 0; i < _args.inputNullifiers.length; i++) { for (uint256 i = 0; i < _args.inputNullifiers.length; i++) {
nullifierHashes[_args.inputNullifiers[i]] = true; nullifierHashes[_args.inputNullifiers[i]] = true;
} }
if (_extData.extAmount < 0) {
bool isWithdrawAndCall = _extData.withdrawalBytecode.length > 0;
require((_extData.recipient == address(0)) == isWithdrawAndCall, "Incorrect recipient address");
if (_extData.isL1Withdrawal) {
require(!isWithdrawAndCall, "withdrawAndCall for L1 is restricted");
token.transferAndCall(
omniBridge,
uint256(-_extData.extAmount),
abi.encodePacked(l1Unwrapper, abi.encode(_extData.recipient, _extData.l1Fee))
);
} else if (isWithdrawAndCall) {
bytes32 salt = keccak256(abi.encodePacked(_args.inputNullifiers));
bytes32 bytecodeHash = keccak256(_extData.withdrawalBytecode);
address workerAddr = Create2.computeAddress(salt, bytecodeHash);
token.transfer(workerAddr, uint256(-_extData.extAmount));
Create2.deploy(0, salt, _extData.withdrawalBytecode);
} else {
token.transfer(_extData.recipient, uint256(-_extData.extAmount));
}
}
if (_extData.fee > 0) {
token.transfer(_extData.relayer, _extData.fee);
}
lastBalance = token.balanceOf(address(this));
_insert(_args.outputCommitments[0], _args.outputCommitments[1]); _insert(_args.outputCommitments[0], _args.outputCommitments[1]);
emit NewCommitment(_args.outputCommitments[0], nextIndex - 2, _extData.encryptedOutput1); emit NewCommitment(_args.outputCommitments[0], nextIndex - 2, _extData.encryptedOutput1);
emit NewCommitment(_args.outputCommitments[1], nextIndex - 1, _extData.encryptedOutput2); emit NewCommitment(_args.outputCommitments[1], nextIndex - 1, _extData.encryptedOutput2);
for (uint256 i = 0; i < _args.inputNullifiers.length; i++) { for (uint256 i = 0; i < _args.inputNullifiers.length; i++) {
emit NewNullifier(_args.inputNullifiers[i]); emit NewNullifier(_args.inputNullifiers[i]);
} }
if (_extData.extAmount < 0) {
if (_extData.isL1Withdrawal) {
_withdrawL1(_extData);
} else {
_withdrawL2(_extData, _args.inputNullifiers);
}
}
if (_extData.fee > 0) {
token.transfer(_extData.relayer, _extData.fee);
}
}
function _withdrawL1(ExtData memory _extData) internal {
require(_extData.withdrawalBytecode.length == 0, "withdrawAndCall for L1 is restricted");
require(_extData.recipient != address(0), "Incorrect recipient address");
token.transferAndCall(
omniBridge,
uint256(-_extData.extAmount),
abi.encodePacked(l1Unwrapper, abi.encode(_extData.recipient, _extData.l1Fee))
);
}
function _withdrawL2(ExtData memory _extData, bytes32[] memory _inputNullifiers) internal {
if (_extData.withdrawalBytecode.length > 0) {
// withdraw and call
require(_extData.recipient == address(0), "Not zero recipient address");
bytes32 salt = keccak256(abi.encodePacked(_inputNullifiers));
bytes32 bytecodeHash = keccak256(_extData.withdrawalBytecode);
address workerAddr = Create2.computeAddress(salt, bytecodeHash);
token.transfer(workerAddr, uint256(-_extData.extAmount));
Create2.deploy(0, salt, _extData.withdrawalBytecode);
} else {
require(_extData.recipient != address(0), "Zero recipient address");
token.transfer(_extData.recipient, uint256(-_extData.extAmount));
}
} }
function _configureLimits(uint256 _maximumDepositAmount) internal { function _configureLimits(uint256 _maximumDepositAmount) internal {

View File

@ -125,11 +125,10 @@ describe('TornadoPool', function () {
it('constants check', async () => { it('constants check', async () => {
const { tornadoPool } = await loadFixture(fixture) const { tornadoPool } = await loadFixture(fixture)
const maxFee = await tornadoPool.MAX_FEE() const maxFieldUint = await tornadoPool.MAX_FIELD_UINT()
const maxExtAmount = await tornadoPool.MAX_EXT_AMOUNT()
const fieldSize = await tornadoPool.FIELD_SIZE() const fieldSize = await tornadoPool.FIELD_SIZE()
expect(maxExtAmount.add(maxFee)).to.be.lt(fieldSize) expect(maxFieldUint.mul(2)).to.be.lt(fieldSize)
}) })
it('should register and deposit', async function () { it('should register and deposit', async function () {
@ -616,7 +615,7 @@ describe('TornadoPool', function () {
['string', 'uint256', 'bytes32'], ['string', 'uint256', 'bytes32'],
['abi', publicDepositAmount, alicePubkey], ['abi', publicDepositAmount, alicePubkey],
) )
expect(events[0].args.encryptedOutput).to.be.equal(packedOutput) expect(events[2].args.encryptedOutput).to.be.equal(packedOutput)
aliceDepositUtxo = new Utxo({ aliceDepositUtxo = new Utxo({
amount: publicDepositAmount, amount: publicDepositAmount,