// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. // // Copyright (C) 2017, Intel Corporation, all rights reserved. // Third party copyrights are property of their respective owners. #include "precomp.hpp" #include "halide_scheduler.hpp" #include "op_halide.hpp" namespace cv { namespace dnn { #ifdef HAVE_HALIDE static void applySplit(const FileNode& directive, Halide::Func& func, const FileNode& params) { for (const auto& varNode : directive) { const std::string varName = varNode.name(); const std::string factorName = (std::string)varNode; Halide::Var var(varName); Halide::Var outerVar(varName + "o"); Halide::Var innerVar(varName + "i"); // If split factor is integer or parameters map has parameter value. CV_Assert(varNode.isString() && !params[factorName].empty() || varNode.isInt()); int factor = (int)(varNode.isInt() ? varNode : params[factorName]); func.split(var, outerVar, innerVar, factor); } } static void applyReorder(const FileNode& directive, Halide::Func& func) { std::string varName; const int numVars = directive.size(); std::vector reorderedVars; reorderedVars.reserve(numVars); for (int i = 0; i < numVars; ++i) { directive[i] >> varName; reorderedVars.push_back(Halide::Var(varName)); } func.reorder(reorderedVars); } static void applyFuse(const FileNode& directive, Halide::Func& func) { CV_Assert(directive["src"].size() >= 2); CV_Assert(directive["dst"].size() == 1); std::string str; directive["src"][0] >> str; Halide::Var firstVar(str); directive["src"][1] >> str; Halide::Var secondVar(str); directive["dst"] >> str; Halide::Var dstVar(str); func.fuse(firstVar, secondVar, dstVar); for (int i = 2, n = directive["src"].size(); i < n; ++i) { directive["src"][i] >> str; func.fuse(Halide::Var(str), dstVar, dstVar); } } static void applyParallel(const FileNode& directive, Halide::Func& func) { std::string varName; if (directive.isSeq()) { for (int i = 0, n = directive.size(); i < n; ++i) { directive[i] >> varName; func.parallel(Halide::Var(varName)); } } else { directive >> varName; func.parallel(Halide::Var(varName)); } } static void applyUnroll(const FileNode& directive, Halide::Func& func) { std::string varName; if (directive.isSeq()) { for (int i = 0, n = directive.size(); i < n; ++i) { directive[i] >> varName; func.unroll(Halide::Var(varName)); } } else { directive >> varName; func.unroll(Halide::Var(varName)); } } static void applyVectorize(const FileNode& directive, Halide::Func& func, const FileNode& params) { for (const auto& varNode : directive) { const std::string varName = varNode.name(); const std::string factorName = (std::string)varNode; // If split factor is integer or parameters map has parameter value. CV_Assert(varNode.isString() && !params[factorName].empty() || varNode.isInt()); int factor = (int)(varNode.isInt() ? varNode : params[factorName]); Halide::Var var(varName); Halide::Var inner(varName + "v"); func.split(var, var, inner, factor); func.vectorize(inner); } } static void applyStoreAt(const FileNode& directive, Halide::Func& func, std::map& funcsMap) { for (const auto& funcNode : directive) { const std::string targetFuncName = funcNode.name(); if (funcsMap.find(targetFuncName) == funcsMap.end()) CV_Error(cv::Error::StsParseError, "Function " + targetFuncName + " is not represented in Halide pipeline"); Halide::Func targetFunc = funcsMap[targetFuncName]; func.store_at(targetFunc, (std::string)funcNode); break; } } static void applyComputeAt(const FileNode& directive, Halide::Func& func, std::map& funcsMap) { for (const auto& funcNode : directive) { const std::string targetFuncName = funcNode.name(); if (funcsMap.find(targetFuncName) == funcsMap.end()) CV_Error(cv::Error::StsParseError, "Function " + targetFuncName + " is not represented in Halide pipeline"); Halide::Func targetFunc = funcsMap[targetFuncName]; func.compute_at(targetFunc, (std::string)funcNode); break; } } static void applyComputeRoot(const FileNode& directive, Halide::Func& func) { bool compute_root; directive >> compute_root; if (compute_root) func.compute_root(); } static void applyGpuBlocks(const FileNode& directive, Halide::Func& func) { std::string varName; for (int i = 0, n = directive.size(); i < n; ++i) { directive[i] >> varName; func.gpu_blocks(Halide::Var(varName)); } } static void applyGpuThreads(const FileNode& directive, Halide::Func& func) { std::string varName; for (int i = 0, n = directive.size(); i < n; ++i) { directive[i] >> varName; func.gpu_threads(Halide::Var(varName)); } } static void apply(const FileNode& directives, Halide::Func& func, std::map& funcsMap, const FileNode& params) { for (const auto& directive : directives) { if (directive.name() == "split") applySplit(directive, func, params); else if (directive.name() == "reorder") applyReorder(directive, func); else if (directive.name() == "fuse") applyFuse(directive, func); else if (directive.name() == "parallel") applyParallel(directive, func); else if (directive.name() == "unroll") applyUnroll(directive, func); else if (directive.name() == "vectorize") applyVectorize(directive, func, params); else if (directive.name() == "store_at") applyStoreAt(directive, func, funcsMap); else if (directive.name() == "compute_at") applyComputeAt(directive, func, funcsMap); else if (directive.name() == "compute_root") applyComputeRoot(directive, func); else if (directive.name() == "gpu_blocks") applyGpuBlocks(directive, func); else if (directive.name() == "gpu_threads") applyGpuThreads(directive, func); else CV_Error(Error::StsNotImplemented, "Scheduling directive " + directive.name() + " is not implemented."); } } // Remove any numeric symbols after '$' sign. static std::string Deunique(std::string str) { int pos = -1; do { pos = str.find('$'); if (pos != -1) { int len = str.find_first_not_of("0123456789", pos + 1) - pos; str = str.replace(pos, len, ""); } } while (pos != -1); return str; } #endif // HAVE_HALIDE HalideScheduler::HalideScheduler(const std::string& configFile) { if (!configFile.empty()) fs = FileStorage(configFile, FileStorage::READ); } HalideScheduler::~HalideScheduler() { if (fs.isOpened()) fs.release(); } bool HalideScheduler::process(Ptr& node) { #ifdef HAVE_HALIDE if (!fs.isOpened()) return false; const FileNode& scheduleNode = fs["scheduling"]; if (scheduleNode.empty()) CV_Error(cv::Error::StsParseError, "Scheduling file should has scheduling node"); std::string str; std::map funcsMap; // Scheduled functions. // For every function, from top to bottom, we try to find a scheduling node. // Scheduling is successful (return true) if for the first function (top) // node is represented. CV_Assert(!node.empty()); std::vector& funcs = node.dynamicCast()->funcs; for (int i = funcs.size() - 1; i >= 0; --i) { Halide::Func& func = funcs[i]; // For functions with the same name Halide generates unique names // for example func, func$1, func$2. // They are always formed with '$' and number. std::string funcName = Deunique(func.name()); const FileNode& funcNode = scheduleNode[funcName]; if (!funcNode.empty()) { if (!funcNode["pattern"].empty()) { funcNode["pattern"] >> str; if (fs["patterns"][str].empty()) CV_Error(cv::Error::StsParseError, "Scheduling pattern " + str + " is not defined"); apply(fs["patterns"][str], func, funcsMap, funcNode["params"]); } else { apply(funcNode, func, funcsMap, funcNode["params"]); } } else { if (funcsMap.empty()) return false; } funcsMap[funcName] = func; } return true; #endif // HAVE_HALIDE return false; } } // namespace dnn } // namespace cv