mirror of
https://github.com/opencv/opencv.git
synced 2025-06-13 13:13:26 +08:00
Merge pull request #8293 from alalek:update_rng_in_parallel_for
This commit is contained in:
commit
c7049ca627
@ -2834,6 +2834,8 @@ public:
|
|||||||
double gaussian(double sigma);
|
double gaussian(double sigma);
|
||||||
|
|
||||||
uint64 state;
|
uint64 state;
|
||||||
|
|
||||||
|
bool operator ==(const RNG& other) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Mersenne Twister random number generator
|
/** @brief Mersenne Twister random number generator
|
||||||
|
@ -349,6 +349,8 @@ inline int RNG::uniform(int a, int b) { return a == b ? a : (int)(next(
|
|||||||
inline float RNG::uniform(float a, float b) { return ((float)*this)*(b - a) + a; }
|
inline float RNG::uniform(float a, float b) { return ((float)*this)*(b - a) + a; }
|
||||||
inline double RNG::uniform(double a, double b) { return ((double)*this)*(b - a) + a; }
|
inline double RNG::uniform(double a, double b) { return ((double)*this)*(b - a) + a; }
|
||||||
|
|
||||||
|
inline bool RNG::operator ==(const RNG& other) const { return state == other.state; }
|
||||||
|
|
||||||
inline unsigned RNG::next()
|
inline unsigned RNG::next()
|
||||||
{
|
{
|
||||||
state = (uint64)(unsigned)state* /*CV_RNG_COEFF*/ 4164903690U + (unsigned)(state >> 32);
|
state = (uint64)(unsigned)state* /*CV_RNG_COEFF*/ 4164903690U + (unsigned)(state >> 32);
|
||||||
|
@ -166,7 +166,8 @@ namespace
|
|||||||
class ParallelLoopBodyWrapper : public cv::ParallelLoopBody
|
class ParallelLoopBodyWrapper : public cv::ParallelLoopBody
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ParallelLoopBodyWrapper(const cv::ParallelLoopBody& _body, const cv::Range& _r, double _nstripes)
|
ParallelLoopBodyWrapper(const cv::ParallelLoopBody& _body, const cv::Range& _r, double _nstripes) :
|
||||||
|
is_rng_used(false)
|
||||||
{
|
{
|
||||||
|
|
||||||
body = &_body;
|
body = &_body;
|
||||||
@ -174,17 +175,30 @@ namespace
|
|||||||
double len = wholeRange.end - wholeRange.start;
|
double len = wholeRange.end - wholeRange.start;
|
||||||
nstripes = cvRound(_nstripes <= 0 ? len : MIN(MAX(_nstripes, 1.), len));
|
nstripes = cvRound(_nstripes <= 0 ? len : MIN(MAX(_nstripes, 1.), len));
|
||||||
|
|
||||||
|
// propagate main thread state
|
||||||
|
rng = cv::theRNG();
|
||||||
|
|
||||||
#ifdef ENABLE_INSTRUMENTATION
|
#ifdef ENABLE_INSTRUMENTATION
|
||||||
pThreadRoot = cv::instr::getInstrumentTLSStruct().pCurrentNode;
|
pThreadRoot = cv::instr::getInstrumentTLSStruct().pCurrentNode;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#ifdef ENABLE_INSTRUMENTATION
|
|
||||||
~ParallelLoopBodyWrapper()
|
~ParallelLoopBodyWrapper()
|
||||||
{
|
{
|
||||||
|
#ifdef ENABLE_INSTRUMENTATION
|
||||||
for(size_t i = 0; i < pThreadRoot->m_childs.size(); i++)
|
for(size_t i = 0; i < pThreadRoot->m_childs.size(); i++)
|
||||||
SyncNodes(pThreadRoot->m_childs[i]);
|
SyncNodes(pThreadRoot->m_childs[i]);
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
if (is_rng_used)
|
||||||
|
{
|
||||||
|
// Some parallel backends execute nested jobs in the main thread,
|
||||||
|
// so we need to restore initial RNG state here.
|
||||||
|
cv::theRNG() = rng;
|
||||||
|
// We can't properly update RNG state based on RNG usage in worker threads,
|
||||||
|
// so lets just change main thread RNG state to the next value.
|
||||||
|
// Note: this behaviour is not equal to single-threaded mode.
|
||||||
|
cv::theRNG().next();
|
||||||
|
}
|
||||||
|
}
|
||||||
void operator()(const cv::Range& sr) const
|
void operator()(const cv::Range& sr) const
|
||||||
{
|
{
|
||||||
#ifdef ENABLE_INSTRUMENTATION
|
#ifdef ENABLE_INSTRUMENTATION
|
||||||
@ -195,12 +209,18 @@ namespace
|
|||||||
#endif
|
#endif
|
||||||
CV_INSTRUMENT_REGION()
|
CV_INSTRUMENT_REGION()
|
||||||
|
|
||||||
|
// propagate main thread state
|
||||||
|
cv::theRNG() = rng;
|
||||||
|
|
||||||
cv::Range r;
|
cv::Range r;
|
||||||
r.start = (int)(wholeRange.start +
|
r.start = (int)(wholeRange.start +
|
||||||
((uint64)sr.start*(wholeRange.end - wholeRange.start) + nstripes/2)/nstripes);
|
((uint64)sr.start*(wholeRange.end - wholeRange.start) + nstripes/2)/nstripes);
|
||||||
r.end = sr.end >= nstripes ? wholeRange.end : (int)(wholeRange.start +
|
r.end = sr.end >= nstripes ? wholeRange.end : (int)(wholeRange.start +
|
||||||
((uint64)sr.end*(wholeRange.end - wholeRange.start) + nstripes/2)/nstripes);
|
((uint64)sr.end*(wholeRange.end - wholeRange.start) + nstripes/2)/nstripes);
|
||||||
(*body)(r);
|
(*body)(r);
|
||||||
|
|
||||||
|
if (!is_rng_used && !(cv::theRNG() == rng))
|
||||||
|
is_rng_used = true;
|
||||||
}
|
}
|
||||||
cv::Range stripeRange() const { return cv::Range(0, nstripes); }
|
cv::Range stripeRange() const { return cv::Range(0, nstripes); }
|
||||||
|
|
||||||
@ -208,6 +228,8 @@ namespace
|
|||||||
const cv::ParallelLoopBody* body;
|
const cv::ParallelLoopBody* body;
|
||||||
cv::Range wholeRange;
|
cv::Range wholeRange;
|
||||||
int nstripes;
|
int nstripes;
|
||||||
|
cv::RNG rng;
|
||||||
|
mutable bool is_rng_used;
|
||||||
#ifdef ENABLE_INSTRUMENTATION
|
#ifdef ENABLE_INSTRUMENTATION
|
||||||
cv::instr::InstrNode *pThreadRoot;
|
cv::instr::InstrNode *pThreadRoot;
|
||||||
#endif
|
#endif
|
||||||
|
@ -382,3 +382,39 @@ TEST(Core_Rand, Regression_Stack_Corruption)
|
|||||||
ASSERT_EQ(param1, -9);
|
ASSERT_EQ(param1, -9);
|
||||||
ASSERT_EQ(param2, 2);
|
ASSERT_EQ(param2, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class RandRowFillParallelLoopBody : public cv::ParallelLoopBody
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
RandRowFillParallelLoopBody(Mat& dst) : dst_(dst) {}
|
||||||
|
~RandRowFillParallelLoopBody() {}
|
||||||
|
void operator()(const cv::Range& r) const
|
||||||
|
{
|
||||||
|
cv::RNG rng = cv::theRNG(); // copy state
|
||||||
|
for (int y = r.start; y < r.end; y++)
|
||||||
|
{
|
||||||
|
cv::theRNG() = cv::RNG(rng.state + y); // seed is based on processed row
|
||||||
|
cv::randu(dst_.row(y), Scalar(-100), Scalar(100));
|
||||||
|
}
|
||||||
|
// theRNG() state is changed here (but state collision has low probability, so we don't check this)
|
||||||
|
}
|
||||||
|
protected:
|
||||||
|
Mat& dst_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(Core_Rand, parallel_for_stable_results)
|
||||||
|
{
|
||||||
|
cv::RNG rng = cv::theRNG(); // save rng state
|
||||||
|
Mat dst1(1000, 100, CV_8SC1);
|
||||||
|
parallel_for_(cv::Range(0, dst1.rows), RandRowFillParallelLoopBody(dst1));
|
||||||
|
|
||||||
|
cv::theRNG() = rng; // restore rng state
|
||||||
|
Mat dst2(1000, 100, CV_8SC1);
|
||||||
|
parallel_for_(cv::Range(0, dst2.rows), RandRowFillParallelLoopBody(dst2));
|
||||||
|
|
||||||
|
ASSERT_EQ(0, countNonZero(dst1 != dst2));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user