Skip to content

Commit

Permalink
Add pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
tinaoberoi committed Oct 30, 2023
1 parent b83b726 commit a143138
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
5 changes: 3 additions & 2 deletions scratchpad/qtensor_MPS/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, tensor_name, N, physical_dim = 1) -> None:
for i in range(N-2):
node = tn.Node(np.array([[[1.0]], *[[[0.0]]]*(physical_dim-1)], dtype=np.complex64), name = tensor_name + str(i+1))
nodes.append(node)
nodes.append(tn.Node(np.array([[1.0], *[[0.0]]*(physical_dim-1)], dtype=np.complex64), name = tensor_name + str(0)))
nodes.append(tn.Node(np.array([[1.0], *[[0.0]]*(physical_dim-1)], dtype=np.complex64), name = tensor_name + str(N-1)))

for i in range(1, N-2):
tn.connect(nodes[i].get_edge(2), nodes[i+1].get_edge(1))
Expand Down Expand Up @@ -234,7 +234,8 @@ def get_norm(self):
"""
Method to calculate norm of mps
"""
return np.sqrt(self.inner_product(self).real)
val = np.sqrt(self.inner_product(self).real)
return val


def left_cannoise(self, i):
Expand Down
40 changes: 28 additions & 12 deletions scratchpad/qtensor_MPS/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,62 @@ def test_from_wavefunction_random():
wavefunction = np.random.rand(2**n)
wavefunction = wavefunction / np.linalg.norm(wavefunction, ord = 2)
mps = MPS.construct_mps_from_wavefunction(wavefunction, 'q', n, 2)
assert mps.get_norm() == 1
assert np.isclose(mps.get_norm(), 1.)
assert np.allclose(mps.get_wavefunction(), wavefunction)

def test_apply_one_qubit_mps_operation_xgate():
# q0q1 = 00
# On apply x gate |00> -> |10>
mps = MPS("q", 2, 2)
assert mps.get_norm() == 1
assert np.isclose(mps.get_norm(), 1.)

mps.apply_single_qubit_gate(xgate(), 0)
assert mps.get_norm() == 1
mps.apply_single_qubit_gate(xgate(), 1)
assert np.isclose(mps.get_norm(), 1.)

assert np.allclose(mps.get_wavefunction(), np.array([0.0, 0.0, 1.0, 0.0], dtype=np.complex64))
assert np.allclose(mps.get_wavefunction(), np.array([0.0, 0.0, 0.0, 1.0], dtype=np.complex64))

def test_apply_twoq_cnot_two_qubits():
# In the following tests, the first qubit is always the control qubit.
# Check that CNOT|10> = |11>
mps = MPS("q", 2, 2)
assert mps.get_norm() == 1
assert np.isclose(mps.get_norm(), 1.)

mps.apply_single_qubit_gate(xgate(), 0)
mps.apply_two_qubit_gate(cnot(), [0, 1])
assert mps.get_norm() == 1
assert np.isclose(mps.get_norm(), 1.)
assert np.allclose(mps.get_wavefunction(), np.array([0.0, 0.0, 0.0, 1.0], dtype=np.complex64))

def test_apply_two_twoq_cnot_two_qubits():
# In the following tests, the first qubit is always the control qubit.
# Check that CNOT(0,1)|100> = |110>
# Check that CNOT(1,2)|110> = |111>
mps = MPS("q", 3, 2)
assert np.isclose(mps.get_norm(), 1.)

mps.apply_single_qubit_gate(xgate(), 0)
mps.apply_two_qubit_gate(cnot(), [0, 1])
mps.apply_two_qubit_gate(cnot(), [1, 2])
assert np.isclose(mps.get_norm(), 1.)
assert np.allclose(mps.get_wavefunction(), np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], dtype=np.complex64))

def test_apply_gate_for_bell_circuit():
mps = MPS("q", 2, 2)
assert mps.get_norm() == 1
assert np.isclose(mps.get_norm(), 1.)

mps.apply_single_qubit_gate(hgate(), 0)
mps.apply_two_qubit_gate(cnot(), [0,1])
assert mps.get_norm() == 1
assert np.allclose(mps.get_wavefunction(), np.array([0.707, 0.0, 0.0, 0.707], dtype=np.complex64))
assert np.isclose(mps.get_norm(), 1.)
assert np.allclose(mps.get_wavefunction(), np.array([0.707106, 0.0, 0.0, 0.707106], dtype=np.complex64))

def test_apply_gate_for_ghz_circuit():
mps = MPS("q", 3, 2)
assert mps.get_norm() == 1
assert np.isclose(mps.get_norm(), 1.)

mps.apply_single_qubit_gate(hgate(), 0)
mps.apply_two_qubit_gate(cnot(), [0,1])
mps.apply_two_qubit_gate(cnot(), [1,2])
assert mps.get_norm() == 1
assert np.allclose(mps.get_wavefunction(), np.array([ 0.7071, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7071], dtype=np.complex64))
assert np.isclose(mps.get_norm(), 1.)
assert np.allclose(mps.get_wavefunction(), np.array([ 0.7071, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7071], dtype=np.complex64))


0 comments on commit a143138

Please sign in to comment.