Program Listing for File TensorNetworkIO.hpp¶
↰ Return to documentation for file (include/jet/TensorNetworkIO.hpp
)
#pragma once
#include <complex>
#include <optional>
#include <stdexcept>
#include <string>
#include <utility>
#include "external/nlohmann/json.hpp"
#ifdef CUTENSOR
#include "CudaTensor.hpp"
#endif
#include "PathInfo.hpp"
#include "Tensor.hpp"
#include "TensorNetwork.hpp"
namespace Jet {
using json = nlohmann::json;
template <class Tensor> struct TensorNetworkFile {
std::optional<PathInfo> path;
TensorNetwork<Tensor> tensors;
};
class TensorFileException : public Exception {
public:
explicit TensorFileException(const std::string &what_arg)
: Exception("Error parsing tensor network file: " + what_arg){};
explicit TensorFileException(const char *what_arg)
: TensorFileException(std::string(what_arg)){};
};
template <class TensorType> class TensorNetworkSerializer {
public:
TensorNetworkSerializer<TensorType>(int indent = -1)
: indent(indent), js(json::object())
{
}
std::string operator()(const TensorNetwork<TensorType> &tn,
const PathInfo &path)
{
js["path"] = path.GetPath();
return operator()(tn);
}
std::string operator()(const TensorNetwork<TensorType> &tn)
{
js["tensors"] = json::array();
for (const auto &node : tn.GetNodes()) {
js["tensors"].push_back(TensorToJSON_(node.tensor, node.tags));
}
std::string ret = js.dump(indent);
js = json::object();
return ret;
}
TensorNetworkFile<TensorType> operator()(std::string js_str,
bool col_major = false)
{
TensorNetworkFile<TensorType> tf;
LoadAndValidateJSON_(js_str);
size_t i = 0;
for (auto &js_tensor : js["tensors"]) {
auto data = TensorDataFromJSON_<typename TensorType::scalar_type_t>(
js_tensor[3], i);
if (col_major) {
std::vector<std::string> rev_idx(js_tensor[1].rbegin(),
js_tensor[1].rend());
std::vector<size_t> rev_shape(js_tensor[2].rbegin(),
js_tensor[2].rend());
tf.tensors.AddTensor(TensorType(rev_idx, rev_shape, data),
js_tensor[0]);
}
else
tf.tensors.AddTensor(
TensorType(js_tensor[1], js_tensor[2], data), js_tensor[0]);
i++;
}
if (js.find("path") != js.end()) {
tf.path = PathInfo(tf.tensors, js["path"]);
}
js = json::object();
return tf;
}
private:
int indent;
json js;
void LoadAndValidateJSON_(const std::string &js_str)
{
js = json::parse(js_str); // throws json::exception if invalid json
if (!js.is_object()) {
throw TensorFileException("root element must be an object.");
}
if (js.find("tensors") == js.end()) {
throw TensorFileException("root object must contain 'tensors' key");
}
}
static json TensorToJSON_(const TensorType &tensor,
const std::vector<std::string> &tags)
{
auto js_tensor = json::array();
js_tensor.push_back(tags);
if constexpr (std::is_same_v<TensorType,
Jet::Tensor<std::complex<float>>> ||
std::is_same_v<TensorType,
Jet::Tensor<std::complex<double>>>) {
js_tensor.push_back(tensor.GetIndices());
js_tensor.push_back(tensor.GetShape());
js_tensor.push_back(TensorDataToJSON_(tensor.GetData()));
}
else { // column-major branch
std::vector<std::string> rev_idx{tensor.GetIndices().rbegin(),
tensor.GetIndices().rend()};
std::vector<size_t> rev_shape{tensor.GetShape().rbegin(),
tensor.GetShape().rend()};
js_tensor.push_back(rev_idx);
js_tensor.push_back(rev_shape);
js_tensor.push_back(TensorDataToJSON_(tensor.GetHostDataVector()));
}
return js_tensor;
}
template <typename S>
static json TensorDataToJSON_(const std::vector<S> &data)
{
auto js_data = json::array();
for (const auto &x : data) {
js_data.push_back({std::real(x), std::imag(x)});
}
return js_data;
}
template <typename S>
static std::vector<S> TensorDataFromJSON_(const json &js_data,
size_t tensor_index)
{
std::vector<S> data(js_data.size());
size_t i = 0;
try {
while (i < js_data.size()) {
data[i] = S{js_data[i].at(0), js_data[i].at(1)};
i++;
}
}
catch (const json::exception &) {
throw TensorFileException(
"Invalid element at index " + std::to_string(i) +
" of tensor " + std::to_string(tensor_index) +
": Could not parse " + js_data[i].dump() + " as complex.");
}
return data;
}
};
}; // namespace Jet
api/program_listing_file_include_jet_TensorNetworkIO.hpp
Download Python script
Download Notebook
View on GitHub