Source code for acqdp.tensor_network.tensor_view

import copy
import numpy
from .tensor_valued import TensorValued, DTYPE


[docs]class TensorView(TensorValued): """ :class:`TensorView` is a subclass of :class:`TensorValued` representing unary operations over another :class:`TensorValued` object that preserves the shape of the tensor. Common examples include element-wise conjugation and normalization with respect to the frobenius norm. :ivar tn: the underlying `TensorValued` object where the unary operation is performed onto. :ivar func: the unary function to be applied. :ivar homomorphism: indicator whether the unary function is homomorphic to the addition and multiplication of tensors. If so, the unary function can be broadcast to lower-level tensors, enabling potential simplification of the tensor network structure. :ivar dtype: dtype for the tensor entries. """
[docs] def __init__(self, tn, func=numpy.conj, homomorphism=False, dtype=DTYPE): """The constructor of a `TensorView` object.""" super().__init__(dtype) self.ref = tn self.func = func self.homomorphism = (func == numpy.conj) | homomorphism
def __str__(self) -> str: data_str = "Data: \n" + str(self.ref) func_str = "Func: " + str(self.func) return super().__str__() + "\n" + data_str + func_str @property def shape(self): """ The common property of all :class:`TensorValued` classes, yielding the shape of the object. :class:`TensorValued` objects must have compatible shapes in order to be connected together in a :class:`TensorNetwork`, or summed over in a :class:`TensorSum`. For :class:`TensorView` objects, it refers to the shape of the underlying :class:`TensorValued` object where the unary operation is performed onto. """ return self.ref.shape @property def is_ready(self): """The common property of all :class:`TensorValued` classes, indicating whether the current :class:`TensorValued` object is ready for contraction, i.e. whether it semantically represents a tensor with a definite value. In the process of a program, not all :class:`TensorValued` objects need to be ready; however once the `data` property of a certain object is queried, such object must be ready in order to successfully yield an :class:`numpy.ndarray` object. For :class:`TensorView` objects, it is to indicate whether the underlying :class:`TensorValued` object where the unary operation is performed onto, is ready for contraction. """ return self.ref.is_ready @property def is_valid(self): """The common property of all :class:`TensorValued` classes, indicating whether the :class:`TensorValued` object is valid or not. In every step of a program, all existing :class:`TensorValued` object must be valid, otherwise an exception should be thrown out; this property is for double checking that the current :class:`TensorValued` object is indeed valid. For :class:`TensorView` objects, it is to indicate whether the underlying :class:`TensorValued` object where the unary operation is performed onto, is valid or not. """ return self.ref.is_valid @property def raw_data(self): """The data of the underlying :class:`TensorValued` object where the unary operation is performed onto.""" return self.ref.contract() def fix_index(self, index, fix_to=0): """Fix the given index to the given value. The object after the method would have the same type as the original one, with rank 1 smaller than the original. :param index: The index to fix. :type index: :class:`int`. :param fix_to: the value to assign to the given index. :type fix_to: :class:`int`. :returns: :class:`TensorView` -- The :class:`TensorView` object after fixing the given index. """ self.ref = self.ref.fix_index(index, fix_to) def expand(self, recursive=False): """Commute the unary operation with the underlying tensor network, when the unary operation is a homomorphism for tensor network contractions.""" from acqdp.tensor_network import TensorNetwork if not self.homomorphism or not isinstance(self.ref, TensorNetwork): return self else: k = self.ref.copy() if recursive: k.expand(recursive=True) for node_name in k.nodes_by_name: k.update_node(node_name, TensorView(k.network.nodes[(0, node_name)]['tensor'], self.func, self.homomorphism)) return k def cast(self, dtype): self.dtype = dtype self.ref = self.ref.cast(dtype) return self def contract(self, **kwargs): return self.func(self.ref.contract(**kwargs)) def copy(self): return TensorView(self.ref.copy(), self.func) def __deepcopy__(self, memo): return TensorView(copy.deepcopy(self.ref), self.func)