mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 06:03:15 +08:00
Use TorchType enum
This commit is contained in:
parent
a7b3d2581f
commit
7d75526373
@ -74,6 +74,18 @@ enum LuaType
|
||||
LEGACY_TYPE_RECUR_FUNCTION = 7
|
||||
};
|
||||
|
||||
// We use OpenCV's types to manage CV_ELEM_SIZE.
|
||||
enum TorchType
|
||||
{
|
||||
TYPE_DOUBLE = CV_64F,
|
||||
TYPE_FLOAT = CV_32F,
|
||||
TYPE_BYTE = CV_8U,
|
||||
TYPE_CHAR = CV_8S,
|
||||
TYPE_SHORT = CV_16S,
|
||||
TYPE_INT = CV_32S,
|
||||
TYPE_LONG = CV_32SC2
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static String toString(const T &v)
|
||||
{
|
||||
@ -203,19 +215,19 @@ struct TorchImporter
|
||||
String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix));
|
||||
|
||||
if (typeStr == "Double")
|
||||
return CV_64F;
|
||||
return TYPE_DOUBLE;
|
||||
else if (typeStr == "Float" || typeStr == "Cuda")
|
||||
return CV_32F;
|
||||
return TYPE_FLOAT;
|
||||
else if (typeStr == "Byte")
|
||||
return CV_8U;
|
||||
return TYPE_BYTE;
|
||||
else if (typeStr == "Char")
|
||||
return CV_8S;
|
||||
return TYPE_CHAR;
|
||||
else if (typeStr == "Short")
|
||||
return CV_16S;
|
||||
return TYPE_SHORT;
|
||||
else if (typeStr == "Int")
|
||||
return CV_32S;
|
||||
else if (typeStr == "Long") //Carefully! CV_64S type coded as CV_32SC2
|
||||
return CV_32SC2;
|
||||
return TYPE_INT;
|
||||
else if (typeStr == "Long")
|
||||
return TYPE_LONG;
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\"");
|
||||
}
|
||||
@ -236,36 +248,44 @@ struct TorchImporter
|
||||
void readTorchStorage(int index, int type = -1)
|
||||
{
|
||||
long size = readLong();
|
||||
Mat storageMat(1, size, (type != CV_32SC2) ? type : CV_64F); //handle LongStorage as CV_64F Mat
|
||||
Mat storageMat;
|
||||
|
||||
switch (type)
|
||||
{
|
||||
case CV_32F:
|
||||
case TYPE_FLOAT:
|
||||
storageMat.create(1, size, CV_32F);
|
||||
THFile_readFloatRaw(file, (float*)storageMat.data, size);
|
||||
break;
|
||||
case CV_64F:
|
||||
case TYPE_DOUBLE:
|
||||
storageMat.create(1, size, CV_64F);
|
||||
THFile_readDoubleRaw(file, (double*)storageMat.data, size);
|
||||
break;
|
||||
case CV_8S:
|
||||
case CV_8U:
|
||||
case TYPE_CHAR:
|
||||
storageMat.create(1, size, CV_8S);
|
||||
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
|
||||
break;
|
||||
case CV_16S:
|
||||
case CV_16U:
|
||||
case TYPE_BYTE:
|
||||
storageMat.create(1, size, CV_8U);
|
||||
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
|
||||
break;
|
||||
case TYPE_SHORT:
|
||||
storageMat.create(1, size, CV_16S);
|
||||
THFile_readShortRaw(file, (short*)storageMat.data, size);
|
||||
break;
|
||||
case CV_32S:
|
||||
case TYPE_INT:
|
||||
storageMat.create(1, size, CV_32S);
|
||||
THFile_readIntRaw(file, (int*)storageMat.data, size);
|
||||
break;
|
||||
case CV_32SC2:
|
||||
case TYPE_LONG:
|
||||
{
|
||||
storageMat.create(1, size, CV_64F); //handle LongStorage as CV_64F Mat
|
||||
double *buf = storageMat.ptr<double>();
|
||||
THFile_readLongRaw(file, (int64*)buf, size);
|
||||
|
||||
for (size_t i = (size_t)size; i-- > 0; )
|
||||
buf[i] = ((int64*)buf)[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
CV_Error(Error::StsInternal, "");
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user