diff --git a/contracts/token/CSSVToken.sol b/contracts/token/CSSVToken.sol index aca4db31..56dac579 100644 --- a/contracts/token/CSSVToken.sol +++ b/contracts/token/CSSVToken.sol @@ -10,22 +10,27 @@ interface ISSVStaking { contract CSSVToken is ERC20 { error NotSSVStaking(); error ZeroAddress(); + error InvalidRecipient(); - address public immutable ssvStaking; + address public immutable ssvNetwork; modifier onlySSVStaking() { - if (msg.sender != ssvStaking) revert NotSSVStaking(); + if (msg.sender != ssvNetwork) revert NotSSVStaking(); _; } - constructor(address ssvStaking_) ERC20("cSSV", "cSSV") { - if (ssvStaking_ == address(0)) revert ZeroAddress(); - ssvStaking = ssvStaking_; + constructor(address ssvNetwork_) ERC20("cSSV", "cSSV") { + if (ssvNetwork_ == address(0)) revert ZeroAddress(); + ssvNetwork = ssvNetwork_; } function _beforeTokenTransfer(address from, address to, uint256 amount) internal override { - if (from != to && from != address(0) && to != address(0) && msg.sender != ssvStaking && amount > 0) { - ISSVStaking(ssvStaking).onCSSVTransfer(from, to, amount); + if (to == address(this) || to == ssvNetwork) { + revert InvalidRecipient(); + } + + if (from != to && from != address(0) && to != address(0) && msg.sender != ssvNetwork && amount > 0) { + ISSVStaking(ssvNetwork).onCSSVTransfer(from, to, amount); } super._beforeTokenTransfer(from, to, amount); } diff --git a/test/common/errors.ts b/test/common/errors.ts index a47e5315..3e5760e2 100644 --- a/test/common/errors.ts +++ b/test/common/errors.ts @@ -61,5 +61,6 @@ export const Errors = { LEGACY_OPERATOR_FEE_DECLARATION_INVALID: "LegacyOperatorFeeDeclarationInvalid", ORACLE_HAS_ZERO_WEIGHT: "OracleHasZeroWeight", MAX_VALUE_EXCEEDED: "MaxValueExceeded", - MAX_PRECISION_EXCEEDED: "MaxPrecisionExceeded" + MAX_PRECISION_EXCEEDED: "MaxPrecisionExceeded", + INVALID_RECIPIENT: "InvalidRecipient" } as const; diff --git a/test/echidna/CSSVTokenAccessControlEchidna.sol b/test/echidna/CSSVTokenAccessControlEchidna.sol index 674ad094..6b6816ca 100644 --- a/test/echidna/CSSVTokenAccessControlEchidna.sol +++ b/test/echidna/CSSVTokenAccessControlEchidna.sol @@ -65,6 +65,6 @@ contract CSSVTokenAccessControlEchidna is CSSVToken { } function echidna_only_self_is_staking() public view returns (bool) { - return ssvStaking == address(this); + return ssvNetwork == address(this); } } diff --git a/test/echidna/CSSVTokenEchidna.sol b/test/echidna/CSSVTokenEchidna.sol index 1c388773..48016e6d 100644 --- a/test/echidna/CSSVTokenEchidna.sol +++ b/test/echidna/CSSVTokenEchidna.sol @@ -151,7 +151,7 @@ contract CSSVTokenEchidna is CSSVToken { } function echidna_staking_is_self() public view returns (bool) { - return ssvStaking == address(this); + return ssvNetwork == address(this); } function echidna_name_immutable() public view returns (bool) { diff --git a/test/integration/SSVNetwork.test.ts b/test/integration/SSVNetwork.test.ts index 7aec5d4e..dbddcbba 100644 --- a/test/integration/SSVNetwork.test.ts +++ b/test/integration/SSVNetwork.test.ts @@ -3243,6 +3243,51 @@ describe("SSVNetwork full integration tests", () => { await expect(network.connect(randomUser).onCSSVTransfer(randomUser.address, randomUser.address, 123)) .to.be.revertedWithCustomError(network, Errors.NOT_CSSV); }); + + it("CSSV transfer is reverted with 'InvalidRecipient' if trying to send tokens to SSVNetwork contract", async function() { + const { network, ssvToken, cssvToken } = + await networkHelpers.loadFixture(deployFullSSVNetworkFixture); + + await ssvToken.connect(randomUser).approve(await network.getAddress(), connection.ethers.MaxUint256); + await ssvToken.mint(randomUser.address, STAKE_AMOUNT); + + await network.connect(randomUser).stake(STAKE_AMOUNT); + + const oracles = (await connection.ethers.getSigners()).slice(10, 14); + + await network.replaceOracle(1, oracles[0].address); + await network.replaceOracle(2, oracles[1].address); + await network.replaceOracle(3, oracles[2].address); + await network.replaceOracle(4, oracles[3].address); + + const operatorIds = await registerOperators(network, operatorOwner, 4) + const clusters = await registerDefaultClusters( + connection, + network, + operatorIds, + operatorOwner, + 8 + ); + const merkleData = buildEBMerkleForDefaultClusters(connection, clusters, 33); + + const block = await connection.ethers.provider.getBlock('latest'); + const blockNum = block!.number + + for (let i = 0; i < 3; i++) { + await network.connect(oracles[i]).commitRoot(merkleData.root, blockNum); + } + + await updateClusterBalancesForDefaultClusters( + network, + clusters, + merkleData, + blockNum, + 33 + ); + + await expect(cssvToken.connect(randomUser).transfer(await network.getAddress(), 123)) + .to.be.revertedWithCustomError(cssvToken, Errors.INVALID_RECIPIENT); + }); }); describe("Reentrancy Guard Tests", async function () {