Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

enable padding with custom value, default to 0 #299

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions mshadow/extension/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ struct PaddingExp:
index_t pad_y_;
/*! \brief pad size in x */
index_t pad_x_;
/*! \brief value to pad with */
index_t value_;
/*! \brief source tensor height */
index_t src_height_;
/*! \brief source tensor width */
index_t src_width_;
/*! \brief constructor */
PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x)
: src_(src), pad_y_(pad_y), pad_x_(pad_x) {
PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x, DType value)
: src_(src), pad_y_(pad_y), pad_x_(pad_x), value_(value) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
src_height_ = this->shape_[srcdim - 2];
src_width_ = this->shape_[srcdim - 1];
Expand All @@ -40,38 +42,40 @@ struct PaddingExp:
}
};
/*!
* \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
* \brief padding expression, pad an image on boundaries, padding affects shape[0], and shape[1]
* \param src original image batches
* \param pad padding size
* \param value value to pad with
* \return expression corresponding to padded result
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
pad(const Exp<SrcExp, DType, etype> &src, index_t pad) {
pad(const Exp<SrcExp, DType, etype> &src, index_t pad, DType value = static_cast<DType>(0)) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), pad, pad);
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), pad, pad, value);
}
/*!
* \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
* \brief padding expression, pad an image on boundaries, padding affects shape[0], and shape[1]
* \param src original image batches
* \param pad_y padding size in y
* \param pad_x padding size in x
* \param pad value to pad with
* \return expression corresponding to padded result
* \tparam SrcExp source expression
* \tparam DType the content data type
* \tparam etype type of expression
*/
template<typename SrcExp, typename DType, int etype>
inline PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
pad(const Exp<SrcExp, DType, etype> &src, index_t pad_y, index_t pad_x) {
pad(const Exp<SrcExp, DType, etype> &src, index_t pad_y, index_t pad_x, DType value = static_cast<DType>(0)) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return PaddingExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
(src.self(), pad_y, pad_x);
(src.self(), pad_y, pad_x, value);
}
//----------------------
// Execution plan
Expand All @@ -82,26 +86,28 @@ struct Plan<PaddingExp<SrcExp, DType, srcdim>, DType> {
explicit Plan(const PaddingExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
pad_y_(e.pad_y_), pad_x_(e.pad_x_),
value_(e.value_),
new_height_(e.shape_[srcdim - 2]),
src_height_(e.src_height_), src_width_(e.src_width_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x = j;
const index_t y = i % new_height_;
const index_t c = i / new_height_;
if (y < pad_y_ || x < pad_x_) return static_cast<DType>(0);
if (y < pad_y_ || x < pad_x_) return value_;
const index_t h = y - pad_y_;
const index_t w = x - pad_x_;
if (h < src_height_ && w < src_width_) {
return src_.Eval(c * src_height_ + h, w);
} else {
return static_cast<DType>(0);
return value_;
}
}

private:
Plan<SrcExp, DType> src_;
const index_t pad_y_;
const index_t pad_x_;
const DType value_;
const index_t new_height_;
const index_t src_height_;
const index_t src_width_;
Expand Down