mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Add sync infer request
This commit is contained in:
parent
96844b0ca5
commit
15d2a5faf8
@ -828,15 +828,15 @@ std::vector<InferenceEngine::InferRequest> cv::gimpl::ie::IECompiled::createInfe
|
||||
|
||||
class cv::gimpl::ie::RequestPool {
|
||||
public:
|
||||
using RunF = std::function<void(InferenceEngine::InferRequest&)>;
|
||||
using CallbackF = std::function<void(InferenceEngine::InferRequest&, InferenceEngine::StatusCode)>;
|
||||
using SetInputDataF = std::function<void(InferenceEngine::InferRequest&)>;
|
||||
using ReadOutputDataF = std::function<void(InferenceEngine::InferRequest&, InferenceEngine::StatusCode)>;
|
||||
|
||||
// NB: The task is represented by:
|
||||
// RunF - function which is set blobs and run async inference.
|
||||
// CallbackF - function which is obtain output blobs and post it to output.
|
||||
// SetInputDataF - function which set input data.
|
||||
// ReadOutputDataF - function which read output data.
|
||||
struct Task {
|
||||
RunF run;
|
||||
CallbackF callback;
|
||||
SetInputDataF set_input_data;
|
||||
ReadOutputDataF read_output_data;
|
||||
};
|
||||
|
||||
explicit RequestPool(std::vector<InferenceEngine::InferRequest>&& requests);
|
||||
@ -850,11 +850,21 @@ private:
|
||||
IE::InferRequest request,
|
||||
IE::StatusCode code) noexcept;
|
||||
void setup();
|
||||
void releaseRequest(const int id);
|
||||
|
||||
QueueClass<size_t> m_idle_ids;
|
||||
std::vector<InferenceEngine::InferRequest> m_requests;
|
||||
bool m_use_sync_api = false;
|
||||
};
|
||||
|
||||
void cv::gimpl::ie::RequestPool::releaseRequest(const int id) {
|
||||
if (!m_use_sync_api) {
|
||||
auto& request = m_requests[id];
|
||||
request.SetCompletionCallback([](){});
|
||||
}
|
||||
m_idle_ids.push(id);
|
||||
}
|
||||
|
||||
// RequestPool implementation //////////////////////////////////////////////
|
||||
cv::gimpl::ie::RequestPool::RequestPool(std::vector<InferenceEngine::InferRequest>&& requests)
|
||||
: m_requests(std::move(requests)) {
|
||||
@ -867,25 +877,30 @@ void cv::gimpl::ie::RequestPool::setup() {
|
||||
}
|
||||
}
|
||||
|
||||
void cv::gimpl::ie::RequestPool::execute(cv::gimpl::ie::RequestPool::Task&& t) {
|
||||
void cv::gimpl::ie::RequestPool::execute(cv::gimpl::ie::RequestPool::Task&& task) {
|
||||
size_t id = 0u;
|
||||
m_idle_ids.pop(id);
|
||||
|
||||
auto& request = m_requests[id];
|
||||
|
||||
using namespace std::placeholders;
|
||||
using callback_t = std::function<void(IE::InferRequest, IE::StatusCode)>;
|
||||
request.SetCompletionCallback(
|
||||
static_cast<callback_t>(
|
||||
std::bind(&cv::gimpl::ie::RequestPool::callback, this,
|
||||
t, id, _1, _2)));
|
||||
// NB: InferRequest is already marked as busy
|
||||
// in case of exception need to return it back to the idle.
|
||||
try {
|
||||
t.run(request);
|
||||
task.set_input_data(request);
|
||||
if (m_use_sync_api) {
|
||||
request.Infer();
|
||||
task.read_output_data(request, IE::StatusCode::OK);
|
||||
releaseRequest(id);
|
||||
} else {
|
||||
using namespace std::placeholders;
|
||||
using callback_t = std::function<void(IE::InferRequest, IE::StatusCode)>;
|
||||
request.SetCompletionCallback(
|
||||
static_cast<callback_t>(
|
||||
std::bind(&cv::gimpl::ie::RequestPool::callback, this,
|
||||
task, id, _1, _2)));
|
||||
request.StartAsync();
|
||||
}
|
||||
} catch (...) {
|
||||
request.SetCompletionCallback([](){});
|
||||
m_idle_ids.push(id);
|
||||
// NB: InferRequest is already marked as busy
|
||||
// in case of exception need to return it back to the idle.
|
||||
releaseRequest(id);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@ -898,9 +913,8 @@ void cv::gimpl::ie::RequestPool::callback(cv::gimpl::ie::RequestPool::Task task,
|
||||
// 1. Run callback
|
||||
// 2. Destroy callback to free resources.
|
||||
// 3. Mark InferRequest as idle.
|
||||
task.callback(request, code);
|
||||
request.SetCompletionCallback([](){});
|
||||
m_idle_ids.push(id);
|
||||
task.read_output_data(request, code);
|
||||
releaseRequest(id);
|
||||
}
|
||||
|
||||
// NB: Not thread-safe.
|
||||
@ -1335,9 +1349,6 @@ struct Infer: public cv::detail::KernelTag {
|
||||
cv::util::optional<cv::Rect>{});
|
||||
setBlob(req, layer_name, this_blob, *ctx);
|
||||
}
|
||||
// FIXME: Should it be done by kernel ?
|
||||
// What about to do that in RequestPool ?
|
||||
req.StartAsync();
|
||||
},
|
||||
std::bind(PostOutputs, _1, _2, ctx)
|
||||
}
|
||||
@ -1455,9 +1466,6 @@ struct InferROI: public cv::detail::KernelTag {
|
||||
*(ctx->uu.params.input_names.begin()),
|
||||
this_blob, *ctx);
|
||||
}
|
||||
// FIXME: Should it be done by kernel ?
|
||||
// What about to do that in RequestPool ?
|
||||
req.StartAsync();
|
||||
},
|
||||
std::bind(PostOutputs, _1, _2, ctx)
|
||||
}
|
||||
@ -1575,7 +1583,6 @@ struct InferList: public cv::detail::KernelTag {
|
||||
cv::gimpl::ie::RequestPool::Task {
|
||||
[ctx, rc, this_blob](InferenceEngine::InferRequest &req) {
|
||||
setROIBlob(req, ctx->uu.params.input_names[0u], this_blob, rc, *ctx);
|
||||
req.StartAsync();
|
||||
},
|
||||
std::bind(callback, std::placeholders::_1, std::placeholders::_2, pos)
|
||||
}
|
||||
@ -1748,7 +1755,6 @@ struct InferList2: public cv::detail::KernelTag {
|
||||
"Only Rect and Mat types are supported for infer list 2!");
|
||||
}
|
||||
}
|
||||
req.StartAsync();
|
||||
},
|
||||
std::bind(callback, std::placeholders::_1, std::placeholders::_2, list_idx)
|
||||
} // task
|
||||
|
Loading…
Reference in New Issue
Block a user