np.pad
是 NumPy 中用于对数组进行填充的函数,它可以在数组的不同维度上添加指定数量的值。
X
:输入的 NumPy 数组。通常是一个 4 维数组,可能表示图像数据,形状为 (batch_size, height, width, channels)
,例如 (样本数量, 高度, 宽度, 通道数)。
((0, 0), (pad, pad), (pad, pad), (0, 0))
: 这是填充方式的定义,np.pad
需要一个形状与数组维度相同
的 tuple(元组)来指示每个维度的填充值。这个元组中的每个元素是一个二元 tuple
,表示对应维度的左边和右边
需要填充的数量。
(0, 0)
:对第一个维度(样本数量)不做填充。
(pad, pad)
:对第二个维度(高度)进行 pad 大小的填充,前后都填充 pad 个单位。
(pad, pad)
:对第三个维度(宽度)进行 pad 大小的填充,前后都填充 pad 个单位。
(0, 0)
:对第四个维度(通道数)不做填充。
'constant'
:指定填充的模式,这里 ‘constant’ 表示用常数来填充。
constant_values=0
:指定填充常数的值。这里用 0 进行填充,因此 pad 大小的区域将填充为 0。
def test1():np.random.seed(1)x = np.random.randn(4, 3, 3, 2)x_pad = zero_pad(x, 2)print("x.shape:", x.shape)print("x_pad.shape:", x_pad.shape)print("x[1,1]=", x[1, 1])print("x_pad[1,1]=", x_pad[1, 1])fig, axarr = plt.subplots(1, 2)axarr[0].set_title('x')print(x)print("-=-----")print(x[0, :, :, 0])axarr[0].imshow(x[0, :, :, 0])axarr[1].set_title('x_pad')axarr[1].imshow(x_pad[0, :, :, 0])plt.show()