opencv/modules/gapi/src/compiler/passes/transformations.cpp
2021-12-24 18:04:11 +03:00

141 lines
5.5 KiB
C++

// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2019 Intel Corporation
#include "precomp.hpp"
#include <ade/util/zip_range.hpp>
#include <ade/graph.hpp>
#include "api/gcomputation_priv.hpp"
#include "compiler/gmodel.hpp"
#include "compiler/gmodelbuilder.hpp"
#include "compiler/passes/passes.hpp"
#include "compiler/passes/pattern_matching.hpp"
#include <sstream>
namespace cv { namespace gimpl { namespace passes {
namespace
{
using Graph = GModel::Graph;
using Metadata = typename Graph::CMetadataT;
// Checks pairs of {pattern node, substitute node} and asserts if there are any incompatibilities
void checkDataNodes(const Graph& pattern,
const Graph& substitute,
const std::vector<ade::NodeHandle>& patternNodes,
const std::vector<ade::NodeHandle>& substituteNodes)
{
for (auto it : ade::util::zip(patternNodes, substituteNodes)) {
auto pNodeMeta = pattern.metadata(std::get<0>(it));
auto sNodeMeta = substitute.metadata(std::get<1>(it));
GAPI_Assert(pNodeMeta.get<NodeType>().t == NodeType::DATA);
GAPI_Assert(pNodeMeta.get<NodeType>().t == sNodeMeta.get<NodeType>().t);
GAPI_Assert(pNodeMeta.get<Data>().shape == sNodeMeta.get<Data>().shape);
}
}
// Checks compatibility of pattern and substitute graphs based on in/out nodes
void checkCompatibility(const Graph& pattern,
const Graph& substitute,
const Protocol& patternP,
const Protocol& substituteP)
{
const auto& patternDataInputs = patternP.in_nhs;
const auto& patternDataOutputs = patternP.out_nhs;
const auto& substituteDataInputs = substituteP.in_nhs;
const auto& substituteDataOutputs = substituteP.out_nhs;
// number of data nodes must be the same
GAPI_Assert(patternDataInputs.size() == substituteDataInputs.size());
GAPI_Assert(patternDataOutputs.size() == substituteDataOutputs.size());
// for each pattern input node, verify a corresponding substitute input node
checkDataNodes(pattern, substitute, patternDataInputs, substituteDataInputs);
// for each pattern output node, verify a corresponding substitute output node
checkDataNodes(pattern, substitute, patternDataOutputs, substituteDataOutputs);
}
// Tries to substitute __single__ pattern with substitute in the given graph
bool tryToSubstitute(ade::Graph& main,
const std::unique_ptr<ade::Graph>& patternG,
const cv::GComputation& substitute)
{
GModel::Graph gm(main);
// 1. find a pattern in main graph
auto match1 = findMatches(*patternG, gm);
if (!match1.ok()) {
return false;
}
// 2. build substitute graph inside the main graph
cv::gimpl::GModelBuilder builder(main);
auto expr = cv::util::get<cv::GComputation::Priv::Expr>(substitute.priv().m_shape);
const auto& proto_slots = builder.put(expr.m_ins, expr.m_outs);
Protocol substituteP;
std::tie(substituteP.inputs, substituteP.outputs, substituteP.in_nhs, substituteP.out_nhs) =
proto_slots;
const Protocol& patternP = GModel::Graph(*patternG).metadata().get<Protocol>();
// 3. check that pattern and substitute are compatible
// FIXME: in theory, we should always have compatible pattern/substitute. if not, we're in
// half-completed state where some transformations are already applied - what can we do
// to handle the situation better? -- use transactional API as in fuse_islands pass?
checkCompatibility(*patternG, gm, patternP, substituteP);
// 4. make substitution
performSubstitution(gm, patternP, substituteP, match1);
return true;
}
} // anonymous namespace
void applyTransformations(ade::passes::PassContext& ctx,
const GKernelPackage& pkg,
const std::vector<std::unique_ptr<ade::Graph>>& patterns)
{
const auto& transforms = pkg.get_transformations();
const auto size = transforms.size();
if (0u == size) return;
// Note: patterns are already generated at this point
GAPI_Assert(patterns.size() == transforms.size());
// transform as long as it is possible
bool canTransform = true;
while (canTransform)
{
canTransform = false;
// iterate through every transformation and try to transform graph parts
for (auto it : ade::util::zip(ade::util::toRange(transforms), ade::util::toRange(patterns)))
{
const auto& t = std::get<0>(it);
auto& pattern = std::get<1>(it); // Note: using pre-created graphs
GAPI_Assert(nullptr != pattern);
// if transformation is successful (pattern found and substituted), it is possible that
// other transformations will also be successful, so set canTransform to the returned
// value from tryToSubstitute
canTransform = tryToSubstitute(ctx.graph, pattern, t.substitute());
// Note: apply the __same__ substitution as many times as possible and only after go to
// the next one. BUT it can happen that after applying some substitution, some
// _previous_ patterns will also be found and these will be applied first
if (canTransform) {
break;
}
}
}
}
} // namespace passes
} // namespace gimpl
} // namespace cv