Program Listing for File PathInfo.hpp¶
↰ Return to documentation for file (include/jet/PathInfo.hpp
)
#pragma once
#include <algorithm>
#include <iostream>
#include <limits>
#include <string>
#include <unordered_set>
#include <vector>
#include "Abort.hpp"
#include "TensorNetwork.hpp"
namespace Jet {
struct PathStepInfo {
size_t id;
size_t parent;
std::pair<size_t, size_t> children;
std::string name;
std::vector<std::string> node_indices;
std::vector<std::string> tensor_indices;
std::vector<std::string> tags;
std::vector<std::string> contracted_indices;
static constexpr size_t MISSING_ID = std::numeric_limits<size_t>::max();
};
class PathInfo {
public:
using NodeID_t = size_t;
using Path = std::vector<std::pair<NodeID_t, NodeID_t>>;
using IndexToSizeMap = std::unordered_map<std::string, size_t>;
using Steps = std::vector<PathStepInfo>;
PathInfo() : num_leaves_(0) {}
template <typename Tensor>
PathInfo(const TensorNetwork<Tensor> &tn, const Path &path) : path_(path)
{
const auto &nodes = tn.GetNodes();
num_leaves_ = nodes.size();
steps_.reserve(num_leaves_);
for (const auto &node : nodes) {
constexpr size_t missing_id = PathStepInfo::MISSING_ID;
PathStepInfo step{
node.id, // id
missing_id, // parent
{missing_id, missing_id}, // children
node.name, // name
node.indices, // node_indices
node.tensor.GetIndices(), // tensor_indices
node.tags, // tags
{}, // contracted_indices
};
steps_.emplace_back(step);
}
for (const auto &[index, edge] : tn.GetIndexToEdgeMap()) {
index_to_size_map_.emplace(index, edge.dim);
}
for (const auto &[node_id_1, node_id_2] : path) {
JET_ABORT_IF_NOT(node_id_1 < steps_.size(),
"Node ID 1 in contraction path pair is invalid.");
JET_ABORT_IF_NOT(node_id_2 < steps_.size(),
"Node ID 2 in contraction path pair is invalid.");
ContractSteps_(node_id_1, node_id_2);
}
}
const IndexToSizeMap &GetIndexSizes() const noexcept
{
return index_to_size_map_;
}
size_t GetNumLeaves() const noexcept { return num_leaves_; }
const Path &GetPath() const noexcept { return path_; }
const Steps &GetSteps() const noexcept { return steps_; }
double GetPathStepFlops(size_t id) const
{
JET_ABORT_IF_NOT(id < steps_.size(), "Step ID is invalid.");
if (id < num_leaves_) {
// Tensor network leaves are constructed for free.
return 0;
}
const auto &step = steps_[id];
// Calculate the number of FLOPs needed for each dot product.
double muls = 1;
for (const auto &index : step.contracted_indices) {
const auto it = index_to_size_map_.find(index);
muls *= it == index_to_size_map_.end() ? 1 : it->second;
}
double adds = muls;
// Find the number of elements in the tensor.
double size = 1;
for (const auto &index : step.tensor_indices) {
const auto it = index_to_size_map_.find(index);
size *= it == index_to_size_map_.end() ? 1 : it->second;
}
return size * (muls + adds);
}
double GetTotalFlops() const noexcept
{
double flops = 0;
for (size_t i = num_leaves_; i < steps_.size(); i++) {
flops += GetPathStepFlops(i);
}
return flops;
}
double GetPathStepMemory(size_t id) const
{
JET_ABORT_IF_NOT(id < steps_.size(), "Step ID is invalid.");
const auto &step = steps_[id];
const auto &indices = step.tensor_indices;
double memory = 1;
for (const auto &index : indices) {
const auto it = index_to_size_map_.find(index);
memory *= it == index_to_size_map_.end() ? 1 : it->second;
}
return memory;
}
double GetTotalMemory() const noexcept
{
double memory = 0;
for (size_t id = 0; id < steps_.size(); id++) {
memory += GetPathStepMemory(id);
}
return memory;
}
private:
Path path_;
Steps steps_;
size_t num_leaves_;
IndexToSizeMap index_to_size_map_;
void ContractSteps_(size_t step_id_1, size_t step_id_2) noexcept
{
using namespace Jet::Utilities;
auto &step_1 = steps_[step_id_1];
auto &step_2 = steps_[step_id_2];
const size_t step_3_id = steps_.size();
const auto step_3_contracted_indices =
VectorIntersection(step_1.tensor_indices, step_2.tensor_indices);
const auto step_3_node_indices = VectorSubtraction(
VectorConcatenation(step_1.node_indices, step_2.node_indices),
step_3_contracted_indices);
const auto step_3_name = step_3_node_indices.size()
? JoinStringVector(step_3_node_indices)
: "_";
const auto step_3_tensor_indices = VectorDisjunctiveUnion(
step_1.tensor_indices, step_2.tensor_indices);
const auto step_3_tags = VectorUnion(step_1.tags, step_2.tags);
// Assign the parent IDs before references to `steps_` elements are
// invalidated by `std::vector::emplace_back()`.
step_1.parent = step_3_id;
step_2.parent = step_3_id;
PathStepInfo step_3{
step_3_id, // id
PathStepInfo::MISSING_ID, // parent
{step_id_1, step_id_2}, // children
step_3_name, // name
step_3_node_indices, // node_indices
step_3_tensor_indices, // tensor_indices
step_3_tags, // tags
step_3_contracted_indices, // contracted_indices
};
steps_.emplace_back(step_3);
}
};
}; // namespace Jet
api/program_listing_file_include_jet_PathInfo.hpp
Download Python script
Download Notebook
View on GitHub