课程笔记:transforms
torchvision:
torchvision是pytorch中的计算机视觉工具包
三个主要的模块:
1)transforms模块:提供常用的图像预处理方法
2)datasets:提供常用的公开数据集的dataset
3)model:提供大量常用的预训练模型
transforms
对图片进行增强是为了提高模型的泛化能力
transforms是在dataset中的__getitem__中调用的
transforms.Normalize: 比较常用的预处理方法
逐通道的对图像进行标准化
transforms中的normalize实际是运用了functional中的normalize
原因:对数据进行标准化后可以加快模型的收敛
把数据进行标准化会加速模型的收敛,标准化为均值为0,标准差为1的数据。
当数据分布几乎为0均值的分布时,模型可以很快收敛,达到一个很低的loss值。
当数据均值不在0的附近,而模型初始化都是0均值的,慢慢找到分界平面,迭代更新过程较慢。
数据增强:
数据增强又称为数据增广,数据扩增,对训练集进行变换,使训练集更丰富,从而提高模型的泛化能力
transform_invert函数是对transform进行逆操作,使得我们可以观察到模型输入的数据是什么样的(因为数据进行transform后,转换成张量的形式,可能是一些浮点的数据,无法可视化,因此需要transform_invert对transform进行反操作,将张量的数据变换成PILimage,我们就可以可视化)
def transform_invert(img_, transform_train):
"""
将data 进行反transfrom操作
:param img_: tensor
:param transform_train: torchvision.transforms
:return: PIL image
"""
if 'Normalize' in str(transform_train):
norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
img_.mul_(std[:, None, None]).add_(mean[:, None, None])
img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C
if 'ToTensor' in str(transform_train) or img_.max() < 1:
img_ = img_.detach().numpy() * 255
if img_.shape[2] == 3:
img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
elif img_.shape[2] == 1:
img_ = Image.fromarray(img_.astype('uint8').squeeze())
else:
raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )
return img_
for epoch in range(MAX_EPOCH):
for i, data in enumerate(train_loader):
inputs, labels = data # B C H W
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform) # transform_invert用来对transform进行逆操作,使我们可以观察到模型的输入数据是长什么样的
plt.imshow(img)
plt.show()
plt.pause(0.5)
plt.close()
裁剪(crop)
1.transforms.CenterCrop(从图像的中心裁剪)
train_transform = transforms.Compose([
transforms.CenterCrop(196)
# 转为tensor,并归一化为0-1之间
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)])
2.transforms.RandomCrop(随机裁剪)
这里的随机值的位置的随机,不一定从中心、左上角、右下角这样去裁剪,是随机位置的裁剪
train_transform = transforms.Compose([
transforms.RandomCrop(224, padding=16),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)])
3.RandomResizedCrop
插值方法:(由于裁剪后的图像可能小于所需的图片尺寸,故需要插值)
NEAREST(最近邻插值)
BILINEAR(双线性插值)
BICUBIC(双三次插值)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)])
4.FiveCrop(从左上角、左下角、右上角、右下角和中心裁剪出5张图片)
5.TenCrop(从刚刚裁剪的5张图像进行水平或者垂直镜像得到10张图片)
train_transform = transforms.Compose([
transforms.FiveCrop(112),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop)
train_transform = transforms.Compose([
transforms.TenCrop(112, vertical_flip=False),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop)
翻转(Flip)
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=1),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)]),
旋转(rotation)
图像变换
1.Pad
transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
2.ColorJitter
在图像增强方法中经常使用的,尤其是自然图像,对色彩会有些偏差
3.Grayscale
4.RandomGrayscale(灰度的调整)
Grayscale是RandomGrayscale的特例,概率为1
5.RadomAffine
对图像进行空间几何变换
6.遮挡
7.transforms.Lambda
由于TenCrop返回的是tuple的形式,而transform的输入输出通常是PILimage或tensor的形式
crops是TenCrop的输出,长度为10的tuple,每一个元素是PILimage
transforms.Lambda返回的是4维张量
Transform Operation
对transform的选择操作
自定义transform方法
call函数:类的实例能被调用
总结
A9912616: 还是同样的错误
未来可期,期许未来: 你的报错是什么
一个努力向上的小狼君: 终于明白了~。博主好厉害啊~
xxxxyyyqq: 博主,我想问map是大中小的平均吗?他们有关系吗,不是平均的话,是为什么呢
盼盼编程: 不错