You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.2 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#coding:utf8
# Copyright 2019 longpeng2008. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# If you find any problem,please contact us
#
# longpeng2008to2012@gmail.com
#
# or create issues
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
## 3层卷积神经网络simpleconv3定义
## 包括3个卷积层3个BN层3个ReLU激活层3个全连接层
class simpleconv3(nn.Module):
## 初始化函数
def __init__(self,nclass):
super(simpleconv3,self).__init__()
self.conv1 = nn.Conv2d(3, 12, 3, 2, 1) #输入图片大小为3*48*48输出特征图大小为12*24*24卷积核大小为3*3步长为2
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 24, 3, 2, 1) #输入图片大小为12*24*24输出特征图大小为24*12*12卷积核大小为3*3步长为2
self.bn2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 48, 3, 2, 1) #输入图片大小为24*12*12输出特征图大小为48*6*6卷积核大小为3*3步长为2
self.bn3 = nn.BatchNorm2d(48)
self.fc1 = nn.Linear(48 * 6 * 6 , 512) #输入向量长为48*6*6=1728输出向量长为512
self.fc2 = nn.Linear(512 , 128) #输入向量长为512输出向量长为128
self.fc3 = nn.Linear(128 , nclass) #输入向量长为128输出向量长为nclass等于类别数
## 前向函数
def forward(self, x):
## relu函数不需要进行实例化直接进行调用
## convfc层需要调用nn.Module进行实例化
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(-1 , 48 * 6 * 6)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
if __name__ == '__main__':
import torch
x = torch.randn(1,3,48,48).cuda()
model = simpleconv3(4).cuda()
y = model(x)
## 可视化
from visualize import make_dot
g = make_dot(y)
g.view()
## 统计参数信息
from torchsummary import summary
summary(model,input_size=(3,48,48))
print(model)