mirror of
https://github.com/opencv/opencv.git
synced 2025-01-10 22:28:13 +08:00
31 lines
1011 B
Python
31 lines
1011 B
Python
|
import cv2 as cv
|
||
|
|
||
|
#! [CropLayer]
|
||
|
class CropLayer(object):
|
||
|
def __init__(self, params, blobs):
|
||
|
self.xstart = 0
|
||
|
self.xend = 0
|
||
|
self.ystart = 0
|
||
|
self.yend = 0
|
||
|
|
||
|
# Our layer receives two inputs. We need to crop the first input blob
|
||
|
# to match a shape of the second one (keeping batch size and number of channels)
|
||
|
def getMemoryShapes(self, inputs):
|
||
|
inputShape, targetShape = inputs[0], inputs[1]
|
||
|
batchSize, numChannels = inputShape[0], inputShape[1]
|
||
|
height, width = targetShape[2], targetShape[3]
|
||
|
|
||
|
self.ystart = (inputShape[2] - targetShape[2]) // 2
|
||
|
self.xstart = (inputShape[3] - targetShape[3]) // 2
|
||
|
self.yend = self.ystart + height
|
||
|
self.xend = self.xstart + width
|
||
|
|
||
|
return [[batchSize, numChannels, height, width]]
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
return [inputs[0][:,:,self.ystart:self.yend,self.xstart:self.xend]]
|
||
|
#! [CropLayer]
|
||
|
|
||
|
#! [Register]
|
||
|
cv.dnn_registerLayer('Crop', CropLayer)
|
||
|
#! [Register]
|