mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #20818 from rogday:yolov4x_mish_cuda
This commit is contained in:
commit
4672dbda2a
@ -31,7 +31,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
size_type boxes_per_cell, size_type box_size,
|
size_type boxes_per_cell, size_type box_size,
|
||||||
size_type rows, size_type cols, T scale_x_y,
|
size_type rows, size_type cols, T scale_x_y,
|
||||||
size_type height_norm, size_type width_norm,
|
size_type height_norm, size_type width_norm,
|
||||||
T object_prob_cutoff)
|
T object_prob_cutoff, bool new_coords)
|
||||||
{
|
{
|
||||||
using vector2_type = get_vector_type_t<T, 2>;
|
using vector2_type = get_vector_type_t<T, 2>;
|
||||||
auto bias_vPtr = vector2_type::get_pointer(bias.data());
|
auto bias_vPtr = vector2_type::get_pointer(bias.data());
|
||||||
@ -47,22 +47,43 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
const auto y = (box_index % batch_inner_size) / row_inner_size;
|
const auto y = (box_index % batch_inner_size) / row_inner_size;
|
||||||
const auto x = (box_index % row_inner_size) / col_inner_size;
|
const auto x = (box_index % row_inner_size) / col_inner_size;
|
||||||
|
|
||||||
using device::fast_sigmoid;
|
/* When new_coords is true, we shouldn't use logistic activation again */
|
||||||
const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
|
T objectness_prob;
|
||||||
const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
|
if (new_coords)
|
||||||
output[box_offset + 0] = (T(x) + tmp_x) / T(cols);
|
{
|
||||||
output[box_offset + 1] = (T(y) + tmp_y) / T(rows);
|
const auto tmp_x = (input[box_offset + 0] - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
|
||||||
|
const auto tmp_y = (input[box_offset + 1] - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
|
||||||
|
|
||||||
vector2_type bias_xy;
|
output[box_offset + 0] = fast_divide_ftz(static_cast<T>(x) + tmp_x, static_cast<T>(cols));
|
||||||
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
|
output[box_offset + 1] = fast_divide_ftz(static_cast<T>(y) + tmp_y, static_cast<T>(rows));
|
||||||
|
|
||||||
using device::fast_exp;
|
vector2_type bias_xy;
|
||||||
output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm);
|
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
|
||||||
output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm);
|
|
||||||
|
|
||||||
/* squash objectness score into a probability */
|
output[box_offset + 2] = input[box_offset + 2] * input[box_offset + 2] *
|
||||||
using device::fast_sigmoid;
|
static_cast<T>(4) * bias_xy.data[0] / static_cast<T>(width_norm);
|
||||||
T objectness_prob = fast_sigmoid(input[box_offset + 4]);
|
output[box_offset + 3] = input[box_offset + 3] * input[box_offset + 3] *
|
||||||
|
static_cast<T>(4) * bias_xy.data[1] / static_cast<T>(height_norm);
|
||||||
|
|
||||||
|
objectness_prob = input[box_offset + 4];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto tmp_x = (fast_sigmoid(input[box_offset + 0]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
|
||||||
|
const auto tmp_y = (fast_sigmoid(input[box_offset + 1]) - static_cast<T>(0.5)) * scale_x_y + static_cast<T>(0.5);
|
||||||
|
|
||||||
|
output[box_offset + 0] = fast_divide_ftz(static_cast<T>(x) + tmp_x, static_cast<T>(cols));
|
||||||
|
output[box_offset + 1] = fast_divide_ftz(static_cast<T>(y) + tmp_y, static_cast<T>(rows));
|
||||||
|
|
||||||
|
vector2_type bias_xy;
|
||||||
|
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
|
||||||
|
|
||||||
|
output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / static_cast<T>(width_norm);
|
||||||
|
output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / static_cast<T>(height_norm);
|
||||||
|
|
||||||
|
/* squash objectness score into a probability */
|
||||||
|
objectness_prob = fast_sigmoid(input[box_offset + 4]);
|
||||||
|
}
|
||||||
|
|
||||||
/* ignore prediction if the objectness probability is less than the cutoff */
|
/* ignore prediction if the objectness probability is less than the cutoff */
|
||||||
if (objectness_prob < object_prob_cutoff)
|
if (objectness_prob < object_prob_cutoff)
|
||||||
@ -73,7 +94,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void region_sigmoid_class_score(Span<T> output, View<T> input, T class_prob_cutoff, size_type box_size)
|
__global__ void region_sigmoid_class_score(Span<T> output, View<T> input, T class_prob_cutoff,
|
||||||
|
size_type box_size, bool new_coords)
|
||||||
{
|
{
|
||||||
for (auto idx : grid_stride_range(output.size())) {
|
for (auto idx : grid_stride_range(output.size())) {
|
||||||
const index_type box_no = idx / box_size;
|
const index_type box_no = idx / box_size;
|
||||||
@ -92,9 +114,20 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
*
|
*
|
||||||
* to obtain the actual class probability, we multiply the conditional probability
|
* to obtain the actual class probability, we multiply the conditional probability
|
||||||
* with the object probability
|
* with the object probability
|
||||||
|
*
|
||||||
|
* when new_coords is true, we shouldn't use logistic activation again.
|
||||||
*/
|
*/
|
||||||
using device::fast_sigmoid;
|
|
||||||
auto actual_class_prob = objectness_prob * fast_sigmoid(input[idx]);
|
T actual_class_prob;
|
||||||
|
if (new_coords)
|
||||||
|
{
|
||||||
|
actual_class_prob = objectness_prob * input[idx];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
actual_class_prob = objectness_prob * fast_sigmoid(input[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
if (actual_class_prob <= class_prob_cutoff)
|
if (actual_class_prob <= class_prob_cutoff)
|
||||||
actual_class_prob = T(0);
|
actual_class_prob = T(0);
|
||||||
output[idx] = actual_class_prob;
|
output[idx] = actual_class_prob;
|
||||||
@ -147,7 +180,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
std::size_t boxes_per_cell, std::size_t box_size,
|
std::size_t boxes_per_cell, std::size_t box_size,
|
||||||
std::size_t rows, std::size_t cols, T scale_x_y,
|
std::size_t rows, std::size_t cols, T scale_x_y,
|
||||||
std::size_t height_norm, std::size_t width_norm,
|
std::size_t height_norm, std::size_t width_norm,
|
||||||
bool if_true_sigmoid_else_softmax /* true = sigmoid, false = softmax */)
|
bool if_true_sigmoid_else_softmax, /* true = sigmoid, false = softmax */
|
||||||
|
bool new_coords)
|
||||||
{
|
{
|
||||||
CV_Assert(output.size() == input.size());
|
CV_Assert(output.size() == input.size());
|
||||||
CV_Assert(output.size() % box_size == 0);
|
CV_Assert(output.size() % box_size == 0);
|
||||||
@ -158,12 +192,12 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
launch_kernel(box_kernel, box_policy,
|
launch_kernel(box_kernel, box_policy,
|
||||||
output, input, bias, boxes_per_cell, box_size,
|
output, input, bias, boxes_per_cell, box_size,
|
||||||
rows, cols, scale_x_y, height_norm, width_norm,
|
rows, cols, scale_x_y, height_norm, width_norm,
|
||||||
object_prob_cutoff);
|
object_prob_cutoff, new_coords);
|
||||||
|
|
||||||
if (if_true_sigmoid_else_softmax) {
|
if (if_true_sigmoid_else_softmax) {
|
||||||
auto kernel_score = raw::region_sigmoid_class_score<T>;
|
auto kernel_score = raw::region_sigmoid_class_score<T>;
|
||||||
auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
|
auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
|
||||||
launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size);
|
launch_kernel(kernel_score, policy_score, output, input, class_prob_cutoff, box_size, new_coords);
|
||||||
} else {
|
} else {
|
||||||
auto kernel_score = raw::region_softmax_class_score<T>;
|
auto kernel_score = raw::region_softmax_class_score<T>;
|
||||||
auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
|
auto policy_score = make_policy(kernel_score, output.size(), 0, stream);
|
||||||
@ -173,10 +207,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
|
|
||||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||||
template void region(const Stream&, Span<__half>, View<__half>, View<__half>,
|
template void region(const Stream&, Span<__half>, View<__half>, View<__half>,
|
||||||
__half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool);
|
__half, __half, std::size_t, std::size_t, std::size_t, std::size_t, __half, std::size_t, std::size_t, bool, bool);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template void region(const Stream&, Span<float>, View<float>, View<float>,
|
template void region(const Stream&, Span<float>, View<float>, View<float>,
|
||||||
float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool);
|
float, float, std::size_t, std::size_t, std::size_t, std::size_t, float, std::size_t, std::size_t, bool, bool);
|
||||||
|
|
||||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||||
|
@ -18,7 +18,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
std::size_t boxes_per_cell, std::size_t box_size,
|
std::size_t boxes_per_cell, std::size_t box_size,
|
||||||
std::size_t rows, std::size_t cols, T scale_x_y,
|
std::size_t rows, std::size_t cols, T scale_x_y,
|
||||||
std::size_t height_norm, std::size_t width_norm,
|
std::size_t height_norm, std::size_t width_norm,
|
||||||
bool if_true_sigmoid_else_softmax);
|
bool if_true_sigmoid_else_softmax, bool new_coords);
|
||||||
|
|
||||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||||
|
|
||||||
|
@ -60,6 +60,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
T class_prob_cutoff;
|
T class_prob_cutoff;
|
||||||
|
|
||||||
T nms_iou_threshold;
|
T nms_iou_threshold;
|
||||||
|
bool new_coords;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
@ -87,6 +88,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
class_prob_cutoff = config.class_prob_cutoff;
|
class_prob_cutoff = config.class_prob_cutoff;
|
||||||
|
|
||||||
nms_iou_threshold = config.nms_iou_threshold;
|
nms_iou_threshold = config.nms_iou_threshold;
|
||||||
|
new_coords = config.new_coords;
|
||||||
}
|
}
|
||||||
|
|
||||||
void forward(
|
void forward(
|
||||||
@ -115,7 +117,8 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
boxes_per_cell, cell_box_size,
|
boxes_per_cell, cell_box_size,
|
||||||
rows, cols, scale_x_y,
|
rows, cols, scale_x_y,
|
||||||
height_norm, width_norm,
|
height_norm, width_norm,
|
||||||
if_true_sigmoid_else_softmax
|
if_true_sigmoid_else_softmax,
|
||||||
|
new_coords
|
||||||
);
|
);
|
||||||
|
|
||||||
if (nms_iou_threshold > 0) {
|
if (nms_iou_threshold > 0) {
|
||||||
@ -176,6 +179,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
|||||||
T object_prob_cutoff, class_prob_cutoff;
|
T object_prob_cutoff, class_prob_cutoff;
|
||||||
|
|
||||||
T nms_iou_threshold;
|
T nms_iou_threshold;
|
||||||
|
bool new_coords;
|
||||||
};
|
};
|
||||||
|
|
||||||
}}} /* namespace cv::dnn::cuda4dnn */
|
}}} /* namespace cv::dnn::cuda4dnn */
|
||||||
|
@ -125,7 +125,7 @@ public:
|
|||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_CUDA
|
#ifdef HAVE_CUDA
|
||||||
if (backendId == DNN_BACKEND_CUDA)
|
if (backendId == DNN_BACKEND_CUDA)
|
||||||
return new_coords == 0;
|
return true;
|
||||||
#endif
|
#endif
|
||||||
return backendId == DNN_BACKEND_OPENCV;
|
return backendId == DNN_BACKEND_OPENCV;
|
||||||
}
|
}
|
||||||
@ -437,11 +437,12 @@ public:
|
|||||||
|
|
||||||
config.scale_x_y = scale_x_y;
|
config.scale_x_y = scale_x_y;
|
||||||
|
|
||||||
config.object_prob_cutoff = (classfix == -1) ? 0.5 : 0.0;
|
config.object_prob_cutoff = (classfix == -1) ? thresh : 0.f;
|
||||||
config.class_prob_cutoff = thresh;
|
config.class_prob_cutoff = thresh;
|
||||||
|
|
||||||
config.nms_iou_threshold = nmsThreshold;
|
config.nms_iou_threshold = nmsThreshold;
|
||||||
|
|
||||||
|
config.new_coords = (new_coords == 1);
|
||||||
return make_cuda_node<cuda4dnn::RegionOp>(preferableTarget, std::move(context->stream), blobs[0], config);
|
return make_cuda_node<cuda4dnn::RegionOp>(preferableTarget, std::move(context->stream), blobs[0], config);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -745,8 +745,14 @@ TEST_P(Test_Darknet_nets, YOLOv4x_mish)
|
|||||||
};
|
};
|
||||||
Mat ref(N0 + N1, 7, CV_32FC1, (void*)ref_);
|
Mat ref(N0 + N1, 7, CV_32FC1, (void*)ref_);
|
||||||
|
|
||||||
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.006 : 8e-5;
|
double scoreDiff = 8e-5;
|
||||||
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.042 : 3e-4;
|
double iouDiff = 3e-4;
|
||||||
|
|
||||||
|
if (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD || target == DNN_TARGET_CUDA_FP16)
|
||||||
|
{
|
||||||
|
scoreDiff = 0.006;
|
||||||
|
iouDiff = 0.042;
|
||||||
|
}
|
||||||
|
|
||||||
std::string config_file = "yolov4x-mish.cfg";
|
std::string config_file = "yolov4x-mish.cfg";
|
||||||
std::string weights_file = "yolov4x-mish.weights";
|
std::string weights_file = "yolov4x-mish.weights";
|
||||||
|
Loading…
Reference in New Issue
Block a user