Program Listing for File TensorNetwork.hpp

Return to documentation for file (include/jet/TensorNetwork.hpp)

#pragma once

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "Abort.hpp"
#include "Utilities.hpp"

namespace Jet {

template <class Tensor> class TensorNetwork {
  public:
    using NodeID_t = size_t;

    struct Node {
        NodeID_t id;

        std::string name;

        std::vector<std::string> indices;

        std::vector<std::string> tags;

        bool contracted;

        Tensor tensor;
    };

    struct Edge {
        size_t dim;

        std::vector<NodeID_t> node_ids;

        bool operator==(const Edge &other) const noexcept
        {
            using set_t = std::unordered_set<size_t>;
            const set_t lhs_ids(node_ids.begin(), node_ids.end());
            const set_t rhs_ids(other.node_ids.begin(), other.node_ids.end());
            return dim == other.dim && lhs_ids == rhs_ids;
        }
    };

    using Nodes = std::vector<Node>;

    using IndexToEdgeMap = std::unordered_map<std::string, Edge>;

    using TagToNodeIdsMap = std::unordered_multimap<std::string, NodeID_t>;

    using Path = std::vector<std::pair<NodeID_t, NodeID_t>>;

    const Nodes &GetNodes() const noexcept { return nodes_; }

    const IndexToEdgeMap &GetIndexToEdgeMap() const noexcept
    {
        return index_to_edge_map_;
    }

    const TagToNodeIdsMap &GetTagToNodesMap() const noexcept
    {
        return tag_to_nodes_map_;
    }

    const Path &GetPath() noexcept { return path_; }

    size_t NumIndices() const noexcept { return index_to_edge_map_.size(); }

    size_t NumTensors() const noexcept { return nodes_.size(); }

    NodeID_t AddTensor(const Tensor &tensor,
                       const std::vector<std::string> &tags) noexcept
    {
        NodeID_t id = nodes_.size();
        nodes_.emplace_back(Node{
            id,                                   // id
            DeriveNodeName_(tensor.GetIndices()), // name
            tensor.GetIndices(),                  // indices
            tags,                                 // tags
            false,                                // contracted
            tensor                                // tensor
        });

        AddNodeToIndexMap_(nodes_[id]);
        AddNodeToTagMap_(nodes_[id]);

        return id;
    }

    void SliceIndices(const std::vector<std::string> &indices,
                      unsigned long long value)
    {
        std::unordered_map<size_t, std::vector<size_t>> node_to_index_map;
        std::vector<size_t> index_sizes(indices.size());

        // Map each node ID to the indexes in `indices` to be sliced.
        for (size_t i = 0; i < indices.size(); i++) {
            const auto it = index_to_edge_map_.find(indices[i]);
            JET_ABORT_IF(it == index_to_edge_map_.end(),
                         "Sliced index does not exist.");
            const auto &edge = it->second;
            index_sizes[i] = edge.dim;

            for (const auto node_id : edge.node_ids) {
                auto &indices_indexes = node_to_index_map[node_id];
                indices_indexes.emplace_back(i);
            }
        }

        const auto values = Jet::Utilities::UnravelIndex(value, index_sizes);

        // Slice the tensors while updating the node indices and names.
        for (const auto &[node_id, indices_indexes] : node_to_index_map) {
            auto &node = nodes_[node_id];
            auto &tensor = node.tensor;

            for (const int indices_index : indices_indexes) {
                const auto &sliced_index = indices[indices_index];
                const auto &sliced_value = values[indices_index];

                // Copy these members to avoid messing with the internal
                // representation of the tensor.
                auto tensor_indices = tensor.GetIndices();
                auto tensor_shape = tensor.GetShape();

                // Find the position of the sliced index in the tensor.
                const auto it = std::find(tensor_indices.begin(),
                                          tensor_indices.end(), sliced_index);
                const auto offset = std::distance(tensor_indices.begin(), it);

                tensor_shape.erase(tensor_shape.begin() + offset);
                tensor_indices.erase(tensor_indices.begin() + offset);

                tensor = Tensor::SliceIndex(tensor, sliced_index, sliced_value);

                if (tensor.GetIndices() != tensor_indices)
                    tensor = tensor.Transpose(tensor_indices);

                if (!tensor_indices.empty()) {
                    tensor = Tensor::Reshape(tensor, tensor_shape);
                }

                // Erase the sliced index from the tensor.
                tensor.InitIndicesAndShape(tensor_indices, tensor_shape);

                // Do not erase the sliced index from the node. Instead,
                // annotate it with the sliced value.
                for (auto &node_index : node.indices) {
                    if (node_index == sliced_index) {
                        node_index += '(';
                        node_index += std::to_string(sliced_value);
                        node_index += ')';
                    }
                }
            }

            node.name = DeriveNodeName_(node.indices);
        }

        // Erase the sliced indices from the index-to-edge map.
        for (const auto &index : indices) {
            index_to_edge_map_.erase(index);
        }
    }

    const Tensor &Contract(const Path &path = {})
    {
        JET_ABORT_IF(nodes_.empty(),
                     "An empty tensor network cannot be contracted.");

        if (!path.empty()) {
            for (const auto &[node_id_1, node_id_2] : path) {
                JET_ABORT_IF_NOT(node_id_1 < nodes_.size(),
                                 "Node ID 1 in contraction pair is invalid.");
                JET_ABORT_IF_NOT(node_id_2 < nodes_.size(),
                                 "Node ID 2 in contraction pair is invalid.");

                const size_t node_id_3 = ContractNodes_(node_id_1, node_id_2);

                const auto &node_1 = nodes_[node_id_1];
                const auto &node_2 = nodes_[node_id_2];
                const auto &node_3 = nodes_[node_id_3];
                UpdateIndexMapAfterContraction_(node_1, node_2, node_3);
            }
            path_ = path;
        }
        else {
            ContractEdges_();
            ContractScalars_();
        }

        return nodes_.back().tensor;
    }

  private:
    Nodes nodes_;

    IndexToEdgeMap index_to_edge_map_;

    TagToNodeIdsMap tag_to_nodes_map_;

    std::vector<std::pair<size_t, size_t>> path_;

    void AddNodeToIndexMap_(const Node &node) noexcept
    {
        const auto &indices = node.indices;
        const auto &shape = node.tensor.GetShape();

        for (size_t i = 0; i < indices.size(); i++) {
            if (shape[i] < 2) {
                continue;
            }

            const auto it = index_to_edge_map_.find(indices[i]);
            if (it != index_to_edge_map_.end()) {
                auto &edge = it->second;
                edge.node_ids.emplace_back(node.id);
            }
            else {
                const Edge edge{
                    shape[i],  // dim
                    {node.id}, // node_ids
                };
                index_to_edge_map_.emplace(indices[i], edge);
            }
        }
    }

    void AddNodeToTagMap_(const Node &node) noexcept
    {
        for (const auto &tag : node.tags) {
            tag_to_nodes_map_.emplace(tag, node.id);
        }
    }

    size_t ContractNodes_(size_t node_id_1, size_t node_id_2) noexcept
    {
        auto &node_1 = nodes_[node_id_1];
        auto &node_2 = nodes_[node_id_2];
        const auto tensor_3 =
            Tensor::ContractTensors(node_1.tensor, node_2.tensor);

        node_1.contracted = true;
        node_2.contracted = true;

        using namespace Jet::Utilities;
        const auto node_3_tags = VectorUnion(node_1.tags, node_2.tags);
        const auto node_3_indices =
            VectorDisjunctiveUnion(node_1.indices, node_2.indices);
        const auto node_3_name = DeriveNodeName_(node_3_indices);

        Node node_3{
            nodes_.size(),  // id
            node_3_name,    // name
            node_3_indices, // indices
            node_3_tags,    // tags
            false,          // contracted
            tensor_3,       // tensor
        };
        nodes_.emplace_back(node_3);

        return node_3.id;
    }

    void UpdateIndexMapAfterContraction_(const Node &node_1, const Node &node_2,
                                         const Node &node_3) noexcept
    {
        // Replace the IDs of the contracted nodes with the ID of the new node
        // in the index-to-edge map.
        for (auto &index : node_3.indices) {
            const auto it = index_to_edge_map_.find(index);
            if (it == index_to_edge_map_.end()) {
                continue;
            }

            Edge &edge = it->second;
            for (auto &node_id : edge.node_ids) {
                if (node_id == node_1.id || node_id == node_2.id) {
                    node_id = node_3.id;
                }
            }
        }

        // Delete the contracted indices in the index-to-edge map.
        const auto contracted_indices = Jet::Utilities::VectorIntersection(
            node_1.tensor.GetIndices(), node_2.tensor.GetIndices());

        for (const auto &index : contracted_indices) {
            index_to_edge_map_.erase(index);
        }
    }

    void ContractEdges_() noexcept
    {
        // Create a copy of the indices from the index-to-edge map since this
        // map will be modified in the next loop.
        std::vector<std::string> indices;

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
        for (const auto &[index, _] : index_to_edge_map_) {
#pragma GCC diagnostic pop
            indices.emplace_back(index);
        }

        for (const auto &index : indices) {
            const auto it = index_to_edge_map_.find(index);
            if (it == index_to_edge_map_.end()) {
                continue;
            }

            const auto &node_ids = it->second.node_ids;
            if (node_ids.size() != 2) {
                continue;
            }

            const size_t node_id_0 = node_ids[0];
            const size_t node_id_1 = node_ids[1];
            const size_t node_id_2 = ContractNodes_(node_id_0, node_id_1);
            UpdateIndexMapAfterContraction_(
                nodes_[node_id_0], nodes_[node_id_1], nodes_[node_id_2]);

            path_.emplace_back(node_id_0, node_id_1);
        }
    }

    void ContractScalars_() noexcept
    {
        std::vector<size_t> node_ids;
        for (size_t i = 0; i < nodes_.size(); i++) {
            const bool scalar = nodes_[i].tensor.GetIndices().empty();
            if (scalar) {
                node_ids.emplace_back(i);
            }
        }

        if (node_ids.size() >= 2) {
            // Use `node_id` to track the cumulative contracted tensor.
            size_t node_id = node_ids[0];
            for (size_t i = 1; i < node_ids.size(); i++) {
                const size_t node_id_1 = node_id;
                const size_t node_id_2 = node_ids[i];
                path_.emplace_back(node_id_1, node_id_2);
                node_id = ContractNodes_(node_id_1, node_id_2);
            }
        }
    }

    std::string
    DeriveNodeName_(const std::vector<std::string> &indices) const noexcept
    {
        return indices.size() ? Jet::Utilities::JoinStringVector(indices) : "_";
    }
};

template <class Tensor>
inline std::ostream &operator<<(std::ostream &out,
                                const TensorNetwork<Tensor> &tn)
{
    // Overloads the "<<" operator between a std::ostream and std::vector.
    using namespace Jet::Utilities;

    out << "Printing Nodes" << std::endl;
    for (const auto &node : tn.GetNodes()) {
        out << node.id << ' ' << node.name << ' ' << node.tags << std::endl;
    }
    out << "Printing Edges" << std::endl;
    for (const auto &[index, edge] : tn.GetIndexToEdgeMap()) {
        out << index << ' ' << edge.dim << ' ' << edge.node_ids << std::endl;
    }
    return out;
}

}; // namespace Jet