Updating API of LSTM layer, fixed LSTM bug in tanh() implementation

LSTM bug is caused by different behaviour of std::tanh() and hand-
crafted
tanh() via std::exp().
pull/707/head
Vitaliy Lyudvichenko 9 years ago
parent 44a8e81856
commit a3c6f1dcf2
  1. 3
      modules/dnn/CMakeLists.txt
  2. 11
      modules/dnn/cmake/OpenCVFindCBLAS.cmake
  3. 48
      modules/dnn/include/opencv2/dnn/all_layers.hpp
  4. 17
      modules/dnn/include/opencv2/dnn/blob.hpp
  5. 28
      modules/dnn/include/opencv2/dnn/blob.inl.hpp
  6. 23
      modules/dnn/src/layers/op_blas.cpp
  7. 233
      modules/dnn/src/layers/recurrent_layers.cpp
  8. 39
      modules/dnn/test/test_layers.cpp

@ -24,7 +24,7 @@ ocv_module_include_directories(include ${PROTOBUF_INCLUDE_DIR})
OCV_OPTION(${the_module}_WITH_BLAS "Use external BLAS library to speedup processing" OFF)
include(cmake/OpenCVFindCBLAS.cmake)
ocv_glob_module_sources(${PROTOBUF_SRCS} ${PROTOBUF_HDRS} ${CBLAS_H_PATH})
ocv_glob_module_sources(${PROTOBUF_SRCS} ${PROTOBUF_HDRS} ${CBLAS_H_PROXY_PATH})
ocv_create_module(${PROTOBUF_LIBRARIES})
ocv_add_samples()
ocv_add_accuracy_tests()
@ -37,6 +37,7 @@ if(${the_module}_WITH_BLAS AND HAVE_BLAS)
add_definitions(-DHAVE_CBLAS=1)
ocv_module_include_directories(${${the_module}_BLAS_INCLUDE_DIR})
ocv_add_dependencies(${the_module} ${${the_module}_BLAS_LIBRARIES})
target_link_libraries(${the_module} ${${the_module}_BLAS_LIBRARIES})
if(${the_module}_BLAS_BINARIES)
ocv_install_target(${the_module} EXPORT ${the_module}_BLAS_BINARIES

@ -1,6 +1,6 @@
macro(_find_file_in_dirs VAR NAME DIRS)
find_path(${VAR} ${NAME} ${DIRS} NO_DEFAULT_PATH)
set(${VAR} ${${VAR}})
set(${VAR} ${${VAR}}/${NAME})
unset(${VAR} CACHE)
endmacro()
@ -16,7 +16,7 @@ if(${the_module}_WITH_BLAS)
endif()
if(NOT HAVE_BLAS)
include(cmake/OpenCVFindMKL.cmake)
if(MKL_FOUND)
if(MKL_FOUND AND FALSE)
set(BLAS_INCLUDE_DIR ${MKL_INCLUDE_DIRS})
set(BLAS_LIBRARIES ${MKL_LIBRARIES} )
set(BLAS_CBLAS_H "mkl_cblas.h" )
@ -52,8 +52,9 @@ if(${the_module}_WITH_BLAS)
if(NOT CBLAS_H_PATH)
message(WARNING "CBLAS header '${${_bp}_CBLAS_H}' not found into '${${_bp}_INCLUDE_DIR}'")
endif()
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/cblas.h #TARGET ${the_module} PRE_BUILD
COMMAND ${CMAKE_COMMAND} ARGS -E echo "\#include \"${CBLAS_H_PATH}\"" > ${CMAKE_CURRENT_BINARY_DIR}/cblas.h
COMMENT "Adding proxy cblas.h header")
set(CBLAS_H_PROXY_PATH ${CMAKE_CURRENT_BINARY_DIR}/opencv_cblas.hpp)
set(_include_str "\#include \"${CBLAS_H_PATH}\"")
file(WRITE ${CBLAS_H_PROXY_PATH} ${_include_str})
endif()
endif()

@ -101,8 +101,8 @@ namespace dnn
@f$W_{x?} \in R^{N_c \times N_x}@f$, @f$W_h? \in R^{N_c \times N_h}@f$, @f$b_? \in R^{N_c}@f$.
For simplicity and performance purposes we use @f$ W_x = [W_{xi}; W_{xf}; W_{xo}, W_{xg}] @f$
(i.e. @f$W_x@f$ is vertical contacentaion of @f$ W_{x?} @f$), @f$ W_x \in R^{4N_c x N_x} @f$.
The same for @f$ W_h = [W_{hi}; W_{hf}; W_{ho}, W_{hg}], W_h \in R^{4N_c x N_h} @f$
(i.e. @f$W_x@f$ is vertical contacentaion of @f$ W_{x?} @f$), @f$ W_x \in R^{4N_c \times N_x} @f$.
The same for @f$ W_h = [W_{hi}; W_{hf}; W_{ho}, W_{hg}], W_h \in R^{4N_c \times N_h} @f$
and for @f$ b = [b_i; b_f, b_o, b_g]@f$, @f$b \in R^{4N_c} @f$.
@param Wh is matrix defining how previous output is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_h @f$)
@ -111,16 +111,44 @@ namespace dnn
*/
virtual void setWeights(const Blob &Wh, const Blob &Wx, const Blob &b) = 0;
/** In common case it uses three inputs (@f$x_t@f$, @f$h_{t-1}@f$ and @f$c_{t-1}@f$) to compute compute two outputs (@f$h_t@f$ and @f$c_t@f$).
/** @brief Set @f$ h_{t-1} @f$ value that will be used in next forward() calls.
* @details By-default @f$ h_{t-1} @f$ is inited by zeros and updated after each forward() call.
*/
virtual void setH(const Blob &H) = 0;
/** @brief Returns current @f$ h_{t-1} @f$ value (deep copy). */
virtual Blob getH() const = 0;
@param input could contain three inputs: @f$x_t@f$, @f$h_{t-1}@f$ and @f$c_{t-1}@f$.
@param output contains computed outputs: @f$h_t@f$ and @f$c_t@f$.
/** @brief Set @f$ c_{t-1} @f$ value that will be used in next forward() calls.
* @details By-default @f$ c_{t-1} @f$ is inited by zeros and updated after each forward() call.
*/
virtual void setC(const Blob &C) = 0;
/** @brief Returns current @f$ c_{t-1} @f$ value (deep copy). */
virtual Blob getC() const = 0;
/** @brief Specifies either interpet first dimension of input blob as timestamp dimenion either as sample.
*
* If flag is set to true then shape of input blob will be interpeted as [`T`, `N`, `[data dims]`] where `T` specifies number of timpestamps, `N` is number of independent streams.
* In this case each forward() call will iterate through `T` timestamps and update layer's state `T` times.
*
* If flag is set to false then shape of input blob will be interpeted as [`N`, `[data dims]`].
* In this case each forward() call will make one iteration and produce one timestamp with shape [`N`, `[out dims]`].
*/
virtual void setUseTimstampsDim(bool use = true) = 0;
The first input @f$x_t@f$ is required.
The second and third inputs are optional: if they weren't set than layer will use internal @f$h_{t-1}@f$ and @f$c_{t-1}@f$ from previous calls,
but at the first call they will be filled by zeros.
Size of the last dimension of @f$x_t@f$ must be @f$N_x@f$, (@f$N_h@f$ for @f$h_{t-1}@f$ and @f$N_c@f$ for @f$c_{t-1}@f$).
Sizes of remainder dimensions could be any, but thay must be consistent among @f$x_t@f$, @f$h_{t-1}@f$ and @f$c_{t-1}@f$.
/** @brief If this flag is set to true then layer will produce @f$ c_t @f$ as second output.
* @details Shape of the second output is the same as first output.
*/
virtual void setProduceCellOutput(bool produce = false) = 0;
/** In common case it use single input with @f$x_t@f$ values to compute output(s) @f$h_t@f$ (and @f$c_t@f$).
* @param input should contain packed values @f$x_t@f$
* @param output contains computed outputs: @f$h_t@f$ (and @f$c_t@f$ if setProduceCellOutput() flag was set to true).
*
* If setUseTimstampsDim() is set to true then @p input[0] should has at least two dimensions with the following shape: [`T`, `N`, `[data dims]`],
* where `T` specifies number of timpestamps, `N` is number of independent streams (i.e. x_{t_0 + t}^{stream} is @p input[0][t, stream, ...]).
*
* If setUseTimstampsDim() is set to fase then @p input[0] should contain single timestamp, its shape should has form [`N`, `[data dims]`] with at least one dimension.
* (i.e. x_{t}^{stream} = @p input[0][stream, ...]).
*/
void forward(std::vector<Blob*> &input, std::vector<Blob> &output);
};

@ -107,6 +107,18 @@ namespace dnn
bool operator== (const BlobShape &r) const;
/** @brief Contacenates two shapes */
BlobShape operator+ (const BlobShape &r) const;
/** @brief Returns shape of passed Mat. */
static BlobShape like(const Mat &m);
/** @brief Returns shape of passed Mat. */
static BlobShape like(const UMat &m);
#ifdef CV_CXX_MOVE_SEMANTICS
//TBD
#endif
private:
cv::AutoBuffer<int,4> sz;
};
@ -228,6 +240,11 @@ namespace dnn
*/
Blob &reshape(const BlobShape &shape);
/** @brief Changes shape of the blob without copying the data.
* @returns shallow copy of original blob with new shape.
*/
Blob reshaped(const BlobShape &newShape) const;
/** @brief Returns type of the blob. */
int type() const;

@ -185,6 +185,16 @@ inline bool BlobShape::operator==(const BlobShape &r) const
return this->equal(r);
}
inline BlobShape BlobShape::like(const Mat &m)
{
return BlobShape(m.dims, (const int*)m.size);
}
inline BlobShape BlobShape::like(const UMat &m)
{
return BlobShape(m.dims, (const int*)m.size);
}
CV_EXPORTS std::ostream &operator<< (std::ostream &stream, const BlobShape &shape);
/////////////////////////////////////////////////////////////////////
@ -277,6 +287,17 @@ inline BlobShape Blob::shape() const
return BlobShape(dims(), sizes());
}
inline BlobShape BlobShape::operator+(const BlobShape &r) const
{
BlobShape newShape(this->dims() + r.dims(), (int*)NULL);
for (int i = 0; i < this->dims(); i++)
newShape[i] = (*this)[i];
for (int i = 0; i < r.dims(); i++)
newShape[this->dims() + i] = r[i];
return newShape;
}
inline bool Blob::equalShape(const Blob &other) const
{
if (this->dims() != other.dims())
@ -366,6 +387,13 @@ inline Blob &Blob::reshape(const BlobShape &newShape)
return *this;
}
inline Blob Blob::reshaped(const BlobShape &newShape) const
{
Blob res(*this); //also, res.shareFrom(*this) could be used
res.reshape(newShape);
return res;
}
}
}

@ -1,9 +1,11 @@
#include "op_blas.hpp"
#if HAVE_CBLAS
#include "cblas.h"
#include "opencv_cblas.hpp"
#endif
#include <iostream>
namespace cv
{
namespace dnn
@ -14,18 +16,19 @@ void gemm(InputArray A, InputArray B, double alpha, InputOutputArray C, double b
cv::gemm(A, B, alpha, C, beta, C, flags);
}
inline void SwapRowCols(const Mat &A, int &rows, int &cols, bool transA)
inline void SwapRowCols(const Mat &A, int &rows, int &cols, bool isTrans)
{
rows = (transA) ? A.cols : A.rows;
cols = (transA) ? A.rows : A.cols;
CV_DbgAssert(A.dims == 2);
rows = (isTrans) ? A.cols : A.rows;
cols = (isTrans) ? A.rows : A.cols;
}
void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int flags /*= 0*/)
{
#if HAVE_CBLAS
int transA = flags & GEMM_1_T;
int transB = flags & GEMM_2_T;
int transC = flags & GEMM_3_T;
bool transA = static_cast<bool>(flags & GEMM_1_T);
bool transB = static_cast<bool>(flags & GEMM_2_T);
bool transC = static_cast<bool>(flags & GEMM_3_T);
int Arows, Acols, Brows, Bcols, Crows, Ccols;
SwapRowCols(A, Arows, Acols, transA);
@ -34,9 +37,9 @@ void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int
CV_DbgAssert(!(flags & GEMM_3_T));
CV_Assert(Acols == Brows && Arows == Crows && Bcols == Ccols);
CV_DbgAssert(A.isContinuous() && B.isContinuous() && C.isContinuous());
CV_DbgAssert(A.type() == CV_32F || A.type() == CV_64F);
CV_DbgAssert(A.type() == B.type() && B.type() == C.type());
CV_Assert(A.isContinuous() && B.isContinuous() && C.isContinuous());
CV_Assert(A.type() == CV_32F || A.type() == CV_64F);
CV_Assert(A.type() == B.type() && B.type() == C.type());
if (C.type() == CV_32F)
{

@ -43,35 +43,118 @@
#include "recurrent_layers.hpp"
#include "op_blas.hpp"
#include <iostream>
#include <cmath>
namespace cv
{
namespace dnn
{
template<typename Dtype>
static void tanh(const Mat &src, Mat &dst)
{
MatConstIterator_<Dtype> itSrc = src.begin<Dtype>();
MatIterator_<Dtype> itDst = dst.begin<Dtype>();
for (; itSrc != src.end<Dtype>(); itSrc++, itDst++)
*itDst = std::tanh(*itSrc);
}
static void tanh(const Mat &src, Mat &dst)
{
dst.create(src.dims, (const int*)src.size, src.type());
if (src.type() == CV_32F)
tanh<float>(src, dst);
else if (src.type() == CV_64F)
tanh<double>(src, dst);
else
CV_Error(Error::StsUnsupportedFormat, "Functions supports only floating point types");
}
static void sigmoid(const Mat &src, Mat &dst)
{
cv::exp(-src, dst);
cv::pow(1 + dst, -1, dst);
}
class LSTMLayerImpl : public LSTMLayer
{
int numOut, numTimeStamps, numSamples, numInp;
Mat hInternal, cInternal;
Mat gates, dummyOnes;
int dtype;
bool allocated;
bool useTimestampDim;
bool produceCellOutput;
public:
LSTMLayerImpl()
{
type = "LSTM";
useTimestampDim = true;
produceCellOutput = false;
allocated = false;
}
int nH, nX, nC, numSamples;
Mat prevH, prevC;
Mat gates, dummyOnes;
void setUseTimstampsDim(bool use)
{
CV_Assert(!allocated);
useTimestampDim = use;
}
void setProduceCellOutput(bool produce)
{
CV_Assert(!allocated);
produceCellOutput = produce;
}
void setC(const Blob &C)
{
CV_Assert(!allocated || C.total() == cInternal.total());
C.matRefConst().copyTo(cInternal);
}
void setH(const Blob &H)
{
CV_Assert(!allocated || H.total() == hInternal.total());
H.matRefConst().copyTo(hInternal);
}
Blob getC() const
{
CV_Assert(!cInternal.empty());
//TODO: add convinient Mat -> Blob constructor
Blob res;
res.fill(BlobShape::like(cInternal), cInternal.type(), cInternal.data);
return res;
}
Blob getH() const
{
CV_Assert(!hInternal.empty());
Blob res;
res.fill(BlobShape::like(hInternal), hInternal.type(), hInternal.data);
return res;
}
void setWeights(const Blob &Wh, const Blob &Wx, const Blob &bias)
{
CV_Assert(Wh.dims() == 2 && Wx.dims() == 2);
CV_Assert(Wh.size(0) == Wx.size(0) && Wh.size(0) % 4 == 0);
CV_Assert(Wh.size(0) == Wx.size(0));
CV_Assert(Wh.size(0) == 4*Wh.size(1));
CV_Assert(Wh.size(0) == (int)bias.total());
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
blobs.resize(3);
blobs[0] = Wh;
blobs[1] = Wx;
blobs[2] = bias;
blobs[2].reshape(BlobShape(1, (int)bias.total()));
}
void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output)
@ -79,103 +162,83 @@ public:
CV_Assert(blobs.size() == 3);
Blob &Wh = blobs[0], &Wx = blobs[1];
nH = Wh.size(1);
nX = Wx.size(1);
nC = Wh.size(0) / 4;
numOut = Wh.size(1);
numInp = Wx.size(1);
CV_Assert(input.size() >= 1 && input.size() <= 3);
CV_Assert(input[0]->size(-1) == nX);
CV_Assert(input.size() == 1);
CV_Assert(input[0]->dims() > 2 && (int)input[0]->total(2) == numInp);
BlobShape inpShape = input[0]->shape();
numSamples = input[0]->total(0, input[0]->dims()-1);
numTimeStamps = input[0]->size(0);
numSamples = input[0]->size(1);
dtype = input[0]->type();
BlobShape hShape = inpShape;
hShape[-1] = nH;
BlobShape cShape = inpShape;
cShape[-1] = nC;
CV_Assert(dtype == CV_32F || dtype == CV_64F);
CV_Assert(Wh.type() == dtype);
BlobShape outShape(numTimeStamps, numSamples, numOut);
output.resize(2);
output[0].create(hShape, input[0]->type());
output[1].create(cShape, input[0]->type());
output[0].create(outShape, dtype);
output[1].create(outShape, dtype);
if (input.size() < 2)
{
prevH.create(numSamples, nH, input[0]->type());
prevH.setTo(0);
}
else
CV_Assert(input[1]->shape() == hShape);
hInternal.create(numSamples, numOut, dtype);
hInternal.setTo(0);
if (input.size() < 3)
{
prevC.create(numSamples, nC, input[0]->type());
prevC.setTo(0);
}
else
CV_Assert(input[2]->shape() == cShape);
cInternal.create(numSamples, numOut, dtype);
cInternal.setTo(0);
gates.create(numSamples, 4*numOut, dtype);
gates.create(numSamples, 4*nC, input[0]->type());
dummyOnes.create(numSamples, 1, input[0]->type());
dummyOnes.create(numSamples, 1, dtype);
dummyOnes.setTo(1);
}
Mat ep, em;
void tanh(Mat &x, Mat &d)
{
//TODO: two exp() is bad idea
cv::exp(-x, em);
cv::exp( x, ep);
cv::divide(ep - em, ep + em, d);
allocated = true;
}
void sigmoid(Mat &x)
void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
{
cv::exp(-x, x);
cv::pow(1 + x, -1, x);
}
const Mat &Wh = blobs[0].matRefConst();
const Mat &Wx = blobs[1].matRefConst();
const Mat &bias = blobs[2].matRefConst();
void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
int numSamplesTotal = numTimeStamps*numSamples;
Mat xTs = input[0]->reshaped(BlobShape(numSamplesTotal, numInp)).matRefConst();
BlobShape outMatShape(numSamplesTotal, numOut);
Mat hOutTs = output[0].reshaped(outMatShape).matRef();
Mat cOutTs = (produceCellOutput) ? output[1].reshaped(outMatShape).matRef() : Mat();
for (int ts = 0; ts < numTimeStamps; ts++)
{
CV_DbgAssert(blobs.size() == 3);
const Mat &Wh = blobs[0].matRefConst(), &Wx = blobs[1].matRefConst();
Mat bias = blobs[2].matRefConst().reshape(1, 1);
CV_DbgAssert(Wh.type() == CV_32F && Wx.type() == CV_32F && bias.type() == CV_32F);
int szx[] = { numSamples, nX };
int szc[] = { numSamples, nC };
Mat xCurr = input[0]->matRefConst().reshape(1, 2, szx);
Mat hPrev = (input.size() >= 2) ? input[1]->matRefConst().reshape(1, 2, szc) : prevH;
Mat cPrev = (input.size() >= 3) ? input[2]->matRefConst().reshape(1, 2, szc) : prevC;
CV_Assert(xCurr.type() == CV_32F && hPrev.type() == CV_32F && cPrev.type() == CV_32F);
Mat hCurr = output[0].matRef().reshape(1, 2, szc);
Mat cCurr = output[1].matRef().reshape(1, 2, szc);
CV_Assert(hCurr.type() == CV_32F && cCurr.type() == CV_32F);
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
Mat xCurr = xTs.rowRange(curRowRange);
gemmCPU(xCurr, Wx, 1, gates, 0, GEMM_2_T); // Wx * x_t
gemmCPU(hPrev, Wh, 1, gates, 1, GEMM_2_T); //+Wh * h_{t-1}
gemmCPU(hInternal, Wh, 1, gates, 1, GEMM_2_T); //+Wh * h_{t-1}
gemmCPU(dummyOnes, bias, 1, gates, 1); //+b
Mat gatesDiv = gates.reshape(1, 4*numSamples);
Mat getesIFO = gatesDiv(Range(0, 3*numSamples), Range::all());
Mat gateI = gatesDiv(Range(0*numSamples, 1*numSamples), Range::all());
Mat gateF = gatesDiv(Range(1*numSamples, 2*numSamples), Range::all());
Mat gateO = gatesDiv(Range(2*numSamples, 3*numSamples), Range::all());
Mat gateG = gatesDiv(Range(3*numSamples, 4*numSamples), Range::all());
Mat getesIFO = gates.colRange(0, 3*numOut);
Mat gateI = gates.colRange(0*numOut, 1*numOut);
Mat gateF = gates.colRange(1*numOut, 2*numOut);
Mat gateO = gates.colRange(2*numOut, 3*numOut);
Mat gateG = gates.colRange(3*numOut, 4*numOut);
sigmoid(getesIFO);
sigmoid(getesIFO, getesIFO);
tanh(gateG, gateG);
cv::add(gateF.mul(cPrev), gateI.mul(gateG), cCurr);
//compute c_t
cv::multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
cv::multiply(gateI, gateG, gateI); // i_t (*) g_t
cv::add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
tanh(cCurr, hCurr);
cv::multiply(gateO, hCurr, hCurr);
//compute h_t
tanh(cInternal, hInternal);
cv::multiply(gateO, hInternal, hInternal);
//save answers for next iteration
if (input.size() <= 2)
hCurr.copyTo(hPrev);
if (input.size() <= 3)
cCurr.copyTo(cPrev);
//save results in output blobs
hInternal.copyTo(hOutTs.rowRange(curRowRange));
if (produceCellOutput)
cInternal.copyTo(cOutTs.rowRange(curRowRange));
}
}
};
@ -193,6 +256,7 @@ void LSTMLayer::forward(std::vector<Blob*>&, std::vector<Blob>&)
class RNNLayerImpl : public RNNLayer
{
int nX, nH, nO, nSamples;
int dtype;
Mat Whh, Wxh, bh;
Mat Who, bo;
Mat hPrevInternal, dummyBiasOnes;
@ -210,7 +274,6 @@ public:
CV_Assert(W_hh.size(0) == W_xh.size(0) && W_hh.size(0) == W_hh.size(1) && (int)b_h.total() == W_xh.size(0));
CV_Assert(W_ho.size(0) == (int)b_o.total());
CV_Assert(W_ho.size(1) == W_hh.size(1));
//TODO: Check type
blobs.resize(5);
blobs[0] = W_hh;
@ -262,16 +325,6 @@ public:
bo = bo.reshape(1, 1); //is 1 x nO mat
}
//in-place tanh function
static void tanh(Mat &x) // 2 / (1 + e^(-2x)) - 1
{
x.convertTo(x, x.type(), -2); // -2x
cv::exp(x, x); // e^(-2x)
x.convertTo(x, x.type(), 1, 1); // 1 + e^(-2x)
cv::pow(x, -1, x); // 1 / (1 + e^(-2x))
x.convertTo(x, x.type(), 2, -1);// 2 / (1 + e^(-2x)) - 1
}
void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
{
Mat xCurr = input[0]->matRefConst();
@ -292,11 +345,11 @@ public:
gemmCPU(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev}
gemmCPU(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr}
gemmCPU(dummyBiasOnes, bh, 1, hCurr, 1); //+bh
tanh(hCurr);
tanh(hCurr, hCurr);
gemmCPU(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev}
gemmCPU(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o
tanh(oCurr);
tanh(oCurr, oCurr);
if (input.size() < 2) //save h_{prev}
hCurr.copyTo(hPrevInternal);

@ -177,6 +177,23 @@ TEST(Layer_Test_Reshape_Split_Slice, Accuracy)
normAssert(input, output);
}
enum RunLayerMode
{
ALLOC_ONLY = 1,
FORWARD_ONLY = 2,
ALLOC_AND_FORWARD = 3
};
void runLayer(Ptr<Layer> layer, std::vector<Blob> &inpBlobs, std::vector<Blob> &outBlobs, int mode=ALLOC_AND_FORWARD)
{
std::vector<Blob*> inpPtrs(inpBlobs.size());
for (size_t i = 0; i < inpBlobs.size(); i++)
inpPtrs[i] = &inpBlobs[i];
if (mode & ALLOC_ONLY) layer->allocate(inpPtrs, outBlobs);
if (mode & FORWARD_ONLY) layer->forward(inpPtrs, outBlobs);
}
class Layer_LSTM_Test : public ::testing::Test
{
public:
@ -233,6 +250,28 @@ TEST_F(Layer_LSTM_Test, BasicTest_2)
EXPECT_EQ(outputs[1].shape(), BlobShape(1, 2, 3, Nc));
}
TEST(Layer_LSTM_Test_Accuracy_Reference_with_, CaffeRecurrent)
{
Ptr<LSTMLayer> layer = LSTMLayer::create();
Blob Wx = blobFromNPY(_tf("lstm.prototxt.w_0.npy"));
Blob Wh = blobFromNPY(_tf("lstm.prototxt.w_2.npy"));
Blob b = blobFromNPY(_tf("lstm.prototxt.w_1.npy"));
layer->setWeights(Wh, Wx, b);
Blob inp = blobFromNPY(_tf("blob.npy"));
std::vector<Blob> inputs(1, inp), outputs;
runLayer(layer, inputs, outputs, ALLOC_ONLY | FORWARD_ONLY);
Blob &h_t_gathered = outputs[0];
Blob h_t_reference = blobFromNPY(_tf("lstm.prototxt.h_1.npy"));
//h_t_gathered.reshape(h_t_reference.shape());
normAssert(h_t_reference, h_t_gathered);
}
class Layer_RNN_Test : public ::testing::Test
{

Loading…
Cancel
Save