LetNet网络结构
网络结构定义
定义一个网络结构首先要先定义一个类并继承torch.nn.Module类,在其构造函数中定义网络结构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| class LetNet(nn.Module): # 构造函数定义网络结构 def __init__(self): # 调用父类构造函数 super(LetNet, self).__init__() # 定义卷积层1 输入3通道,输出16通道,卷积核大小5*5 self.conv1 = nn.Conv2d(3,16,5) # input:(3,32,32) output:(16,28,28) # 定义池化层1 self.pool1 = nn.MaxPool2d(2,2) # input:(16,28,28) output:(16,14,14) # 定义卷积层2 输入16通道,输出32通道,卷积核大小5*5 self.conv2 = nn.Conv2d(16,32,5) # input:(16,14,14) output:(32,10,10) # 定义池化层2 self.pool2 = nn.MaxPool2d(2,2) # input:(32,10,10) output:(32,5,5) # 定义全连接层1 self.fc1 = nn.Linear(32*5*5,120) # input:(32*5*5) output:(120) # 定义全连接层2 self.fc2 = nn.Linear(120,84) # input:(120) output:(84) # 定义全连接层3 self.fc3 = nn.Linear(84,10) # input:(84) output:(10) (10个分类)
def forward(self, x): # 定义卷积层1的前向传播过程 x = F.relu(self.conv1(x)) # 定义池化层1的前向传播过程 x = self.pool1(x) # 定义卷积层2的前向传播过程 x = F.relu(self.conv2(x)) # 定义池化层2的前向传播过程 x = self.pool2(x) # 定义全连接层1的前向传播过程 x = x.view(-1, 32*5*5) # 将x变成一维向量 x = F.relu(self.fc1(x)) # 定义全连接层2的前向传播过程 x = F.relu(self.fc2(x)) # 定义全连接层3的前向传播过程 x = self.fc3(x) return x
# 测试网络 if __name__ == '__main__': import torch input1 = torch.randn([32,3,32,32]) model = LetNet() print(model) outpput = model(input1)
|
CIFAR-10数据集包含10个分类 60000张大小为32x32的彩色图片,每个类别6000张图片。训练集包含50000张图片测试集包含10000张图片。
导入数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| # 下载训练集 首次运行需要将download设置为True train_dataset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
# 加载训练集 batch_size每次训练的图片数 shuffle打乱数据集 num_workers子线程数 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
# 导入测试集 test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
# 加载测试集 test_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10000, shuffle=False, num_workers=0)
# 将测试集加载器转换为迭代器 test_iter = iter(test_loader) test_image, test_label = test_iter.next() # GPU加载数据 test_image, test_label = test_image.to(DEVICE), test_label.to(DEVICE)
|
因为我用的是GPU训练,所以数据集都需要在GPU中加载,否则在训练过程中会报错
训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| # 定义标签 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LetNet().to(DEVICE) loss_func = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.001) # lr学习率
for epoch in range(EPOCH): running_loss = 0.0 for i, data in enumerate(train_loader): # 获取输入数据 和标签 inputs, labels = data
# 将数据加载到DEVICE inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
# 将优化器梯度置0 optimizer.zero_grad()
# 前向传播 获取预测值 和 损失 (预测值与标签进行比较) 反向传播 计算梯度并进行反向传播更新参数 outputs = net(inputs) loss = loss_func(outputs, labels) loss.backward() optimizer.step()
# 计算平均损失 running_loss += loss.item() if i % 500 == 499: # 每500次输出一次损失 with torch.no_grad(): test_outputs = net(test_image)# 预测值 [batch_size, 10] predict_lable = torch.max(test_outputs, 1)[1] accuracy = (predict_lable == test_label).sum().item() / test_label.size(0) # 计算准确率 print('[%d, %5d] loss: %.3f, accuracy: %.3f' % (epoch + 1, i + 1, running_loss / 500, accuracy)) # 输出损失和准确率 running_loss = 0.0 # 每次损失置0
|
保存训练模型
1 2 3 4 5
| save_path = './model/letnet.pth' # 判断是文件夹是否存在 if not os.path.exists('./model'): os.mkdir('./model') torch.save(net.state_dict(), save_path)
|
预测
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| transform = transforms.Compose([ transforms.Resize((32, 32)), # 将图片转换为32*32 transforms.ToTensor(), # 将图片转换为Tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LetNet() net.load_state_dict(torch.load('./model/letnet.pth'))
img = Image.open('test.jpg') # 加载图片 [height, width, channel] img = transform(img) # 将图片转换成tensor [channel, height, width] img = torch.unsqueeze(img, 0) # 增加一个维度 [batch, channel, height, width]
with torch.no_grad(): output = net(img) predict = torch.max(output, 1)[1].data.numpy() # torch.softmax(output, 1) 计算概率 print(classes[int(predict)])
|
笔记根据B站UP主霹雳吧啦Wz视频合集【深度学习-图像分类篇章】学习整理