Source code for acqdp.tensor_network.slicer

from acqdp.tensor_network.local_optimizer import defaultOrderResolver, LocalOptimizer
from acqdp.tensor_network.contraction import ContractionScheme
from multiprocessing import Pool
import copy
import numpy


[docs]class Slicer: """ :class:`Slicer` finds slicing of an unsliced contraction scheme when called by the :class:`SlicedOrderFinder`. :ivar num_iter_before: Number of iterations of local optimization before slicing. :ivar num_iter_before: Number of iterations of local optimization in the middle of slicing. :ivar num_iter_before: Number of iterations of local optimization after slicing. :ivar max_num_slice: Maxmimum number of edges to be sliced. If set to -1, the constraint will be ignored. :ivar num_threads: Number of threads for multi-processing. :ivar slice_thres: Automatically slice edges that introduce an overhed below this threshold. Set to 0.02 by default. """ def __init__(self, num_iter_before=0, num_iter_middle=20, num_iter_after=100, max_tw=29, max_num_slice=-1, num_threads=28, slice_thres=.02, **kwargs): self.num_iter_before = num_iter_before self.num_iter_middle = num_iter_middle self.num_iter_after = num_iter_after self.max_tw = max_tw self.max_num_slice = max_num_slice self.num_threads = num_threads self.slice_thres = slice_thres self.local_optimizer = LocalOptimizer( **kwargs.get('local_optimizer_params', {})) self.num_suc_candidates = kwargs.get('num_suc_candidates', 10) def _slice(self, tn, orders, num_process=0): tn = tn.copy() tnc = tn while True: try: tn = tnc.copy() y = min(orders, key=lambda a: (a.cost, orders.index(a))) slice_edges = [] print(f'Process {num_process} initial cost: {y.cost}', flush=True) y = self.local_optimizer.optimize( tn, y, self.num_iter_before) while y.cost.t > 2**self.max_tw: y = self.local_optimizer.optimize( tn, y, self.num_iter_middle) k, order = self._biggest_weight_edge(tn, y.order) slice_edges += k for a in k: tn.fix_edge(a) tn.fix() y = defaultOrderResolver.order_to_contraction_scheme(tn, order) y.cost.k = len(slice_edges) if len(slice_edges) >= self.max_num_slice: break if numpy.log2(float(y.cost.t)) - self.max_tw + len( slice_edges) >= self.max_num_slice + 2: # early termination break new_y = self.local_optimizer.optimize( tn, y, self.num_iter_after) new_y.cost.k = len(slice_edges) if new_y.cost.t <= 2**self.max_tw: y = new_y if y.cost.t <= 2**self.max_tw: print(f'Process {num_process} succeeded with {y.cost}', flush=True) return ContractionScheme(y.order, slice_edges, cost=y.cost) else: return None except KeyboardInterrupt: return None def _biggest_weight_edge(self, tn, order): """Find an edge or a list of edges, slicing of which introduces an overhead each that is below a threshold given by self.slice_thres, or a minimal overhead if self.slice_thres is unattainable. The method enumerates all edges that appear frequently on the stem of the contraction tree. It tries to introduce as minimal overhead as possible by flipping branches on the stem while trying to slice the edges. """ tn_copy = tn.copy() tn_copy.fix() nodes_names = list(tn_copy.nodes_by_name) from acqdp.tensor_network.undirected_contraction_tree import UndirectedContractionTree eq, path, eedd = defaultOrderResolver.order_to_path(tn_copy, order) uct = UndirectedContractionTree(eq, path) se = set() edges_dic = {} ss = [] for i in range(len(uct.stem) - 1): new_se = uct.open_subscripts_at_edge( uct.graph.nodes[uct.stem[i]]['parent'], uct.stem[i]) for a in new_se.difference(se): edges_dic[a] = i for a in se.difference(new_se): ss.append((a, edges_dic[a], i)) se = new_se ss = sorted(ss, key=lambda x: x[1] - x[2])[:10] ss_dic = {} c = uct.cost slice_edges = [] for s in ss: uct_copy = copy.deepcopy(uct) for v in range(uct_copy.n): uct_copy.graph.nodes[v]['subscripts'] = uct_copy.graph.nodes[v][ 'subscripts'].difference({s[0]}) for u, v in uct_copy.graph.edges: uct_copy.graph[u][v].clear() uct_copy.graph[v][u].clear() uct_copy.preprocess_edge(u, v) uct_copy.preprocess_edge(v, u) uct_copy.compute_root_cost() for v in range(uct_copy.n, uct_copy.n * 2 - 2): uct_copy.compute_node_cost(v) res = (uct_copy.cost, s[1], s[2], uct_copy.get_path()) i = s[1] curr_cost = res[0] while i > 5: ii = i while ii > 3: for k in range(3, ii)[::-1]: for l in range(k, ii): uct_copy.switch_branches(l) if uct_copy.cost <= curr_cost: curr_cost = uct_copy.cost ii -= 1 break else: for l in range(k, ii)[::-1]: uct_copy.switch_branches(l) else: break k = ii sss = ii + 1 while k > 5: uct_copy.switch_branches(k) if uct_copy.cost <= curr_cost: curr_cost = uct_copy.cost sss = k k -= 1 ii = sss for k in range(6, ii): uct_copy.switch_branches(k) if ii >= i: break i = ii res = (curr_cost, i, res[2]) j = res[2] while j < len(uct_copy.stem) - 5: jj = j while jj < len(uct_copy.stem) - 3: for k in range(jj + 1, len(uct_copy.stem) - 1): for l in range(jj, k)[::-1]: uct_copy.switch_branches(l) ucost = uct_copy.cost if ucost <= curr_cost: curr_cost = ucost jj += 1 break else: for l in range(jj, k): uct_copy.switch_branches(l) else: break k = jj sss = jj - 1 while k < len(uct_copy.stem) - 5: uct_copy.switch_branches(k) ucost = uct_copy.cost if ucost <= curr_cost: curr_cost = uct_copy.cost sss = k k += 1 jj = sss for k in range(jj + 1, len(uct_copy.stem) - 5)[::-1]: uct_copy.switch_branches(k) if jj <= j: break j = jj if curr_cost < (1 + self.slice_thres) / 2 * c: slice_edges.append(eedd[s[0]]) c = curr_cost uct = uct_copy else: ss_dic[s[0]] = (curr_cost, res[1], j, uct_copy.get_path()) if len(slice_edges) > 0: pp = uct.get_path() else: kk = sorted(ss_dic, key=lambda x: ss_dic[x][0])[0] slice_edges = [eedd[kk]] pp = ss_dic[kk][-1] new_order = [] for i, p in enumerate(pp): new_order.append([[nodes_names[p[0]], nodes_names[p[1]]], order[i][1]]) nodes_names.pop(max(p[0], p[1])) nodes_names.pop(min(p[0], p[1])) nodes_names.append(order[i][1]) return slice_edges, new_order def slice(self, tn, order_gen): orders = [next(order_gen) for _ in range(self.num_suc_candidates)] return self._slice(tn, orders)
def mpwrapper(slicer, tn, orders, num_process): return slicer._slice(tn, orders, num_process)
[docs]class MPSlicer(Slicer): """Multi-processing slicing, by concurrently trying different slicing routes.""" def slice(self, tn, order_gen): candidates = [] while len(candidates) <= self.num_suc_candidates: with Pool(self.num_threads) as p: lk = list((self, tn, [next(order_gen)], num_process) for num_process in range(self.num_threads)) new_candidates = p.starmap(mpwrapper, lk) candidates += [i for i in new_candidates if i is not None] print("Num of candidates now: {}".format(len(candidates))) res = min(candidates, key=lambda x: (x.cost, candidates.index(x))) return res
def get_slicer(**kwargs): slicers = {'default': Slicer, 'mp': MPSlicer} slicer_name = kwargs.get('slicer_name', 'default') return (slicers[slicer_name])(**kwargs.get('slicer_params', {}))