mirror of
https://github.com/opencv/opencv.git
synced 2025-06-17 15:20:51 +08:00
add broadcast where node
This commit is contained in:
parent
097891e311
commit
0513741a85
@ -48,6 +48,7 @@ public:
|
|||||||
SUM,
|
SUM,
|
||||||
ADD,
|
ADD,
|
||||||
DIV,
|
DIV,
|
||||||
|
WHERE,
|
||||||
} op;
|
} op;
|
||||||
|
|
||||||
NaryEltwiseLayerImpl(const LayerParams& params)
|
NaryEltwiseLayerImpl(const LayerParams& params)
|
||||||
@ -94,6 +95,8 @@ public:
|
|||||||
op = OPERATION::OR;
|
op = OPERATION::OR;
|
||||||
else if (operation == "xor")
|
else if (operation == "xor")
|
||||||
op = OPERATION::XOR;
|
op = OPERATION::XOR;
|
||||||
|
else if (operation == "where")
|
||||||
|
op = OPERATION::WHERE;
|
||||||
else
|
else
|
||||||
CV_Error(cv::Error::StsBadArg, "Unknown operation type \"" + operation + "\"");
|
CV_Error(cv::Error::StsBadArg, "Unknown operation type \"" + operation + "\"");
|
||||||
}
|
}
|
||||||
@ -499,6 +502,120 @@ public:
|
|||||||
f, scale, ninputs, max_ndims, shapes[0], inp, out, (const size_t **) steps, ptrs);
|
f, scale, ninputs, max_ndims, shapes[0], inp, out, (const size_t **) steps, ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Functor>
|
||||||
|
void trinary_forward(const Functor& f, const std::vector<Mat>& inputs, std::vector<Mat>& outputs)
|
||||||
|
{
|
||||||
|
const Mat& a = inputs[0];
|
||||||
|
const Mat& b = inputs[1];
|
||||||
|
const Mat& c = inputs[2];
|
||||||
|
Mat& out = outputs[0];
|
||||||
|
|
||||||
|
// collect info of inputs and output
|
||||||
|
const int* in_shape[] = {a.size.p, b.size.p, c.size.p};
|
||||||
|
const size_t* in_step[] = {a.step.p, b.step.p, c.step.p};
|
||||||
|
const int* out_shape = out.size.p;
|
||||||
|
const size_t* out_step = out.step.p;
|
||||||
|
const int in_ndims[] = {a.dims, b.dims, c.dims};
|
||||||
|
int out_ndims = out.dims;
|
||||||
|
|
||||||
|
int max_ndims = std::max(a.dims, std::max(b.dims, std::max(c.dims, out.dims)));
|
||||||
|
|
||||||
|
AutoBuffer<size_t> buf(4 * (2 * max_ndims + 6));
|
||||||
|
|
||||||
|
int** orig_shapes = (int**)(buf.data());
|
||||||
|
int** shapes = orig_shapes + 4;
|
||||||
|
size_t** orig_steps = (size_t**)(shapes + 4);
|
||||||
|
size_t** steps = orig_steps + 4;
|
||||||
|
|
||||||
|
int* shape_buf = (int*)(steps + 4);
|
||||||
|
size_t* step_buf = (size_t*)(shape_buf + 4 * max_ndims);
|
||||||
|
|
||||||
|
int* all_ndims = (int*)(step_buf + 4 * max_ndims);
|
||||||
|
size_t* all_type_sizes = (size_t*)(all_ndims + 4);
|
||||||
|
|
||||||
|
// assign orig_shapes, shapes, orig_steps, steps, all_ndims, all_type_sizes
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
orig_shapes[i] = (int*)(i == 0 ? out_shape : in_shape[i-1]);
|
||||||
|
orig_steps[i] = (size_t*)(i == 0 ? out_step : in_step[i-1]);
|
||||||
|
shapes[i] = shape_buf + i * max_ndims;
|
||||||
|
steps[i] = step_buf + i * max_ndims;
|
||||||
|
all_ndims[i] = i == 0 ? out_ndims : in_ndims[i-1];
|
||||||
|
all_type_sizes[i] = sizeof(T);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!prepare_for_broadcast_op(4, max_ndims, all_type_sizes,
|
||||||
|
all_ndims, (const int**)orig_shapes,
|
||||||
|
(const size_t**)orig_steps,
|
||||||
|
shapes, steps))
|
||||||
|
return;
|
||||||
|
|
||||||
|
trinary_forward_impl<T, Functor>(
|
||||||
|
max_ndims, shapes[0], a.ptr<char>(), steps[1], b.ptr<char>(), steps[2],
|
||||||
|
c.ptr<char>(), steps[3], out.ptr<char>(), steps[0],
|
||||||
|
f);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Functor>
|
||||||
|
void trinary_forward_impl(
|
||||||
|
int ndims, const int* shape,
|
||||||
|
const char* data1, const size_t* step1,
|
||||||
|
const char* data2, const size_t* step2,
|
||||||
|
const char* data3, const size_t* step3,
|
||||||
|
char* data, const size_t* step,
|
||||||
|
const Functor& op)
|
||||||
|
{
|
||||||
|
assert(ndims >= 2);
|
||||||
|
size_t dp1 = step1[ndims-1]/sizeof(T);
|
||||||
|
size_t dp2 = step2[ndims-1]/sizeof(T);
|
||||||
|
size_t dp3 = step3[ndims-1]/sizeof(T);
|
||||||
|
size_t dp = step[ndims-1]/sizeof(T);
|
||||||
|
int k, n1 = shape[ndims-1], n2 = shape[ndims-2];
|
||||||
|
size_t plane_idx, nplanes = 1;
|
||||||
|
for (k = 0; k < ndims-2; k++) nplanes *= shape[k];
|
||||||
|
|
||||||
|
for (plane_idx = 0; plane_idx < nplanes; plane_idx++)
|
||||||
|
{
|
||||||
|
const char* ptr1_ = data1;
|
||||||
|
const char* ptr2_ = data2;
|
||||||
|
const char* ptr3_ = data3;
|
||||||
|
char* ptr_ = data;
|
||||||
|
size_t idx = plane_idx;
|
||||||
|
for (k = ndims-3; k >= 0; k--)
|
||||||
|
{
|
||||||
|
size_t next_idx = idx/shape[k];
|
||||||
|
int i_k = (int)(idx - next_idx*shape[k]);
|
||||||
|
ptr1_ += i_k*step1[k];
|
||||||
|
ptr2_ += i_k*step2[k];
|
||||||
|
ptr3_ += i_k*step3[k];
|
||||||
|
ptr_ += i_k*step[k];
|
||||||
|
idx = next_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i2 = 0; i2 < n2; i2++, ptr1_ += step1[ndims-2],
|
||||||
|
ptr2_ += step2[ndims-2],
|
||||||
|
ptr3_ += step3[ndims-2],
|
||||||
|
ptr_ += step[ndims-2])
|
||||||
|
{
|
||||||
|
const T* ptr1 = (const T*)ptr1_;
|
||||||
|
const T* ptr2 = (const T*)ptr2_;
|
||||||
|
const T* ptr3 = (const T*)ptr3_;
|
||||||
|
T* ptr = (T*)ptr_;
|
||||||
|
|
||||||
|
if (dp1 == 1 && dp2 == 1 && dp3 == 1 && dp == 1)
|
||||||
|
{
|
||||||
|
for(int i1 = 0; i1 < n1; i1++)
|
||||||
|
ptr[i1] = op(ptr1[i1], ptr2[i1], ptr3[i1]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr3 += dp3, ptr += dp)
|
||||||
|
*ptr = op(*ptr1, *ptr2, *ptr3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||||
{
|
{
|
||||||
CV_TRACE_FUNCTION();
|
CV_TRACE_FUNCTION();
|
||||||
@ -637,6 +754,12 @@ public:
|
|||||||
binary_forward<T>(op_xor, std::forward<Args>(args)...);
|
binary_forward<T>(op_xor, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case OPERATION::WHERE:
|
||||||
|
{
|
||||||
|
auto op_where = [](const T &a, const T &b, const T &c) { return a ? b : c; };
|
||||||
|
trinary_forward<T>(op_where, std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
CV_Error(Error::StsBadArg, "Unsupported operation.");
|
CV_Error(Error::StsBadArg, "Unsupported operation.");
|
||||||
};
|
};
|
||||||
|
@ -1168,8 +1168,12 @@ Mat getMatFromTensor(const opencv_onnx::TensorProto& tensor_proto)
|
|||||||
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
|
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
|
||||||
{
|
{
|
||||||
const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
|
const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
|
||||||
CV_Assert(!field.empty());
|
char* val = nullptr;
|
||||||
char* val = (char *)field.data();
|
if (!field.empty())
|
||||||
|
val = (char *)field.data();
|
||||||
|
else
|
||||||
|
val = const_cast<char*>(tensor_proto.raw_data().c_str()); // sometime, the double will be stored at raw_data.
|
||||||
|
|
||||||
#if CV_STRONG_ALIGNMENT
|
#if CV_STRONG_ALIGNMENT
|
||||||
// Aligned pointer is required.
|
// Aligned pointer is required.
|
||||||
AutoBuffer<double, 16> aligned_val;
|
AutoBuffer<double, 16> aligned_val;
|
||||||
|
@ -4058,6 +4058,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
|||||||
dispatch["LessOrEqual"] = &ONNXImporter::parseElementWise;
|
dispatch["LessOrEqual"] = &ONNXImporter::parseElementWise;
|
||||||
|
|
||||||
dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise;
|
dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise;
|
||||||
|
dispatch["Where"] = &ONNXImporter::parseElementWise;
|
||||||
dispatch["Range"] = &ONNXImporter::parseRange;
|
dispatch["Range"] = &ONNXImporter::parseRange;
|
||||||
|
|
||||||
std::vector<std::string> simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos",
|
std::vector<std::string> simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos",
|
||||||
|
@ -2492,6 +2492,11 @@ TEST_P(Test_ONNX_layers, OpenAI_CLIP_head)
|
|||||||
testONNXModels("clip-vit-base-head");
|
testONNXModels("clip-vit-base-head");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, where_node)
|
||||||
|
{
|
||||||
|
testONNXModels("where_layer");
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
||||||
|
|
||||||
}} // namespace
|
}} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user