Program Listing for File TaskBasedContractor.hpp

Return to documentation for file (include/jet/TaskBasedContractor.hpp)

#pragma once

#include <complex>
#include <future>
#include <iostream>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include <taskflow/taskflow.hpp>

#include "PathInfo.hpp"
#include "TensorNetwork.hpp"

namespace Jet {

template <class TensorType> class TaskBasedContractor {
  public:
    using NameToTaskMap = std::unordered_map<std::string, tf::Task>;

    using NameToTensorMap =
        std::unordered_map<std::string, std::unique_ptr<TensorType>>;

    using NameToParentsMap =
        std::unordered_map<std::string, std::unordered_set<std::string>>;

    using TaskFlow = tf::Taskflow;

    TaskBasedContractor(
        size_t num_threads = std::thread::hardware_concurrency())
        : executor_{num_threads}, memory_(0), flops_(0), reduced_(false)
    {
    }

    const NameToTaskMap &GetNameToTaskMap() const noexcept
    {
        return name_to_task_map_;
    }

    const NameToTensorMap &GetNameToTensorMap() const noexcept
    {
        return name_to_tensor_map_;
    }

    const NameToParentsMap &GetNameToParentsMap() const noexcept
    {
        return name_to_parents_map_;
    }

    const std::vector<TensorType> &GetResults() const noexcept
    {
        return results_;
    }

    const TensorType &GetReductionResult() const noexcept
    {
        return reduction_result_;
    }

    const TaskFlow &GetTaskflow() const noexcept { return taskflow_; }

    void AddTaskflow(tf::Taskflow &taskflow) noexcept
    {
        taskflow_.composed_of(taskflow);
    }

    double GetFlops() const noexcept { return flops_; }

    double GetMemory() const noexcept { return memory_; }

    size_t AddContractionTasks(const TensorNetwork<TensorType> &tn,
                               const PathInfo &path_info) noexcept
    {
        const auto &path = path_info.GetPath();
        const auto &steps = path_info.GetSteps();

        if (path.empty()) {
            return 0;
        }

        const auto &nodes = tn.GetNodes();
        const size_t num_leaves = nodes.size();

        const size_t result_id = results_.size();
        results_.resize(results_.size() + 1);

        size_t shared_tasks = 0;

        for (size_t i = 0; i < path.size(); i++) {
            const auto [step_1_id, step_2_id] = path[i];

            const auto &step_1 = steps[step_1_id];
            const auto &step_2 = steps[step_2_id];
            const auto &step_3 = steps[num_leaves + i];

            const auto name_1 = DeriveTaskName_(step_1);
            const auto name_2 = DeriveTaskName_(step_2);
            auto name_3 = DeriveTaskName_(step_3);

            // Append the result ID to the final contraction task.
            const bool last_step = i == path.size() - 1;
            if (last_step) {
                name_3 += ":results[";
                name_3 += std::to_string(result_id);
                name_3 += ']';
            }

            // The name-to-parents map is used in AddDeletionTasks().
            name_to_parents_map_[name_1].emplace(name_3);
            name_to_parents_map_[name_2].emplace(name_3);

            // Ensure all the tensors have a place in the name-to-tensor map.
            if (step_1_id < num_leaves) {
                const auto &tensor = nodes[step_1_id].tensor;
                name_to_tensor_map_.try_emplace(
                    name_1, std::make_unique<TensorType>(tensor));
            }

            if (step_2_id < num_leaves) {
                const auto &tensor = nodes[step_2_id].tensor;
                name_to_tensor_map_.try_emplace(
                    name_2, std::make_unique<TensorType>(tensor));
            }

            name_to_tensor_map_.try_emplace(name_3, nullptr);

            // Do nothing if this contraction is already tracked.
            if (name_to_task_map_.count(name_3)) {
                shared_tasks++;
                continue;
            }

            flops_ += path_info.GetPathStepFlops(step_3.id);
            memory_ += path_info.GetPathStepMemory(step_3.id);

            AddContractionTask_(name_1, name_2, name_3);

            // Make sure the child tensors exist before the contraction happens.
            if (step_1_id >= num_leaves) {
                auto &task_1 = name_to_task_map_.at(name_1);
                auto &task_3 = name_to_task_map_.at(name_3);
                task_1.precede(task_3);
            }

            if (step_2_id >= num_leaves) {
                auto &task_2 = name_to_task_map_.at(name_2);
                auto &task_3 = name_to_task_map_.at(name_3);
                task_2.precede(task_3);
            }

            // Store the final tensor in the `results_` map.
            if (last_step) {
                AddStorageTask_(name_3, result_id);
            }
        }
        return shared_tasks;
    }

    size_t AddReductionTask() noexcept
    {
        // Scheduling multiple reduction tasks introduces a race condition.
        if (reduced_) {
            return 0;
        }
        reduced_ = true;

        auto reduce_func = [](const TensorType &a, const TensorType &b) {
            return a.AddTensor(b);
        };

        auto reduce_task = taskflow_
                               .reduce(results_.begin(), results_.end(),
                                       reduction_result_, reduce_func)
                               .name("reduce");

        for (auto &result_task : result_tasks_) {
            result_task.precede(reduce_task);
        }

        return 1;
    }

    size_t AddDeletionTasks() noexcept
    {
        size_t delete_tasks = 0;
        for (const auto &[name, parents] : name_to_parents_map_) {
            if (parents.empty()) {
                continue;
            }

            const auto runner = [this, name = name]() {
                name_to_tensor_map_[name] = nullptr;
            };

            const std::string delete_task_name = name + ":delete";
            auto delete_task = taskflow_.emplace(runner).name(delete_task_name);
            ++delete_tasks;

            for (const auto &parent : parents) {
                const auto it = name_to_task_map_.find(parent);
                if (it != name_to_task_map_.end()) {
                    auto &parent_task = it->second;
                    parent_task.precede(delete_task);
                }
            }
        }
        return delete_tasks;
    }

    std::future<void> Contract() { return executor_.run(taskflow_); }

  private:
    tf::Executor executor_;

    TaskFlow taskflow_;

    NameToTaskMap name_to_task_map_;

    NameToTensorMap name_to_tensor_map_;

    NameToParentsMap name_to_parents_map_;

    std::vector<tf::Task> result_tasks_;

    std::vector<TensorType> results_;

    TensorType reduction_result_;

    double memory_;

    double flops_;

    bool reduced_;

    std::string DeriveTaskName_(const PathStepInfo &step) const noexcept
    {
        return std::to_string(step.id) + ":" + step.name;
    }

    void AddContractionTask_(const std::string &name_1,
                             const std::string &name_2,
                             const std::string &name_3) noexcept
    {
        const auto runner = [this, name_1, name_2, name_3]() {
            name_to_tensor_map_[name_3] = std::make_unique<TensorType>(
                TensorType::ContractTensors(*name_to_tensor_map_.at(name_1),
                                            *name_to_tensor_map_.at(name_2)));
        };

        const auto task_3 = taskflow_.emplace(runner).name(name_3);
        name_to_task_map_.emplace(name_3, task_3);
    }

    inline void AddStorageTask_(const std::string &name,
                                size_t result_id) noexcept
    {
        const auto runner = [this, result_id, name]() {
            auto &tensor = *name_to_tensor_map_.at(name);
            results_[result_id] = tensor;
        };

        std::string storage_task_name = name;
        storage_task_name += ":storage[";
        storage_task_name += std::to_string(result_id);
        storage_task_name += ']';

        auto storage_task = taskflow_.emplace(runner).name(storage_task_name);

        auto &preceeding_task = name_to_task_map_.at(name);
        preceeding_task.precede(storage_task);

        result_tasks_.emplace_back(storage_task);
    }
};

template <class TensorType>
inline std::ostream &operator<<(std::ostream &out,
                                const TaskBasedContractor<TensorType> &tbc)
{
    const auto &taskflow = tbc.GetTaskflow();
    taskflow.dump(out);
    return out;
}

}; // namespace Jet