Program Listing for File PathInfo.hpp

Return to documentation for file (include/jet/PathInfo.hpp)

#pragma once

#include <algorithm>
#include <iostream>
#include <limits>
#include <string>
#include <unordered_set>
#include <vector>

#include "Abort.hpp"
#include "TensorNetwork.hpp"

namespace Jet {

struct PathStepInfo {
    size_t id;

    size_t parent;

    std::pair<size_t, size_t> children;

    std::string name;

    std::vector<std::string> node_indices;

    std::vector<std::string> tensor_indices;

    std::vector<std::string> tags;

    std::vector<std::string> contracted_indices;

    static constexpr size_t MISSING_ID = std::numeric_limits<size_t>::max();
};

class PathInfo {
  public:
    using NodeID_t = size_t;

    using Path = std::vector<std::pair<NodeID_t, NodeID_t>>;

    using IndexToSizeMap = std::unordered_map<std::string, size_t>;

    using Steps = std::vector<PathStepInfo>;

    PathInfo() : num_leaves_(0) {}

    template <typename Tensor>
    PathInfo(const TensorNetwork<Tensor> &tn, const Path &path) : path_(path)
    {
        const auto &nodes = tn.GetNodes();
        num_leaves_ = nodes.size();

        steps_.reserve(num_leaves_);
        for (const auto &node : nodes) {
            constexpr size_t missing_id = PathStepInfo::MISSING_ID;
            PathStepInfo step{
                node.id,                  // id
                missing_id,               // parent
                {missing_id, missing_id}, // children
                node.name,                // name
                node.indices,             // node_indices
                node.tensor.GetIndices(), // tensor_indices
                node.tags,                // tags
                {},                       // contracted_indices
            };
            steps_.emplace_back(step);
        }

        for (const auto &[index, edge] : tn.GetIndexToEdgeMap()) {
            index_to_size_map_.emplace(index, edge.dim);
        }

        for (const auto &[node_id_1, node_id_2] : path) {
            JET_ABORT_IF_NOT(node_id_1 < steps_.size(),
                             "Node ID 1 in contraction path pair is invalid.");
            JET_ABORT_IF_NOT(node_id_2 < steps_.size(),
                             "Node ID 2 in contraction path pair is invalid.");
            ContractSteps_(node_id_1, node_id_2);
        }
    }

    const IndexToSizeMap &GetIndexSizes() const noexcept
    {
        return index_to_size_map_;
    }

    size_t GetNumLeaves() const noexcept { return num_leaves_; }

    const Path &GetPath() const noexcept { return path_; }

    const Steps &GetSteps() const noexcept { return steps_; }

    double GetPathStepFlops(size_t id) const
    {
        JET_ABORT_IF_NOT(id < steps_.size(), "Step ID is invalid.");

        if (id < num_leaves_) {
            // Tensor network leaves are constructed for free.
            return 0;
        }

        const auto &step = steps_[id];

        // Calculate the number of FLOPs needed for each dot product.
        double muls = 1;
        for (const auto &index : step.contracted_indices) {
            const auto it = index_to_size_map_.find(index);
            muls *= it == index_to_size_map_.end() ? 1 : it->second;
        }
        double adds = muls;

        // Find the number of elements in the tensor.
        double size = 1;
        for (const auto &index : step.tensor_indices) {
            const auto it = index_to_size_map_.find(index);
            size *= it == index_to_size_map_.end() ? 1 : it->second;
        }
        return size * (muls + adds);
    }

    double GetTotalFlops() const noexcept
    {
        double flops = 0;
        for (size_t i = num_leaves_; i < steps_.size(); i++) {
            flops += GetPathStepFlops(i);
        }
        return flops;
    }

    double GetPathStepMemory(size_t id) const
    {
        JET_ABORT_IF_NOT(id < steps_.size(), "Step ID is invalid.");

        const auto &step = steps_[id];
        const auto &indices = step.tensor_indices;

        double memory = 1;
        for (const auto &index : indices) {
            const auto it = index_to_size_map_.find(index);
            memory *= it == index_to_size_map_.end() ? 1 : it->second;
        }
        return memory;
    }

    double GetTotalMemory() const noexcept
    {
        double memory = 0;
        for (size_t id = 0; id < steps_.size(); id++) {
            memory += GetPathStepMemory(id);
        }
        return memory;
    }

  private:
    Path path_;

    Steps steps_;

    size_t num_leaves_;

    IndexToSizeMap index_to_size_map_;

    void ContractSteps_(size_t step_id_1, size_t step_id_2) noexcept
    {
        using namespace Jet::Utilities;
        auto &step_1 = steps_[step_id_1];
        auto &step_2 = steps_[step_id_2];

        const size_t step_3_id = steps_.size();
        const auto step_3_contracted_indices =
            VectorIntersection(step_1.tensor_indices, step_2.tensor_indices);
        const auto step_3_node_indices = VectorSubtraction(
            VectorConcatenation(step_1.node_indices, step_2.node_indices),
            step_3_contracted_indices);
        const auto step_3_name = step_3_node_indices.size()
                                     ? JoinStringVector(step_3_node_indices)
                                     : "_";
        const auto step_3_tensor_indices = VectorDisjunctiveUnion(
            step_1.tensor_indices, step_2.tensor_indices);
        const auto step_3_tags = VectorUnion(step_1.tags, step_2.tags);

        // Assign the parent IDs before references to `steps_` elements are
        // invalidated by `std::vector::emplace_back()`.
        step_1.parent = step_3_id;
        step_2.parent = step_3_id;

        PathStepInfo step_3{
            step_3_id,                 // id
            PathStepInfo::MISSING_ID,  // parent
            {step_id_1, step_id_2},    // children
            step_3_name,               // name
            step_3_node_indices,       // node_indices
            step_3_tensor_indices,     // tensor_indices
            step_3_tags,               // tags
            step_3_contracted_indices, // contracted_indices
        };
        steps_.emplace_back(step_3);
    }
};

}; // namespace Jet