Use TorchType enum

This commit is contained in:
Dmitry Kurtaev 2018-09-17 17:28:26 +03:00
parent a7b3d2581f
commit 7d75526373

View File

@ -74,6 +74,18 @@ enum LuaType
LEGACY_TYPE_RECUR_FUNCTION = 7 LEGACY_TYPE_RECUR_FUNCTION = 7
}; };
// We use OpenCV's types to manage CV_ELEM_SIZE.
enum TorchType
{
TYPE_DOUBLE = CV_64F,
TYPE_FLOAT = CV_32F,
TYPE_BYTE = CV_8U,
TYPE_CHAR = CV_8S,
TYPE_SHORT = CV_16S,
TYPE_INT = CV_32S,
TYPE_LONG = CV_32SC2
};
template<typename T> template<typename T>
static String toString(const T &v) static String toString(const T &v)
{ {
@ -203,19 +215,19 @@ struct TorchImporter
String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix)); String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix));
if (typeStr == "Double") if (typeStr == "Double")
return CV_64F; return TYPE_DOUBLE;
else if (typeStr == "Float" || typeStr == "Cuda") else if (typeStr == "Float" || typeStr == "Cuda")
return CV_32F; return TYPE_FLOAT;
else if (typeStr == "Byte") else if (typeStr == "Byte")
return CV_8U; return TYPE_BYTE;
else if (typeStr == "Char") else if (typeStr == "Char")
return CV_8S; return TYPE_CHAR;
else if (typeStr == "Short") else if (typeStr == "Short")
return CV_16S; return TYPE_SHORT;
else if (typeStr == "Int") else if (typeStr == "Int")
return CV_32S; return TYPE_INT;
else if (typeStr == "Long") //Carefully! CV_64S type coded as CV_32SC2 else if (typeStr == "Long")
return CV_32SC2; return TYPE_LONG;
else else
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\""); CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\"");
} }
@ -236,36 +248,44 @@ struct TorchImporter
void readTorchStorage(int index, int type = -1) void readTorchStorage(int index, int type = -1)
{ {
long size = readLong(); long size = readLong();
Mat storageMat(1, size, (type != CV_32SC2) ? type : CV_64F); //handle LongStorage as CV_64F Mat Mat storageMat;
switch (type) switch (type)
{ {
case CV_32F: case TYPE_FLOAT:
storageMat.create(1, size, CV_32F);
THFile_readFloatRaw(file, (float*)storageMat.data, size); THFile_readFloatRaw(file, (float*)storageMat.data, size);
break; break;
case CV_64F: case TYPE_DOUBLE:
storageMat.create(1, size, CV_64F);
THFile_readDoubleRaw(file, (double*)storageMat.data, size); THFile_readDoubleRaw(file, (double*)storageMat.data, size);
break; break;
case CV_8S: case TYPE_CHAR:
case CV_8U: storageMat.create(1, size, CV_8S);
THFile_readByteRaw(file, (uchar*)storageMat.data, size); THFile_readByteRaw(file, (uchar*)storageMat.data, size);
break; break;
case CV_16S: case TYPE_BYTE:
case CV_16U: storageMat.create(1, size, CV_8U);
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
break;
case TYPE_SHORT:
storageMat.create(1, size, CV_16S);
THFile_readShortRaw(file, (short*)storageMat.data, size); THFile_readShortRaw(file, (short*)storageMat.data, size);
break; break;
case CV_32S: case TYPE_INT:
storageMat.create(1, size, CV_32S);
THFile_readIntRaw(file, (int*)storageMat.data, size); THFile_readIntRaw(file, (int*)storageMat.data, size);
break; break;
case CV_32SC2: case TYPE_LONG:
{ {
storageMat.create(1, size, CV_64F); //handle LongStorage as CV_64F Mat
double *buf = storageMat.ptr<double>(); double *buf = storageMat.ptr<double>();
THFile_readLongRaw(file, (int64*)buf, size); THFile_readLongRaw(file, (int64*)buf, size);
for (size_t i = (size_t)size; i-- > 0; ) for (size_t i = (size_t)size; i-- > 0; )
buf[i] = ((int64*)buf)[i]; buf[i] = ((int64*)buf)[i];
}
break; break;
}
default: default:
CV_Error(Error::StsInternal, ""); CV_Error(Error::StsInternal, "");
break; break;