mirror of https://github.com/opencv/opencv.git
Merge pull request #26056 from vpisarev:new_dnn_engine
New dnn engine #26056 This is the 1st PR with the new engine; CI is green and PR is ready to be merged, I think. Merge together with https://github.com/opencv/opencv_contrib/pull/3794 --- **Known limitations:** * [solved] OpenVINO is temporarily disabled, but is probably easy to restore (it's not a deal breaker to merge this PR, I guess) * The new engine does not support any backends nor any targets except for the default CPU implementation. But it's possible to choose the old engine when loading a model, then all the functionality is available. * [Caffe patch is here: #26208] The new engine only supports ONNX. When a model is constructed manually or is loaded from a file of different format (.tf, .tflite, .caffe, .darknet), the old engine is used. * Even in the case of ONNX some layers are not supported by the new engine, such as all quantized layers (including DequantizeLinear, QuantizeLinear, QLinearConv etc.), LSTM, GRU, .... It's planned, of course, to have full support for ONNX by OpenCV 5.0 gold release. When a loaded model contains unsupported layers, we switch to the old engine automatically (at ONNX parsing time, not at `forward()` time). * Some layers , e.g. Expat, are only partially supported by the new engine. In the case of unsupported flavours it switches to the old engine automatically (at ONNX parsing time, not at `forward()` time). * 'Concat' graph optimization is disabled. The optimization eliminates Concat layer and instead makes the layers that generate tensors to be concatenated to write the outputs to the final destination. Of course, it's only possible when `axis=0` or `axis=N=1`. The optimization is not compatible with dynamic shapes since we need to know in advance where to store the tensors. Because some of the layer implementations have been modified to become more compatible with the new engine, the feature appears to be broken even when the old engine is used. * Some `dnn::Net` API is not available with the new engine. Also, shape inference may return false if some of the output or intermediate tensors' shapes cannot be inferred without running the model. Probably this can be fixed by a dummy run of the model with zero inputs. * Some overloads of `dnn::Net::getFLOPs()` and `dnn::Net::getMemoryConsumption()` are not exposed any longer in wrapper generators; but the most useful overloads are exposed (and checked by Java tests). * [in progress] A few Einsum tests related to empty shapes have been disabled due to crashes in the tests and in Einsum implementations. The code and the tests need to be repaired. * OpenCL implementation of Deconvolution is disabled. It's very bad and very slow anyway; need to be completely revised. * Deconvolution3D test is now skipped, because it was only supported by CUDA and OpenVINO backends, both of which are not supported by the new engine. * Some tests, such as FastNeuralStyle, checked that the in the case of CUDA backend there is no fallback to CPU. Currently all layers in the new engine are processed on CPU, so there are many fallbacks. The checks, therefore, have been temporarily disabled. --- - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMakepull/26324/head
parent
12738deaef
commit
3cd57ea09e
112 changed files with 11197 additions and 554 deletions
@ -0,0 +1,336 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "precomp.hpp" |
||||
#include "net_impl.hpp" |
||||
|
||||
namespace cv { namespace dnn { |
||||
CV__DNN_INLINE_NS_BEGIN |
||||
|
||||
using std::vector; |
||||
using std::string; |
||||
|
||||
/* Assigns buffers for all intermediate tensors of the graph/model
|
||||
|
||||
The algorithm is quite simple, but there are some nuances in the attempt to re-use memory more efficiently: |
||||
|
||||
All layer arguments in graph and sub-graphs are classified into 4 categories: |
||||
a) inputs, b) outputs, c) constants and d) temporary values/tensors. |
||||
|
||||
Except for the temporary values ("d" category), each other argument gets |
||||
its own dedicated storage, which makes things more clear and predictable. |
||||
So, this algorithm assigns buffers only for the temporary values. |
||||
|
||||
During the inference process, each temporary value is computed |
||||
by one of the layers and then used by zero or more subsequent layers (only as input). |
||||
An example of a model where some tensors are used more than once is Resnet. |
||||
After a tensor is used for the last time and |
||||
won't be used in any subsequent layer, the memory buffer for that tensor could be re-used for |
||||
other arguments. We want to assign each temporary tensor to some temporary buffer, |
||||
and it's typically N:1 mapping. |
||||
|
||||
We do it using 2-stage algorithm: |
||||
|
||||
1. First, we calculate, how many times each argument is used and store the counters into 'usecounts'. |
||||
2. Second, we scan the layers in topologically sorted order |
||||
2.0. Sanity check: We check that each input argument of the operation is either input or constant, |
||||
or it's a temporary tensor with the buffer assigned to it. |
||||
If not, then the layers are not sorted in a topological order. |
||||
2.1. For in-place reshape operations, such as squeeze/unsqueeze/flatten etc. |
||||
or for unary element-wise operations, |
||||
we check whether the input is a temporary value and is not used in any subsequent operations. |
||||
If these checks all pass, we assign output argument to the same buffer as input. Note that |
||||
we don't try to reuse inputs of binary/ternary etc. operation because of the broadcasting. |
||||
We need to do symbolic shape inference to proof that the output is of the same shape as one of the inputs. |
||||
2.2. Otherwise, for each output argument of operation, which is not a network output argument. |
||||
we assign the most recently-used free buffer (i.e. the top buffer in the stack of free buffers). |
||||
If there is no free buffers, i.e. the stack is empty, we create a new buffer, and use it. |
||||
2.3. For each input we decrement the corresponding element of 'usecounts'. If the counter reaches 0 and the input |
||||
is not aliased with one of the outputs (see 2.1), |
||||
we push the corresponding buffer index into the stack of free buffers. |
||||
2.4. In the case of in-place operations and sometimes when using subgraphs (e.g. in If, Loop operations) we may |
||||
re-use the same buffer for several arguments |
||||
(which can be ouputs for some operations and inputs for some subsequent operations). |
||||
In order to handle it all properly, during the buffer assignment algorithm we maintain use counter for each |
||||
buffer, which should not be confused with use counters for arguments. A pool of free buffers contains zero or |
||||
more "spare" buffers with 0 use counts. A buffer in use has the corresponding usage count > 0. |
||||
When some argument is not needed anymore, and if it's not a constant, it decrements the usage counter of the buffer |
||||
where it resides. When the counter reaches zero, we return the buffer into the pool of free buffers and then |
||||
we can reuse the same buffer for another argument (or probably different shape and/or type, see below). |
||||
In principle, we could 'protect' some buffers from the premature release and re-use by incrementing the use counts |
||||
of the respective arguments that reside in those buffers, but that would make the bookkeeping much more complex. |
||||
|
||||
Please, note that when we reuse buffers, we don't check any types, shape or a total size of the buffer needed. |
||||
We reallocate each buffer at runtime to fit each single argument that it's used for. For example, let's say the buffer #3 |
||||
is used for arguments #5 (10x10x10 FP32), #10 (6x6x32 FP32) and #14 (300x1 UINT64). Then during the the first run of |
||||
the inference the buffer #3 will be reallocated from 0 bytes to 1000*4 bytes to fit arg #10, |
||||
then from 4000 to 6*6*32*4=4608 bytes to fit arg #10 and then it will fit arg #14 without reallocations. |
||||
During the second run of inference with the same resolution input the buffer will not be reallocated. |
||||
|
||||
The reallocation is done using Buffer.fit() function. |
||||
*/ |
||||
|
||||
struct BufferAllocator |
||||
{ |
||||
Net::Impl* netimpl; |
||||
vector<int> usecounts; |
||||
vector<int> freebufs; |
||||
vector<int> buf_usecounts; |
||||
vector<int> bufidxs; |
||||
int nbufs = 0; |
||||
|
||||
BufferAllocator(Net::Impl* netimpl_) : netimpl(netimpl_) {} |
||||
|
||||
/*
|
||||
Here are 3 workhorse methods that abstract the use and bookkeeping of buffers: |
||||
1. getFreeBuffer() takes the first spare buffer from the pool of free buffers. Since |
||||
we don't necessarily know the shape/type of tensor type at this stage, this is quite |
||||
reasonable behaviour - we cannot do anything more complex that that. On the positive side, |
||||
since the pool of free buffers operates like a stack, the first free buffer is the most |
||||
recently released buffer, so we improve cache locality using this pattern. |
||||
When we don't have spare buffers in the pool, we "virtually" create a new buffer |
||||
(by incrementing the number of buffers used) and return it. |
||||
|
||||
For the retrieved buffer we set its use count to 1. |
||||
2. releaseBuffer(bufidx) decrements the buffer use count and returns it to the pool |
||||
of free buffers as long as the use counter reaches 0. |
||||
3. shareBuffer(from_arg, to_arg) takes two argument indices. |
||||
It makes argument 'to_arg' use the same buffer as 'from_arg'. |
||||
Use counter for the assigned to 'to_arg' buffer (if any) is decremented. |
||||
Use counter for the 'from_arg' buffer is incremented, correpondingly. |
||||
*/ |
||||
|
||||
int getFreeBuffer() |
||||
{ |
||||
if (freebufs.empty()) { |
||||
freebufs.push_back(nbufs); |
||||
buf_usecounts.push_back(0); |
||||
//printf("added buf %d\n", nbufs);
|
||||
nbufs++; |
||||
} |
||||
int outidx = freebufs.back(); |
||||
freebufs.pop_back(); |
||||
buf_usecounts[outidx] = 1; |
||||
return outidx; |
||||
} |
||||
|
||||
void releaseBuffer(int bufidx) |
||||
{ |
||||
if (bufidx >= 0) { |
||||
CV_Assert(buf_usecounts[bufidx] > 0); |
||||
if (--buf_usecounts[bufidx] == 0) |
||||
freebufs.push_back(bufidx); |
||||
} |
||||
} |
||||
|
||||
void shareBuffer(Arg fromArg, Arg toArg) |
||||
{ |
||||
CV_Assert(!netimpl->isConstArg(fromArg) && !netimpl->isConstArg(toArg)); |
||||
int fromBuf = bufidxs[fromArg.idx], toBuf = bufidxs[toArg.idx]; |
||||
CV_Assert(fromBuf >= 0); |
||||
bufidxs[toArg.idx] = fromBuf; |
||||
buf_usecounts[fromBuf]++; |
||||
if (toBuf >= 0) |
||||
releaseBuffer(toBuf); |
||||
} |
||||
|
||||
void assign() |
||||
{ |
||||
netimpl->useCounts(usecounts); |
||||
size_t nargs = usecounts.size(); |
||||
bufidxs.assign(nargs, -1); |
||||
nbufs = 0; |
||||
assign(netimpl->mainGraph); |
||||
netimpl->bufidxs = bufidxs; |
||||
netimpl->buffers.resize(nbufs); |
||||
for (int i = 0; i < nbufs; i++) |
||||
netimpl->buffers[i] = Mat(); |
||||
} |
||||
|
||||
void assign(const Ptr<Graph>& graph) |
||||
{ |
||||
if (!graph) |
||||
return; |
||||
const std::vector<Ptr<Layer> >& prog = graph->prog(); |
||||
for (const auto& layer: prog) { |
||||
bool inplace = false; |
||||
Arg reuseArg; |
||||
|
||||
if (!layer) continue; |
||||
|
||||
const std::vector<Arg>& inputs = layer->inputs; |
||||
const std::vector<Arg>& outputs = layer->outputs; |
||||
size_t ninputs = inputs.size(); |
||||
size_t noutputs = outputs.size(); |
||||
|
||||
/*
|
||||
Determine if we can possibly re-use some of the input buffers for the output as well, |
||||
in other words, whether we can run the operation in-place. |
||||
Not only it saves memory, but it can also: |
||||
1. improve L2/L3 cache re-use |
||||
2. effectively convert some copy/re-shape operations |
||||
(Identity, Flatten, Reshape, Squeeze, Unsqueeze) |
||||
into Nop (no-operation). |
||||
*/ |
||||
//const ElemwiseOp* elemwise_op = dynamic_cast<const ElemwiseOp*>(op);
|
||||
|
||||
if (/*dynamic_cast<const BatchNormOp*>(op) != 0 ||
|
||||
dynamic_cast<const FlattenOp*>(op) != 0 || |
||||
(elemwise_op != 0 && elemwise_op->getActivation(CV_32F) != 0) || |
||||
dynamic_cast<const ReshapeOp*>(op) != 0 || |
||||
dynamic_cast<const SqueezeOp*>(op) != 0 || |
||||
dynamic_cast<const UnsqueezeOp*>(op) != 0*/ |
||||
layer->alwaysSupportInplace()) { |
||||
CV_Assert(ninputs >= 1); |
||||
Arg inp0 = inputs[0]; |
||||
inplace = netimpl->argKind(inp0) == DNN_ARG_TEMP && usecounts[inp0.idx] == 1; |
||||
reuseArg = inp0; |
||||
} |
||||
|
||||
/*
|
||||
Unless the operation is in-place, assign buffers for each output. |
||||
We do it before we recursively process subgraphs inside If/Loop/Scan. |
||||
this way we avoid any possible influence of buffer allocation inside a subgraph |
||||
to the parent graphs. |
||||
*/ |
||||
//if (layer->type == "Softmax")
|
||||
// putchar('.');
|
||||
if (noutputs > 0) { |
||||
Arg out0 = outputs[0]; |
||||
if (inplace && |
||||
noutputs == 1 && |
||||
netimpl->argKind(out0) == DNN_ARG_TEMP && |
||||
bufidxs.at(out0.idx) < 0) |
||||
shareBuffer(reuseArg, out0); |
||||
else { |
||||
for (auto out: outputs) { |
||||
if (netimpl->argKind(out) == DNN_ARG_TEMP && |
||||
bufidxs.at(out.idx) < 0) { |
||||
bufidxs.at(out.idx) = getFreeBuffer(); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
std::string opname = layer->type; |
||||
|
||||
if (opname == "If") { |
||||
/*
|
||||
Pre-allocate buffers for the output nodes of then- and else- branches. |
||||
We try to alias them with the corresponding t_out[i] elements, so |
||||
that we save one copy operation. |
||||
[TODO] |
||||
It's not the most optimal buffer allocation. |
||||
In the ideal case, e.g. when both then- and else- branches |
||||
are just sequences of element-wise operations that can be executed in-place, |
||||
we could simply use a single buffer for both then- and else- branches. |
||||
Here we will use separate buffers, but let's assume we could |
||||
optimize out such trivial branches at the graph fusion level |
||||
(especially when we have JIT). |
||||
*/ |
||||
auto branches = layer->subgraphs(); |
||||
CV_Assert(branches->size() == 2); |
||||
|
||||
const Ptr<Graph>& thenBranch = branches->at(0); |
||||
const Ptr<Graph>& elseBranch = branches->at(1); |
||||
const vector<Arg>& thenOutargs = thenBranch->outputs(); |
||||
const vector<Arg>& elseOutargs = elseBranch->outputs(); |
||||
CV_Assert(thenOutargs.size() == noutputs && elseOutargs.size() == noutputs); |
||||
for (size_t i = 0; i < noutputs; i++) { |
||||
Arg outarg = outputs[i]; |
||||
Arg thenOutarg = thenOutargs[i]; |
||||
Arg elseOutarg = elseOutargs[i]; |
||||
|
||||
if (!netimpl->isConstArg(thenOutarg) && usecounts[thenOutarg.idx] == 1) |
||||
shareBuffer(outarg, thenOutarg); |
||||
if (!netimpl->isConstArg(elseOutarg) && usecounts[elseOutarg.idx] == 1) |
||||
shareBuffer(outarg, elseOutarg); |
||||
} |
||||
|
||||
assign(thenBranch); |
||||
assign(elseBranch); |
||||
|
||||
for (size_t i = 0; i < noutputs; i++) { |
||||
Arg thenOutarg = thenOutargs[i]; |
||||
Arg elseOutarg = elseOutargs[i]; |
||||
releaseBuffer(bufidxs[thenOutarg.idx]); |
||||
releaseBuffer(bufidxs[elseOutarg.idx]); |
||||
} |
||||
} else if (opname == "Loop") { |
||||
/*
|
||||
In the case of loop we try to alias t_v_in[i] and t_v_out[i] so that |
||||
we eliminate some copy operations after each loop iteration. |
||||
*/ |
||||
//LoopLayer* loop = dynamic_cast<LoopLayer*>(op);
|
||||
CV_Assert(ninputs >= 2); |
||||
auto subgraphs = layer->subgraphs(); |
||||
CV_Assert(subgraphs && subgraphs->size() == 1); |
||||
const Ptr<Graph>& body = subgraphs->at(0); |
||||
Arg trip_count = inputs[0]; |
||||
const std::vector<Arg>& body_inputs = body->inputs(); |
||||
const std::vector<Arg>& body_outputs = body->outputs(); |
||||
size_t body_ninputs = body_inputs.size(); |
||||
size_t body_noutputs = body_outputs.size(); |
||||
int n_state_vars = (int)(ninputs - 2); |
||||
int n_accums = (int)(body_noutputs - n_state_vars - 1); |
||||
CV_Assert(body_ninputs == ninputs); |
||||
CV_Assert(body_noutputs == noutputs+1); |
||||
CV_Assert(n_state_vars >= 0 && n_accums >= 0); |
||||
Arg inp0 = inputs[0]; |
||||
if (inp0.idx > 0 && usecounts[inp0.idx] > 0) { |
||||
CV_Assert(!netimpl->isConstArg(inp0)); |
||||
if (!netimpl->isConstArg(trip_count)) |
||||
shareBuffer(trip_count, inputs[0]); |
||||
else |
||||
bufidxs.at(inputs[0].idx) = getFreeBuffer(); |
||||
} |
||||
|
||||
for (int i = -1; i < n_state_vars; i++) { |
||||
Arg inparg = body_inputs[i+2]; |
||||
Arg outarg = body_outputs[i+1]; |
||||
Arg v_inp = inputs[i+2]; |
||||
Arg v_out = i >= 0 ? outputs[i] : Arg(); |
||||
if (inparg.idx > 0 && usecounts[inparg.idx] > 0) { |
||||
CV_Assert(!netimpl->isConstArg(inparg)); |
||||
if (!netimpl->isConstArg(v_inp)) |
||||
shareBuffer(v_inp, inparg); |
||||
else |
||||
bufidxs[inparg.idx] = getFreeBuffer(); |
||||
} |
||||
if (!netimpl->isConstArg(v_out)) { |
||||
if (!netimpl->isConstArg(outarg) && usecounts[outarg.idx] == 1) |
||||
shareBuffer(v_out, outarg); |
||||
} |
||||
} |
||||
|
||||
assign(body); |
||||
for (auto body_out: body_outputs) |
||||
releaseBuffer(bufidxs.at(body_out.idx)); |
||||
} |
||||
|
||||
for (auto out: outputs) { |
||||
if (usecounts[out.idx] == 0) |
||||
releaseBuffer(bufidxs.at(out.idx)); |
||||
} |
||||
// let's release inputs in the reverse order to keep the buffer allocation consistent across the network
|
||||
for (size_t i = 0; i < ninputs; i++) { |
||||
Arg inp = inputs[ninputs-i-1]; |
||||
int bufidx = bufidxs[inp.idx]; |
||||
if (bufidx >= 0) { |
||||
if (--usecounts.at(inp.idx) == 0) |
||||
releaseBuffer(bufidx); |
||||
} |
||||
} |
||||
} |
||||
} |
||||
}; |
||||
|
||||
void Net::Impl::assignBuffers() |
||||
{ |
||||
BufferAllocator buf_allocator(this); |
||||
buf_allocator.assign(); |
||||
} |
||||
|
||||
CV__DNN_INLINE_NS_END |
||||
}} |
@ -0,0 +1,139 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "precomp.hpp" |
||||
#include "net_impl.hpp" |
||||
|
||||
namespace cv { namespace dnn { |
||||
CV__DNN_INLINE_NS_BEGIN |
||||
|
||||
using std::vector; |
||||
using std::string; |
||||
|
||||
typedef std::pair<int, int> int_pair; |
||||
typedef std::pair<int, Arg> int_arg_pair; |
||||
|
||||
struct ConstFolding |
||||
{ |
||||
Net::Impl* netimpl; |
||||
std::vector<int> usecounts; |
||||
|
||||
ConstFolding(Net::Impl* netimpl_) : netimpl(netimpl_) {} |
||||
|
||||
void process() |
||||
{ |
||||
size_t nargs = netimpl->args.size(); |
||||
netimpl->__tensors__.resize(nargs); |
||||
netimpl->useCounts(usecounts); |
||||
netimpl->scratchBufs.clear(); |
||||
processGraph(netimpl->mainGraph); |
||||
netimpl->scratchBufs.clear(); |
||||
} |
||||
|
||||
Layer* getLayer(std::vector<Ptr<Layer> >& newprog, int op_idx) const |
||||
{ |
||||
return op_idx >= 0 ? newprog.at(op_idx).get() : 0; |
||||
} |
||||
|
||||
void unuse(Arg inp) |
||||
{ |
||||
CV_Assert(usecounts[inp.idx] > 0); |
||||
if (--usecounts[inp.idx] == 0 && netimpl->isConstArg(inp)) { |
||||
netimpl->__tensors__[inp.idx] = Mat(); // deallocate unused tensor
|
||||
} |
||||
} |
||||
|
||||
bool processGraph(Ptr<Graph>& graph) |
||||
{ |
||||
bool modified = false; |
||||
const std::vector<Ptr<Layer> >& prog = graph->prog(); |
||||
size_t i, nops = prog.size(); |
||||
std::vector<Ptr<Layer> > newprog; |
||||
std::vector<Arg> removed_args; |
||||
std::vector<Mat> inpMats, tempMats; |
||||
std::vector<int> inpTypes, outTypes, tempTypes; |
||||
std::vector<MatShape> inpShapes, outShapes, tempShapes; |
||||
|
||||
for (i = 0; i < nops; i++) { |
||||
const Ptr<Layer>& layer = prog[i]; |
||||
std::vector<Ptr<Graph> >* subgraphs = layer->subgraphs(); |
||||
if (subgraphs) { |
||||
for (Ptr<Graph>& g: *subgraphs) { |
||||
if (processGraph(g)) |
||||
modified = true; |
||||
} |
||||
} |
||||
const std::vector<Arg>& inputs = layer->inputs; |
||||
const std::vector<Arg>& outputs = layer->outputs; |
||||
size_t j, ninputs = inputs.size(), noutputs = outputs.size(); |
||||
bool all_const = true; |
||||
inpMats.assign(ninputs, Mat()); |
||||
inpTypes.resize(ninputs); |
||||
inpShapes.resize(ninputs); |
||||
for (j = 0; j < ninputs; j++) { |
||||
Arg inp = inputs[j]; |
||||
bool const_arg = netimpl->isConstArg(inp); |
||||
if (!const_arg) |
||||
all_const = false; |
||||
if (all_const) { |
||||
const Mat& m = netimpl->argTensor(inp); |
||||
inpMats[j] = m; |
||||
inpTypes[j] = m.type(); |
||||
inpShapes[j] = m.shape(); |
||||
} |
||||
} |
||||
|
||||
if (all_const /*&&
|
||||
op->supportBlockLayout(0, (int)ninputs) <= 0 // we don't currently support constant folding
|
||||
// for block-layout operations (Convolution, MaxPool, AveragePool)
|
||||
*/) { |
||||
// Use a fresh vector of Mat's for outputs since we want to make these outputs the new constant tensors.
|
||||
// So, they must be unique and don't interfere with other tensors.
|
||||
std::vector<Mat> outMats(noutputs); |
||||
std::vector<std::pair<uchar*, size_t> > outOrigData; |
||||
if (!layer->dynamicOutputShapes()) |
||||
netimpl->allocateLayerOutputs(layer, inpTypes, inpShapes, outTypes, |
||||
outShapes, outOrigData, outMats, tempTypes, tempShapes, tempMats, |
||||
netimpl->scratchBufs, false); |
||||
layer->finalize(inpMats, outMats); |
||||
layer->forward(inpMats, outMats, tempMats); |
||||
CV_Assert(outMats.size() == noutputs); |
||||
for (j = 0; j < noutputs; j++) { |
||||
Arg out = outputs[j]; |
||||
ArgData& out_data = netimpl->args.at(out.idx); |
||||
const Mat& m = outMats[j]; |
||||
out_data.type = m.type(); |
||||
out_data.shape = m.shape(); |
||||
out_data.kind = DNN_ARG_CONST; // re-classify each output as constant
|
||||
netimpl->__tensors__.at(out.idx) = m; |
||||
} |
||||
|
||||
modified = true; |
||||
for (size_t i = 0; i < ninputs; i++) |
||||
unuse(inputs[i]); |
||||
//printf("folded %s: %s\n", op->name().data(), node->name().data());
|
||||
// we don't add operation into the new program,
|
||||
// because the output of the all-const inputs operation is now a constant,
|
||||
// stored in a separate tensor
|
||||
} else { |
||||
newprog.push_back(layer); |
||||
} |
||||
} |
||||
|
||||
if (modified) { |
||||
graph->setProg(newprog); |
||||
} |
||||
|
||||
return modified; |
||||
} |
||||
}; |
||||
|
||||
void Net::Impl::constFold() |
||||
{ |
||||
ConstFolding constfolder(this); |
||||
constfolder.process(); |
||||
} |
||||
|
||||
CV__DNN_INLINE_NS_END |
||||
}} |
@ -0,0 +1,191 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Concat layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Concat.html
|
||||
|
||||
Opset's 1 to 13 are covered. |
||||
*/ |
||||
|
||||
// out must be pre-allocated
|
||||
static void concat(const std::vector<Mat>& inps, Mat& out, int axis) |
||||
{ |
||||
CV_Assert(out.isContinuous()); |
||||
|
||||
MatShape outShape = out.shape(); |
||||
int ndims = outShape.dims, nslices = 1; |
||||
size_t esz = out.elemSize(); |
||||
size_t sliceSize = esz; |
||||
size_t totalSize = 0; |
||||
size_t outStep = 0; |
||||
int ninputs = (int)inps.size(); |
||||
for (int i = ndims-1; i > axis; i--) |
||||
sliceSize *= outShape[i]; |
||||
outStep = sliceSize*outShape[axis]; |
||||
for (int i = 0; i < axis; i++) |
||||
nslices *= outShape[i]; |
||||
for (int i = 0; i < ninputs; i++) { |
||||
CV_Assert(inps[i].isContinuous()); |
||||
totalSize += inps[i].total()*esz; |
||||
} |
||||
|
||||
parallel_for_(Range(0, ninputs), [&](const Range& r) { |
||||
for (int k = r.start; k < r.end; k++) { |
||||
const Mat& inp_k = inps[k]; |
||||
uchar* outptr = out.data; |
||||
const uchar* inptr_k = inp_k.data; |
||||
int sz_a; |
||||
for (int i = 0; i < k; i++) { |
||||
sz_a = inps[i].size[axis]; |
||||
outptr += sliceSize*sz_a; |
||||
} |
||||
sz_a = inp_k.size[axis]; |
||||
size_t sliceSize_k = sliceSize*sz_a; |
||||
for (int i = 0; i < nslices; i++) |
||||
memcpy(outptr + i*outStep, inptr_k + i*sliceSize_k, sliceSize_k); |
||||
} |
||||
}, (totalSize > 1000000 ? ninputs : 1)); |
||||
} |
||||
|
||||
class Concat2LayerImpl CV_FINAL : public Concat2Layer |
||||
{ |
||||
public: |
||||
Concat2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
axis = params.get<int>("axis", 1); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
MatShape getOutShape(const std::vector<MatShape>& inpShapes) const |
||||
{ |
||||
size_t ninputs = inpShapes.size(); |
||||
CV_Assert(ninputs == inputs.size()); |
||||
|
||||
const MatShape& inpShape0 = inpShapes[0]; |
||||
int inpDims = inpShape0.dims; |
||||
int axis_ = normalize_axis(axis, inpDims); |
||||
CV_Assert(0 <= axis_ && axis_ < inpDims); |
||||
MatShape outShape = inpShape0; |
||||
outShape[axis_] = 0; |
||||
|
||||
for (size_t i = 0; i < ninputs; i++) { |
||||
const MatShape& inpShape_i = inpShapes[i]; |
||||
CV_Assert(inpShape_i.dims == inpDims); |
||||
for (int j = 0; j < inpDims; j++) { |
||||
if (j == axis_) { |
||||
outShape[j] += inpShape_i[j]; |
||||
continue; |
||||
} |
||||
CV_Assert(inpShape0[j] == inpShape_i[j]); |
||||
} |
||||
} |
||||
|
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
outputs.assign(1, getOutShape(inputs)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs > 0); |
||||
for (size_t i = 1; i < ninputs; i++) { |
||||
CV_Assert(inputs[i] == inputs[0]); |
||||
} |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
|
||||
CV_Assert(ninputs > 0); |
||||
|
||||
std::vector<MatShape> inpShapes(ninputs); |
||||
int inpType = inputs_arr.type(0); |
||||
|
||||
for (int i = 0; i < ninputs; i++) { |
||||
inpShapes[i] = inputs_arr.shape(i); |
||||
CV_Assert(inputs_arr.type(i) == inpType); |
||||
} |
||||
|
||||
MatShape outShape = getOutShape(inpShapes); |
||||
int outKind = outputs_arr.kind(); |
||||
int axis_ = normalize_axis(axis, inpShapes[0].dims); |
||||
|
||||
CV_Assert(outKind == _InputArray::STD_VECTOR_MAT || |
||||
outKind == _InputArray::STD_VECTOR_UMAT); |
||||
|
||||
if (outKind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat> inps; |
||||
inputs_arr.getMatVector(inps); |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, inpType); |
||||
runOp(inps, outs[0], axis_); |
||||
} else { |
||||
// [TODO] more efficient OpenCL implementation
|
||||
std::vector<Mat> inps; |
||||
inputs_arr.getMatVector(inps); |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, inpType); |
||||
Mat temp(outShape, inpType); |
||||
runOp(inps, temp, axis_); |
||||
temp.copyTo(outs[0]); |
||||
} |
||||
} |
||||
|
||||
void runOp(const std::vector<Mat>& inps, Mat& out, int axis_) |
||||
{ |
||||
concat(inps, out, axis_); |
||||
} |
||||
}; |
||||
|
||||
Ptr<Concat2Layer> Concat2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Concat2Layer>(new Concat2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,149 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
ConstantOfShape layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html
|
||||
|
||||
Opset's 9 to 23 are covered. |
||||
*/ |
||||
|
||||
// out must be pre-allocated
|
||||
static void constantOfShape(const Mat& value, Mat& out) |
||||
{ |
||||
CV_Assert(value.total() == 1); |
||||
CV_Assert(out.isContinuous()); |
||||
CV_CheckEQ(value.type(), out.type(), "input and output tensor types must be the same"); |
||||
|
||||
size_t esz = value.elemSize(); |
||||
size_t total = out.total(); |
||||
const uchar* inpdata_ = value.data; |
||||
uchar* outdata_ = out.data; |
||||
|
||||
#undef IMPL_CONST_OF_SHAPE |
||||
#define IMPL_CONST_OF_SHAPE(T) \ |
||||
T val = *(const T*)inpdata_; \
|
||||
T* outdata = (T*)outdata_; \
|
||||
for (size_t i = 0; i < total; i++) \
|
||||
outdata[i] = val |
||||
|
||||
if (esz == 1) { |
||||
IMPL_CONST_OF_SHAPE(uint8_t); |
||||
} else if (esz == 2) { |
||||
IMPL_CONST_OF_SHAPE(uint16_t); |
||||
} else if (esz == 4) { |
||||
IMPL_CONST_OF_SHAPE(uint32_t); |
||||
} else if (esz == 8) { |
||||
IMPL_CONST_OF_SHAPE(uint64_t); |
||||
} else { |
||||
CV_Error_(Error::StsNotImplemented, ("invalid/unsupported tensor type: %s", typeToString(value.type()).c_str())); |
||||
} |
||||
} |
||||
|
||||
class ConstantOfShapeLayerImpl CV_FINAL : public ConstantOfShapeLayer |
||||
{ |
||||
public: |
||||
ConstantOfShapeLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
CV_Assert(netimpl_); |
||||
CV_Assert(this->inputs.size() == 1); |
||||
return !netimpl_->isConstArg(this->inputs[0]); |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape>&, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(!dynamicOutputShapes()); |
||||
|
||||
CV_Assert(this->inputs.size() == (size_t)1); |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
Mat shapeTensor = netimpl_->argTensor(this->inputs[0]); |
||||
MatShape shape = tensorToShape(shapeTensor); |
||||
outputs.assign(1, shape); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(blobs.size() == 1); |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == (size_t)1); |
||||
outputs.assign(requiredOutputs, blobs[0].type()); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_Assert(blobs.size() == 1); |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 1); |
||||
|
||||
const Mat& value = blobs[0]; |
||||
Mat shapeTensor = inputs_arr.getMat(0); |
||||
MatShape shape = tensorToShape(shapeTensor); |
||||
|
||||
auto kind = outputs_arr.kind(); |
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(shape, value.type()); |
||||
constantOfShape(value, outs[0]); |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(shape, value.type()); |
||||
Mat temp(shape, value.type()); |
||||
constantOfShape(value, temp); |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<ConstantOfShapeLayer> ConstantOfShapeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<ConstantOfShapeLayer>(new ConstantOfShapeLayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,348 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
DequantizeLinear layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
|
||||
|
||||
Opset's 10 to 23 are covered. |
||||
*/ |
||||
|
||||
template <typename _InpTp, typename _ScaleTp, typename _OutTp> |
||||
static void dequantizeLinear(const _InpTp* inp_, const _ScaleTp* scale_, |
||||
const _InpTp* zp_, _OutTp* out_, |
||||
int64_t nslices, int sz_a_, |
||||
int64_t slice_size_, int block_size_) |
||||
{ |
||||
int bsz_ = std::max(block_size_, 1); |
||||
int nblocks_per_axis = (sz_a_ + bsz_ - 1) / bsz_; |
||||
int64_t nmacro_blocks = nslices * nblocks_per_axis; |
||||
CV_Assert(nmacro_blocks <= (int64_t)INT_MAX); |
||||
|
||||
parallel_for_(Range(0, (int)nmacro_blocks), [&](const Range& r) { |
||||
int sz_a = sz_a_; |
||||
int64_t slice_size = slice_size_; |
||||
int block_size = block_size_; |
||||
int delta = 0; |
||||
int64_t scale_step = block_size > 0 ? slice_size : 1; |
||||
int64_t zp_step = zp_ ? scale_step : 0; |
||||
|
||||
for (int i = r.start; i < r.end; i += delta) { |
||||
int slice_idx = i / nblocks_per_axis; |
||||
int block_idx = i - slice_idx * nblocks_per_axis; |
||||
int64_t block_ofs, scale_ofs; |
||||
if (block_size > 0) { |
||||
delta = std::min(nblocks_per_axis - block_idx, r.end - i); |
||||
block_ofs = (slice_idx*sz_a + block_idx*block_size)*slice_size; |
||||
scale_ofs = (slice_idx*nblocks_per_axis + block_idx)*slice_size; |
||||
} else { |
||||
delta = std::min(sz_a - block_idx, r.end - i); |
||||
block_ofs = (slice_idx*sz_a + block_idx)*slice_size; |
||||
scale_ofs = block_idx; |
||||
} |
||||
const _InpTp* inp = inp_ + block_ofs; |
||||
const _InpTp* zp = zp_ ? zp_ + scale_ofs : nullptr; |
||||
const _ScaleTp* sc = scale_ + scale_ofs; |
||||
_OutTp* out = out_ + block_ofs; |
||||
|
||||
// [TODO] vectorize using intrinsics
|
||||
if (slice_size > 1) { |
||||
for (int k = 0; k < delta; k++, inp += slice_size, out += slice_size, |
||||
sc += scale_step, zp += zp_step) { |
||||
float scval = (float)*sc; |
||||
_InpTp zpval = zp ? *zp : (_InpTp)0; |
||||
|
||||
for (int64_t j = 0; j < slice_size; j++) |
||||
out[j] = _OutTp((inp[j] - zpval)*scval); |
||||
} |
||||
} else if (block_size > 0 ) { |
||||
int bsz = block_size; |
||||
for (int k = 0; k < delta; k++, inp += bsz, out += bsz) { |
||||
bsz = std::min(bsz, sz_a - (block_idx + k)*block_size); |
||||
float scval = (float)sc[k]; |
||||
_InpTp zpval = zp ? zp[k] : (_InpTp)0; |
||||
|
||||
for (int j = 0; j < bsz; j++) |
||||
out[j] = _OutTp((inp[j] - zpval)*scval); |
||||
} |
||||
sc += delta; |
||||
zp += zp ? delta : 0; |
||||
} else { |
||||
if (zp) { |
||||
for (int j = 0; j < delta; j++) { |
||||
float scval = (float)sc[j]; |
||||
_InpTp zpval = zp[j]; |
||||
out[j] = _OutTp((inp[j] - zpval)*scval); |
||||
} |
||||
} else { |
||||
for (int j = 0; j < delta; j++) { |
||||
float scval = (float)sc[j]; |
||||
out[j] = _OutTp(inp[j]*scval); |
||||
} |
||||
} |
||||
inp += delta; |
||||
out += delta; |
||||
} |
||||
} |
||||
}); |
||||
} |
||||
|
||||
// Dequantize INT8/UINT8 to FP32/FP16; out must be preallocated
|
||||
static void dequantizeLinear(const Mat& inp, const Mat& scale_, const Mat& zp, |
||||
int axis, int block_size, Mat& out) |
||||
{ |
||||
Mat scale = scale_; |
||||
CV_Assert(inp.isContinuous()); |
||||
CV_Assert(scale.isContinuous()); |
||||
CV_Assert(out.isContinuous()); |
||||
|
||||
int inptype = inp.type(); |
||||
int outtype = out.type(); |
||||
int sctype = scale.type(); |
||||
int zptype = zp.type(); |
||||
MatShape inpshape = inp.shape(); |
||||
MatShape scshape = scale.shape(); |
||||
MatShape zpshape = zp.shape(); |
||||
int i, ndims = inpshape.dims; |
||||
int64_t nslices = 1, slice_size = 1; |
||||
|
||||
CV_Assert(inptype == CV_8U || inptype == CV_8S || inptype == CV_32S); |
||||
CV_Assert(sctype == CV_32F || sctype == CV_16F); |
||||
CV_Assert(outtype == CV_32F || outtype == CV_16F); |
||||
|
||||
if (!zp.empty()) { |
||||
CV_Assert(zp.isContinuous()); |
||||
CV_Assert(zptype == inptype); |
||||
CV_Assert(zpshape == scshape); |
||||
} |
||||
|
||||
axis = normalize_axis(axis, ndims); |
||||
for (i = 0; i < axis; i++) |
||||
nslices *= inpshape[i]; |
||||
for (i = axis+1; i < ndims; i++) |
||||
slice_size *= inpshape[i]; |
||||
int sz_a = inpshape[axis]; |
||||
|
||||
if (block_size == 0) { |
||||
size_t sc_total = scshape.total(); |
||||
CV_Assert(scale.dims <= 1); |
||||
CV_Assert(sc_total == 1 || sc_total == (size_t)sz_a); |
||||
|
||||
// unroll the innermost loop if the scale's/zp's are the same
|
||||
if (sc_total == 1) { |
||||
slice_size *= sz_a; |
||||
sz_a = 1; |
||||
} |
||||
|
||||
// avoid FP16 => FP32 conversion for scale inside the innermost loop
|
||||
if (sctype == CV_16F && slice_size == 1 && nslices > 1) { |
||||
Mat temp; |
||||
scale_.convertTo(temp, CV_32F); |
||||
scale = temp; |
||||
sctype = CV_32F; |
||||
} |
||||
} else { |
||||
CV_Assert(block_size > 0); |
||||
CV_Assert(scale.dims == ndims); |
||||
for (int i = 0; i < ndims; i++) { |
||||
int inp_i = inpshape[i]; |
||||
int sc_i = scshape[i]; |
||||
if (i == axis) { |
||||
CV_Assert((inp_i + block_size - 1)/block_size == sc_i); |
||||
} else { |
||||
CV_Assert(sc_i == inp_i); |
||||
} |
||||
} |
||||
} |
||||
|
||||
if (inptype == CV_8U && sctype == CV_32F && outtype == CV_32F) |
||||
dequantizeLinear(reinterpret_cast<const uint8_t*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<float*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8U && sctype == CV_16F && outtype == CV_32F) |
||||
dequantizeLinear(reinterpret_cast<const uint8_t*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<float*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8U && sctype == CV_32F && outtype == CV_16F) |
||||
dequantizeLinear(reinterpret_cast<const uint8_t*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<hfloat*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8U && sctype == CV_16F && outtype == CV_16F) |
||||
dequantizeLinear(reinterpret_cast<const uint8_t*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<hfloat*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8S && sctype == CV_32F && outtype == CV_32F) |
||||
dequantizeLinear(reinterpret_cast<const int8_t*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<float*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8S && sctype == CV_16F && outtype == CV_32F) |
||||
dequantizeLinear(reinterpret_cast<const int8_t*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<float*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8S && sctype == CV_32F && outtype == CV_16F) |
||||
dequantizeLinear(reinterpret_cast<const int8_t*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<hfloat*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_8S && sctype == CV_16F && outtype == CV_16F) |
||||
dequantizeLinear(reinterpret_cast<const int8_t*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<hfloat*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_32S && sctype == CV_32F && outtype == CV_32F) |
||||
dequantizeLinear(reinterpret_cast<const int32_t*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const int32_t*>(zp.data), |
||||
reinterpret_cast<float*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_32S && sctype == CV_16F && outtype == CV_32F) |
||||
dequantizeLinear(reinterpret_cast<const int32_t*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const int32_t*>(zp.data), |
||||
reinterpret_cast<float*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_32S && sctype == CV_32F && outtype == CV_16F) |
||||
dequantizeLinear(reinterpret_cast<const int32_t*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const int32_t*>(zp.data), |
||||
reinterpret_cast<hfloat*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (inptype == CV_32S && sctype == CV_16F && outtype == CV_16F) |
||||
dequantizeLinear(reinterpret_cast<const int32_t*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const int32_t*>(zp.data), |
||||
reinterpret_cast<hfloat*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else { |
||||
CV_Error_(Error::StsNotImplemented, |
||||
("the following combination of types is not supported in " |
||||
"DequantizeLinear: inp=%s, scale=%s, out=%s", |
||||
typeToString(inptype).c_str(), |
||||
typeToString(sctype).c_str(), |
||||
typeToString(outtype).c_str())); |
||||
} |
||||
} |
||||
|
||||
class DequantizeLinearLayerImpl CV_FINAL : public DequantizeLinearLayer |
||||
{ |
||||
public: |
||||
DequantizeLinearLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
|
||||
axis = params.get<int>("axis", 1); |
||||
block_size = params.get<int>("block_size", 0); |
||||
CV_Assert(block_size >= 0); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int requiredOutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(2 <= ninputs && ninputs <= 3); |
||||
CV_Assert(requiredOutputs == 1); |
||||
outputs.assign(1, inputs[0]); |
||||
return true; |
||||
} |
||||
|
||||
int getOutType() const |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
return netimpl_->enableFP16 ? CV_16F : CV_32F; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(2 <= ninputs && ninputs <= 3); |
||||
if (ninputs == 3) { |
||||
CV_Assert(inputs[0] == inputs[2]); |
||||
} |
||||
outputs.assign(1, getOutType()); |
||||
} |
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
int ninputs = inputs_arr.size(-1).area(); |
||||
CV_Assert(2 <= ninputs && ninputs <= 3); |
||||
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
Mat scale = inputs_arr.getMat(1); |
||||
Mat zeropoint; |
||||
int outtype = getOutType(); |
||||
MatShape inpshape = inp.shape(); |
||||
|
||||
if (ninputs >= 3) { |
||||
zeropoint = inputs_arr.getMat(2); |
||||
} |
||||
|
||||
auto kind = outputs_arr.kind(); |
||||
|
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(inpshape, outtype); |
||||
dequantizeLinear(inp, scale, zeropoint, axis, block_size, outs[0]); |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(inpshape, outtype); |
||||
Mat temp(inpshape, outtype); |
||||
dequantizeLinear(inp, scale, zeropoint, axis, block_size, temp); |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<DequantizeLinearLayer> DequantizeLinearLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<DequantizeLinearLayer>(new DequantizeLinearLayerImpl(params)); |
||||
} |
||||
|
||||
}} |
@ -0,0 +1,130 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Expand layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Expand.html
|
||||
|
||||
Opset's 8 to 13 are covered. |
||||
*/ |
||||
|
||||
class Expand2LayerImpl CV_FINAL : public Expand2Layer |
||||
{ |
||||
public: |
||||
Expand2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
CV_Assert(netimpl_); |
||||
size_t ninputs = this->inputs.size(); |
||||
CV_Assert(ninputs == 2); |
||||
return !netimpl_->isConstArg(this->inputs[1]); |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpshape, const Mat& shapeTensor) const |
||||
{ |
||||
MatShape shape0 = tensorToShape(shapeTensor); |
||||
MatShape shape = inpshape.expand(shape0); |
||||
// according to ONNX specification, the specified shape can be smaller than the input!
|
||||
// so we comment off the check
|
||||
// CV_Assert(shape == shape0); // check that input can be expanded to the specified shape
|
||||
return shape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape>& inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(!dynamicOutputShapes()); |
||||
|
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == (size_t)2); |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
|
||||
Mat shapeTensor = netimpl_->argTensor(this->inputs[1]); |
||||
|
||||
outputs.assign(1, getOutShape(inputs[0], shapeTensor)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == (size_t)2); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 2); |
||||
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
int inptype = inp.type(); |
||||
Mat shapeTensor = inputs_arr.getMat(1); |
||||
|
||||
MatShape outshape = getOutShape(inp.shape(), shapeTensor); |
||||
|
||||
auto kind = outputs_arr.kind(); |
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outshape, inptype); |
||||
broadcast(inp, outshape, outs[0]); |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outshape, inptype); |
||||
Mat temp(outshape, inptype); |
||||
broadcast(inp, outshape, temp); |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<Expand2Layer> Expand2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Expand2Layer>(new Expand2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,210 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Gather layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Gather.html
|
||||
|
||||
Opset's 1 to 13 are covered. |
||||
*/ |
||||
|
||||
// out must be pre-allocated
|
||||
static void gather(const Mat& data, const Mat& ind, Mat& out, int axis) |
||||
{ |
||||
CV_Assert_N(data.isContinuous(), ind.isContinuous(), out.isContinuous()); |
||||
int indType = ind.type(); |
||||
CV_Assert(indType == CV_32S || indType == CV_64S); |
||||
|
||||
MatShape dataShape = data.shape(); |
||||
MatShape indShape = ind.shape(); |
||||
MatShape outShape = out.shape(); |
||||
int dataDims = dataShape.dims; |
||||
int indDims = indShape.dims; |
||||
int outDims = outShape.dims; |
||||
|
||||
CV_Assert(outDims == dataDims + indDims - 1); |
||||
size_t indTotal = indShape.total(), nslices = 1; |
||||
size_t elemSize = data.elemSize(); |
||||
size_t sliceSize = elemSize; |
||||
|
||||
for(int j = 0; j < dataDims; j++) { |
||||
int szj = dataShape[j]; |
||||
if (j < axis) |
||||
nslices *= szj; |
||||
else if (j > axis) |
||||
sliceSize *= szj; |
||||
} |
||||
size_t dataStep = sliceSize * dataShape[axis]; |
||||
size_t outStep = sliceSize * indTotal; |
||||
volatile bool globOutOfRangeIdx = false; |
||||
|
||||
parallel_for_(Range(0, (int)indTotal), [&](const Range& r) { |
||||
int shape_a = dataShape[axis]; |
||||
const uchar* dataptr0 = data.data; |
||||
uchar* outptr0 = out.data; |
||||
const int32_t* ind32 = indType == CV_32S ? ind.ptr<int32_t>() : nullptr; |
||||
const int64_t* ind64 = indType == CV_64S ? ind.ptr<int64_t>() : nullptr; |
||||
bool outOfRangeIdx = globOutOfRangeIdx; |
||||
for (int j = r.start; j < r.end && !outOfRangeIdx; j++) { |
||||
int k = ind32 ? (int)ind32[j] : (int)ind64[j]; |
||||
uchar* outptr = outptr0 + j*sliceSize; |
||||
const uchar* dataptr = dataptr0; |
||||
for (size_t i = 0; i < nslices; i++, dataptr += dataStep, outptr += outStep) { |
||||
k += k < 0 ? shape_a : 0; |
||||
if (k < 0 || k >= shape_a) { |
||||
outOfRangeIdx = true; |
||||
break; |
||||
} |
||||
memcpy(outptr, dataptr + k*sliceSize, sliceSize); |
||||
} |
||||
} |
||||
if (outOfRangeIdx) |
||||
globOutOfRangeIdx = true; |
||||
}, std::min((double)indTotal, (double)sliceSize*nslices*indTotal/1e6)); |
||||
|
||||
if (globOutOfRangeIdx) { |
||||
CV_Error(Error::StsOutOfRange, "some of indices are outside of range"); |
||||
} |
||||
} |
||||
|
||||
class Gather2LayerImpl CV_FINAL : public Gather2Layer |
||||
{ |
||||
public: |
||||
Gather2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
axis = params.get<int>("axis", 0); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& dataShape, const MatShape& indShape) const |
||||
{ |
||||
int dataDims = dataShape.dims; |
||||
int indDims = indShape.dims; |
||||
|
||||
int axis_ = normalize_axis(axis, dataDims); |
||||
CV_Assert(0 <= axis_ && axis_ < dataDims); |
||||
MatShape outShape(dataDims + indDims - 1); |
||||
|
||||
for (int i = 0; i < outShape.dims; i++) { |
||||
if (i < axis_) { |
||||
outShape[i] = dataShape[i]; |
||||
} else { |
||||
int j = i - axis_; |
||||
outShape[i] = j < indDims ? indShape[j] : dataShape[i - indDims + 1]; |
||||
} |
||||
} |
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 2); |
||||
outputs.assign(1, getOutShape(inputs[0], inputs[1])); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == 2); |
||||
int dataType = inputs[0]; |
||||
int indType = inputs[1]; |
||||
CV_Assert(indType == CV_32S || indType == CV_64S); |
||||
outputs.assign(requiredOutputs, dataType); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
|
||||
CV_Assert(ninputs == 2); |
||||
|
||||
MatShape dataShape = inputs_arr.shape(0); |
||||
MatShape indShape = inputs_arr.shape(1); |
||||
int dataType = inputs_arr.type(0); |
||||
int indType = inputs_arr.type(1); |
||||
CV_Assert(indType == CV_32S || indType == CV_64S); |
||||
|
||||
MatShape outShape = getOutShape(dataShape, indShape); |
||||
int outKind = outputs_arr.kind(); |
||||
int axis_ = normalize_axis(axis, dataShape.dims); |
||||
|
||||
CV_Assert(outKind == _InputArray::STD_VECTOR_MAT || |
||||
outKind == _InputArray::STD_VECTOR_UMAT); |
||||
|
||||
if (outKind == _InputArray::STD_VECTOR_MAT) { |
||||
Mat data = inputs_arr.getMat(0); |
||||
Mat ind = inputs_arr.getMat(1); |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, dataType); |
||||
runOp(data, ind, outs[0], axis_); |
||||
} else { |
||||
// [TODO] more efficient OpenCL implementation
|
||||
Mat data = inputs_arr.getMat(0); |
||||
Mat ind = inputs_arr.getMat(1); |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, dataType); |
||||
Mat temp(outShape, dataType); |
||||
runOp(data, ind, temp, axis_); |
||||
temp.copyTo(outs[0]); |
||||
} |
||||
} |
||||
|
||||
void runOp(const Mat& data, const Mat& ind, Mat& out, int axis_) |
||||
{ |
||||
gather(data, ind, out, axis_); |
||||
} |
||||
}; |
||||
|
||||
Ptr<Gather2Layer> Gather2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Gather2Layer>(new Gather2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,377 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
static constexpr int PAD_MAX_DIMS = 5; |
||||
|
||||
/*
|
||||
Padding layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Pad.html
|
||||
|
||||
Opset's 1 to 23 are covered. |
||||
*/ |
||||
|
||||
// out must be pre-allocated
|
||||
// pads[] should contains as many elements as inp.dims*2
|
||||
static void pad(const Mat& inp, const std::vector<int>& pads_, int mode_, const Mat& value, Mat& out) |
||||
{ |
||||
int inptype = inp.type(); |
||||
MatShape inpshape_ = inp.shape(); |
||||
MatShape outshape_ = out.shape(); |
||||
double buf = 0; |
||||
Mat vbuf(1, 1, inptype, &buf); |
||||
|
||||
int inpshape[PAD_MAX_DIMS]; |
||||
int outshape[PAD_MAX_DIMS]; |
||||
int pads[PAD_MAX_DIMS*2]; |
||||
int64_t inpstep[PAD_MAX_DIMS]; |
||||
int64_t outstep[PAD_MAX_DIMS]; |
||||
std::vector<int> tab[PAD_MAX_DIMS]; |
||||
|
||||
int ndims = inp.dims, delta = PAD_MAX_DIMS - ndims; |
||||
int64_t esz = inp.elemSize(); |
||||
|
||||
CV_Assert(inp.isContinuous()); |
||||
CV_Assert(out.isContinuous()); |
||||
CV_Assert(inp.type() == out.type()); |
||||
CV_Assert(esz == 1 || esz == 2 || esz == 4 || esz == 8); |
||||
CV_Assert(inp.dims == out.dims); |
||||
CV_Assert(inp.dims <= PAD_MAX_DIMS); |
||||
|
||||
if (!value.empty()) { |
||||
CV_Assert(value.dims <= 2 && value.total() == 1 && value.channels() == 1); |
||||
tensorToScalar(value, inptype, &buf); |
||||
} |
||||
|
||||
for (int i = 0; i < PAD_MAX_DIMS; i++) { |
||||
inpshape[i] = outshape[i] = 1; |
||||
pads[i] = pads[i + PAD_MAX_DIMS] = 0; |
||||
} |
||||
|
||||
for (int i = 0; i < ndims; i++) { |
||||
inpshape[i+delta] = inpshape_[i]; |
||||
outshape[i+delta] = outshape_[i]; |
||||
pads[i+delta] = pads_[i]; |
||||
pads[i+delta + PAD_MAX_DIMS] = pads_[i + ndims]; |
||||
|
||||
// initialize lookup table along the corresponding axis
|
||||
int inpsz_i = inpshape_[i]; |
||||
int outsz_i = outshape_[i]; |
||||
tab[i+delta].resize(outsz_i); |
||||
int* tab_i = tab[i+delta].data(); |
||||
int before = pads_[i]; |
||||
for (int j = 0; j < outsz_i; j++) |
||||
tab_i[j] = borderInterpolate(j - before, inpsz_i, mode_); |
||||
} |
||||
|
||||
for (int i = PAD_MAX_DIMS-1; i >= 0; i--) { |
||||
if (i == PAD_MAX_DIMS-1) |
||||
inpstep[i] = outstep[i] = 1; |
||||
else { |
||||
inpstep[i] = inpstep[i+1]*inpshape[i+1]; |
||||
outstep[i] = outstep[i+1]*outshape[i+1]; |
||||
} |
||||
} |
||||
|
||||
int nplanes = outshape[0]*outshape[1]*outshape[2]; |
||||
|
||||
CV_Assert(!tab[4].empty()); |
||||
|
||||
#undef IMPL_PAD |
||||
#define IMPL_PAD(T) \ |
||||
parallel_for_(Range(0, nplanes), [&](const Range& r) { \
|
||||
int mode = mode_; \
|
||||
int sz1 = outshape[1], sz2 = outshape[2], sz3 = outshape[3], sz4 = outshape[4]; \
|
||||
const int* tab0 = tab[0].data(); \
|
||||
const int* tab1 = tab[1].data(); \
|
||||
const int* tab2 = tab[2].data(); \
|
||||
const int* tab3 = tab[3].data(); \
|
||||
const int* tab4 = tab[4].data(); \
|
||||
const T* inpdata0 = (const T*)inp.data; \
|
||||
T val0 = *reinterpret_cast<T*>(vbuf.data); \
|
||||
T* outdata0 = (T*)out.data; \
|
||||
int p0 = pads[PAD_MAX_DIMS-1], p1 = pads[PAD_MAX_DIMS*2-1]; \
|
||||
int p0_ = std::max(p0, 0), p1_ = std::max(p1, 0); \
|
||||
for (int plane = r.start; plane < r.end; plane++) { \
|
||||
int plane_ = plane; \
|
||||
int i2 = plane_ % sz2; \
|
||||
plane_ /= sz2; \
|
||||
int i1 = plane_ % sz1; \
|
||||
int i0 = plane_ / sz1; \
|
||||
int ii0 = tab0 ? tab0[i0] : i0; \
|
||||
int ii1 = tab1 ? tab1[i1] : i1; \
|
||||
int ii2 = tab2 ? tab2[i2] : i2; \
|
||||
for (int i3 = 0; i3 < sz3; i3++) { \
|
||||
int ii3 = tab3 ? tab3[i3] : i3; \
|
||||
T* outdata = outdata0 + i0*outstep[0] + i1*outstep[1] + i2*outstep[2] + i3*outstep[3]; \
|
||||
int i4 = 0; \
|
||||
if ((ii0|ii1|ii2|ii3) < 0) { \
|
||||
for (; i4 < sz4; i4++) \
|
||||
outdata[i4] = val0; \
|
||||
continue; \
|
||||
} \
|
||||
const T* inpdata = inpdata0 + ii0*inpstep[0] + ii1*inpstep[1] + ii2*inpstep[2] + ii3*inpstep[3]; \
|
||||
if (mode == BORDER_CONSTANT) {\
|
||||
for (; i4 < p0_; i4++) \
|
||||
outdata[i4] = val0; \
|
||||
} else { \
|
||||
for (; i4 < p0_; i4++) \
|
||||
outdata[i4] = inpdata[tab4[i4]]; \
|
||||
} \
|
||||
for (; i4 < sz4 - p1_; i4++) \
|
||||
outdata[i4] = inpdata[i4 - p0]; \
|
||||
if (mode == BORDER_CONSTANT) { \
|
||||
for (; i4 < sz4; i4++) \
|
||||
outdata[i4] = val0; \
|
||||
} else { \
|
||||
for (; i4 < sz4; i4++) \
|
||||
outdata[i4] = inpdata[tab4[i4]]; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}) |
||||
|
||||
if (esz == 1) { |
||||
IMPL_PAD(uint8_t); |
||||
} else if (esz == 2) { |
||||
IMPL_PAD(uint16_t); |
||||
} else if (esz == 4) { |
||||
IMPL_PAD(uint32_t); |
||||
} else { |
||||
CV_Assert(esz == 8); |
||||
IMPL_PAD(uint64_t); |
||||
} |
||||
} |
||||
|
||||
class Pad2LayerImpl CV_FINAL : public Pad2Layer |
||||
{ |
||||
public: |
||||
std::vector<int> pads0; |
||||
float value0 = 0.f; |
||||
int mode = BORDER_CONSTANT; |
||||
|
||||
Pad2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
std::vector<int> pads0_ = params.getVector<int>("paddings"); |
||||
// [TODO] remove this transposition after the original transposition is removed from onnx importer 2
|
||||
if (!pads0_.empty()) { |
||||
int i, ndims = (int)(pads0_.size()/2); |
||||
pads0.resize(ndims*2); |
||||
for (i = 0; i < ndims; i++) { |
||||
pads0[i] = pads0_[i*2]; |
||||
pads0[i + ndims] = pads0_[i*2+1]; |
||||
} |
||||
} |
||||
std::string strmode = params.get<std::string>("mode", "constant"); |
||||
if (strmode == "constant") |
||||
mode = BORDER_CONSTANT; |
||||
else if (strmode == "reflect") |
||||
mode = BORDER_REFLECT101; |
||||
else if (strmode == "edge") |
||||
mode = BORDER_REPLICATE; |
||||
else if (strmode == "wrap") |
||||
mode = BORDER_WRAP; |
||||
else { |
||||
CV_Error_(Error::StsNotImplemented, ("mode '%s' is not supported", strmode.c_str())); |
||||
} |
||||
value0 = params.get<float>("value", 0.f); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
CV_Assert(netimpl_); |
||||
size_t ninputs = this->inputs.size(); |
||||
CV_Assert(1 <= ninputs && ninputs <= 4); |
||||
return (ninputs >= 2 && !netimpl_->isConstArg(this->inputs[1])) || |
||||
(ninputs >= 4 && !netimpl_->isConstArg(this->inputs[3])); |
||||
} |
||||
|
||||
void getPads(int ndims, const Mat& pads_, const Mat& axes_, std::vector<int>& pads) const |
||||
{ |
||||
int atype = axes_.type(), ptype = pads_.type(); |
||||
CV_Assert(ndims <= PAD_MAX_DIMS); |
||||
|
||||
const int32_t* adata_i32 = nullptr; |
||||
const int64_t* adata_i64 = nullptr; |
||||
const int32_t* pdata_i32 = nullptr; |
||||
const int64_t* pdata_i64 = nullptr; |
||||
|
||||
bool axismask[PAD_MAX_DIMS]; |
||||
int naxes = !axes_.empty() ? (int)axes_.total() : ndims; |
||||
|
||||
CV_Assert(pads_.dims == 1); |
||||
CV_Assert(ptype == CV_32S || ptype == CV_64S); |
||||
|
||||
if (ptype == CV_32S) |
||||
pdata_i32 = reinterpret_cast<const int32_t*>(pads_.data); |
||||
else |
||||
pdata_i64 = reinterpret_cast<const int64_t*>(pads_.data); |
||||
|
||||
if (!axes_.empty()) { |
||||
CV_Assert(axes_.dims == 1); |
||||
CV_Assert(atype == CV_32S || atype == CV_64S); |
||||
CV_Assert(pads_.total() == axes_.total()*2); |
||||
CV_Assert(axes_.total() <= (size_t)ndims); |
||||
|
||||
if (atype == CV_32S) |
||||
adata_i32 = reinterpret_cast<const int32_t*>(axes_.data); |
||||
else |
||||
adata_i64 = reinterpret_cast<const int64_t*>(axes_.data); |
||||
} else { |
||||
CV_Assert(pads_.total() == (size_t)ndims*2); |
||||
} |
||||
|
||||
pads.resize(ndims*2); |
||||
|
||||
for (int i = 0; i < ndims; i++) { |
||||
pads[i] = pads[i+ndims] = 0; |
||||
axismask[i] = false; |
||||
} |
||||
|
||||
for (int i = 0; i < naxes; i++) { |
||||
int a = adata_i32 ? (int)adata_i32[i] : adata_i64 ? (int)adata_i64[i] : i; |
||||
a = normalize_axis(a, ndims); |
||||
if (axismask[a]) { |
||||
CV_Error_(Error::StsBadArg, ("duplicate axis %d in Pad", a)); |
||||
} |
||||
axismask[a] = true; |
||||
int p0 = pdata_i32 ? (int)pdata_i32[i] : pdata_i64 ? (int)pdata_i64[i] : 0; |
||||
int p1 = pdata_i32 ? (int)pdata_i32[i+naxes] : pdata_i64 ? (int)pdata_i64[i+naxes] : 0; |
||||
pads[a] = p0; |
||||
pads[a+ndims] = p1; |
||||
// p0, p1 can be positive, zero or even negative, according to ONNX specification.
|
||||
// so we don't put any checks here.
|
||||
} |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpshape, const std::vector<int>& pads) const |
||||
{ |
||||
MatShape outshape = inpshape; |
||||
int ndims = inpshape.dims; |
||||
for (int i = 0; i < ndims; i++) { |
||||
outshape[i] += pads[i] + pads[i+ndims]; |
||||
CV_Assert(outshape[i] >= 0); |
||||
} |
||||
return outshape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape>& inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(!dynamicOutputShapes()); |
||||
|
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(1 <= ninputs && ninputs <= 4); |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
|
||||
std::vector<int> padsbuf; |
||||
const std::vector<int>* pads = &pads0; |
||||
|
||||
if (ninputs >= 2) { |
||||
int ndims = inputs[0].dims; |
||||
Mat padsTensor = netimpl_->argTensor(this->inputs[1]); |
||||
Mat axesTensor; |
||||
if (ninputs >= 4) |
||||
axesTensor = netimpl_->argTensor(this->inputs[3]); |
||||
getPads(ndims, padsTensor, axesTensor, padsbuf); |
||||
pads = &padsbuf; |
||||
} |
||||
|
||||
outputs.assign(1, getOutShape(inputs[0], *pads)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(1 <= ninputs && ninputs <= 4); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(1 <= ninputs && ninputs <= 4); |
||||
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
Mat value(1, 1, CV_32F, &value0); |
||||
int inptype = inp.type(); |
||||
std::vector<int> padsbuf; |
||||
const std::vector<int>* pads = &pads0; |
||||
|
||||
if (ninputs >= 2) { |
||||
int ndims = inp.dims; |
||||
Mat padsTensor = inputs_arr.getMat(1); |
||||
Mat axesTensor; |
||||
if (ninputs >= 4) |
||||
axesTensor = inputs_arr.getMat(3); |
||||
getPads(ndims, padsTensor, axesTensor, padsbuf); |
||||
pads = &padsbuf; |
||||
if (ninputs >= 3) |
||||
value = inputs_arr.getMat(2); |
||||
} |
||||
|
||||
MatShape inpshape = inp.shape(); |
||||
MatShape outshape = getOutShape(inpshape, *pads); |
||||
|
||||
auto kind = outputs_arr.kind(); |
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outshape, inptype); |
||||
pad(inp, *pads, mode, value, outs[0]); |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outshape, inptype); |
||||
Mat temp(outshape, inptype); |
||||
pad(inp, *pads, mode, value, temp); |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<Pad2Layer> Pad2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Pad2Layer>(new Pad2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,336 @@ |
||||
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
QuantizeLinear layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html
|
||||
|
||||
Opset's 10 to 23 are covered. |
||||
*/ |
||||
|
||||
template <typename _InpTp, typename _ScaleTp, typename _OutTp> |
||||
static void quantizeLinear(const _InpTp* inp_, const _ScaleTp* scale_, |
||||
const _OutTp* zp_, _OutTp* out_, |
||||
int64_t nslices, int sz_a_, |
||||
int64_t slice_size_, int block_size_) |
||||
{ |
||||
int bsz_ = std::max(block_size_, 1); |
||||
int nblocks_per_axis = (sz_a_ + bsz_ - 1) / bsz_; |
||||
int64_t nmacro_blocks = nslices * nblocks_per_axis; |
||||
CV_Assert(nmacro_blocks <= (int64_t)INT_MAX); |
||||
|
||||
parallel_for_(Range(0, (int)nmacro_blocks), [&](const Range& r) { |
||||
int sz_a = sz_a_; |
||||
int64_t slice_size = slice_size_; |
||||
int block_size = block_size_; |
||||
int delta = 0; |
||||
int64_t scale_step = block_size > 0 ? slice_size : 1; |
||||
int64_t zp_step = zp_ ? scale_step : 0; |
||||
|
||||
for (int i = r.start; i < r.end; i += delta) { |
||||
int slice_idx = i / nblocks_per_axis; |
||||
int block_idx = i - slice_idx * nblocks_per_axis; |
||||
int64_t block_ofs, scale_ofs; |
||||
if (block_size > 0) { |
||||
delta = std::min(nblocks_per_axis - block_idx, r.end - i); |
||||
block_ofs = (slice_idx*sz_a + block_idx*block_size)*slice_size; |
||||
scale_ofs = (slice_idx*nblocks_per_axis + block_idx)*slice_size; |
||||
} else { |
||||
delta = std::min(sz_a - block_idx, r.end - i); |
||||
block_ofs = (slice_idx*sz_a + block_idx)*slice_size; |
||||
scale_ofs = block_idx; |
||||
} |
||||
const _InpTp* inp = inp_ + block_ofs; |
||||
const _OutTp* zp = zp_ ? zp_ + scale_ofs : nullptr; |
||||
const _ScaleTp* sc = scale_ + scale_ofs; |
||||
_OutTp* out = out_ + block_ofs; |
||||
|
||||
// [TODO] vectorize using intrinsics
|
||||
if (slice_size > 1) { |
||||
for (int k = 0; k < delta; k++, inp += slice_size, out += slice_size, |
||||
sc += scale_step, zp += zp_step) { |
||||
float scval = 1.f/(float)(*sc); |
||||
_OutTp zpval = zp ? *zp : (_InpTp)0; |
||||
|
||||
for (int64_t j = 0; j < slice_size; j++) |
||||
out[j] = saturate_cast<_OutTp>(inp[j]*scval + zpval); |
||||
} |
||||
} else if (block_size > 0 ) { |
||||
int bsz = block_size; |
||||
for (int k = 0; k < delta; k++, inp += bsz, out += bsz) { |
||||
bsz = std::min(bsz, sz_a - (block_idx + k)*block_size); |
||||
float scval = 1.f/(float)sc[k]; |
||||
_OutTp zpval = zp ? zp[k] : (_InpTp)0; |
||||
|
||||
for (int j = 0; j < bsz; j++) |
||||
out[j] = saturate_cast<_OutTp>(inp[j]*scval + zpval); |
||||
} |
||||
sc += delta; |
||||
zp += zp ? delta : 0; |
||||
} else { |
||||
// here we assume that scale's have been inversed in advance in the parent function
|
||||
if (zp) { |
||||
for (int j = 0; j < delta; j++) { |
||||
float scval = (float)sc[j]; |
||||
_OutTp zpval = zp[j]; |
||||
out[j] = saturate_cast<_OutTp>(inp[j]*scval + zpval); |
||||
} |
||||
} else { |
||||
for (int j = 0; j < delta; j++) { |
||||
float scval = (float)sc[j]; |
||||
out[j] = saturate_cast<_OutTp>(inp[j]*scval); |
||||
} |
||||
} |
||||
inp += delta; |
||||
out += delta; |
||||
} |
||||
} |
||||
}); |
||||
} |
||||
|
||||
// Dequantize INT8/UINT8 to FP32/FP16; out must be preallocated
|
||||
static void quantizeLinear(const Mat& inp, const Mat& scale_, const Mat& zp, |
||||
int axis, int block_size, Mat& out) |
||||
{ |
||||
Mat scale = scale_; |
||||
CV_Assert(inp.isContinuous()); |
||||
CV_Assert(scale.isContinuous()); |
||||
CV_Assert(out.isContinuous()); |
||||
|
||||
int inptype = inp.type(); |
||||
int outtype = out.type(); |
||||
int sctype = scale.type(); |
||||
int zptype = zp.type(); |
||||
MatShape inpshape = inp.shape(); |
||||
MatShape scshape = scale.shape(); |
||||
MatShape zpshape = zp.shape(); |
||||
int i, ndims = inpshape.dims; |
||||
int64_t nslices = 1, slice_size = 1; |
||||
|
||||
CV_Assert(inptype == CV_32F || inptype == CV_16F); |
||||
CV_Assert(sctype == CV_32F || sctype == CV_16F); |
||||
CV_Assert(outtype == CV_8U || outtype == CV_8S); |
||||
|
||||
if (!zp.empty()) { |
||||
CV_Assert(zp.isContinuous()); |
||||
CV_Assert(zptype == outtype); |
||||
CV_Assert(zpshape == scshape); |
||||
} |
||||
|
||||
axis = normalize_axis(axis, ndims); |
||||
for (i = 0; i < axis; i++) |
||||
nslices *= inpshape[i]; |
||||
for (i = axis+1; i < ndims; i++) |
||||
slice_size *= inpshape[i]; |
||||
int sz_a = inpshape[axis]; |
||||
|
||||
if (block_size == 0) { |
||||
size_t sc_total = scshape.total(); |
||||
CV_Assert(scale.dims <= 1); |
||||
CV_Assert(sc_total == 1 || sc_total == (size_t)sz_a); |
||||
|
||||
// unroll the innermost loop if the scale's/zp's are the same
|
||||
if (sc_total == 1) { |
||||
slice_size *= sz_a; |
||||
sz_a = 1; |
||||
} |
||||
|
||||
// avoid repeated inversion and FP16 => FP32 conversion inside the innermost loop
|
||||
if (slice_size == 1) { |
||||
Mat temp(scale.size(), CV_32F); |
||||
const float* scdata_32f = reinterpret_cast<const float*>(scale.data); |
||||
const hfloat* scdata_16f = reinterpret_cast<const hfloat*>(scale.data); |
||||
float* tempdata = temp.ptr<float>(); |
||||
|
||||
for (size_t i = 0; i < sc_total; i++) |
||||
tempdata[i] = 1.f/(sctype == CV_32F ? scdata_32f[i] : (float)scdata_16f[i]); |
||||
scale = temp; |
||||
sctype = CV_32F; |
||||
} |
||||
} else { |
||||
CV_Assert(block_size > 0); |
||||
CV_Assert(scale.dims == ndims); |
||||
for (int i = 0; i < ndims; i++) { |
||||
int inp_i = inpshape[i]; |
||||
int sc_i = scshape[i]; |
||||
if (i == axis) { |
||||
CV_Assert((inp_i + block_size - 1)/block_size == sc_i); |
||||
} else { |
||||
CV_Assert(sc_i == inp_i); |
||||
} |
||||
} |
||||
} |
||||
|
||||
if (outtype == CV_8U && sctype == CV_32F && inptype == CV_32F) |
||||
quantizeLinear(reinterpret_cast<const float*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<uint8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8U && sctype == CV_16F && inptype == CV_32F) |
||||
quantizeLinear(reinterpret_cast<const float*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<uint8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8U && sctype == CV_32F && inptype == CV_16F) |
||||
quantizeLinear(reinterpret_cast<const hfloat*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<uint8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8U && sctype == CV_16F && inptype == CV_16F) |
||||
quantizeLinear(reinterpret_cast<const hfloat*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const uint8_t*>(zp.data), |
||||
reinterpret_cast<uint8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8S && sctype == CV_32F && inptype == CV_32F) |
||||
quantizeLinear(reinterpret_cast<const float*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<int8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8S && sctype == CV_16F && inptype == CV_32F) |
||||
quantizeLinear(reinterpret_cast<const float*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<int8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8S && sctype == CV_32F && inptype == CV_16F) |
||||
quantizeLinear(reinterpret_cast<const hfloat*>(inp.data), |
||||
reinterpret_cast<const float*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<int8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else if (outtype == CV_8S && sctype == CV_16F && inptype == CV_16F) |
||||
quantizeLinear(reinterpret_cast<const hfloat*>(inp.data), |
||||
reinterpret_cast<const hfloat*>(scale.data), |
||||
reinterpret_cast<const int8_t*>(zp.data), |
||||
reinterpret_cast<int8_t*>(out.data), |
||||
nslices, sz_a, slice_size, block_size); |
||||
else { |
||||
CV_Error_(Error::StsNotImplemented, |
||||
("the following combination of types is not supported in " |
||||
"QuantizeLinear: inp=%s, scale=%s, out=%s", |
||||
typeToString(inptype).c_str(), |
||||
typeToString(sctype).c_str(), |
||||
typeToString(outtype).c_str())); |
||||
} |
||||
} |
||||
|
||||
class QuantizeLinearLayerImpl CV_FINAL : public QuantizeLinearLayer |
||||
{ |
||||
public: |
||||
QuantizeLinearLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
|
||||
axis = params.get<int>("axis", 1); |
||||
block_size = params.get<int>("block_size", 0); |
||||
saturate = params.get<bool>("saturate", true); |
||||
output_dtype = params.get<int>("output_dtype", -1); |
||||
CV_Assert(block_size >= 0); |
||||
CV_Assert(saturate); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int requiredOutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(2 <= ninputs && ninputs <= 3); |
||||
CV_Assert(requiredOutputs == 1); |
||||
outputs.assign(1, inputs[0]); |
||||
return true; |
||||
} |
||||
|
||||
int getOutType(int zptype) const |
||||
{ |
||||
return output_dtype >= 0 ? output_dtype : zptype; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(2 <= ninputs && ninputs <= 3); |
||||
int zptype = CV_8U; |
||||
if (ninputs == 3) { |
||||
zptype = inputs[2]; |
||||
} |
||||
outputs.assign(1, getOutType(zptype)); |
||||
} |
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
int ninputs = inputs_arr.size(-1).area(); |
||||
CV_Assert(2 <= ninputs && ninputs <= 3); |
||||
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
Mat scale = inputs_arr.getMat(1); |
||||
Mat zeropoint; |
||||
int zptype = CV_8U, outtype; |
||||
MatShape inpshape = inp.shape(); |
||||
|
||||
if (ninputs >= 3) { |
||||
zeropoint = inputs_arr.getMat(2); |
||||
zptype = zeropoint.type(); |
||||
} |
||||
|
||||
outtype = getOutType(zptype); |
||||
auto kind = outputs_arr.kind(); |
||||
|
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(inpshape, outtype); |
||||
quantizeLinear(inp, scale, zeropoint, axis, block_size, outs[0]); |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(inpshape, outtype); |
||||
Mat temp(inpshape, outtype); |
||||
quantizeLinear(inp, scale, zeropoint, axis, block_size, temp); |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<QuantizeLinearLayer> QuantizeLinearLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<QuantizeLinearLayer>(new QuantizeLinearLayerImpl(params)); |
||||
} |
||||
|
||||
}} |
@ -0,0 +1,224 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Range layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Range.html
|
||||
|
||||
Opset's 11 to 11 are covered. |
||||
*/ |
||||
|
||||
static int rangeSize(double start, double limit, double delta) |
||||
{ |
||||
return std::max((int)ceil((limit - start)/delta), 0); |
||||
} |
||||
|
||||
static int rangeSize(int64_t start, int64_t limit, int64_t delta) |
||||
{ |
||||
return delta > 0 ? |
||||
std::max((int)((limit - start + delta - 1)/delta), 0) : |
||||
std::max((int)((start - limit - delta - 1)/-delta), 0); |
||||
} |
||||
|
||||
// out must be pre-allocated
|
||||
template <typename _Tp> |
||||
static void makeRange(_Tp start, _Tp limit, _Tp delta, Mat& out) |
||||
{ |
||||
int nout = rangeSize(start, limit, delta); |
||||
CV_Assert(out.dims == 1); |
||||
CV_Assert(out.total() == (size_t)nout); |
||||
uchar* outdata_ = out.data; |
||||
|
||||
int type = out.type(); |
||||
|
||||
#undef IMPL_RANGE |
||||
#define IMPL_RANGE(T) \ |
||||
T* outdata = (T*)outdata_; \
|
||||
for (int i = 0; i < nout; i++) \
|
||||
outdata[i] = saturate_cast<T>(start + i*delta) |
||||
|
||||
if (type == CV_32F) { |
||||
IMPL_RANGE(float); |
||||
} else if (type == CV_64F) { |
||||
IMPL_RANGE(double); |
||||
} else if (type == CV_32S) { |
||||
IMPL_RANGE(int32_t); |
||||
} else if (type == CV_64S) { |
||||
IMPL_RANGE(int64_t); |
||||
} else { |
||||
CV_Error_(Error::StsNotImplemented, ("invalid/unsupported tensor type: %s", typeToString(out.type()).c_str())); |
||||
} |
||||
} |
||||
|
||||
class RangeLayerImpl CV_FINAL : public RangeLayer |
||||
{ |
||||
public: |
||||
RangeLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
CV_Assert(netimpl_); |
||||
CV_Assert(this->inputs.size() == 3); |
||||
return !netimpl_->isConstArg(this->inputs[0]) || |
||||
!netimpl_->isConstArg(this->inputs[1]) || |
||||
!netimpl_->isConstArg(this->inputs[2]); |
||||
} |
||||
|
||||
int getRangeParams(const Mat& startTensor, const Mat& limitTensor, const Mat& deltaTensor, |
||||
double& fstart, double& flimit, double& fdelta, |
||||
int64_t& istart, int64_t& ilimit, int64_t& idelta, bool& isflt) const |
||||
{ |
||||
CV_Assert(startTensor.total() == (size_t)1); |
||||
CV_Assert(limitTensor.total() == (size_t)1); |
||||
CV_Assert(deltaTensor.total() == (size_t)1); |
||||
|
||||
int rtype = startTensor.type(); |
||||
CV_Assert(rtype == limitTensor.type()); |
||||
CV_Assert(rtype == deltaTensor.type()); |
||||
|
||||
fstart = flimit = fdelta = 0.; |
||||
istart = ilimit = idelta = 0; |
||||
|
||||
isflt = rtype == CV_32F || rtype == CV_64F || rtype == CV_16F || rtype == CV_16BF; |
||||
|
||||
if (isflt) { |
||||
fstart = tensorToScalar<double>(startTensor); |
||||
flimit = tensorToScalar<double>(limitTensor); |
||||
fdelta = tensorToScalar<double>(deltaTensor); |
||||
|
||||
return rangeSize(fstart, flimit, fdelta); |
||||
} else { |
||||
istart = tensorToScalar<int64_t>(startTensor); |
||||
ilimit = tensorToScalar<int64_t>(limitTensor); |
||||
idelta = tensorToScalar<int64_t>(deltaTensor); |
||||
|
||||
return rangeSize(istart, ilimit, idelta); |
||||
} |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape>& inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(!dynamicOutputShapes()); |
||||
|
||||
CV_Assert(inputs.size() == (size_t)3); |
||||
CV_Assert(inputs.size() == this->inputs.size()); |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
|
||||
Mat startTensor = netimpl_->argTensor(this->inputs[0]); |
||||
Mat limitTensor = netimpl_->argTensor(this->inputs[1]); |
||||
Mat deltaTensor = netimpl_->argTensor(this->inputs[2]); |
||||
|
||||
double fstart, flimit, fdelta; |
||||
int64_t istart, ilimit, idelta; |
||||
bool isflt; |
||||
|
||||
int nout = getRangeParams(startTensor, limitTensor, deltaTensor, |
||||
fstart, flimit, fdelta, istart, ilimit, idelta, isflt); |
||||
MatShape shape(1); |
||||
shape[0] = nout; |
||||
outputs.assign(1, shape); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == (size_t)3); |
||||
CV_Assert(inputs[0] == inputs[1]); |
||||
CV_Assert(inputs[0] == inputs[2]); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 3); |
||||
|
||||
Mat startTensor = inputs_arr.getMat(0); |
||||
Mat limitTensor = inputs_arr.getMat(1); |
||||
Mat deltaTensor = inputs_arr.getMat(2); |
||||
|
||||
double fstart, flimit, fdelta; |
||||
int64_t istart, ilimit, idelta; |
||||
bool isflt; |
||||
|
||||
int nout = getRangeParams(startTensor, limitTensor, deltaTensor, |
||||
fstart, flimit, fdelta, istart, ilimit, idelta, isflt); |
||||
MatShape shape(1); |
||||
shape[0] = nout; |
||||
|
||||
int rtype = startTensor.type(); |
||||
|
||||
auto kind = outputs_arr.kind(); |
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(shape, rtype); |
||||
if (isflt) { |
||||
makeRange(fstart, flimit, fdelta, outs[0]); |
||||
} else { |
||||
makeRange(istart, ilimit, idelta, outs[0]); |
||||
} |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(shape, rtype); |
||||
Mat temp(shape, rtype); |
||||
if (isflt) { |
||||
makeRange(fstart, flimit, fdelta, temp); |
||||
} else { |
||||
makeRange(istart, ilimit, idelta, temp); |
||||
} |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<RangeLayer> RangeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<RangeLayer>(new RangeLayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,190 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Reshape2 layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Reshape.html
|
||||
|
||||
Opset's 1 to 23 are covered. |
||||
|
||||
The layers Flatten, Reshape2, Squeeze and Unsqueeze all share the same |
||||
implementation idea: |
||||
1. calculate shape of the output tensor |
||||
2. assuming that the input is continuous, just copy all the data to output tensor. |
||||
reshapeAndCopyFirst() does that. |
||||
The engine buffer allocator recognizes all these operations and tries to run |
||||
them in-place. In such a case no copy operation is actually done, |
||||
so the operations are really cheap. |
||||
*/ |
||||
|
||||
class Reshape2LayerImpl CV_FINAL : public Reshape2Layer |
||||
{ |
||||
public: |
||||
bool dynamicShapeSpec; |
||||
|
||||
Reshape2LayerImpl(const LayerParams& params) |
||||
{ |
||||
dynamicShapeSpec = true; |
||||
setParamsFrom(params); |
||||
if (params.has("shape")) |
||||
{ |
||||
dynamicShapeSpec = false; |
||||
|
||||
const DictValue& shapeParam = params.get("shape"); |
||||
int i, ndims = shapeParam.size(); |
||||
newShapeDesc.resize(ndims); |
||||
for (i = 0; i < ndims; i++) { |
||||
int sz = shapeParam.get<int>(i); |
||||
if (sz <= 0) |
||||
dynamicShapeSpec = true; |
||||
newShapeDesc[i] = sz; |
||||
} |
||||
} |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
// [TODO] fix. If the 'shape' spec is attribute,
|
||||
// or if shape is a constant 2nd input of the layer,
|
||||
// then the output shape can be inferred from the input tensor shape.
|
||||
// That is, dynamicShapeSpec is not quite incorrect.
|
||||
return dynamicShapeSpec; |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
bool haveShapeSpec() const |
||||
{ |
||||
return newShapeDesc.dims >= 0; |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpShape, MatShape& shapeSpec) const |
||||
{ |
||||
MatShape outShape = shapeSpec; |
||||
int m1idx = -1; |
||||
int i, ndims = outShape.dims; |
||||
int64_t outTotal = 1; |
||||
for (i = 0; i < ndims; i++) { |
||||
if (outShape[i] < 0) { |
||||
CV_Assert(outShape[i] == -1); |
||||
if (m1idx >= 0) { |
||||
CV_Error(Error::StsBadArg, "invalid shape spec, there must be at most one '-1'"); |
||||
} |
||||
m1idx = i; |
||||
} |
||||
else { |
||||
if (outShape[i] == 0) { |
||||
if (i >= inpShape.dims) { |
||||
CV_Error(Error::StsBadArg, "cannot copy dimension from the input tensor"); |
||||
} |
||||
outShape[i] = inpShape[i]; |
||||
} |
||||
outTotal *= outShape[i]; |
||||
} |
||||
} |
||||
|
||||
int64_t inpTotal = (int64_t)inpShape.total(); |
||||
if (m1idx >= 0) { |
||||
int64_t autoSize = inpTotal/outTotal; |
||||
CV_Assert(autoSize <= INT_MAX && autoSize*outTotal == inpTotal); |
||||
outShape[m1idx] = (int)autoSize; |
||||
} else { |
||||
CV_Assert(outTotal == inpTotal); |
||||
} |
||||
|
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
bool haveShapeSpec_ = haveShapeSpec(); |
||||
CV_Assert((inputs.size() == 1 && haveShapeSpec_) || |
||||
(inputs.size() == 2 && !haveShapeSpec_)); |
||||
MatShape shapeSpec = newShapeDesc, outShape; |
||||
|
||||
if (inputs.size() == 2) |
||||
{ |
||||
CV_Assert(this->inputs.size() == 2); |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
Mat shapeTensor = netimpl_->argTensor(this->inputs[1]); |
||||
shapeSpec = tensorToShape(shapeTensor); |
||||
} else { |
||||
CV_Assert(shapeSpec.dims >= 0); |
||||
} |
||||
outputs.assign(1, getOutShape(inputs[0], shapeSpec)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
bool haveShapeSpec_ = haveShapeSpec(); |
||||
CV_Assert((ninputs == 1 && haveShapeSpec_) || |
||||
(ninputs == 2 && !haveShapeSpec_)); |
||||
|
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
MatShape shapeSpec = newShapeDesc; |
||||
if (!haveShapeSpec_) { |
||||
Mat shapeTensor = inputs_arr.getMat(1); |
||||
shapeSpec = tensorToShape(shapeTensor); |
||||
} |
||||
MatShape outShape = getOutShape(inpShape, shapeSpec); |
||||
reshapeAndCopyFirst(inputs_arr, outputs_arr, outShape); |
||||
} |
||||
}; |
||||
|
||||
Ptr<Reshape2Layer> Reshape2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Reshape2Layer>(new Reshape2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,137 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
class ShapeLayerImpl CV_FINAL : public ShapeLayer |
||||
{ |
||||
public: |
||||
typedef int64_t shape_type_t; |
||||
int shapeType; |
||||
|
||||
ShapeLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
|
||||
start = params.get<int>("start", 0); |
||||
end = params.get<int>("end", INT_MAX); |
||||
shapeType = DataType<shape_type_t>::type; |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
Range getShapeRange(const MatShape& inpShape) const |
||||
{ |
||||
int outDims = inpShape.dims; |
||||
int start_ = start < 0 ? start + outDims : start; |
||||
int end_ = end >= outDims ? outDims : end < 0 ? end + outDims : end; |
||||
|
||||
CV_Assert(0 <= start_); |
||||
CV_Assert(start_ <= end_); |
||||
CV_Assert(end_ <= outDims); |
||||
|
||||
return Range(start_, end_); |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpShape) const |
||||
{ |
||||
MatShape outShape; |
||||
outShape.dims = 1; |
||||
|
||||
Range r = getShapeRange(inpShape); |
||||
|
||||
outShape[0] = r.end - r.start; |
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int requiredOutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 1); |
||||
|
||||
outputs.assign(1, getOutShape(inputs[0])); |
||||
internals.clear(); |
||||
|
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 1); |
||||
outputs.assign(requiredOutputs, shapeType); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 1); |
||||
|
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
Range r = getShapeRange(inpShape); |
||||
|
||||
shape_type_t shapeData[CV_MAX_DIM]; |
||||
for (int i = r.start; i < r.end; i++) |
||||
shapeData[i] = (shape_type_t)inpShape[i]; |
||||
|
||||
Mat shape({r.end - r.start}, shapeType, shapeData); |
||||
|
||||
int outKind = outputs_arr.kind(); |
||||
|
||||
if (outKind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& out = outputs_arr.getMatVecRef(); |
||||
CV_Assert(out.size() == 1); |
||||
shape.copyTo(out[0]); |
||||
} else if (outKind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& out = outputs_arr.getUMatVecRef(); |
||||
CV_Assert(out.size() == 1); |
||||
shape.copyTo(out[0]); |
||||
} else { |
||||
CV_Error_(Error::StsBadArg, ("invalid/unsupported outputs_arr kind: %d", outKind)); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<ShapeLayer> ShapeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<ShapeLayer>(new ShapeLayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,359 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Slice2 layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Slice2.html
|
||||
|
||||
Opset's 1 to 13 are covered. |
||||
*/ |
||||
|
||||
/* Slice op for CPU.
|
||||
starts_, ends_ and steps_ must contain as many elements as |
||||
the dimensionality in inp and out. |
||||
*/ |
||||
static void slice(const Mat& inp, const int* starts_, |
||||
const int*, const int* steps_, |
||||
Mat& out) |
||||
{ |
||||
/// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
/// in this function steps can be negative, so
|
||||
/// please don't replace int64_t's with size_t's
|
||||
/// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
enum {SLICE_MAX_DIMS=7}; |
||||
|
||||
CV_Assert_N(inp.isContinuous(), out.isContinuous()); |
||||
CV_Assert(inp.type() == out.type()); |
||||
CV_Assert_N(inp.dims <= SLICE_MAX_DIMS, inp.dims == out.dims); |
||||
|
||||
MatShape inpShape = inp.shape(); |
||||
MatShape outShape = out.shape(); |
||||
int64_t esz = (int64_t)inp.elemSize(); |
||||
|
||||
int ndims = inpShape.dims; |
||||
int starts[SLICE_MAX_DIMS], steps[SLICE_MAX_DIMS]; |
||||
int inpsz[SLICE_MAX_DIMS], outsz[SLICE_MAX_DIMS]; |
||||
int64_t inpstep[SLICE_MAX_DIMS]; |
||||
|
||||
int delta = SLICE_MAX_DIMS - ndims; |
||||
bool emptyOut = false; |
||||
|
||||
for (int i = 0; i < SLICE_MAX_DIMS; i++) { |
||||
inpsz[i] = outsz[i] = steps[i] = 1; |
||||
starts[i] = 0; |
||||
} |
||||
|
||||
for (int i = 0; i < ndims; i++) { |
||||
inpsz[delta + i] = inpShape[i]; |
||||
outsz[delta + i] = outShape[i]; |
||||
starts[delta + i] = starts_[i]; |
||||
steps[delta + i] = steps_[i]; |
||||
if (outShape[i] == 0) |
||||
emptyOut = true; |
||||
} |
||||
|
||||
for (int i = SLICE_MAX_DIMS-1; i >= 0; i--) |
||||
inpstep[i] = i == SLICE_MAX_DIMS-1 ? 1 : inpstep[i+1]*inpsz[i+1]; |
||||
|
||||
const uchar* inptr0 = inp.data; |
||||
|
||||
for (int i = 0; i < SLICE_MAX_DIMS; i++) { |
||||
inptr0 += starts[i]*inpstep[i]*esz; |
||||
inpstep[i] *= steps[i]; |
||||
} |
||||
|
||||
int sz0 = outsz[6], sz1 = outsz[5]; |
||||
int sz2 = outsz[4], sz3 = outsz[3]; |
||||
int sz4 = outsz[2], sz5 = outsz[1], sz6 = outsz[0]; |
||||
int64_t p0 = inpstep[6], p1 = inpstep[5]; |
||||
int64_t p2 = inpstep[4], p3 = inpstep[3]; |
||||
int64_t p4 = inpstep[2], p5 = inpstep[1], p6 = inpstep[0]; |
||||
|
||||
#undef CV_IMPLEMENT_SLICE |
||||
#define CV_IMPLEMENT_SLICE(typ) \ |
||||
typ* outptr = (typ*)(out.data); \
|
||||
for(int i6 = 0; i6 < sz6; i6++) { \
|
||||
for(int i5 = 0; i5 < sz5; i5++) { \
|
||||
for(int i4 = 0; i4 < sz4; i4++) { \
|
||||
for(int i3 = 0; i3 < sz3; i3++) { \
|
||||
for(int i2 = 0; i2 < sz2; i2++) { \
|
||||
for(int i1 = 0; i1 < sz1; i1++, outptr += sz0) { \
|
||||
const typ* inptr = (const typ*)inptr0 + i6*p6 + \
|
||||
i5*p5 + i4*p4 + i3*p3 + i2*p2 + i1*p1; \
|
||||
int i0 = 0; \
|
||||
if (p0 == 1) { \
|
||||
for (; i0 < sz0; i0++) \
|
||||
outptr[i0] = inptr[i0]; \
|
||||
} \
|
||||
else { \
|
||||
for (; i0 <= sz0 - 4; i0 += 4) { \
|
||||
int64_t ip0 = i0*p0; \
|
||||
typ t0 = inptr[ip0], t1 = inptr[ip0 + p0]; \
|
||||
typ t2 = inptr[ip0 + p0*2], t3 = inptr[ip0 + p0*3]; \
|
||||
outptr[i0] = t0; outptr[i0+1] = t1; \
|
||||
outptr[i0+2] = t2; outptr[i0+3] = t3; \
|
||||
} \
|
||||
for (; i0 < sz0; i0++) \
|
||||
outptr[i0] = inptr[i0*p0]; \
|
||||
} \
|
||||
}}}}}} |
||||
|
||||
if (emptyOut) return; |
||||
if (esz == 4) { |
||||
CV_IMPLEMENT_SLICE(int) |
||||
} else if (esz == 2) { |
||||
CV_IMPLEMENT_SLICE(int16_t) |
||||
} else if (esz == 1) { |
||||
CV_IMPLEMENT_SLICE(int8_t) |
||||
} else if (esz == 8) { |
||||
CV_IMPLEMENT_SLICE(int64_t) |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
|
||||
class Slice2LayerImpl CV_FINAL : public Slice2Layer |
||||
{ |
||||
public: |
||||
Slice2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
axes = params.getVector<int>("axes"); |
||||
starts = params.getVector<int>("starts"); |
||||
ends = params.getVector<int>("ends"); |
||||
} |
||||
|
||||
void checkNumInputs(size_t ninputs) const |
||||
{ |
||||
CV_Assert(ninputs == 1 || (3 <= ninputs && ninputs <= 5)); |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
size_t ninputs = inputs.size(); |
||||
|
||||
for (size_t i = 1; i < ninputs; i++) { |
||||
if (!netimpl_->isConstArg(inputs[i])) |
||||
return true; |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpShape, |
||||
const std::vector<int>& starts_, |
||||
const std::vector<int>& ends_, |
||||
const std::vector<int>& axes_, |
||||
const std::vector<int>& steps_, |
||||
int* allStarts = nullptr, |
||||
int* allEnds = nullptr, |
||||
int* allSteps = nullptr) const |
||||
{ |
||||
bool sliceMask[MatShape::MAX_DIMS]; |
||||
|
||||
int ndims = inpShape.dims; |
||||
int nstarts = (int)starts_.size(), nends = (int)ends_.size(); |
||||
int naxes = (int)axes_.size(), nsteps = (int)steps_.size(); |
||||
|
||||
CV_Assert_N(nstarts > 0, nstarts <= ndims, nstarts == nends); |
||||
CV_Assert(naxes == 0 || naxes == nstarts); |
||||
CV_Assert(nsteps == 0 || nsteps == nstarts); |
||||
|
||||
MatShape outShape = inpShape; |
||||
|
||||
for (int i = 0; i < ndims; i++) { |
||||
sliceMask[i] = false; |
||||
if (allStarts) |
||||
allStarts[i] = 0; |
||||
if (allEnds) |
||||
allEnds[i] = inpShape[i]; |
||||
if (allSteps) |
||||
allSteps[i] = 1; |
||||
} |
||||
|
||||
for (int i = 0; i < nstarts; i++) { |
||||
int axis = i; |
||||
if (!axes_.empty()) { |
||||
axis = axes_[i]; |
||||
axis = normalize_axis(axis, ndims); |
||||
if (sliceMask[axis]) { |
||||
CV_Error(Error::StsBadArg, "duplicate axis occurs in Slice"); |
||||
} |
||||
} |
||||
sliceMask[axis] = true; |
||||
int inpsz = inpShape[axis]; |
||||
int start = starts_[i]; |
||||
int end = ends_[i]; |
||||
int step = 1; |
||||
if (!steps_.empty()) |
||||
step = steps_[i]; |
||||
CV_Assert(step != 0); |
||||
start = start < 0 ? std::max(start + inpsz, 0) : |
||||
std::min(start, inpsz - (step < 0)); |
||||
end = end < 0 ? std::max(end + inpsz, -(step < 0)) : |
||||
std::min(end, inpsz); |
||||
if (allStarts) |
||||
allStarts[axis] = start; |
||||
if (allEnds) |
||||
allEnds[axis] = end; |
||||
if (allSteps) |
||||
allSteps[axis] = step; |
||||
int outsz = step > 0 ? (end - start + step-1)/step : |
||||
(start - end - step-1)/(-step); |
||||
CV_Assert(outsz >= 0); |
||||
outShape[axis] = outsz; |
||||
} |
||||
|
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
checkNumInputs(ninputs); |
||||
std::vector<int> tempStarts, tempEnds, tempAxes, steps; |
||||
const std::vector<int> *starts_ = &starts, *ends_ = &ends, *axes_ = &axes; |
||||
|
||||
if (ninputs > 1) { |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
Mat startsTensor = netimpl_->argTensor(this->inputs[1]); |
||||
tensorToIntVec(startsTensor, tempStarts); |
||||
starts_ = &tempStarts; |
||||
Mat endsTensor = netimpl_->argTensor(this->inputs[2]); |
||||
tensorToIntVec(endsTensor, tempEnds); |
||||
ends_ = &tempEnds; |
||||
if (ninputs > 3) { |
||||
Mat axesTensor = netimpl_->argTensor(this->inputs[3]); |
||||
tensorToIntVec(axesTensor, tempAxes); |
||||
axes_ = &tempAxes; |
||||
} |
||||
if (ninputs > 4) { |
||||
Mat stepsTensor = netimpl_->argTensor(this->inputs[4]); |
||||
tensorToIntVec(stepsTensor, steps); |
||||
} |
||||
} |
||||
MatShape outShape = getOutShape(inputs[0], *starts_, *ends_, *axes_, steps); |
||||
outputs.assign(1, outShape); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
checkNumInputs(ninputs); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
checkNumInputs(ninputs); |
||||
|
||||
int inpType = inputs_arr.type(0); |
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
std::vector<int> tempStarts, tempEnds, tempAxes, steps; |
||||
const std::vector<int> *starts_ = &starts, *ends_ = &ends, *axes_ = &axes; |
||||
|
||||
if (ninputs > 1) { |
||||
Mat startsTensor = inputs_arr.getMat(1); |
||||
tensorToIntVec(startsTensor, tempStarts); |
||||
starts_ = &tempStarts; |
||||
Mat endsTensor = inputs_arr.getMat(2); |
||||
tensorToIntVec(endsTensor, tempEnds); |
||||
ends_ = &tempEnds; |
||||
if (ninputs > 3) { |
||||
Mat axesTensor = inputs_arr.getMat(3); |
||||
tensorToIntVec(axesTensor, tempAxes); |
||||
axes_ = &tempAxes; |
||||
} |
||||
if (ninputs > 4) { |
||||
Mat stepsTensor = inputs_arr.getMat(4); |
||||
tensorToIntVec(stepsTensor, steps); |
||||
} |
||||
} |
||||
int allStarts[MatShape::MAX_DIMS]; |
||||
int allEnds[MatShape::MAX_DIMS]; |
||||
int allSteps[MatShape::MAX_DIMS]; |
||||
MatShape outShape = getOutShape(inpShape, *starts_, *ends_, *axes_, steps, |
||||
allStarts, allEnds, allSteps); |
||||
|
||||
int outKind = outputs_arr.kind(); |
||||
|
||||
CV_Assert(outKind == _InputArray::STD_VECTOR_MAT || |
||||
outKind == _InputArray::STD_VECTOR_UMAT); |
||||
|
||||
if (outKind == _InputArray::STD_VECTOR_MAT) { |
||||
Mat inp = inputs_arr.getMat(0); |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, inpType); |
||||
runOp(inp, allStarts, allEnds, allSteps, outs[0]); |
||||
} else { |
||||
// [TODO] more efficient OpenCL implementation
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, inpType); |
||||
Mat temp(outShape, inpType); |
||||
runOp(inp, allStarts, allEnds, allSteps, temp); |
||||
temp.copyTo(outs[0]); |
||||
} |
||||
} |
||||
|
||||
void runOp(const Mat& inp, const int* starts_, |
||||
const int* ends_, const int* steps_, Mat& out) |
||||
{ |
||||
slice(inp, starts_, ends_, steps_, out); |
||||
} |
||||
}; |
||||
|
||||
Ptr<Slice2Layer> Slice2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Slice2Layer>(new Slice2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,266 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Split2 layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Split2.html
|
||||
|
||||
Opset's 1 to 13 are covered. |
||||
*/ |
||||
|
||||
// all outputs must be pre-allocated.
|
||||
// axis must be normalized
|
||||
static void split(const Mat& inp, std::vector<Mat>& outs, int axis) |
||||
{ |
||||
CV_Assert(inp.isContinuous()); |
||||
|
||||
MatShape inpShape = inp.shape(); |
||||
int ndims = inpShape.dims; |
||||
|
||||
CV_Assert_N(0 <= axis, axis <= inp.dims); |
||||
|
||||
int nslices = 1; |
||||
int inpType = inp.type(); |
||||
size_t esz = inp.elemSize(); |
||||
size_t sliceSize = esz; |
||||
size_t inpStep = 0; |
||||
size_t totalSize = inp.total()*esz; |
||||
int outSize_a = 0; |
||||
for (int i = ndims-1; i > axis; i--) |
||||
sliceSize *= inpShape[i]; |
||||
inpStep = sliceSize*inpShape[axis]; |
||||
for (int i = 0; i < axis; i++) |
||||
nslices *= inpShape[i]; |
||||
|
||||
size_t noutputs = outs.size(); |
||||
for (size_t k = 0; k < noutputs; k++) { |
||||
Mat& out = outs[k]; |
||||
MatShape outShape = out.shape(); |
||||
CV_Assert(out.isContinuous()); |
||||
CV_Assert(out.type() == inpType); |
||||
CV_Assert(out.dims == ndims); |
||||
for (int i = 0; i < ndims; i++) { |
||||
if (i == axis) |
||||
outSize_a += outShape[i]; |
||||
else { |
||||
CV_Assert(inpShape[i] == outShape[i]); |
||||
} |
||||
} |
||||
} |
||||
|
||||
CV_Assert(outSize_a == inpShape[axis]); |
||||
|
||||
parallel_for_(Range(0, (int)noutputs), [&](const Range& r) { |
||||
for (int k = r.start; k < r.end; k++) { |
||||
const uchar* inptr = inp.data; |
||||
Mat& out_k = outs[k]; |
||||
uchar* outptr_k = out_k.data; |
||||
int sz_a; |
||||
for (int i = 0; i < k; i++) { |
||||
sz_a = outs[i].size[axis]; |
||||
inptr += sliceSize*sz_a; |
||||
} |
||||
sz_a = out_k.size[axis]; |
||||
size_t sliceSize_k = sliceSize*sz_a; |
||||
for (int i = 0; i < nslices; i++) |
||||
memcpy(outptr_k + i*sliceSize_k, inptr + i*inpStep, sliceSize_k); |
||||
} |
||||
}, (totalSize > 1000000 ? noutputs : 1)); |
||||
} |
||||
|
||||
class Split2LayerImpl CV_FINAL : public Split2Layer |
||||
{ |
||||
public: |
||||
Split2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
axis = params.get<int>("axis", 1); |
||||
split = params.getVector<int>("split"); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
void getOutShapes(const MatShape& inpShape, int axis_, |
||||
const std::vector<int>& split, |
||||
std::vector<MatShape>& outShapes) const |
||||
{ |
||||
size_t noutputs = split.size(); |
||||
CV_Assert(noutputs == outputs.size()); |
||||
|
||||
int inpDims = inpShape.dims; |
||||
CV_Assert(0 <= axis_ && axis_ < inpDims); |
||||
int totalSize_a = 0; |
||||
|
||||
outShapes.resize(noutputs); |
||||
for (size_t i = 0; i < noutputs; i++) { |
||||
MatShape outShape = inpShape; |
||||
int s = split[i]; |
||||
CV_Assert(s >= 0); |
||||
CV_Assert(s <= inpShape[axis_] - totalSize_a); |
||||
outShape[axis_] = s; |
||||
outShapes[i] = outShape; |
||||
totalSize_a += s; |
||||
} |
||||
} |
||||
|
||||
void makeDefaultSplit(int totalSize, size_t noutputs, std::vector<int>& split_) const |
||||
{ |
||||
split_.resize(noutputs); |
||||
int chunkSize = (int)((totalSize + noutputs - 1) / noutputs); |
||||
for (size_t i = 0; i < noutputs; i++) { |
||||
int sz_i = std::min(totalSize, chunkSize); |
||||
split_[i] = sz_i; |
||||
totalSize -= sz_i; |
||||
} |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int noutputs, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(noutputs == (int)this->outputs.size()); |
||||
|
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
|
||||
MatShape inpShape = inputs[0]; |
||||
std::vector<int> tempSplit; |
||||
const std::vector<int>* split_ = &split; |
||||
int axis_ = normalize_axis(axis, inpShape.dims); |
||||
|
||||
if (ninputs == 2) { |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
Mat splitTensor = netimpl_->argTensor(this->inputs[1]); |
||||
tensorToIntVec(splitTensor, tempSplit); |
||||
split_ = &tempSplit; |
||||
} |
||||
else if (split.empty()) { |
||||
makeDefaultSplit(inpShape[axis_], noutputs, tempSplit); |
||||
split_ = &tempSplit; |
||||
} |
||||
|
||||
getOutShapes(inputs[0], axis_, *split_, outputs); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
int noutputs = (int)outputs.size(); |
||||
|
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
|
||||
int inpType = inputs_arr.type(0); |
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
std::vector<int> tempSplit; |
||||
const std::vector<int>* split_ = &split; |
||||
std::vector<MatShape> outShapes; |
||||
|
||||
int axis_ = normalize_axis(axis, inpShape.dims); |
||||
|
||||
if (ninputs == 2) { |
||||
Mat splitTensor = inputs_arr.getMat(1); |
||||
tensorToIntVec(splitTensor, tempSplit); |
||||
split_ = &tempSplit; |
||||
} |
||||
else if (split.empty()) { |
||||
makeDefaultSplit(inpShape[axis_], noutputs, tempSplit); |
||||
split_ = &tempSplit; |
||||
} |
||||
getOutShapes(inpShape, axis_, *split_, outShapes); |
||||
CV_Assert(outShapes.size() == (size_t)noutputs); |
||||
|
||||
int outKind = outputs_arr.kind(); |
||||
|
||||
CV_Assert(outKind == _InputArray::STD_VECTOR_MAT || |
||||
outKind == _InputArray::STD_VECTOR_UMAT); |
||||
|
||||
if (outKind == _InputArray::STD_VECTOR_MAT) { |
||||
Mat inp = inputs_arr.getMat(0); |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(noutputs); |
||||
for (int i = 0; i < noutputs; i++) { |
||||
MatShape outShape = outShapes[i]; |
||||
outs[i].fit(outShape, inpType); |
||||
} |
||||
runOp(inp, outs, axis_); |
||||
} else { |
||||
// [TODO] more efficient OpenCL implementation
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(noutputs); |
||||
|
||||
std::vector<Mat> temps(noutputs); |
||||
for (int i = 0; i < noutputs; i++) { |
||||
MatShape outShape = outShapes[i]; |
||||
temps[i].fit(outShape, inpType); |
||||
} |
||||
runOp(inp, temps, axis_); |
||||
for (int i = 0; i < noutputs; i++) { |
||||
MatShape outShape = outShapes[i]; |
||||
outs[i].fit(outShape, inpType); |
||||
temps[i].copyTo(outs[i]); |
||||
temps[i].release(); |
||||
} |
||||
} |
||||
} |
||||
|
||||
void runOp(const Mat& inp, std::vector<Mat>& outs, int axis_) |
||||
{ |
||||
cv::dnn::split(inp, outs, axis_); |
||||
} |
||||
}; |
||||
|
||||
Ptr<Split2Layer> Split2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Split2Layer>(new Split2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,159 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Squeeze layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Squeeze.html
|
||||
|
||||
Opset's 1 to 13 are covered. |
||||
|
||||
See description in reshape2_layer.cpp |
||||
for more some common implementation details. |
||||
*/ |
||||
class SqueezeLayerImpl CV_FINAL : public SqueezeLayer |
||||
{ |
||||
public: |
||||
SqueezeLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
axes = params.getVector<int>("axes"); |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
return inputs.size() == 2 && !netimpl_->isConstArg(inputs[1]); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpShape, const std::vector<int>& axes_) const |
||||
{ |
||||
bool squeezeMask[MatShape::MAX_DIMS]; |
||||
|
||||
if (axes_.empty()) { |
||||
// remove all 1's
|
||||
for (int i = 0; i < inpShape.dims; i++) |
||||
squeezeMask[i] = inpShape[i] == 1; |
||||
} else { |
||||
for (int i = 0; i < inpShape.dims; i++) |
||||
squeezeMask[i] = false; |
||||
for (int a: axes_) { |
||||
int a_ = normalize_axis(a, inpShape.dims); |
||||
if (squeezeMask[a_]) { |
||||
CV_Error_(Error::StsBadArg, ("duplicate squeezed axis #%d", a)); |
||||
} |
||||
if (inpShape[a_] != 1) { |
||||
CV_Error_(Error::StsBadArg, ("squeezed axis #%d (== %d) != 1", a, inpShape[a_])); |
||||
} |
||||
squeezeMask[a_] = true; |
||||
} |
||||
} |
||||
|
||||
MatShape outShape(inpShape.dims); |
||||
int j = 0; |
||||
for (int i = 0; i < inpShape.dims; i++) { |
||||
if (!squeezeMask[i]) |
||||
outShape[j++] = inpShape[i]; |
||||
} |
||||
outShape.dims = j; |
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 1 || inputs.size() == 2); |
||||
MatShape outShape; |
||||
std::vector<int> tempAxes; |
||||
const std::vector<int>* axes_ = &axes; |
||||
|
||||
if (inputs.size() == 2) |
||||
{ |
||||
CV_Assert(axes.empty()); // if we have a dedicated 'axes' input,
|
||||
// we should not have 'axes' attribute at the same time
|
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
Mat axesTensor = netimpl_->argTensor(this->inputs[1]); |
||||
tensorToIntVec(axesTensor, tempAxes); |
||||
axes_ = &tempAxes; |
||||
} |
||||
outputs.assign(1, getOutShape(inputs[0], *axes_)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
|
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
std::vector<int> tempAxes; |
||||
const std::vector<int>* axes_ = &axes; |
||||
|
||||
if (ninputs == 2) |
||||
{ |
||||
CV_Assert(axes.empty()); // if we have a dedicated 'axes' input,
|
||||
// we should not have 'axes' attribute at the same time
|
||||
Mat axesTensor = inputs_arr.getMat(1); |
||||
tensorToIntVec(axesTensor, tempAxes); |
||||
axes_ = &tempAxes; |
||||
} |
||||
MatShape outShape = getOutShape(inpShape, *axes_); |
||||
reshapeAndCopyFirst(inputs_arr, outputs_arr, outShape); |
||||
} |
||||
}; |
||||
|
||||
Ptr<SqueezeLayer> SqueezeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<SqueezeLayer>(new SqueezeLayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,304 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
static constexpr int TILE_MAX_DIMS = 6; |
||||
|
||||
/*
|
||||
Tile layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Tile.html
|
||||
|
||||
Opset's 1 to 13 are covered. |
||||
*/ |
||||
|
||||
// out must be pre-allocated
|
||||
// repeats_[] should contains as many elements as inp.dims (== out.dims)
|
||||
static void tile(const Mat& inp, const int* repeats_, Mat& out) |
||||
{ |
||||
MatShape inpshape_ = inp.shape(); |
||||
MatShape outshape_ = out.shape(); |
||||
const uchar* inpdata0 = inp.data; |
||||
uchar* outdata0_ = out.data; |
||||
|
||||
int inpshape[TILE_MAX_DIMS]; |
||||
int outshape[TILE_MAX_DIMS]; |
||||
int repeats[TILE_MAX_DIMS]; |
||||
int64_t inpstep[TILE_MAX_DIMS]; |
||||
int64_t outstep[TILE_MAX_DIMS]; |
||||
|
||||
int ndims = inp.dims, delta = TILE_MAX_DIMS - ndims; |
||||
int64_t esz = inp.elemSize(); |
||||
int64_t total_size = 1, total_repeats = 1; |
||||
|
||||
CV_Assert(inp.isContinuous()); |
||||
CV_Assert(out.isContinuous()); |
||||
CV_Assert(inp.type() == out.type()); |
||||
CV_Assert(esz == 1 || esz == 2 || esz == 4 || esz == 8); |
||||
CV_Assert(inp.dims == out.dims); |
||||
CV_Assert(inp.dims <= TILE_MAX_DIMS); |
||||
|
||||
for (int i = 0; i < TILE_MAX_DIMS; i++) { |
||||
inpshape[i] = outshape[i] = repeats[i] = 1; |
||||
} |
||||
|
||||
for (int i = 0; i < ndims; i++) { |
||||
inpshape[i + delta] = inpshape_[i]; |
||||
outshape[i + delta] = outshape_[i]; |
||||
repeats[i + delta] = repeats_[i]; |
||||
|
||||
CV_Assert(inpshape_[i]*repeats_[i] == outshape_[i]); |
||||
|
||||
total_size *= outshape_[i]; |
||||
total_repeats *= repeats_[i]; |
||||
} |
||||
|
||||
for (int i = TILE_MAX_DIMS-1; i >= 0; i--) { |
||||
if (i == TILE_MAX_DIMS-1) |
||||
inpstep[i] = outstep[i] = 1; |
||||
else { |
||||
inpstep[i] = inpstep[i+1]*inpshape[i+1]; |
||||
outstep[i] = outstep[i+1]*outshape[i+1]; |
||||
} |
||||
} |
||||
|
||||
int ntasks = 8; |
||||
if (ntasks > total_repeats) |
||||
ntasks = (int)total_repeats; |
||||
if (total_size < 1000000) |
||||
ntasks = 1; |
||||
|
||||
parallel_for_(Range(0, ntasks), [&](const Range& r) |
||||
{ |
||||
int sz0 = inpshape[0], sz1 = inpshape[1], sz2 = inpshape[2]; |
||||
int sz3 = inpshape[3], sz4 = inpshape[4], sz5 = inpshape[5]; |
||||
|
||||
int64_t outstep_prelast = outstep[TILE_MAX_DIMS-2]; |
||||
int64_t j0 = r.start*total_repeats/ntasks, j1 = r.end*total_repeats/ntasks; |
||||
|
||||
for (int64_t j = j0; j < j1; j++) |
||||
{ |
||||
// convert raw tile index into n-dim tile index.
|
||||
// but we don't need this nd-index itself, we just need the
|
||||
// offset of the tile in the output tensor
|
||||
int64_t j_ = j, rawofs = 0; |
||||
for (int k = TILE_MAX_DIMS-1; k >= 0; k--) { |
||||
int r = repeats[k]; |
||||
int64_t q = j_ / r; |
||||
rawofs += (j_ - q*r)*inpshape[k]*outstep[k]; |
||||
j_ = q; |
||||
} |
||||
|
||||
#undef IMPL_COPY_TILE |
||||
#define IMPL_COPY_TILE(T) \ |
||||
T* inpdata = (T*)inpdata0; \
|
||||
T* outdata0 = (T*)outdata0_ + rawofs; \
|
||||
for (int i0 = 0; i0 < sz0; i0++) { \
|
||||
for (int i1 = 0; i1 < sz1; i1++) { \
|
||||
for (int i2 = 0; i2 < sz2; i2++) { \
|
||||
for (int i3 = 0; i3 < sz3; i3++) { \
|
||||
T* outdata = outdata0 + i0*outstep[0] + i1*outstep[1] + i2*outstep[2] + i3*outstep[3]; \
|
||||
for (int i4 = 0; i4 < sz4; i4++, outdata += outstep_prelast, inpdata += sz5) { \
|
||||
for (int i5 = 0; i5 < sz5; i5++) \
|
||||
outdata[i5] = inpdata[i5]; \
|
||||
} \
|
||||
}}}} |
||||
|
||||
if (esz == 1) { |
||||
IMPL_COPY_TILE(uint8_t) |
||||
} else if (esz == 2) { |
||||
IMPL_COPY_TILE(uint16_t) |
||||
} else if (esz == 4) { |
||||
IMPL_COPY_TILE(uint32_t) |
||||
} else { |
||||
IMPL_COPY_TILE(uint64_t) |
||||
} |
||||
} |
||||
} |
||||
, ntasks); |
||||
} |
||||
|
||||
class Tile2LayerImpl CV_FINAL : public Tile2Layer |
||||
{ |
||||
public: |
||||
Tile2LayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
CV_Assert(netimpl_); |
||||
size_t ninputs = this->inputs.size(); |
||||
CV_Assert(ninputs == 2 || ninputs == 3); |
||||
return !netimpl_->isConstArg(this->inputs[1]) || |
||||
(ninputs == 3 && !netimpl_->isConstArg(this->inputs[2])); |
||||
} |
||||
|
||||
void getRepeats(const Mat& repeats_, const Mat& axes_, int ndims, int* repeats) const |
||||
{ |
||||
int atype = axes_.type(), rtype = repeats_.type(); |
||||
CV_Assert(ndims <= TILE_MAX_DIMS); |
||||
|
||||
const int32_t* adata_i32 = nullptr; |
||||
const int64_t* adata_i64 = nullptr; |
||||
const int32_t* rdata_i32 = nullptr; |
||||
const int64_t* rdata_i64 = nullptr; |
||||
|
||||
bool axismask[TILE_MAX_DIMS]; |
||||
|
||||
CV_Assert(repeats_.dims == 1); |
||||
CV_Assert(rtype == CV_32S || rtype == CV_64S); |
||||
|
||||
if (rtype == CV_32S) |
||||
rdata_i32 = reinterpret_cast<const int32_t*>(repeats_.data); |
||||
else |
||||
rdata_i64 = reinterpret_cast<const int64_t*>(repeats_.data); |
||||
|
||||
if (!axes_.empty()) { |
||||
CV_Assert(axes_.dims == 1); |
||||
CV_Assert(atype == CV_32S || atype == CV_64S); |
||||
CV_Assert(repeats_.total() == axes_.total()); |
||||
CV_Assert(axes_.total() <= (size_t)ndims); |
||||
|
||||
if (atype == CV_32S) |
||||
adata_i32 = reinterpret_cast<const int32_t*>(axes_.data); |
||||
else |
||||
adata_i64 = reinterpret_cast<const int64_t*>(axes_.data); |
||||
} else { |
||||
CV_Assert(repeats_.total() == (size_t)ndims); |
||||
} |
||||
|
||||
for (int i = 0; i < ndims; i++) { |
||||
repeats[i] = 1; |
||||
axismask[i] = false; |
||||
} |
||||
|
||||
int nrepeats = (int)repeats_.total(); |
||||
for (int i = 0; i < nrepeats; i++) { |
||||
int a = adata_i32 ? (int)adata_i32[i] : adata_i64 ? (int)adata_i64[i] : i; |
||||
a = normalize_axis(a, ndims); |
||||
if (axismask[a]) { |
||||
CV_Error_(Error::StsBadArg, ("duplicate axis %d in Tile", a)); |
||||
} |
||||
axismask[a] = true; |
||||
int r = rdata_i32 ? (int)rdata_i32[i] : rdata_i64 ? (int)rdata_i64[i] : 1; |
||||
repeats[a] = r; |
||||
} |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpshape, const int* repeats) const |
||||
{ |
||||
MatShape outshape = inpshape; |
||||
for (int i = 0; i < outshape.dims; i++) |
||||
outshape[i] *= repeats[i]; |
||||
return outshape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape>& inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(!dynamicOutputShapes()); |
||||
|
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == (size_t)2 || ninputs == (size_t)3); |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
|
||||
int repeats[TILE_MAX_DIMS]; |
||||
|
||||
Mat repeatsTensor = netimpl_->argTensor(this->inputs[1]); |
||||
Mat axesTensor; |
||||
if (ninputs > 2) |
||||
axesTensor = netimpl_->argTensor(this->inputs[2]); |
||||
|
||||
int ndims = inputs[0].dims; |
||||
getRepeats(repeatsTensor, axesTensor, ndims, repeats); |
||||
|
||||
outputs.assign(1, getOutShape(inputs[0], repeats)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == (size_t)2 || ninputs == (size_t)3); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 2 || ninputs == 3); |
||||
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
Mat repeatsTensor = inputs_arr.getMat(1); |
||||
Mat axesTensor; |
||||
int repeats[TILE_MAX_DIMS]; |
||||
int inptype = inp.type(); |
||||
int ndims = inp.dims; |
||||
|
||||
if (ninputs > 2) |
||||
axesTensor = inputs_arr.getMat(2); |
||||
|
||||
getRepeats(repeatsTensor, axesTensor, ndims, repeats); |
||||
MatShape outshape = getOutShape(inp.shape(), repeats); |
||||
|
||||
auto kind = outputs_arr.kind(); |
||||
if (kind == _InputArray::STD_VECTOR_MAT) { |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outshape, inptype); |
||||
tile(inp, repeats, outs[0]); |
||||
} else if (kind == _InputArray::STD_VECTOR_UMAT) { |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outshape, inptype); |
||||
Mat temp(outshape, inptype); |
||||
tile(inp, repeats, temp); |
||||
temp.copyTo(outs[0]); |
||||
} else { |
||||
CV_Error(Error::StsNotImplemented, ""); |
||||
} |
||||
} |
||||
}; |
||||
|
||||
Ptr<Tile2Layer> Tile2Layer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<Tile2Layer>(new Tile2LayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,218 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Transpose layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Transpose.html
|
||||
|
||||
Opset's 1 to 23 are covered. |
||||
*/ |
||||
|
||||
static void transpose(const Mat& inp, const std::vector<int>& perm, Mat& out) |
||||
{ |
||||
enum {TRANSPOSE_MAX_DIMS=7}; |
||||
MatShape inpShape = inp.shape(); |
||||
MatShape outShape = out.shape(); |
||||
int ndims = inpShape.dims; |
||||
size_t esz = inp.elemSize(); |
||||
CV_Assert(esz == 1 || esz == 2 || esz == 4 || esz == 8); |
||||
|
||||
int perm_[TRANSPOSE_MAX_DIMS]; |
||||
int inpShape_[TRANSPOSE_MAX_DIMS]; |
||||
int outShape_[TRANSPOSE_MAX_DIMS]; |
||||
size_t inpStep_[TRANSPOSE_MAX_DIMS]; |
||||
int delta = TRANSPOSE_MAX_DIMS - ndims; |
||||
|
||||
CV_Assert(ndims <= TRANSPOSE_MAX_DIMS); |
||||
CV_Assert(inp.isContinuous()); |
||||
CV_Assert(out.isContinuous()); |
||||
|
||||
for (int i = 0; i < TRANSPOSE_MAX_DIMS; i++) { |
||||
perm_[i] = i; |
||||
inpShape_[i] = outShape_[i] = 1; |
||||
inpStep_[i] = 0; |
||||
} |
||||
inpStep_[TRANSPOSE_MAX_DIMS-1] = 1; // step's are measured in elements, not bytes
|
||||
|
||||
for(int i = 0; i < ndims; i++) { |
||||
int j = perm.empty() ? ndims - i - 1 : perm[i]; |
||||
if (j < 0) |
||||
j += ndims; |
||||
CV_Assert(0 <= j && j < ndims); |
||||
perm_[i + delta] = j + delta; |
||||
int inpsz = inpShape[j]; |
||||
int outsz = outShape[i]; |
||||
CV_Assert(inpsz == outsz); |
||||
inpShape_[i + delta] = inpShape[i]; |
||||
outShape_[i + delta] = outShape[i]; |
||||
} |
||||
|
||||
for (int i = TRANSPOSE_MAX_DIMS-2; i >= 0; i--) |
||||
inpStep_[i] = inpStep_[i+1]*inpShape_[i+1]; |
||||
|
||||
int sz6 = outShape_[0], sz5 = outShape_[1]; |
||||
int sz4 = outShape_[2], sz3 = outShape_[3]; |
||||
int sz2 = outShape_[4], sz1 = outShape_[5], sz0 = outShape_[6]; |
||||
size_t p6 = inpStep_[perm_[0]], p5 = inpStep_[perm_[1]]; |
||||
size_t p4 = inpStep_[perm_[2]], p3 = inpStep_[perm_[3]]; |
||||
size_t p2 = inpStep_[perm_[4]], p1 = inpStep_[perm_[5]], p0 = inpStep_[perm_[6]]; |
||||
|
||||
#undef CV_IMPLEMENT_TRANSPOSE |
||||
#define CV_IMPLEMENT_TRANSPOSE(typ) \ |
||||
const typ* inptr0 = (const typ*)inp.data; \
|
||||
typ* outptr = (typ*)out.data; \
|
||||
for (int i6 = 0; i6 < sz6; i6++) { \
|
||||
for (int i5 = 0; i5 < sz5; i5++) { \
|
||||
for (int i4 = 0; i4 < sz4; i4++) { \
|
||||
for (int i3 = 0; i3 < sz3; i3++) { \
|
||||
for (int i2 = 0; i2 < sz2; i2++) { \
|
||||
for (int i1 = 0; i1 < sz1; i1++, outptr += sz0) { \
|
||||
int i0 = 0; \
|
||||
const typ* inptr = inptr0 + i6*p6 + i5*p5 + i4*p4 + i3*p3 + i2*p2 + i1*p1; \
|
||||
for (; i0 <= sz0 - 3; i0 += 3) { \
|
||||
size_t ip0 = i0*p0; \
|
||||
typ t0 = inptr[ip0]; \
|
||||
typ t1 = inptr[ip0+p0]; \
|
||||
typ t2 = inptr[ip0+p0*2]; \
|
||||
outptr[i0] = t0; \
|
||||
outptr[i0+1] = t1; \
|
||||
outptr[i0+2] = t2; \
|
||||
} \
|
||||
for (; i0 < sz0; i0++) \
|
||||
outptr[i0] = inptr[i0*p0]; \
|
||||
}}}}}} |
||||
|
||||
if (esz == 4) { |
||||
CV_IMPLEMENT_TRANSPOSE(int) |
||||
} else if (esz == 2) { |
||||
CV_IMPLEMENT_TRANSPOSE(short) |
||||
} else if (esz == 1) { |
||||
CV_IMPLEMENT_TRANSPOSE(char) |
||||
} else if (esz == 8) { |
||||
CV_IMPLEMENT_TRANSPOSE(int64_t) |
||||
} |
||||
} |
||||
|
||||
class TransposeLayerImpl CV_FINAL : public TransposeLayer |
||||
{ |
||||
public: |
||||
TransposeLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
perm = params.getVector<int>("perm"); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpShape) const |
||||
{ |
||||
MatShape outShape(inpShape.dims); |
||||
CV_Assert(perm.empty() || perm.size() == (size_t)inpShape.dims); |
||||
|
||||
for (int i = 0; i < inpShape.dims; i++) { |
||||
int j = perm.empty() ? inpShape.dims - i - 1 : perm[i]; |
||||
CV_Assert(0 <= j && j < inpShape.dims); |
||||
outShape[i] = inpShape[j]; |
||||
} |
||||
|
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 1); |
||||
outputs.assign(1, getOutShape(inputs[0])); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert(inputs.size() == 1); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert(ninputs == 1); |
||||
|
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
MatShape outShape = getOutShape(inpShape); |
||||
int inpType = inputs_arr.type(0); |
||||
int outKind = outputs_arr.kind(); |
||||
|
||||
CV_Assert(outKind == _InputArray::STD_VECTOR_MAT || |
||||
outKind == _InputArray::STD_VECTOR_UMAT); |
||||
|
||||
if (outKind == _InputArray::STD_VECTOR_MAT) { |
||||
Mat inp = inputs_arr.getMat(0); |
||||
std::vector<Mat>& outs = outputs_arr.getMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, inpType); |
||||
runOp(inp, outs[0]); |
||||
} else { |
||||
// [TODO] more efficient OpenCL implementation
|
||||
Mat inp = inputs_arr.getMat(0); |
||||
std::vector<UMat>& outs = outputs_arr.getUMatVecRef(); |
||||
outs.resize(1); |
||||
outs[0].fit(outShape, inpType); |
||||
Mat temp(outShape, inpType); |
||||
runOp(inp, temp); |
||||
temp.copyTo(outs[0]); |
||||
} |
||||
} |
||||
|
||||
void runOp(const Mat& inp, Mat& out) |
||||
{ |
||||
transpose(inp, perm, out); |
||||
} |
||||
}; |
||||
|
||||
Ptr<TransposeLayer> TransposeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<TransposeLayer>(new TransposeLayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
@ -0,0 +1,156 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "layers_common.hpp" |
||||
#include "../net_impl.hpp" |
||||
//#include "../op_cuda.hpp"
|
||||
//#include "../op_inf_engine.hpp"
|
||||
//#include "../ie_ngraph.hpp"
|
||||
//#include "../op_webnn.hpp"
|
||||
//#include "../op_timvx.hpp"
|
||||
//#include "../op_cann.hpp"
|
||||
|
||||
//#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv |
||||
{ |
||||
namespace dnn |
||||
{ |
||||
|
||||
/*
|
||||
Unsqueeze layer, as defined in ONNX specification: |
||||
https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
|
||||
|
||||
Opset's 1 to 23 are covered. |
||||
|
||||
See description in reshape2_layer.cpp |
||||
for more some common implementation details. |
||||
*/ |
||||
class UnsqueezeLayerImpl CV_FINAL : public UnsqueezeLayer |
||||
{ |
||||
public: |
||||
UnsqueezeLayerImpl(const LayerParams& params) |
||||
{ |
||||
setParamsFrom(params); |
||||
axes = params.getVector<int>("axes"); |
||||
} |
||||
|
||||
virtual bool dynamicOutputShapes() const CV_OVERRIDE |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
return inputs.size() == 2 && !netimpl_->isConstArg(inputs[1]); |
||||
} |
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE |
||||
{ |
||||
return backendId == DNN_BACKEND_OPENCV; |
||||
} |
||||
|
||||
MatShape getOutShape(const MatShape& inpShape, const std::vector<int>& axes_) const |
||||
{ |
||||
bool unsqueezeMask[MatShape::MAX_DIMS]; |
||||
|
||||
int outDims = inpShape.dims + (int)axes_.size(); |
||||
CV_Assert(0 <= outDims && outDims <= MatShape::MAX_DIMS); |
||||
|
||||
for (int i = 0; i < outDims; i++) |
||||
unsqueezeMask[i] = false; |
||||
for (int a: axes_) { |
||||
int a_ = normalize_axis(a, outDims); |
||||
if (unsqueezeMask[a_]) { |
||||
CV_Error_(Error::StsBadArg, ("duplicate unsqueezed axis #%d", a)); |
||||
} |
||||
unsqueezeMask[a_] = true; |
||||
} |
||||
|
||||
MatShape outShape(outDims); |
||||
int j = 0; |
||||
for (int i = 0; i < outDims; i++) { |
||||
if (unsqueezeMask[i]) |
||||
outShape[i] = 1; |
||||
else { |
||||
CV_Assert(j < inpShape.dims); |
||||
outShape[i] = inpShape[j++]; |
||||
} |
||||
} |
||||
return outShape; |
||||
} |
||||
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs, |
||||
const int, |
||||
std::vector<MatShape> &outputs, |
||||
std::vector<MatShape> &internals) const CV_OVERRIDE |
||||
{ |
||||
CV_Assert((inputs.size() == 1 && !axes.empty()) || |
||||
(inputs.size() == 2 && axes.empty())); |
||||
MatShape outShape; |
||||
std::vector<int> tempAxes; |
||||
const std::vector<int>* axes_ = &axes; |
||||
|
||||
if (inputs.size() == 2) |
||||
{ |
||||
Net::Impl* netimpl_ = getNetImpl(this); |
||||
Mat axesTensor = netimpl_->argTensor(this->inputs[1]); |
||||
tensorToIntVec(axesTensor, tempAxes); |
||||
axes_ = &tempAxes; |
||||
} |
||||
outputs.assign(1, getOutShape(inputs[0], *axes_)); |
||||
internals.clear(); |
||||
return true; |
||||
} |
||||
|
||||
void getTypes(const std::vector<MatType>& inputs, |
||||
const int requiredOutputs, |
||||
const int requiredInternals, |
||||
std::vector<MatType>& outputs, |
||||
std::vector<MatType>& internals) const CV_OVERRIDE |
||||
{ |
||||
size_t ninputs = inputs.size(); |
||||
CV_Assert(ninputs == 1 || ninputs == 2); |
||||
outputs.assign(requiredOutputs, inputs[0]); |
||||
CV_Assert(requiredInternals == 0); |
||||
internals.clear(); |
||||
} |
||||
|
||||
void finalize(InputArrayOfArrays, OutputArrayOfArrays outputs_arr) CV_OVERRIDE |
||||
{ |
||||
} |
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, |
||||
OutputArrayOfArrays outputs_arr, |
||||
OutputArrayOfArrays) CV_OVERRIDE |
||||
{ |
||||
CV_TRACE_FUNCTION(); |
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str()); |
||||
|
||||
Size size = inputs_arr.size(); |
||||
int ninputs = size.area(); |
||||
CV_Assert((ninputs == 1 && !axes.empty()) || |
||||
(ninputs == 2 && axes.empty())); |
||||
|
||||
MatShape inpShape = inputs_arr.shape(0); |
||||
std::vector<int> tempAxes; |
||||
const std::vector<int>* axes_ = &axes; |
||||
|
||||
if (ninputs == 2) |
||||
{ |
||||
CV_Assert(axes.empty()); // if we have a dedicated 'axes' input,
|
||||
// we should not have 'axes' attribute at the same time
|
||||
Mat axesTensor = inputs_arr.getMat(1); |
||||
tensorToIntVec(axesTensor, tempAxes); |
||||
axes_ = &tempAxes; |
||||
} |
||||
MatShape outShape = getOutShape(inpShape, *axes_); |
||||
reshapeAndCopyFirst(inputs_arr, outputs_arr, outShape); |
||||
} |
||||
}; |
||||
|
||||
Ptr<UnsqueezeLayer> UnsqueezeLayer::create(const LayerParams& params) |
||||
{ |
||||
return Ptr<UnsqueezeLayer>(new UnsqueezeLayerImpl(params)); |
||||
} |
||||
|
||||
} |
||||
} |
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue