You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.2 KiB
Python

11 months ago
#coding:utf8
# Copyright 2019 longpeng2008. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# If you find any problem,please contact us
#
# longpeng2008to2012@gmail.com
#
# or create issues
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
## 3层卷积神经网络simpleconv3定义
## 包括3个卷积层3个BN层3个ReLU激活层3个全连接层
class simpleconv3(nn.Module):
## 初始化函数
def __init__(self,nclass):
super(simpleconv3,self).__init__()
self.conv1 = nn.Conv2d(3, 12, 3, 2, 1) #输入图片大小为3*48*48输出特征图大小为12*24*24卷积核大小为3*3步长为2
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 24, 3, 2, 1) #输入图片大小为12*24*24输出特征图大小为24*12*12卷积核大小为3*3步长为2
self.bn2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 48, 3, 2, 1) #输入图片大小为24*12*12输出特征图大小为48*6*6卷积核大小为3*3步长为2
self.bn3 = nn.BatchNorm2d(48)
self.fc1 = nn.Linear(48 * 6 * 6 , 512) #输入向量长为48*6*6=1728输出向量长为512
self.fc2 = nn.Linear(512 , 128) #输入向量长为512输出向量长为128
self.fc3 = nn.Linear(128 , nclass) #输入向量长为128输出向量长为nclass等于类别数
## 前向函数
def forward(self, x):
## relu函数不需要进行实例化直接进行调用
## convfc层需要调用nn.Module进行实例化
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(-1 , 48 * 6 * 6)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
if __name__ == '__main__':
import torch
x = torch.randn(1,3,48,48).cuda()
model = simpleconv3(4).cuda()
y = model(x)
## 可视化
from visualize import make_dot
g = make_dot(y)
g.view()
## 统计参数信息
from torchsummary import summary
summary(model,input_size=(3,48,48))
print(model)