diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 0e126ff..4e6048e 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -123,29 +123,29 @@ FlatbufferLoader::FlatbufferLoader() : mcu_(std::make_shared()), cu_(std::make_shared()), ivalue_parsers_{nullptr} { - registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic); - registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic); - registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic); - registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_NONE, &parseBasic); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Int, &parseBasic); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Bool, &parseBasic); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Double, &parseBasic); registerIValueParser( - mobile::serialization::IValueUnion::ComplexDouble, &parseBasic); + mobile::serialization::IValueUnion::IValueUnion_ComplexDouble, &parseBasic); registerIValueParser( - mobile::serialization::IValueUnion::TensorMetadata, &parseTensor); - registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic); - registerIValueParser(mobile::serialization::IValueUnion::List, &parseList); + mobile::serialization::IValueUnion::IValueUnion_TensorMetadata, &parseTensor); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_String, &parseBasic); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_List, &parseList); registerIValueParser( - mobile::serialization::IValueUnion::IntList, &parseIntList); + mobile::serialization::IValueUnion::IValueUnion_IntList, &parseIntList); registerIValueParser( - mobile::serialization::IValueUnion::DoubleList, &parseDoubleList); + mobile::serialization::IValueUnion::IValueUnion_DoubleList, &parseDoubleList); registerIValueParser( - mobile::serialization::IValueUnion::BoolList, &parseBoolList); - registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple); - registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict); + mobile::serialization::IValueUnion::IValueUnion_BoolList, &parseBoolList); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Tuple, &parseTuple); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Dict, &parseDict); registerIValueParser( - mobile::serialization::IValueUnion::Object, &parseObject); - registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic); + mobile::serialization::IValueUnion::IValueUnion_Object, &parseObject); + registerIValueParser(mobile::serialization::IValueUnion::IValueUnion_Device, &parseBasic); registerIValueParser( - mobile::serialization::IValueUnion::EnumValue, &parseEnum); + mobile::serialization::IValueUnion::IValueUnion_EnumValue, &parseEnum); internal_registerTypeResolver(&resolveType); } @@ -334,21 +334,21 @@ IValue parseBasic( FlatbufferLoader&, const mobile::serialization::IValue& ivalue) { switch (ivalue.val_type()) { - case mobile::serialization::IValueUnion::NONE: + case mobile::serialization::IValueUnion::IValueUnion_NONE: return {}; - case mobile::serialization::IValueUnion::Int: + case mobile::serialization::IValueUnion::IValueUnion_Int: return ivalue.val_as_Int()->int_val(); - case mobile::serialization::IValueUnion::Bool: + case mobile::serialization::IValueUnion::IValueUnion_Bool: return ivalue.val_as_Bool()->bool_val(); - case mobile::serialization::IValueUnion::Double: + case mobile::serialization::IValueUnion::IValueUnion_Double: return ivalue.val_as_Double()->double_val(); - case mobile::serialization::IValueUnion::ComplexDouble: { + case mobile::serialization::IValueUnion::IValueUnion_ComplexDouble: { const auto* comp = ivalue.val_as_ComplexDouble(); return c10::complex(comp->real(), comp->imag()); } - case mobile::serialization::IValueUnion::String: + case mobile::serialization::IValueUnion::IValueUnion_String: return ivalue.val_as_String()->data()->str(); - case mobile::serialization::IValueUnion::Device: { + case mobile::serialization::IValueUnion::IValueUnion_Device: { return c10::Device(ivalue.val_as_Device()->str()->str()); } default: @@ -504,7 +504,7 @@ ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject( TORCH_CHECK(object->type_index() < all_ivalues_.size()); all_types_[object->type_index()] = cls; - if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) { + if (obj_type->type() == mobile::serialization::TypeType::TypeType_CLASS_WITH_FIELD) { for (uint32_t i = 0; i < object->attrs()->size(); i++) { IValue val = getIValue(object->attrs()->Get(i)); // Need to use concrete object's field's type to set type of field. @@ -529,7 +529,7 @@ IValue parseObject( auto cls = loader.getOrCreateClassTypeForObject(object); Stack stack; switch (obj_type->type()) { - case mobile::serialization::TypeType::CLASS_WITH_FIELD: { + case mobile::serialization::TypeType::TypeType_CLASS_WITH_FIELD: { auto obj = c10::ivalue::Object::create( at::StrongTypePtr(loader.cu_, cls), object->attrs()->size()); for (uint32_t i = 0; i < object->attrs()->size(); i++) { @@ -538,7 +538,7 @@ IValue parseObject( } return obj; } - case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: { + case mobile::serialization::TypeType::TypeType_CLASS_WITH_SETSTATE: { IValue input = loader.getIValue(object->state()); mobile::Function* setstate = loader.getFunction(object->setstate_func()); auto obj = @@ -548,7 +548,7 @@ IValue parseObject( setstate->run(stack); return obj; } - case mobile::serialization::TypeType::CUSTOM_CLASS: { + case mobile::serialization::TypeType::TypeType_CUSTOM_CLASS: { auto custom_class_type = torch::jit::getCustomClass(cls->name()->qualifiedName()); IValue input = loader.getIValue(object->state()); diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index 7e88326..dc95202 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -151,7 +151,7 @@ class TORCH_API FlatbufferLoader { std::vector all_ivalues_; std::array< IValueParser, - static_cast(mobile::serialization::IValueUnion::MAX) + 1> + static_cast(mobile::serialization::IValueUnion::IValueUnion_MAX) + 1> ivalue_parsers_; TypeResolver type_resolver_ = nullptr; mobile::serialization::Module* module_ = nullptr; diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index 066a78c..a9e19bb 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -336,7 +336,7 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule( mcu_ = &module.compilation_unit(); // first element is None. - insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0)); + insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::IValueUnion_NONE, 0)); auto methods = module.get_methods(); std::vector functions_index; @@ -459,7 +459,7 @@ flatbuffers::Offset FlatbufferSerializer::dictToFB( flatbuffers::Offset FlatbufferSerializer:: classTypeToFB(FlatBufferBuilder& fbb, ClassTypePtr class_ptr) { mobile::serialization::TypeType typetype = - mobile::serialization::TypeType::UNSET; + mobile::serialization::TypeType::TypeType_UNSET; flatbuffers::Offset< flatbuffers::Vector>> @@ -469,11 +469,11 @@ flatbuffers::Offset FlatbufferSerializer:: const mobile::Function* setstate = mcu_->find_function(setstate_name); const mobile::Function* getstate = mcu_->find_function(getstate_name); if (setstate != nullptr && getstate != nullptr) { - typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE; + typetype = mobile::serialization::TypeType::TypeType_CLASS_WITH_SETSTATE; } else if ( class_ptr->findMethod("__setstate__") && class_ptr->findMethod("__getstate__")) { - typetype = mobile::serialization::TypeType::CUSTOM_CLASS; + typetype = mobile::serialization::TypeType::TypeType_CUSTOM_CLASS; } else { size_t num_attr = class_ptr->numAttributes(); std::vector> names; @@ -482,7 +482,7 @@ flatbuffers::Offset FlatbufferSerializer:: names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i))); } names_offset = fbb.CreateVector(names); - typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD; + typetype = mobile::serialization::TypeType::TypeType_CLASS_WITH_FIELD; } auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName()); @@ -500,7 +500,7 @@ uint32_t FlatbufferSerializer::storeFunctionAndGetIndex( auto offset = CreateIValue( fbb, - mobile::serialization::IValueUnion::Function, + mobile::serialization::IValueUnion::IValueUnion_Function, functionToFB(fbb, qn, function).Union()); uint32_t index = insertIValue(offset); @@ -662,68 +662,68 @@ flatbuffers::Offset FlatbufferSerializer:: iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) { using mobile::serialization::IValueUnion; - IValueUnion ivalue_type = IValueUnion::NONE; + IValueUnion ivalue_type = IValueUnion::IValueUnion_NONE; flatbuffers::Offset offset = 0; if (ivalue.isTensor()) { - ivalue_type = IValueUnion::TensorMetadata; + ivalue_type = IValueUnion::IValueUnion_TensorMetadata; offset = tensorToFB(fbb, ivalue).Union(); } else if (ivalue.isTuple()) { - ivalue_type = IValueUnion::Tuple; + ivalue_type = IValueUnion::IValueUnion_Tuple; offset = tupleToFB(fbb, ivalue).Union(); } else if (ivalue.isDouble()) { - ivalue_type = IValueUnion::Double; + ivalue_type = IValueUnion::IValueUnion_Double; offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble())) .Union(); } else if (ivalue.isComplexDouble()) { auto comp = ivalue.toComplexDouble(); - ivalue_type = IValueUnion::ComplexDouble; + ivalue_type = IValueUnion::IValueUnion_ComplexDouble; offset = fbb.CreateStruct(mobile::serialization::ComplexDouble( comp.real(), comp.imag())) .Union(); } else if (ivalue.isInt()) { - ivalue_type = IValueUnion::Int; + ivalue_type = IValueUnion::IValueUnion_Int; offset = fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union(); } else if (ivalue.isBool()) { - ivalue_type = IValueUnion::Bool; + ivalue_type = IValueUnion::IValueUnion_Bool; offset = fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union(); } else if (ivalue.isString()) { - ivalue_type = IValueUnion::String; + ivalue_type = IValueUnion::IValueUnion_String; offset = mobile::serialization::CreateString( fbb, fbb.CreateSharedString(ivalue.toString()->string())) .Union(); } else if (ivalue.isGenericDict()) { - ivalue_type = IValueUnion::Dict; + ivalue_type = IValueUnion::IValueUnion_Dict; offset = dictToFB(fbb, ivalue).Union(); } else if (ivalue.isNone()) { - ivalue_type = IValueUnion::NONE; + ivalue_type = IValueUnion::IValueUnion_NONE; offset = 0; } else if (ivalue.isIntList()) { - ivalue_type = IValueUnion::IntList; + ivalue_type = IValueUnion::IValueUnion_IntList; offset = mobile::serialization::CreateIntList( fbb, fbb.CreateVector(ivalue.toIntVector())) .Union(); } else if (ivalue.isDoubleList()) { - ivalue_type = IValueUnion::DoubleList; + ivalue_type = IValueUnion::IValueUnion_DoubleList; offset = mobile::serialization::CreateDoubleList( fbb, fbb.CreateVector(ivalue.toDoubleVector())) .Union(); } else if (ivalue.isBoolList()) { - ivalue_type = IValueUnion::BoolList; + ivalue_type = IValueUnion::IValueUnion_BoolList; auto boollist = ivalue.toBoolList(); std::vector bool_vec(boollist.begin(), boollist.end()); offset = mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union(); } else if (ivalue.isList()) { - ivalue_type = IValueUnion::List; + ivalue_type = IValueUnion::IValueUnion_List; offset = listToFB(fbb, ivalue).Union(); } else if (ivalue.isObject()) { - ivalue_type = IValueUnion::Object; + ivalue_type = IValueUnion::IValueUnion_Object; offset = objectToFB(fbb, ivalue).Union(); } else if (ivalue.isDevice()) { - ivalue_type = IValueUnion::Device; + ivalue_type = IValueUnion::IValueUnion_Device; offset = mobile::serialization::CreateDevice( fbb, fbb.CreateSharedString(ivalue.toDevice().str())) .Union(); @@ -732,7 +732,7 @@ flatbuffers::Offset FlatbufferSerializer:: const auto& qualified_class_name = enum_holder->type()->qualifiedClassName(); uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value()); - ivalue_type = IValueUnion::EnumValue; + ivalue_type = IValueUnion::IValueUnion_EnumValue; offset = mobile::serialization::CreateEnumValue( fbb, fbb.CreateSharedString(qualified_class_name.qualifiedName()),