Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup Kahypar #39

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 101 additions & 146 deletions qtensor/optimisation/kahypar_ordering/use_kahypar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#from qtensor.optimisation.kahypar_ordering import generate_TN
import kahypar as kahypar
from os.path import join, abspath, dirname

import copy
# -- Timing
from contextlib import contextmanager
import time
Expand Down Expand Up @@ -43,7 +43,7 @@ def set_context(**kwargs):

def ka_hg_init(tn): # tn: a dictionary from circ2tn
h = tn.values()
l = list({l for word in h for l in word}) #set of unique edges (eg, v_1)
l = list({l for word in h for l in word}) #set of unique vertexes (eg, '0x7ff77d9dfb50')
l.sort()
nodes = list(range(0, len(l)))
edges = []
Expand Down Expand Up @@ -113,11 +113,18 @@ def recur_partition(tn,**kwargs):
K = int(kwargs.get('K'))
while max([len(x) for x in tn_partite_list[layer]]) > K-1:
layer += 1
tn_partite_list.append([])
result = []
for (count,subgraph) in enumerate(tn_partite_list[layer-1]):
if subgraph != {}: # important
tn_partite_list[layer][2*count:2*count] = subgraph_partition(subgraph,**kwargs)

result.extend(subgraph_partition(subgraph,**kwargs))

if result == tn_partite_list[layer-1]:
return tn_partite_list # for large imbalance
else:
tn_partite_list.append(result)
#tn_partite_list[layer].extend(result)
#tn_partite_list[layer][K*count:K*count] = result

# TODO: Adjust the hyperparameters during the partition

return tn_partite_list
Expand Down Expand Up @@ -154,7 +161,6 @@ def tree2order(tn,tn_partite_list):
all_edge = list(tn.keys())
layer_num = len(tn_partite_list)
order = []
import copy
order_tree = copy.deepcopy(tn_partite_list)
t = 0 # count the temp result in order_tree
for layer in range(layer_num):
Expand Down Expand Up @@ -203,154 +209,103 @@ def tree2order(tn,tn_partite_list):

order = [x for x in order if type(x) != str]
assert len(order) == len(all_edge)
# complete the top of order_tree

#complete the top of order_tree
set_last1 = set(list(tn_partite_list[0][0].keys()))
set_last2 = set(list(tn_partite_list[0][1].keys()))
result = list(set(all_edge) - set_last1 - set_last2)
order_tree= [result] + order_tree

return order,order_tree

# find the order from bottom to top
# correct order_tree,
# TODO: there is some bugs in order insertion (find the index of children)
all_edge = list(tn.keys())
tn_partite_list = tn_partite_list[::-1]
import copy
order_tree = copy.deepcopy(tn_partite_list)
t = 0 # count the temp result in order_tree
layer_num = len(tn_partite_list)
order = []
sub_opt = False # local order search for the bottom graph
def order_tree2ec(order_tree,tn,tn_partite_list):
# There is still some bugs in this function
K = len(tn_partite_list[0])
ec_tree = copy.deepcopy(tn_partite_list)
t = [[]] *(len(ec_tree[-1])*2)
ec_tree.append(t) #contraction tree like Fig in Johnnie's paper

layer_num = len(ec_tree)
for layer in range(layer_num):
if layer == 0: #bottom layer, append order if there is an edge
# if layer == 0:
# parent_graph = order_tree[layer]
# for (count,subgraph) in enumerate(tn_partite_list[layer]):
# add_list=[]; t = 0
# self_node = subgraph.values()
# self_node = list({l for word in self_node for l in word})
# for item in parent_graph:
# parent_node = tn.get(item)
# if parent_node != None:
# if any(i in self_node for i in parent_node):
# add_list.append(item)
# t += 1
# continue
# ec_tree[layer][count] = add_list

if layer < layer_num - 1:
for (count,subgraph) in enumerate(tn_partite_list[layer]):
if sub_opt is True:
# TODO: when the subgraph is small, call other order optimizor)
#order.append(local_search(subgraph))
continue
if subgraph != {}:
add_list=[]; t = 0
self_node = subgraph.values()
self_node = list({l for word in self_node for l in word})
if layer == 0:
parent_graph = order_tree[layer]
else:
parent_ind = find_parent_ind(subgraph,tn_partite_list, layer)
parent_graph = ec_tree[layer-1][parent_ind]

for item in parent_graph:
parent_node = tn.get(item)
if parent_node != None:
if any(i in self_node for i in parent_node):
add_list.append(item)
t += 1
continue

if layer == 0:
ec_tree[layer][count] = add_list
else:
ec_tree[layer][count] = add_list + order_tree[layer][parent_ind]
#eliminate the "temp" ind
ec_tree[layer][count] = [x for x in ec_tree[layer][count] if type(x)!=str]
else:
#if subgraph != {} :
result = list(subgraph.keys())
if result == [] :
result = [f'temp_{t}']
t = t+1
order.extend(result)
order_tree[layer][count]=result
else:
#non-bottom layer, need to insert order
for (count,subgraph) in enumerate(tn_partite_list[layer]):
#left_node_empty = 0
if 1 == 1 :
node_list = list({l for word in list(subgraph.values()) for l in word})
ind_last = []
# find the child node of the subgraph
for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]):
if subgraph_last != {}:
node_list_last = list({l for word in list(subgraph_last.values()) for l in word})
check = all(item in node_list for item in node_list_last)
if check is True:
ind_last.append(count_last)

if len(ind_last) == 2:
last_set1 = set(list(tn_partite_list[layer-1][ind_last[0]].keys()))
last_set2 = set(list(tn_partite_list[layer-1][ind_last[1]].keys()))
result = list(set(subgraph) - last_set1 - last_set2)
if result == [] :
result = [f'temp_{t}']
t = t + 1
# for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]):
# node_list_last = list({l for word in list(subgraph_last.values()) for l in word})
# check = all(item in node_list for item in node_list_last)
# if check is True:
# child_ind = count_last
child_set1 = list(order_tree[layer-1][ind_last[0]])
child_set2 = list(order_tree[layer-1][ind_last[1]])

exist_order = [order.index(x) for x in list(child_set1) if x in order] + \
[order.index(x) for x in list(child_set2) if x in order]
ind = max(exist_order)+1
order[ind:ind]=result
order_tree[layer][count]=result
ec_tree[layer][count]=[]
elif layer == layer_num - 1:
#TODO: to fix
for (count,_) in enumerate(ec_tree[layer]):
if type(order_tree[layer][count//K]) == list:
ec_tree[layer][count] = order_tree[layer][count//K]
else:
ec_tree[layer][count] = []

### there is a single node partition in the subgraph
if len(ind_last) == 1:
last_set = set(list(tn_partite_list[layer-1][ind_last[0]].keys()))
result = list(set(subgraph) - last_set)
if result == [] :
result = [f'temp_{t}']
t = t + 1
# for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]):
# node_list_last = list({l for word in list(subgraph_last.values()) for l in word})
# check = all(item in node_list for item in node_list_last)
# if check is True:
# child_ind = count_last
child_set = list(order_tree[layer-1][ind_last[0]])
exist_order = [order.index(x) for x in list(child_set) if x in order]
ind = max(exist_order)+1
order[ind:ind]=result
order_tree[layer][count]=result

### there are two single node partition in the subgraph
if len(ind_last) == 0:
result = list(subgraph)
if result == [] :
result = [f'temp_{t}']
t = t + 1
for (count_last,subgraph_last) in enumerate(tn_partite_list[layer-1]):
node_list_last = list({l for word in list(subgraph_last.values()) for l in word})
check = all(item in node_list for item in node_list_last)
if check is True:
child_ind = count_last
child_set = list(order_tree[layer-1][child_ind])

exist_order = [order.index(x) for x in list(child_set) if x in order]
ind = max(exist_order)+1
order[ind:ind]=result
order_tree[layer][count] = result

'''
#count = tn_partite_list[layer].index(subgraph)
if count % 2 == 0: # left node
# find the contracted edge of its paired node in the same layer
if len(tn_partite_list[layer][count+1].keys()) < 2:
order.extend(result)
order_tree[layer][count]=result
# both left and right subgraphs are empty, new start
else:
left_node_empty = 1
empty_count = count
left_node_buffer = result
order_tree[layer][count]=result
# follow the order of the right-paired subgraph
else: # right node
if len(tn_partite_list[layer][count-1].keys()) < 2:
order.extend(result)
order_tree[layer][count]=result
# both left and right subgraphs are empty, new start
else:
exist_order = [order.index(x) for x in list(tn_partite_list[layer][count-1]) if x in order]
ind = max(exist_order)+1
order[ind:ind] = result
order_tree[layer][count]=result
# follow the order of the left-paired subgraph
if left_node_empty != 0:
exist_order = [order.index(x) for x in list(tn_partite_list[layer][empty_count+1]) if x in order]
ind = max(exist_order)+1
order[ind:ind] = left_node_buffer
'''

if layer == layer_num - 1:
set_last1 = set(list(tn_partite_list[layer][0].keys()))
set_last2 = set(list(tn_partite_list[layer][1].keys()))
result = list(set(all_edge) - set_last1 - set_last2)
if result == [] :
result = [f'temp_{t}']
t = t + 1
order.extend(result)
order_tree.append(result)
# Count the edge contraction from the ec_tree
# Open edges from two subgraphs - 1
ec=[] #edge contraction
for layer in range(layer_num):
if layer == 0:
temp = []
for i in range(K):
temp = temp + ec_tree[layer][i]
result = len(set(temp))-1
ec.append(result)
elif layer < layer_num - 1:
for (count,subgraph) in enumerate(tn_partite_list[layer-1]):
if subgraph != {}:
child_ind = find_child_ind(subgraph,tn_partite_list, layer-1)
temp = []
#TODO: to fix
if len(child_ind) > 0:
for i in range(len(child_ind)):
temp = temp + ec_tree[layer][child_ind[i]]
result = len(set(temp))-1
else:
result = 0
ec.append(result)
elif layer == layer_num - 1:
for (count,_) in enumerate(ec_tree[layer]):
result = len(set(ec_tree[layer][count//2]+ec_tree[layer][count//K + 1]))
if result > 0:
result = result -1
ec.append(result)

order = [x for x in order if type(x) != str]
assert len(order) == len(all_edge)
order_tree = order_tree[::-1]

return order, order_tre
return max(ec)