Program Listing for File TensorNetworkIO.hpp

Return to documentation for file (include/jet/TensorNetworkIO.hpp)

#pragma once

#include <complex>
#include <optional>
#include <stdexcept>
#include <string>
#include <utility>

#include "external/nlohmann/json.hpp"

#ifdef CUTENSOR
#include "CudaTensor.hpp"
#endif
#include "PathInfo.hpp"
#include "Tensor.hpp"
#include "TensorNetwork.hpp"

namespace Jet {

using json = nlohmann::json;

template <class Tensor> struct TensorNetworkFile {
    std::optional<PathInfo> path;
    TensorNetwork<Tensor> tensors;
};

class TensorFileException : public Exception {
  public:
    explicit TensorFileException(const std::string &what_arg)
        : Exception("Error parsing tensor network file: " + what_arg){};

    explicit TensorFileException(const char *what_arg)
        : TensorFileException(std::string(what_arg)){};
};

template <class TensorType> class TensorNetworkSerializer {
  public:
    TensorNetworkSerializer<TensorType>(int indent = -1)
        : indent(indent), js(json::object())
    {
    }

    std::string operator()(const TensorNetwork<TensorType> &tn,
                           const PathInfo &path)
    {
        js["path"] = path.GetPath();

        return operator()(tn);
    }

    std::string operator()(const TensorNetwork<TensorType> &tn)
    {
        js["tensors"] = json::array();

        for (const auto &node : tn.GetNodes()) {
            js["tensors"].push_back(TensorToJSON_(node.tensor, node.tags));
        }

        std::string ret = js.dump(indent);
        js = json::object();

        return ret;
    }

    TensorNetworkFile<TensorType> operator()(std::string js_str,
                                             bool col_major = false)
    {
        TensorNetworkFile<TensorType> tf;
        LoadAndValidateJSON_(js_str);

        size_t i = 0;
        for (auto &js_tensor : js["tensors"]) {
            auto data = TensorDataFromJSON_<typename TensorType::scalar_type_t>(
                js_tensor[3], i);

            if (col_major) {
                std::vector<std::string> rev_idx(js_tensor[1].rbegin(),
                                                 js_tensor[1].rend());
                std::vector<size_t> rev_shape(js_tensor[2].rbegin(),
                                              js_tensor[2].rend());
                tf.tensors.AddTensor(TensorType(rev_idx, rev_shape, data),
                                     js_tensor[0]);
            }
            else
                tf.tensors.AddTensor(
                    TensorType(js_tensor[1], js_tensor[2], data), js_tensor[0]);
            i++;
        }

        if (js.find("path") != js.end()) {
            tf.path = PathInfo(tf.tensors, js["path"]);
        }

        js = json::object();

        return tf;
    }

  private:
    int indent;

    json js;

    void LoadAndValidateJSON_(const std::string &js_str)
    {
        js = json::parse(js_str); // throws json::exception if invalid json

        if (!js.is_object()) {
            throw TensorFileException("root element must be an object.");
        }

        if (js.find("tensors") == js.end()) {
            throw TensorFileException("root object must contain 'tensors' key");
        }
    }

    static json TensorToJSON_(const TensorType &tensor,
                              const std::vector<std::string> &tags)
    {
        auto js_tensor = json::array();

        js_tensor.push_back(tags);
        if constexpr (std::is_same_v<TensorType,
                                     Jet::Tensor<std::complex<float>>> ||
                      std::is_same_v<TensorType,
                                     Jet::Tensor<std::complex<double>>>) {
            js_tensor.push_back(tensor.GetIndices());
            js_tensor.push_back(tensor.GetShape());
            js_tensor.push_back(TensorDataToJSON_(tensor.GetData()));
        }
        else { // column-major branch
            std::vector<std::string> rev_idx{tensor.GetIndices().rbegin(),
                                             tensor.GetIndices().rend()};
            std::vector<size_t> rev_shape{tensor.GetShape().rbegin(),
                                          tensor.GetShape().rend()};
            js_tensor.push_back(rev_idx);
            js_tensor.push_back(rev_shape);
            js_tensor.push_back(TensorDataToJSON_(tensor.GetHostDataVector()));
        }

        return js_tensor;
    }

    template <typename S>
    static json TensorDataToJSON_(const std::vector<S> &data)
    {
        auto js_data = json::array();
        for (const auto &x : data) {
            js_data.push_back({std::real(x), std::imag(x)});
        }

        return js_data;
    }

    template <typename S>
    static std::vector<S> TensorDataFromJSON_(const json &js_data,
                                              size_t tensor_index)
    {
        std::vector<S> data(js_data.size());

        size_t i = 0;
        try {
            while (i < js_data.size()) {
                data[i] = S{js_data[i].at(0), js_data[i].at(1)};
                i++;
            }
        }
        catch (const json::exception &) {
            throw TensorFileException(
                "Invalid element at index " + std::to_string(i) +
                " of tensor " + std::to_string(tensor_index) +
                ": Could not parse " + js_data[i].dump() + " as complex.");
        }

        return data;
    }
};

}; // namespace Jet