Source code for acqdp.tensor_network.contractor

import inspect
import numpy
from multiprocessing import Pool
from functools import lru_cache


def getDeployPlugin(pluginName):
    """plugin is just an installed python module that has "tensor_network" submodule."""
    import importlib
    res = importlib.import_module(pluginName)
    return res

[docs]class Contractor: """Contractor class for tensor network contraction takes a :class:`ContractionTask` object and execute it sequentially. For :class:`NetworkContractionTask`, multi-processing is available for further accelarate the computation. :ivar backend: When set to `jax`, large tensor contractions will make use of the `jax` backend. `numpy.einsum` is used otherwise. :ivar exeEngine: Extension interface for other contraction backends. Set to `None` by default. When set to `parallel`, subtasks will be computed simultaneously. """ def __init__(self, exeEngine=None, backend='default', dtype=complex, **kwargs): self.exeEngine = exeEngine self.backend = backend self.dtype = numpy.dtype(dtype)
[docs] def execute(self, tasks, lst=None, **kwargs): """Execute a contraction task. :param tasks: The task to be executed. :type tasks: :class:`acqdp.ContractionScheme` :param lst: The list of subtasks to be executed. If set to `None`, all subtasks are executed and merged together. :type lst: :class:`List` :returns: :class:`numpy.ndarray` -- Final result expressed as a multi-dimensional array. """ tasks._load_data() if lst is None: lst = range(tasks.length) engine = self.exeEngine if engine is None: engine = DEFAULT_ENGINE if engine == 'numpy': return tasks._merge( {i: self._execute(tasks[i], **kwargs) for i in lst}) elif inspect.ismodule(engine): res = engine.tensor_networkService.contractorExecute(tasks) return res elif engine == "parallel": with Pool() as p: return tasks._merge(p.starmap(self._execute, [tasks[i] for i in lst])) elif engine.startswith("plugin:"): pluginName = engine[7:] plugin = getDeployPlugin(pluginName) res = plugin.tensor_networkService.contractorExecute(tasks) return res
def _execute(self, task, track=False, normalize=False, cnt=None, **kwargs): if cnt is None: commands = task.commands else: commands = task.commands[:cnt] output = task.output for command in commands: if track: print("Current Memory usage = {}".format( self._track_memory(commands) + 4)) operation = command[0] lhs = command[1] rhs = command[2] kwargs = command[3] try: if operation == 'f': res = numpy.moveaxis(lhs[0][1], kwargs['fix_idx'], range(len(kwargs['fix_idx'])))[tuple( [a[0] for a in kwargs['fix_to']])] rhs[0] = (lhs[0][0], numpy.array(res)) else: if operation == 'c': init_norm = sum([l[0][0] for l in lhs]) if 'expr' in kwargs: if self.backend == 'jax': res = kwargs['expr'](*[l[0][1] for l in lhs], backend='jax') else: res = kwargs['expr'](*[l[0][1] for l in lhs]) else: res = numpy.array(numpy.einsum(kwargs['subscripts'], *[l[0][1] for l in lhs])) elif operation == 'n': init_norm = 0 res = kwargs['func'](numpy.exp(lhs[0][0]) * lhs[0][1], **kwargs) if normalize: import numexpr as ne norm = ne.evaluate( 'max(res.real ** 2 + res.imag ** 2)')**0.5 if norm == 0: rhs[0] = (0, numpy.zeros(res.shape)) else: res /= norm rhs[0] = (numpy.log(norm) + init_norm, res) else: rhs[0] = (init_norm, res) except Exception as e: print(e) print(command) raise e if cnt is None: res = numpy.exp(output[0][0]) * output[0][1] return numpy.array(res) def _track_memory(self, commands): lst = [] for command in commands: if command[0] == 'c': for k in command[1]: for j in lst: if k[0] is None or j is k[0][1]: break else: lst.append(k[0][1]) elif command[0] == 'f': k = command[1] for j in lst: if k[0] is None or j is k[0][1]: break else: lst.append(k[0][1]) return numpy.log2(max([a.size for a in lst]))
defaultContractor = Contractor() _defaultContractor = None def getDefault(): global _defaultContractor if _defaultContractor is None: try: from acqdp.tensor_network import contractor _defaultContractor = contractor.defaultContractor except ImportError: _defaultContractor = Contractor() return _defaultContractor def setDefault(aContractor: Contractor): global _defaultContractor _defaultContractor = aContractor def contract(tn, **kwargs): return getDefault().contract(tn, **kwargs)