Program Listing for File TensorNetwork.hpp¶
↰ Return to documentation for file (include/jet/TensorNetwork.hpp
)
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "Abort.hpp"
#include "Utilities.hpp"
namespace Jet {
template <class Tensor> class TensorNetwork {
public:
using NodeID_t = size_t;
struct Node {
NodeID_t id;
std::string name;
std::vector<std::string> indices;
std::vector<std::string> tags;
bool contracted;
Tensor tensor;
};
struct Edge {
size_t dim;
std::vector<NodeID_t> node_ids;
bool operator==(const Edge &other) const noexcept
{
using set_t = std::unordered_set<size_t>;
const set_t lhs_ids(node_ids.begin(), node_ids.end());
const set_t rhs_ids(other.node_ids.begin(), other.node_ids.end());
return dim == other.dim && lhs_ids == rhs_ids;
}
};
using Nodes = std::vector<Node>;
using IndexToEdgeMap = std::unordered_map<std::string, Edge>;
using TagToNodeIdsMap = std::unordered_multimap<std::string, NodeID_t>;
using Path = std::vector<std::pair<NodeID_t, NodeID_t>>;
const Nodes &GetNodes() const noexcept { return nodes_; }
const IndexToEdgeMap &GetIndexToEdgeMap() const noexcept
{
return index_to_edge_map_;
}
const TagToNodeIdsMap &GetTagToNodesMap() const noexcept
{
return tag_to_nodes_map_;
}
const Path &GetPath() noexcept { return path_; }
size_t NumIndices() const noexcept { return index_to_edge_map_.size(); }
size_t NumTensors() const noexcept { return nodes_.size(); }
NodeID_t AddTensor(const Tensor &tensor,
const std::vector<std::string> &tags) noexcept
{
NodeID_t id = nodes_.size();
nodes_.emplace_back(Node{
id, // id
DeriveNodeName_(tensor.GetIndices()), // name
tensor.GetIndices(), // indices
tags, // tags
false, // contracted
tensor // tensor
});
AddNodeToIndexMap_(nodes_[id]);
AddNodeToTagMap_(nodes_[id]);
return id;
}
void SliceIndices(const std::vector<std::string> &indices,
unsigned long long value)
{
std::unordered_map<size_t, std::vector<size_t>> node_to_index_map;
std::vector<size_t> index_sizes(indices.size());
// Map each node ID to the indexes in `indices` to be sliced.
for (size_t i = 0; i < indices.size(); i++) {
const auto it = index_to_edge_map_.find(indices[i]);
JET_ABORT_IF(it == index_to_edge_map_.end(),
"Sliced index does not exist.");
const auto &edge = it->second;
index_sizes[i] = edge.dim;
for (const auto node_id : edge.node_ids) {
auto &indices_indexes = node_to_index_map[node_id];
indices_indexes.emplace_back(i);
}
}
const auto values = Jet::Utilities::UnravelIndex(value, index_sizes);
// Slice the tensors while updating the node indices and names.
for (const auto &[node_id, indices_indexes] : node_to_index_map) {
auto &node = nodes_[node_id];
auto &tensor = node.tensor;
for (const int indices_index : indices_indexes) {
const auto &sliced_index = indices[indices_index];
const auto &sliced_value = values[indices_index];
// Copy these members to avoid messing with the internal
// representation of the tensor.
auto tensor_indices = tensor.GetIndices();
auto tensor_shape = tensor.GetShape();
// Find the position of the sliced index in the tensor.
const auto it = std::find(tensor_indices.begin(),
tensor_indices.end(), sliced_index);
const auto offset = std::distance(tensor_indices.begin(), it);
tensor_shape.erase(tensor_shape.begin() + offset);
tensor_indices.erase(tensor_indices.begin() + offset);
tensor = Tensor::SliceIndex(tensor, sliced_index, sliced_value);
if (tensor.GetIndices() != tensor_indices)
tensor = tensor.Transpose(tensor_indices);
if (!tensor_indices.empty()) {
tensor = Tensor::Reshape(tensor, tensor_shape);
}
// Erase the sliced index from the tensor.
tensor.InitIndicesAndShape(tensor_indices, tensor_shape);
// Do not erase the sliced index from the node. Instead,
// annotate it with the sliced value.
for (auto &node_index : node.indices) {
if (node_index == sliced_index) {
node_index += '(';
node_index += std::to_string(sliced_value);
node_index += ')';
}
}
}
node.name = DeriveNodeName_(node.indices);
}
// Erase the sliced indices from the index-to-edge map.
for (const auto &index : indices) {
index_to_edge_map_.erase(index);
}
}
const Tensor &Contract(const Path &path = {})
{
JET_ABORT_IF(nodes_.empty(),
"An empty tensor network cannot be contracted.");
if (!path.empty()) {
for (const auto &[node_id_1, node_id_2] : path) {
JET_ABORT_IF_NOT(node_id_1 < nodes_.size(),
"Node ID 1 in contraction pair is invalid.");
JET_ABORT_IF_NOT(node_id_2 < nodes_.size(),
"Node ID 2 in contraction pair is invalid.");
const size_t node_id_3 = ContractNodes_(node_id_1, node_id_2);
const auto &node_1 = nodes_[node_id_1];
const auto &node_2 = nodes_[node_id_2];
const auto &node_3 = nodes_[node_id_3];
UpdateIndexMapAfterContraction_(node_1, node_2, node_3);
}
path_ = path;
}
else {
ContractEdges_();
ContractScalars_();
}
return nodes_.back().tensor;
}
private:
Nodes nodes_;
IndexToEdgeMap index_to_edge_map_;
TagToNodeIdsMap tag_to_nodes_map_;
std::vector<std::pair<size_t, size_t>> path_;
void AddNodeToIndexMap_(const Node &node) noexcept
{
const auto &indices = node.indices;
const auto &shape = node.tensor.GetShape();
for (size_t i = 0; i < indices.size(); i++) {
if (shape[i] < 2) {
continue;
}
const auto it = index_to_edge_map_.find(indices[i]);
if (it != index_to_edge_map_.end()) {
auto &edge = it->second;
edge.node_ids.emplace_back(node.id);
}
else {
const Edge edge{
shape[i], // dim
{node.id}, // node_ids
};
index_to_edge_map_.emplace(indices[i], edge);
}
}
}
void AddNodeToTagMap_(const Node &node) noexcept
{
for (const auto &tag : node.tags) {
tag_to_nodes_map_.emplace(tag, node.id);
}
}
size_t ContractNodes_(size_t node_id_1, size_t node_id_2) noexcept
{
auto &node_1 = nodes_[node_id_1];
auto &node_2 = nodes_[node_id_2];
const auto tensor_3 =
Tensor::ContractTensors(node_1.tensor, node_2.tensor);
node_1.contracted = true;
node_2.contracted = true;
using namespace Jet::Utilities;
const auto node_3_tags = VectorUnion(node_1.tags, node_2.tags);
const auto node_3_indices =
VectorDisjunctiveUnion(node_1.indices, node_2.indices);
const auto node_3_name = DeriveNodeName_(node_3_indices);
Node node_3{
nodes_.size(), // id
node_3_name, // name
node_3_indices, // indices
node_3_tags, // tags
false, // contracted
tensor_3, // tensor
};
nodes_.emplace_back(node_3);
return node_3.id;
}
void UpdateIndexMapAfterContraction_(const Node &node_1, const Node &node_2,
const Node &node_3) noexcept
{
// Replace the IDs of the contracted nodes with the ID of the new node
// in the index-to-edge map.
for (auto &index : node_3.indices) {
const auto it = index_to_edge_map_.find(index);
if (it == index_to_edge_map_.end()) {
continue;
}
Edge &edge = it->second;
for (auto &node_id : edge.node_ids) {
if (node_id == node_1.id || node_id == node_2.id) {
node_id = node_3.id;
}
}
}
// Delete the contracted indices in the index-to-edge map.
const auto contracted_indices = Jet::Utilities::VectorIntersection(
node_1.tensor.GetIndices(), node_2.tensor.GetIndices());
for (const auto &index : contracted_indices) {
index_to_edge_map_.erase(index);
}
}
void ContractEdges_() noexcept
{
// Create a copy of the indices from the index-to-edge map since this
// map will be modified in the next loop.
std::vector<std::string> indices;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
for (const auto &[index, _] : index_to_edge_map_) {
#pragma GCC diagnostic pop
indices.emplace_back(index);
}
for (const auto &index : indices) {
const auto it = index_to_edge_map_.find(index);
if (it == index_to_edge_map_.end()) {
continue;
}
const auto &node_ids = it->second.node_ids;
if (node_ids.size() != 2) {
continue;
}
const size_t node_id_0 = node_ids[0];
const size_t node_id_1 = node_ids[1];
const size_t node_id_2 = ContractNodes_(node_id_0, node_id_1);
UpdateIndexMapAfterContraction_(
nodes_[node_id_0], nodes_[node_id_1], nodes_[node_id_2]);
path_.emplace_back(node_id_0, node_id_1);
}
}
void ContractScalars_() noexcept
{
std::vector<size_t> node_ids;
for (size_t i = 0; i < nodes_.size(); i++) {
const bool scalar = nodes_[i].tensor.GetIndices().empty();
if (scalar) {
node_ids.emplace_back(i);
}
}
if (node_ids.size() >= 2) {
// Use `node_id` to track the cumulative contracted tensor.
size_t node_id = node_ids[0];
for (size_t i = 1; i < node_ids.size(); i++) {
const size_t node_id_1 = node_id;
const size_t node_id_2 = node_ids[i];
path_.emplace_back(node_id_1, node_id_2);
node_id = ContractNodes_(node_id_1, node_id_2);
}
}
}
std::string
DeriveNodeName_(const std::vector<std::string> &indices) const noexcept
{
return indices.size() ? Jet::Utilities::JoinStringVector(indices) : "_";
}
};
template <class Tensor>
inline std::ostream &operator<<(std::ostream &out,
const TensorNetwork<Tensor> &tn)
{
// Overloads the "<<" operator between a std::ostream and std::vector.
using namespace Jet::Utilities;
out << "Printing Nodes" << std::endl;
for (const auto &node : tn.GetNodes()) {
out << node.id << ' ' << node.name << ' ' << node.tags << std::endl;
}
out << "Printing Edges" << std::endl;
for (const auto &[index, edge] : tn.GetIndexToEdgeMap()) {
out << index << ' ' << edge.dim << ' ' << edge.node_ids << std::endl;
}
return out;
}
}; // namespace Jet
api/program_listing_file_include_jet_TensorNetwork.hpp
Download Python script
Download Notebook
View on GitHub
Downloads