Skip to content

Commit

Permalink
client: separate contract tests and use unittest framework
Browse files Browse the repository at this point in the history
  • Loading branch information
dtebbs committed Jan 7, 2021
1 parent 880137b commit 107f198
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 317 deletions.
10 changes: 3 additions & 7 deletions client/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,11 @@ syntax: ${PROTOBUF_OUTPUT}
mypy zeth/cli/zeth zeth/helper/zeth_helper
mypy -p tests
mypy -p test_commands
mypy -p test_contracts
pylint zeth.core zeth.cli zeth.helper tests test_commands

test: ${PROTOBUF_OUTPUT}
python -m unittest
python -m unittest discover tests

test_contracts: ${PROTOBUF_OUTPUT}
python -m test_commands.test_altbn128_mixer_base
python -m test_commands.test_bls12_377_contract
python -m test_commands.test_bw6_761_contract
python -m test_commands.test_groth16_bls12_377_contract
python -m test_commands.test_merkle_tree_contract
python -m test_commands.test_mimc_contract
python -m unittest discover test_contracts
60 changes: 0 additions & 60 deletions client/test_commands/test_mimc_contract.py

This file was deleted.

5 changes: 5 additions & 0 deletions client/test_contracts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) 2015-2020 Clearmatics Technologies Ltd
#
# SPDX-License-Identifier: LGPL-3.0+
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from zeth.core.constants import \
JS_INPUTS, ZETH_PUBLIC_UNIT_VALUE, ZETH_MERKLE_TREE_DEPTH
import test_commands.mock as mock
from unittest import TestCase
from typing import Any

# pylint: disable=line-too-long
Expand Down Expand Up @@ -91,58 +92,46 @@
PACKED_PRIMARY_INPUTS = \
[ROOT] + COMMITMENTS + NULLIFIERS + [HSIG] + HTAGS + [RESIDUAL_BITS]


def test_assemble_nullifiers(mixer_instance: Any) -> None:
# Test retrieving nullifiers
print("--- test_assemble_nullifiers")
for i in range(JS_INPUTS):
res = mixer_instance.functions.\
assemble_nullifier_test(i, PACKED_PRIMARY_INPUTS).call()
val = int.from_bytes(res, byteorder="big")
assert val == NULLIFIERS[i], f"expected: {NULLIFIERS[i]}, got: {val}"


def test_assemble_hsig(mixer_instance: Any) -> None:
# Test retrieving hsig
print("--- test_assemble_hsig")
res = mixer_instance.functions.\
assemble_hsig_test(PACKED_PRIMARY_INPUTS).call()
hsig = int.from_bytes(res, byteorder="big")
assert hsig == HSIG, f"expected: {HSIG}, got {hsig}"


def test_assemble_vpub(mixer_instance: Any) -> None:
# Test retrieving public values
print("--- test_assemble_vpub")
v_in, v_out = mixer_instance.functions.assemble_public_values_test(
PACKED_PRIMARY_INPUTS[-1]).call()
v_in_expect = VPUB[0] * ZETH_PUBLIC_UNIT_VALUE
v_out_expect = VPUB[1] * ZETH_PUBLIC_UNIT_VALUE
assert v_in == v_in_expect, f"expected: {v_in_expect}, got: {v_in}"
assert v_out == v_out_expect, f"expected: {v_out_expect}, got: {v_out}"


def main() -> None:
print("Deploying AltBN128MixerBase_test.sol")
_web3, eth = mock.open_test_web3()
deployer_eth_address = eth.accounts[0]
_mixer_interface, mixer_instance = mock.deploy_contract(
eth,
deployer_eth_address,
"AltBN128MixerBase_test",
{
'mk_depth': ZETH_MERKLE_TREE_DEPTH,
})

print("Testing ...")
test_assemble_nullifiers(mixer_instance)
test_assemble_vpub(mixer_instance)
test_assemble_hsig(mixer_instance)

print("========================================")
print("== PASSED ==")
print("========================================")


if __name__ == '__main__':
main()
MIXER_INSTANCE: Any = None


class TestAltBN128MixerBaseContract(TestCase):

@staticmethod
def setUpClass() -> None:
print("Deploying AltBN128MixerBase_test.sol")
_web3, eth = mock.open_test_web3()
deployer_eth_address = eth.accounts[0]
_mixer_interface, mixer_instance = mock.deploy_contract(
eth,
deployer_eth_address,
"AltBN128MixerBase_test",
{
'mk_depth': ZETH_MERKLE_TREE_DEPTH,
})
global MIXER_INSTANCE # pylint: disable=global-statement
MIXER_INSTANCE = mixer_instance

def test_assemble_nullifiers(self) -> None:
# Test retrieving nullifiers
for i in range(JS_INPUTS):
res = MIXER_INSTANCE.functions.\
assemble_nullifier_test(i, PACKED_PRIMARY_INPUTS).call()
val = int.from_bytes(res, byteorder="big")
self.assertEqual(NULLIFIERS[i], val)

def test_assemble_hsig(self) -> None:
# Test retrieving hsig
res = MIXER_INSTANCE.functions.\
assemble_hsig_test(PACKED_PRIMARY_INPUTS).call()
hsig = int.from_bytes(res, byteorder="big")
self.assertEqual(HSIG, hsig)

def test_assemble_vpub(self) -> None:
# Test retrieving public values
v_in, v_out = MIXER_INSTANCE.functions.assemble_public_values_test(
PACKED_PRIMARY_INPUTS[-1]).call()
v_in_expect = VPUB[0] * ZETH_PUBLIC_UNIT_VALUE
v_out_expect = VPUB[1] * ZETH_PUBLIC_UNIT_VALUE
self.assertEqual(v_in_expect, v_in)
self.assertEqual(v_out_expect, v_out)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: LGPL-3.0+

from test_commands import mock
from unittest import TestCase
from typing import Any

# Data generated by libzeth/tests/core/ec_operation_data_test. Statements to be
Expand Down Expand Up @@ -74,53 +75,47 @@ def _b32(hex_value: str) -> bytes:
]


def test_bls12_ecadd(bls12_instance: Any) -> None:
"""
Check that [6] == [2] + [4]
"""
result = bls12_instance.functions.testECAdd(G1_2 + G1_4).call()
assert result == G1_6


def test_bls12_ecmul(bls12_instance: Any) -> None:
"""
Check that [-8] == -2 * [4]
"""
result = bls12_instance.functions.testECMul(G1_4 + FR_MINUS_2).call()
assert result == G1_MINUS_8


def test_bls12_ecpairing(bls12_instance: Any) -> None:
"""
Check that e([6], [4]) * e([3],[8]) * e([4],[4]) * e([-8], [8]) == 1
"""
# Note, return result here is uint256(1) or uint256(0) depending on the
# pairing check result.
points = G1_6 + G2_4 + G1_3 + G2_8 + G1_4 + G2_4 + G1_MINUS_8 + G2_8
result = bls12_instance.functions.testECPairing(points).call()
assert result == 1

points = G1_6 + G2_4 + G1_3 + G2_8 + G1_4 + G2_4 + G1_MINUS_8 + G2_4
result = bls12_instance.functions.testECPairing(points).call()
assert result == 0


def main() -> None:
_web3, eth = mock.open_test_web3()
_bls12_interface, bls12_instance = mock.deploy_contract(
eth,
eth.accounts[0],
"BLS12_377_test",
{})

test_bls12_ecadd(bls12_instance)
test_bls12_ecmul(bls12_instance)
test_bls12_ecpairing(bls12_instance)

print("========================================")
print("== PASSED ==")
print("========================================")


if __name__ == "__main__":
main()
BLS12_INSTANCE: Any = None


class TestBLS12_377Contract(TestCase):

@staticmethod
def setUpClass() -> None:
print("Deploying BLS12_377_test.sol")
_web3, eth = mock.open_test_web3()
_bls12_interface, bls12_instance = mock.deploy_contract(
eth,
eth.accounts[0],
"BLS12_377_test",
{})
global BLS12_INSTANCE # pylint: disable=global-statement
BLS12_INSTANCE = bls12_instance

def test_bls12_ecadd(self) -> None:
"""
Check that [6] == [2] + [4]
"""
result = BLS12_INSTANCE.functions.testECAdd(G1_2 + G1_4).call()
self.assertEqual(G1_6, result)

def test_bls12_ecmul(self) -> None:
"""
Check that [-8] == -2 * [4]
"""
result = BLS12_INSTANCE.functions.testECMul(G1_4 + FR_MINUS_2).call()
self.assertEqual(G1_MINUS_8, result)

def test_bls12_ecpairing(self) -> None:
"""
Check that e([6], [4]) * e([3],[8]) * e([4],[4]) * e([-8], [8]) == 1
"""
# Note, return result here is uint256(1) or uint256(0) depending on the
# pairing check result.
points = G1_6 + G2_4 + G1_3 + G2_8 + G1_4 + G2_4 + G1_MINUS_8 + G2_8
result = BLS12_INSTANCE.functions.testECPairing(points).call()
self.assertEqual(1, result)

points = G1_6 + G2_4 + G1_3 + G2_8 + G1_4 + G2_4 + G1_MINUS_8 + G2_4
result = BLS12_INSTANCE.functions.testECPairing(points).call()
self.assertEqual(0, result)
Loading

0 comments on commit 107f198

Please sign in to comment.