正文共:17787字 预计阅读时间:45分钟
1 引言
本篇,我们继续深入PyTorch图像分类建模, 相较于之前介绍的CIFAR10图像分类建模,本次我们选择了另一个更具挑战性的图像分类数据集——垃圾分类数据集。该数据集包含了更丰富的垃圾类别,共40类。因此,相比上篇CIFAR-10数据集,本次建模难度更大,需要用到更多优化方法。此次使用的垃圾分类数据集,可以在这里下载到。https://pan.baidu.com/share/init?surl=AnMwl7NmdSl8pGPA780juQ&pwd=x69p
具体来说,本篇中,我们会用到以下几方面前面文章介绍过的知识:
PyTorch数据集读取,我们加载了ImageFolder格式的数据;
查看并分析了数据集中图像的大小分布,为后续转换操作提供依据;
计算了数据集图像的均值与方差,并进行了归一化,这在“PyTorch那些事儿:图像增强”中介绍过;
构建了一个自定义的卷积神经网络,包含卷积层、全连接层等,这些内容在“PyTorch那些事儿:解析nn.modules”中介绍过;
使用交叉熵损失函数和SGD优化器对模型进行优化,这在之前的文中也多次用到过;
保存了模型参数,“PyTorch那些事儿:模型保存与加载”介绍过;
使用了迁移学习技术,加载了预训练的Resnet50模型进行微调,这是在“PyTorch那些事儿:迁移学习”中介绍的内容。
通过本文,我们可以更加深入运用前面学习的知识,针对实际的数据集完成一个完整的图像分类建模过程。同时也会引入一些新的方法,使建模效果更优。
2 数据集介绍
数据集存储在“data/垃圾分类数据集”,目录结构如下所示。垃圾分类数据集目录下包括json文件garbage_dict.json和三个分别用于存储训练集、测试集、验证集的目录。在这三个目录内,都分别有40个子目录,这40个子目录以0到39的序号命名,至于每个序号分别对应哪个垃圾分类,可以在garbage_dict.json找到对应关系。
data
└── 垃圾分类数据集
├── garbage_dict.json
├── test
├── 0
├── 1
├── 39
├── train
├── 0
├── 1
├── 39
└── val
├── 0
├── 1
├── 39
我们先导入后续过程可能需要用到的各个代码库:
import os
import json
from pathlib import Path
import torch
torch.__version__
Out:
'2.0.1+cu118'torch.cuda.is_available()
Out:
Truefrom torch import nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.models.resnet import resnet50
from torchvision.models import ResNet50_Weights
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from chb import show_image
from chb import bar
from chb._utils import EarlyStopping
img_lst = [img for img in Path().glob('./data/垃圾分类数据集/train/*/*.jpg')]
len(img_lst)
Out:
数据集中提供了一个garbage_dict.json文件,文件中记录了数据集目录中目录名与数据集图像类别的对应关系:
with open('./data/垃圾分类数据集/garbage_dict.json', 'r') as f:
garbage_dict = json.load(f)
定义本篇中需要使用的全局超参数:
# 判断是否存在GPU
device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 定义损失函数和优化器:探索图像大小分布
loss_fn = nn.CrossEntropyLoss()
num_epochs = 100
在大部分AI视觉建模过程中,都需要对输入图像的尺寸进行变换,做到统一尺寸输入到模型中。统一数据集图像尺寸,一般是在进行图像增强的时候调用transforms.Resize进行,需要考虑两个主要因素:计算资源限制和模型性能。
如果训练使用的硬件资源(如GPU)充足,并且不需要担心内存使用,那么可以将图像的尺寸设定得相对较大,例如设为均值400或稍大一些,如448。这样做的好处是能够保留更多的图像细节,有可能提高模型性能。但是,更大的图像尺寸会增加计算量,可能会导致训练速度变慢。
如果需要考虑计算资源的限制,或者希望提高训练速度,那么可以将图像的尺寸设定得较小,例如设为256或更小。这样做的好处是能够减少计算量,提高训练速度。但是,较小的图像尺寸可能会丢失一些图像细节,可能会对模型性能产生一定影响。
以下是一些常见的图像尺寸,这些尺寸经常在图像处理和计算机视觉的任务中使用:
224x224:这是一个非常常见的图像尺寸,在许多著名的神经网络模型中都有使用,例如VGG16、VGG19和ResNet。
256x256:这个尺寸也相当常见,一些图像处理和计算机视觉的模型会使用这个尺寸。
299x299:Google的Inception模型和Xception模型就使用了这个尺寸。
331x331:这个尺寸在NasNetLarge模型中使用。
512x512:这个尺寸在一些需要更高分辨率输入的任务中使用,例如某些医学图像处理任务。
32x32或64x64:这些较小的尺寸在一些经典的图像数据集例如CIFAR-10,CIFAR-100和MNIST中使用。
在我们建模过程中,也建议使用上面这些尺寸。这些尺寸并非随机选择,而是根据卷积,池化等操作的特性,以及模型在特定尺寸下的性能。
以299x299为例,这是Inception v3模型的输入尺寸。该尺寸的选择与模型的特定结构有关。Inception v3包含多个模块,每个模块包含多个卷积层,这些卷积层具有不同的感受野(卷积核大小)。模型的设计者发现,当输入尺寸为299x299时,模型可以在有效地利用计算资源的同时,实现最优的性能。
同样,VGG和ResNet模型选择224x224作为输入尺寸,也是因为在这个尺寸下,模型可以最好地平衡计算效率和性能。
注意,并非所有的数字都适合作为输入尺寸。例如,对于某些包含卷积层和池化层的模型,如果输入尺寸无法被2的n次方整除(n是卷积层和池化层的数量),可能会导致在计算过程中出现错误。因此,模型的设计者需要根据模型的具体结构和需求,选择合适的输入尺寸,上述常见的图像尺寸就是不错的选择。
当然,探索数据集中图像大小的分布情况还有其他作用,例如:
预处理决策:了解图像大小分布有助于做出更合适的预处理决策,比如如何需要缩放、裁剪或填充等。
性能和效率:如果图像大小非常不一致,那么在进行计算时可能会遇到效率问题。特别是一些需要固定输入大小的模型(比如CNN)需要将所有图像转换为相同的尺寸,这可能需要额外的计算时间。
数据不均衡:如果某一特定大小或尺寸的图像在数据集中过多或过少,可能会导致模型对这类图像过拟合或欠拟合。
模型选择:了解图像的大小和纵横比也有助于选择更适合的模型架构。某些模型可能更适合特定尺寸或纵横比的图像。
特征捕捉:对于小图像,细粒度的特征可能更难以捕捉;而对于大图像,全局特征可能更难以捕捉。了解这一点有助于在模型设计(例如,选择合适的卷积核大小)或数据增强(例如,不同程度的裁剪或缩放)时做出更好的决策。
内存和存储:如果数据集中的图像大小非常大,那么在存储和训练模型时可能需要更多的资源。
可解释性和调试:如果模型在某些特定大小的图像上表现不佳,了解这一点可能有助于更有效地调试模型。
下面,我们尝试对本次建模数据集中的的图像大小进行探索:
# 存储图像大小的列表
img_sizes = []
# 遍历图像路径列表
for img_path in bar(img_lst):
# 打开图像文件
img = Image.open(img_path)
# 获取图像的宽度和高度
width, height = img.size
# 将图像的宽度和高度添加到列表中
img_sizes.append((width, height))
# 将列表转换为Pandas DataFrame
df = pd.DataFrame(img_sizes, columns=['Width', 'Height'])
fig, ax =plt.subplots(1,2,constrained_layout=True, figsize=(10, 3))
axesSub = sns.histplot(x="Width", data=df, ax=ax[0])
axesSub.set_title('Width')
axesSub = sns.histplot(x="Height", data=df, ax=ax[1])
axesSub.set_title('Height')
100%|██████████████████████████████████████████████████| 11530/11530 [Time cost: 1.87041 s]
Out:
Text(0.5, 1.0, 'Height')
可以看到,数据集中图像宽度大致分布为均值为400左右的正态分布,高度分布大致为均值为400左右的正态分布。我们选择resize尺寸时,一般选择比均值略大一些的尺寸,保留大部分图像信息,另外,结合之前举例的常见尺寸,我们选择512 * 512的尺寸进行resize。
热门跟贴