mirror of
https://github.com/microsoft/vcpkg.git
synced 2024-12-18 19:40:18 +08:00
277 lines
14 KiB
Diff
277 lines
14 KiB
Diff
|
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
|
||
|
index 0e126ff..4e6048e 100644
|
||
|
--- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp
|
||
|
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp
|
||
|
@@ -123,29 +123,29 @@ FlatbufferLoader::FlatbufferLoader()
|
||
|
: mcu_(std::make_shared<mobile::CompilationUnit>()),
|
||
|
cu_(std::make_shared<CompilationUnit>()),
|
||
|
ivalue_parsers_{nullptr} {
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_NONE, &parseBasic);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Int, &parseBasic);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Bool, &parseBasic);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Double, &parseBasic);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_ComplexDouble, &parseBasic);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_TensorMetadata, &parseTensor);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_String, &parseBasic);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_List, &parseList);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::IntList, &parseIntList);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_IntList, &parseIntList);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_DoubleList, &parseDoubleList);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::BoolList, &parseBoolList);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_BoolList, &parseBoolList);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Tuple, &parseTuple);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Dict, &parseDict);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::Object, &parseObject);
|
||
|
- registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_Object, &parseObject);
|
||
|
+ registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Device, &parseBasic);
|
||
|
registerIValueParser(
|
||
|
- mobile::serialization::IValueUnion::EnumValue, &parseEnum);
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_EnumValue, &parseEnum);
|
||
|
internal_registerTypeResolver(&resolveType);
|
||
|
}
|
||
|
|
||
|
@@ -334,21 +334,21 @@ IValue parseBasic(
|
||
|
FlatbufferLoader&,
|
||
|
const mobile::serialization::IValue& ivalue) {
|
||
|
switch (ivalue.val_type()) {
|
||
|
- case mobile::serialization::IValueUnion::NONE:
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_NONE:
|
||
|
return {};
|
||
|
- case mobile::serialization::IValueUnion::Int:
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_Int:
|
||
|
return ivalue.val_as_Int()->int_val();
|
||
|
- case mobile::serialization::IValueUnion::Bool:
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_Bool:
|
||
|
return ivalue.val_as_Bool()->bool_val();
|
||
|
- case mobile::serialization::IValueUnion::Double:
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_Double:
|
||
|
return ivalue.val_as_Double()->double_val();
|
||
|
- case mobile::serialization::IValueUnion::ComplexDouble: {
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_ComplexDouble: {
|
||
|
const auto* comp = ivalue.val_as_ComplexDouble();
|
||
|
return c10::complex<double>(comp->real(), comp->imag());
|
||
|
}
|
||
|
- case mobile::serialization::IValueUnion::String:
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_String:
|
||
|
return ivalue.val_as_String()->data()->str();
|
||
|
- case mobile::serialization::IValueUnion::Device: {
|
||
|
+ case mobile::serialization::IValueUnion::IValueUnion_Device: {
|
||
|
return c10::Device(ivalue.val_as_Device()->str()->str());
|
||
|
}
|
||
|
default:
|
||
|
@@ -504,7 +504,7 @@ ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
|
||
|
TORCH_CHECK(object->type_index() < all_ivalues_.size());
|
||
|
all_types_[object->type_index()] = cls;
|
||
|
|
||
|
- if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
|
||
|
+ if (obj_type->type() == mobile::serialization::TypeType::TypeType_CLASS_WITH_FIELD) {
|
||
|
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||
|
IValue val = getIValue(object->attrs()->Get(i));
|
||
|
// Need to use concrete object's field's type to set type of field.
|
||
|
@@ -529,7 +529,7 @@ IValue parseObject(
|
||
|
auto cls = loader.getOrCreateClassTypeForObject(object);
|
||
|
Stack stack;
|
||
|
switch (obj_type->type()) {
|
||
|
- case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
|
||
|
+ case mobile::serialization::TypeType::TypeType_CLASS_WITH_FIELD: {
|
||
|
auto obj = c10::ivalue::Object::create(
|
||
|
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
|
||
|
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
|
||
|
@@ -538,7 +538,7 @@ IValue parseObject(
|
||
|
}
|
||
|
return obj;
|
||
|
}
|
||
|
- case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
|
||
|
+ case mobile::serialization::TypeType::TypeType_CLASS_WITH_SETSTATE: {
|
||
|
IValue input = loader.getIValue(object->state());
|
||
|
mobile::Function* setstate = loader.getFunction(object->setstate_func());
|
||
|
auto obj =
|
||
|
@@ -548,7 +548,7 @@ IValue parseObject(
|
||
|
setstate->run(stack);
|
||
|
return obj;
|
||
|
}
|
||
|
- case mobile::serialization::TypeType::CUSTOM_CLASS: {
|
||
|
+ case mobile::serialization::TypeType::TypeType_CUSTOM_CLASS: {
|
||
|
auto custom_class_type =
|
||
|
torch::jit::getCustomClass(cls->name()->qualifiedName());
|
||
|
IValue input = loader.getIValue(object->state());
|
||
|
diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h
|
||
|
index 7e88326..dc95202 100644
|
||
|
--- a/torch/csrc/jit/mobile/flatbuffer_loader.h
|
||
|
+++ b/torch/csrc/jit/mobile/flatbuffer_loader.h
|
||
|
@@ -151,7 +151,7 @@ class TORCH_API FlatbufferLoader {
|
||
|
std::vector<IValue> all_ivalues_;
|
||
|
std::array<
|
||
|
IValueParser,
|
||
|
- static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
|
||
|
+ static_cast<uint8_t>(mobile::serialization::IValueUnion::IValueUnion_MAX) + 1>
|
||
|
ivalue_parsers_;
|
||
|
TypeResolver type_resolver_ = nullptr;
|
||
|
mobile::serialization::Module* module_ = nullptr;
|
||
|
diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp
|
||
|
index 066a78c..a9e19bb 100644
|
||
|
--- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp
|
||
|
+++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp
|
||
|
@@ -336,7 +336,7 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
|
||
|
mcu_ = &module.compilation_unit();
|
||
|
|
||
|
// first element is None.
|
||
|
- insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0));
|
||
|
+ insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::IValueUnion_NONE, 0));
|
||
|
|
||
|
auto methods = module.get_methods();
|
||
|
std::vector<uint32_t> functions_index;
|
||
|
@@ -459,7 +459,7 @@ flatbuffers::Offset<mobile::serialization::Dict> FlatbufferSerializer::dictToFB(
|
||
|
flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
|
||
|
classTypeToFB(FlatBufferBuilder& fbb, ClassTypePtr class_ptr) {
|
||
|
mobile::serialization::TypeType typetype =
|
||
|
- mobile::serialization::TypeType::UNSET;
|
||
|
+ mobile::serialization::TypeType::TypeType_UNSET;
|
||
|
|
||
|
flatbuffers::Offset<
|
||
|
flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
|
||
|
@@ -469,11 +469,11 @@ flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
|
||
|
const mobile::Function* setstate = mcu_->find_function(setstate_name);
|
||
|
const mobile::Function* getstate = mcu_->find_function(getstate_name);
|
||
|
if (setstate != nullptr && getstate != nullptr) {
|
||
|
- typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
|
||
|
+ typetype = mobile::serialization::TypeType::TypeType_CLASS_WITH_SETSTATE;
|
||
|
} else if (
|
||
|
class_ptr->findMethod("__setstate__") &&
|
||
|
class_ptr->findMethod("__getstate__")) {
|
||
|
- typetype = mobile::serialization::TypeType::CUSTOM_CLASS;
|
||
|
+ typetype = mobile::serialization::TypeType::TypeType_CUSTOM_CLASS;
|
||
|
} else {
|
||
|
size_t num_attr = class_ptr->numAttributes();
|
||
|
std::vector<flatbuffers::Offset<flatbuffers::String>> names;
|
||
|
@@ -482,7 +482,7 @@ flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
|
||
|
names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
|
||
|
}
|
||
|
names_offset = fbb.CreateVector(names);
|
||
|
- typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD;
|
||
|
+ typetype = mobile::serialization::TypeType::TypeType_CLASS_WITH_FIELD;
|
||
|
}
|
||
|
|
||
|
auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName());
|
||
|
@@ -500,7 +500,7 @@ uint32_t FlatbufferSerializer::storeFunctionAndGetIndex(
|
||
|
|
||
|
auto offset = CreateIValue(
|
||
|
fbb,
|
||
|
- mobile::serialization::IValueUnion::Function,
|
||
|
+ mobile::serialization::IValueUnion::IValueUnion_Function,
|
||
|
functionToFB(fbb, qn, function).Union());
|
||
|
|
||
|
uint32_t index = insertIValue(offset);
|
||
|
@@ -662,68 +662,68 @@ flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
|
||
|
iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
|
||
|
using mobile::serialization::IValueUnion;
|
||
|
|
||
|
- IValueUnion ivalue_type = IValueUnion::NONE;
|
||
|
+ IValueUnion ivalue_type = IValueUnion::IValueUnion_NONE;
|
||
|
flatbuffers::Offset<void> offset = 0;
|
||
|
|
||
|
if (ivalue.isTensor()) {
|
||
|
- ivalue_type = IValueUnion::TensorMetadata;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_TensorMetadata;
|
||
|
offset = tensorToFB(fbb, ivalue).Union();
|
||
|
} else if (ivalue.isTuple()) {
|
||
|
- ivalue_type = IValueUnion::Tuple;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Tuple;
|
||
|
offset = tupleToFB(fbb, ivalue).Union();
|
||
|
} else if (ivalue.isDouble()) {
|
||
|
- ivalue_type = IValueUnion::Double;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Double;
|
||
|
offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble()))
|
||
|
.Union();
|
||
|
} else if (ivalue.isComplexDouble()) {
|
||
|
auto comp = ivalue.toComplexDouble();
|
||
|
- ivalue_type = IValueUnion::ComplexDouble;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_ComplexDouble;
|
||
|
offset = fbb.CreateStruct(mobile::serialization::ComplexDouble(
|
||
|
comp.real(), comp.imag()))
|
||
|
.Union();
|
||
|
} else if (ivalue.isInt()) {
|
||
|
- ivalue_type = IValueUnion::Int;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Int;
|
||
|
offset =
|
||
|
fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union();
|
||
|
} else if (ivalue.isBool()) {
|
||
|
- ivalue_type = IValueUnion::Bool;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Bool;
|
||
|
offset =
|
||
|
fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union();
|
||
|
} else if (ivalue.isString()) {
|
||
|
- ivalue_type = IValueUnion::String;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_String;
|
||
|
offset = mobile::serialization::CreateString(
|
||
|
fbb, fbb.CreateSharedString(ivalue.toString()->string()))
|
||
|
.Union();
|
||
|
} else if (ivalue.isGenericDict()) {
|
||
|
- ivalue_type = IValueUnion::Dict;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Dict;
|
||
|
offset = dictToFB(fbb, ivalue).Union();
|
||
|
} else if (ivalue.isNone()) {
|
||
|
- ivalue_type = IValueUnion::NONE;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_NONE;
|
||
|
offset = 0;
|
||
|
} else if (ivalue.isIntList()) {
|
||
|
- ivalue_type = IValueUnion::IntList;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_IntList;
|
||
|
offset = mobile::serialization::CreateIntList(
|
||
|
fbb, fbb.CreateVector(ivalue.toIntVector()))
|
||
|
.Union();
|
||
|
} else if (ivalue.isDoubleList()) {
|
||
|
- ivalue_type = IValueUnion::DoubleList;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_DoubleList;
|
||
|
offset = mobile::serialization::CreateDoubleList(
|
||
|
fbb, fbb.CreateVector(ivalue.toDoubleVector()))
|
||
|
.Union();
|
||
|
} else if (ivalue.isBoolList()) {
|
||
|
- ivalue_type = IValueUnion::BoolList;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_BoolList;
|
||
|
auto boollist = ivalue.toBoolList();
|
||
|
std::vector<uint8_t> bool_vec(boollist.begin(), boollist.end());
|
||
|
offset =
|
||
|
mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union();
|
||
|
} else if (ivalue.isList()) {
|
||
|
- ivalue_type = IValueUnion::List;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_List;
|
||
|
offset = listToFB(fbb, ivalue).Union();
|
||
|
} else if (ivalue.isObject()) {
|
||
|
- ivalue_type = IValueUnion::Object;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Object;
|
||
|
offset = objectToFB(fbb, ivalue).Union();
|
||
|
} else if (ivalue.isDevice()) {
|
||
|
- ivalue_type = IValueUnion::Device;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_Device;
|
||
|
offset = mobile::serialization::CreateDevice(
|
||
|
fbb, fbb.CreateSharedString(ivalue.toDevice().str()))
|
||
|
.Union();
|
||
|
@@ -732,7 +732,7 @@ flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
|
||
|
const auto& qualified_class_name =
|
||
|
enum_holder->type()->qualifiedClassName();
|
||
|
uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value());
|
||
|
- ivalue_type = IValueUnion::EnumValue;
|
||
|
+ ivalue_type = IValueUnion::IValueUnion_EnumValue;
|
||
|
offset = mobile::serialization::CreateEnumValue(
|
||
|
fbb,
|
||
|
fbb.CreateSharedString(qualified_class_name.qualifiedName()),
|