Merge pull request #12559 from dkurt:dnn_remove_usrtype1

This commit is contained in:
Alexander Alekhin 2018-09-17 18:13:29 +00:00
commit 27a4e370f9

View File

@ -74,6 +74,18 @@ enum LuaType
LEGACY_TYPE_RECUR_FUNCTION = 7
};
// We use OpenCV's types to manage CV_ELEM_SIZE.
enum TorchType
{
TYPE_DOUBLE = CV_64F,
TYPE_FLOAT = CV_32F,
TYPE_BYTE = CV_8U,
TYPE_CHAR = CV_8S,
TYPE_SHORT = CV_16S,
TYPE_INT = CV_32S,
TYPE_LONG = CV_32SC2
};
template<typename T>
static String toString(const T &v)
{
@ -203,19 +215,19 @@ struct TorchImporter
String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix));
if (typeStr == "Double")
return CV_64F;
return TYPE_DOUBLE;
else if (typeStr == "Float" || typeStr == "Cuda")
return CV_32F;
return TYPE_FLOAT;
else if (typeStr == "Byte")
return CV_8U;
return TYPE_BYTE;
else if (typeStr == "Char")
return CV_8S;
return TYPE_CHAR;
else if (typeStr == "Short")
return CV_16S;
return TYPE_SHORT;
else if (typeStr == "Int")
return CV_32S;
else if (typeStr == "Long") //Carefully! CV_64S type coded as CV_USRTYPE1
return CV_USRTYPE1;
return TYPE_INT;
else if (typeStr == "Long")
return TYPE_LONG;
else
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\"");
}
@ -236,36 +248,44 @@ struct TorchImporter
void readTorchStorage(int index, int type = -1)
{
long size = readLong();
Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat
Mat storageMat;
switch (type)
{
case CV_32F:
case TYPE_FLOAT:
storageMat.create(1, size, CV_32F);
THFile_readFloatRaw(file, (float*)storageMat.data, size);
break;
case CV_64F:
case TYPE_DOUBLE:
storageMat.create(1, size, CV_64F);
THFile_readDoubleRaw(file, (double*)storageMat.data, size);
break;
case CV_8S:
case CV_8U:
case TYPE_CHAR:
storageMat.create(1, size, CV_8S);
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
break;
case CV_16S:
case CV_16U:
case TYPE_BYTE:
storageMat.create(1, size, CV_8U);
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
break;
case TYPE_SHORT:
storageMat.create(1, size, CV_16S);
THFile_readShortRaw(file, (short*)storageMat.data, size);
break;
case CV_32S:
case TYPE_INT:
storageMat.create(1, size, CV_32S);
THFile_readIntRaw(file, (int*)storageMat.data, size);
break;
case CV_USRTYPE1:
case TYPE_LONG:
{
storageMat.create(1, size, CV_64F); //handle LongStorage as CV_64F Mat
double *buf = storageMat.ptr<double>();
THFile_readLongRaw(file, (int64*)buf, size);
for (size_t i = (size_t)size; i-- > 0; )
buf[i] = ((int64*)buf)[i];
}
break;
}
default:
CV_Error(Error::StsInternal, "");
break;