Skip to content

Commit

Permalink
Debug
Browse files Browse the repository at this point in the history
  • Loading branch information
dallonasnes committed Sep 25, 2023
1 parent 0a4a093 commit cc455be
Showing 1 changed file with 43 additions and 48 deletions.
91 changes: 43 additions & 48 deletions qtensor/Simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,54 +156,49 @@ def sample(self):
return _sequence_sample(tn, composer.qubits)

def _sequence_sample(tn: TNAdapter, indices, batch_size=10, batch_fix_sequence=None, dim=2):
"""
Args:
tn: tensor network
indices: list of indices to contract
"""
K = int(np.ceil(len(indices) / batch_size))
if batch_fix_sequence is None:
batch_fix_sequence = [1]*K

slice_dict = {}
cache = {}
samples = [Bs.str('', prob=1., dim=dim)]
z_0 = None
for i in range(K):
for j in range(len(samples)):
bs = samples.pop(0)
res = None
if len(bs)>0:
res = cache.get(bs.to_int())
if res is None:
free_batch_ix = indices[i*batch_size:(i+1)*batch_size]
_fix_indices = indices[: len(bs)]
update = dict(zip(_fix_indices, list(bs)))
slice_dict.update(dict(zip(_fix_indices, list(bs))))
res = contract_tn(tn, slice_dict, free_batch_ix)
res = res.real**2
K = int(np.ceil(len(indices) / batch_size))
if batch_fix_sequence is None:
batch_fix_sequence = [1]*K

slice_dict = {}
cache = {}
samples = [Bs.str('', prob=1., dim=dim)]
z_0 = None
for i in range(K):
for j in range(len(samples)):
bs = samples.pop(0)
res = None
if len(bs)>0:
cache[bs.to_int()] = res

# result should be shaped accourdingly
if z_0 is None:
z_0 = res.sum()
prob_prev = bs._prob
z_n = prob_prev * z_0
z_n = res.sum()
logger.debug('bs {}, Sum res {}, prev_Z {}, prob_prev {}',
bs, res.sum(), prob_prev*z_0, prob_prev
)
pdist = res.flatten() / z_n
logger.debug(f'Prob distribution: {pdist.round(4)}')
indices_bs = np.arange(len(pdist))
batch_ix = np.random.choice(indices_bs, batch_fix_sequence[i], p=pdist)
for ix in batch_ix:
_new_s = bs + Bs.int(ix, width=len(free_batch_ix), prob=pdist[ix], dim=dim)
logger.trace(f'New sample: {_new_s}')
samples.append(_new_s)

return samples
res = cache.get(bs.to_int())
if res is None:
free_batch_ix = indices[i*batch_size:(i+1)*batch_size]
_fix_indices = indices[: len(bs)]
update = dict(zip(_fix_indices, list(bs)))
slice_dict.update(dict(zip(_fix_indices, list(bs))))
res = contract_tn(tn, slice_dict, free_batch_ix)
res = res.real**2
if len(bs)>0:
cache[bs.to_int()] = res

# result should be shaped accourdingly
if z_0 is None:
z_0 = res.sum()
prob_prev = bs._prob
z_n = prob_prev * z_0
z_n = res.sum()
logger.debug('bs {}, Sum res {}, prev_Z {}, prob_prev {}',
bs, res.sum(), prob_prev*z_0, prob_prev
)
pdist = res.flatten() / z_n
logger.debug(f'Prob distribution: {pdist.round(4)}')
indices_bs = np.arange(len(pdist))
batch_ix = np.random.choice(indices_bs, batch_fix_sequence[i], p=pdist)
for ix in batch_ix:
_new_s = bs + Bs.int(ix, width=len(free_batch_ix), prob=pdist[ix], dim=dim)
logger.trace(f'New sample: {_new_s}')
samples.append(_new_s)

return samples


class CirqSimulator(Simulator):
Expand All @@ -225,4 +220,4 @@ def simulate(self, qc, **params):

sim = QAOAQtreeSimulator(composer)

print("hello world")
logger.debug('hello world')

0 comments on commit cc455be

Please sign in to comment.