You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
5.3 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#coding:utf8
# Copyright 2019 longpeng2008. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# If you find any problem,please contact us
#
# longpeng2008to2012@gmail.com
#
# or create issues
# =============================================================================
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
import os
from net import simpleconv3
## 使用tensorboardX进行可视化
from tensorboardX import SummaryWriter
writer = SummaryWriter('logs') ## 创建一个SummaryWriter的示例默认目录名字为runs
## 训练主函数
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
model.train(True) ## 设置为训练模式
else:
model.train(False) ## 设置为验证模式
running_loss = 0.0 ##损失变量
running_accs = 0.0 ##精度变量
number_batch = 0 ##
## 从dataloaders中获得数据
for data in dataloaders[phase]:
inputs, labels = data
if use_gpu:
inputs = inputs.cuda()
labels = labels.cuda()
optimizer.zero_grad() ##清空梯度
outputs = model(inputs) ##前向运行
_, preds = torch.max(outputs.data, 1) ##使用max()函数对输出值进行操作,得到预测值索引
loss = criterion(outputs, labels) ##计算损失
if phase == 'train':
loss.backward() ##误差反向传播
optimizer.step() ##参数更新
scheduler.step() ## 调整学习率
running_loss += loss.data.item()
running_accs += torch.sum(preds == labels).item()
number_batch += 1
## 得到每一个epoch的平均损失与精度
epoch_loss = running_loss / number_batch
epoch_acc = running_accs / dataset_sizes[phase]
## 收集精度和损失用于可视化
if phase == 'train':
writer.add_scalar('data/trainloss', epoch_loss, epoch)
writer.add_scalar('data/trainacc', epoch_acc, epoch)
else:
writer.add_scalar('data/valloss', epoch_loss, epoch)
writer.add_scalar('data/valacc', epoch_acc, epoch)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
writer.close()
return model
if __name__ == '__main__':
crop_size = 48 ##图像裁剪大小,即训练输入大小
nclass = 4 ##分类类别数
model = simpleconv3(nclass) ##创建模型
data_dir = './data' ##数据目录
## 模型缓存接口
if not os.path.exists('models'):
os.mkdir('models')
## 检查GPU是否可用如果是使用GPU否使用CPU
use_gpu = torch.cuda.is_available()
if use_gpu:
model = model.cuda()
print(model)
## 创建数据预处理函数,训练预处理包括随机裁剪缩放、随机翻转、归一化,验证预处理包括中心裁剪,归一化
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
## 使用torchvision的dataset ImageFolder接口读取数据
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
## 创建数据指针设置batch大小shuffle多进程数量
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=96,
shuffle=True,
num_workers=4) for x in ['train', 'val']}
## 获得数据集大小
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
## 优化目标使用交叉熵优化方法使用带动量项的SGD学习率迭代策略为step每隔100个epoch变为原来的0.1倍
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
step_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)
model = train_model(model=model,
criterion=criterion,
optimizer=optimizer_ft,
scheduler=step_lr_scheduler,
num_epochs=300)
torch.save(model.state_dict(),'models/model.pt')