Source code for acqdp.tensor_network.tensor_sum

import numpy
import copy
from collections import OrderedDict
from .tensor_valued import TensorValued, DTYPE


[docs]class TensorSum(TensorValued): """A :class:`TensorSum` object represents the summation of multiple tensors. :ivar terms_by_name: a dictionary with key-value pairs, where the key is the name of a summand and the value is the corresponding summand :class:`TensorValued` object. """
[docs] def __init__(self, terms=None, dtype: type = DTYPE) -> None: """The constructor of a `TensorSum` object.""" super().__init__(dtype) if terms is None: self.terms_by_name = OrderedDict() else: for term_name in self.terms_by_name: self.terms_by_name[term_name] = terms[term_name]
def __str__(self): term_str = "\nTerms:" for term_name in self.terms_by_name: term_str += "\n" + str(term_name) + "\n" + str(self.terms_by_name[term_name]) return super().__str__() + term_str def _update_shape(self, curr, tmp): if tmp is None: return curr if curr is None: return list(tmp) if len(curr) != len(tmp): raise ValueError('Component shapes do not match') for i in range(len(curr)): if curr[i] is None: curr[i] = tmp[i] elif (tmp[i] is not None) and (tmp[i] != curr[i]): raise ValueError('Component shapes do not match') return curr def _invalidate_shape_cache(self): if hasattr(self, '_cached_shape'): del self._cached_shape @property def shape(self): """The common property of all :class:`TensorValued` classes, yielding the shape of the object. :class:`TensorValued` objects must have compatible shapes in order to be connected together in a :class:`TensorNetwork`, or summed over in a :class:`TensorSum`. """ if not hasattr(self, '_cached_shape'): curr = None for tsr_name in self.terms_by_name: tsr = self.terms_by_name[tsr_name] curr = self._update_shape(curr, tsr.shape) self._cached_shape = curr return tuple(self._cached_shape) if self._cached_shape is not None else None @property def is_valid(self): """The common property of all :class:`TensorValued` classes, indicating whether the :class:`TensorValued` object is valid or not. In every step of a program, all existing :class:`TensorValued` object must be valid, otherwise an exception should be thrown out; this property is for double checking that the current :class:`TensorValued` object is indeed valid. """ try: self.shape except ValueError: return False else: return True @property def is_ready(self): """The common property of all :class:`TensorValued` classes, indicating whether the current :class:`TensorValued` object is ready for contraction, i.e. whether it semantically represents a tensor with a definite value. In the process of a program, not all :class:`TensorValued` objects need to be ready; however once the `data` property of a certain object is queried, such object must be ready in order to successfully yield an :class:`numpy.ndarray` object. """ for t in self.terms_by_name.values(): if not t.is_ready: return False return self.is_valid
[docs] def add_term(self, term=None, tensor=None): """Add a term to the summation. :param term: Name of the term to be added. If not given, an auto-assigned one will be given as the output. :type term: hashable :param tensor: Value of the term to be added. :type tensor: :class:`TensorValued` or None :returns: The name of the newly added term. """ from .tensor import Tensor if not isinstance(tensor, TensorValued): tensor = Tensor(tensor) if term is None: term = tensor.identifier if tensor.dtype == complex: self.dtype = complex if term in self.terms_by_name: raise KeyError("term {} to be added into the tensor network already in the tensor network!".format(term)) if tensor.shape is not None: self.shape # Make sure the shape cache is initialized self._cached_shape = self._update_shape(self._cached_shape, tensor.shape) self.terms_by_name[term] = tensor return term
def __iadd__(self, t): self.add_term(tensor=t) return self
[docs] def update_term(self, term, tensor=None): """Update the value of a term in the summation. :param term: Name of the term to be updated. :type term: hashable :param tensor: New value of the term :type tensor: :class:`TensorValued` :returns: Name of the term to be updated. """ from .tensor import Tensor if (type(tensor) == numpy.ndarray) or (tensor is None): tensor = Tensor(tensor) if term not in self.terms_by_name: raise KeyError("term {} not in the TensorSum object".format(term)) self.terms_by_name[term] = tensor self._invalidate_shape_cache() return term
[docs] def remove_term(self, term): """Remove a term from the summation. :param term: Name of the term to be removed. :type term: hashable :returns: :class:`TensorValued` Value of the removed term """ pop = self.terms_by_name.pop(term) self._invalidate_shape_cache() return pop
def fix_index(self, index, fix_to=0): """Fix the given index to the given value. The result :class:`TensorValued` object would have the same type as the original one, with rank 1 smaller than the original. :param index: The index to fix. :type index: :class:`int`. :param fix_to: The value to assign to the given index. :type fix_to: :class:`int`. :returns: :class:`TensorValued` -- The :class:`TensorValued` object after fixing the given index. :raises: NotImplementedError """ ts = self.copy() for term in ts.terms_by_name: ts.terms_by_name[term] = ts.terms_by_name[term].fix_index(index, fix_to) ts._invalidate_shape_cache() return ts def cast(self, dtype): self.dtype = dtype for term in self.terms_by_name: self.update_term(term, self.terms_by_name[term].cast(dtype)) return self def contract(self, **kwargs): """Evaluate the object by summing over all the terms. :returns: :class:`numpy.ndarray` """ res = [ self.terms_by_name[term].contract(**kwargs) for term in self.terms_by_name ] return sum(res) def copy(self): ts = TensorSum(dtype=self.dtype) for t in self.terms_by_name: ts.add_term(t, self.terms_by_name[t]) return ts def __deepcopy__(self, memo): ts = TensorSum(dtype=self.dtype) for t in self.terms_by_name: ts.add_term(t, copy.deepcopy(self.terms_by_name[t])) return ts