FFA-Net:文章理解于代码注释


转载自:https://blog.csdn.net/weixin_46773169/article/details/105462644,本文只做个人记录学习使用,版权归原作者所有。

github链接:https://github.com/zhilin007/FFA-Net

代码注释:

data_utils.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch.utils.data as data
import torchvision.transforms as tfs
from torchvision.transforms import functional as FF
import os, sys

sys.path.append('.')
sys.path.append('..')
import numpy as np
import torch
import random
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
from net.metrics import *                # metrics.py
from net.option import opt               # option.py

BS = opt.bs
print('BS:',BS)
crop_size = 'whole_img'   # 裁剪图片的大小
if opt.crop:
    crop_size = opt.crop_size

def tensorShow(tensors, titles=None):
    '''
        t:BCWH
    '''
    fig = plt.figure()
    for tensor, tit, i in zip(tensors, titles, range(len(tensors))):
        img = make_grid(tensor)
        npimg = img.numpy()
        ax = fig.add_subplot(211 + i)
        ax.imshow(np.transpose(npimg, (1, 2, 0)))
        ax.set_title(tit)
    plt.show()

class RESIDE_Dataset(data.Dataset):
    def __init__(self, path, train, size=crop_size, format='.png'):
        super(RESIDE_Dataset, self).__init__()
        self.size = size
        # print('crop size:', size) # ---本人测试命令
        self.train = train
        self.format = format
        self.haze_imgs_dir = os.listdir(os.path.join(path, 'hazy'))
        # 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中
        # print('self_haze_imgs_dir :', self.haze_imgs_dir) # 本人测试命令
        self.haze_imgs = [os.path.join(path, 'hazy', img) for img in self.haze_imgs_dir]
        # hazy图像所有的路径,并存放于一个列表中
        # print('self_haze_imgs:',self.haze_imgs) # ---本人测试命令
        self.clear_dir = os.path.join(path, 'clear')
        # print('self_clean:', self.clear_dir) # ---本人测试命令

    def __getitem__(self, index):
        haze = Image.open(self.haze_imgs[index])
        # print('haze_size:',haze.size,haze.size[::-1]) # ---本人测试命令
        # print('index:', index) # ---本人测试命令
        if isinstance(self.size, int):  # 如果size是int型,则返回True
            # print('这个isinstance方法被调用') # ---本人测试命令
            while haze.size[0] < self.size or haze.size[1] < self.size:
                index = random.randint(0, 20000)
                haze = Image.open(self.haze_imgs[index])
        img = self.haze_imgs[index]  # 从haze_imgs(路径名称列表)中取出对于索引值的路径
        # print('img:', img) # ---本人测试命令
        # id = img.split('/')[-1].split('_')[0] # 此命令在windows下执行会报路径错误,改为以下命令
        id = img.split('\')[-1].split('_')[0]
        # 提取最后‘\’之后和第一个‘_’之前的内容,以hazy图像的路径找到对应clear图像的序号
        # print('id:',id) # ---本人测试命令
        clear_name = id + self.format
        # print('clear_name:', clear_name) # ---本人测试命令
        # test_dir = os.path.join(self.clear_dir, clear_name) # ---本人测试命令
        # print('clear_dir:',test_dir) # ---本人测试命令
       
        clear = Image.open(os.path.join(self.clear_dir, clear_name))
        clear = tfs.CenterCrop(haze.size[::-1])(clear)
        # haze.size=(W, H) -> haze.size[::-1]=(H, W),然后按(H, W)对clear进行中心裁剪

        if not isinstance(self.size, str): # 如果size不是str类型,则返回True
            # print('这个not isinstance方法被调用')
            i, j, h, w = tfs.RandomCrop.get_params(haze, output_size=(self.size, self.size))
            '''
            w, h = haze.size
            th, tw = output_size
            i = random.randint(0, h - th)
            j = random.randint(0, w - tw)
            return i, j, th, tw
            '''
            haze = FF.crop(haze, i, j, h, w)  # 把haze随机裁剪成(i, j, h, w)的大小
            clear = FF.crop(clear, i, j, h, w)
        haze, clear = self.augData(haze.convert("RGB"), clear.convert("RGB"))
        # 使用数据增强后把图片转为RGB格式
        return haze, clear

    def augData(self, data, target):  # 数据增强
        if self.train:
            rand_hor = random.randint(0, 1)  # 从[0, 1]中随机选一个数
            rand_rot = random.randint(0, 3)  # 从[0, 1, 2, 3]中随机选一个数
            data = tfs.RandomHorizontalFlip(rand_hor)(data)
            # 依据概率rand_hor对data(图片)进行水平翻转(这里,rand_hor=0:不翻转;=1:翻转)
            target = tfs.RandomHorizontalFlip(rand_hor)(target)
            if rand_rot:  # rand_rot>0时执行此命令
                data = FF.rotate(data, 90 * rand_rot)  # 将data旋转的角度为90*rand_rot
                target = FF.rotate(target, 90 * rand_rot)
        data = tfs.ToTensor()(data)  # range [0, 255] -> [0.0, 1.0]
        data = tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152])(data)
        # 归一化操作
        # 输入的data(图片)大小为CxWxH(三维张量),mean为各通道的均值,std为各通道的方差
        # output = (input - mean) / std

        target = tfs.ToTensor()(target)
        return data, target

    def __len__(self):
        return len(self.haze_imgs)

import os
pwd = os.getcwd()
print(pwd)
# path = '/FFA-Net-master/data'  # path to your 'data' folder
path = '../data'  # path to your 'data' folder

ITS_train_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/ITS', train=True, size=crop_size), batch_size=BS,
                              shuffle=True)
ITS_test_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/SOTS/indoor', train=False, size='whole img'),
                             batch_size=1, shuffle=False)
OTS_train_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/OTS', train=True, format='.jpg'), batch_size=BS,
                               shuffle=True)
OTS_test_loader = DataLoader(dataset=RESIDE_Dataset(path + '/RESIDE/SOTS/outdoor', train=False, size='whole img', format='.png'), batch_size=1,
                              shuffle=False)
# 如果train_loader没有数据,即检查Dataset的__len__()函数输出为零,会报ValueError:num_samples...的错

if __name__ == "__main__":
    pass

option.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch,os,sys,torchvision,argparse
import torchvision.transforms as tfs
import time,math
import numpy as np
from torch.backends import cudnn
from torch import optim
import torch,warnings
from torch import nn
import torchvision.utils as vutils
warnings.filterwarnings('ignore')

parser=argparse.ArgumentParser()  # 命令行选项、参数和子命令解析器
'''
argparse 模块可以让人轻松编写用户友好的命令行接口。
程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。
argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
'''

# 添加参数
# default - 当参数未在命令行中出现时使用的值。
# type - 命令行参数应当被转换成的类型。
# action='store_true',只要运行时该变量有传参就将该变量设为True
parser.add_argument('--steps',type=int,default=10) # 10000
parser.add_argument('--device',type=str,default='Automatic detection')
parser.add_argument('--resume',type=bool,default=True)
parser.add_argument('--eval_step',type=int,default=5)  # 5000
parser.add_argument('--lr', default=0.0001, type=float, help='learning rate')
parser.add_argument('--model_dir',type=str,default='./trained_models/')
parser.add_argument('--trainset',type=str,default='its_train')
parser.add_argument('--testset',type=str,default='its_test')
parser.add_argument('--net',type=str,default='ffa')
parser.add_argument('--gps',type=int,default=3,help='residual_groups')
parser.add_argument('--blocks',type=int,default=20,help='residual_blocks')
parser.add_argument('--bs',type=int,default=16,help='batch size')
parser.add_argument('--crop',action='store_true')
parser.add_argument('--crop_size',type=int,default=240,help='Takes effect when using --crop ')
parser.add_argument('--no_lr_sche',action='store_true',help='no lr cos schedule')
parser.add_argument('--perloss',action='store_true',help='perceptual loss')

opt=parser.parse_args()  # 解析参数
opt.device='cuda' if torch.cuda.is_available() else 'cpu'
model_name=opt.trainset+'_'+opt.net.split('.')[0]+'_'+str(opt.gps)+'_'+str(opt.blocks)
# split('.')[0] , 以'.'作分隔符,输出'.'之前的内容

opt.model_dir=opt.model_dir+model_name+'.pk'
log_dir='logs/'+model_name

# ---以下为本人自己的测试命令---
# print('opt:', opt)
# print('model_name:', model_name)
# print('model_dir:',opt.model_dir)
# print('log_dir:', log_dir)

if not os.path.exists('trained_models'):
   os.mkdir('trained_models')  # 创建路径
if not os.path.exists('numpy_files'):
   os.mkdir('numpy_files')
if not os.path.exists('logs'):
   os.mkdir('logs')
if not os.path.exists('samples'):
   os.mkdir('samples')
if not os.path.exists(f"samples/{model_name}"):
   os.mkdir(f'samples/{model_name}')
if not os.path.exists(log_dir):
   os.mkdir(log_dir)

metrics.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from math import exp
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from  torchvision.transforms import ToPILImage

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)  # 添加一个轴,变成二维张量
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    # torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
    # torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
    # .t(), 求转置,输入tensor结构维度<=2D
    # 在二维张量前面添加2个轴,变成四维张量

    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    # 把张量扩展成(channel, 1, window_size, window_size)的大小,以原来的值填充(其自身的值不变)
    # contiguous:view只能用在contiguous的variable上。contiguous一般与transpose,permute,view搭配使用
    # 即使用transpose或permute进行维度变换后,需要用contiguous()来返回一个contiguous copy,然后方可使用view对维度进行变形

    return window

def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
    mu1_sq = mu1.pow(2)  # mul的2次方
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size=11, size_average=True):
    img1=torch.clamp(img1,min=0,max=1)
    # 将输入img1张量每个元素的范围限制到区间[min, max],返回结果到一个新张量。
    img2=torch.clamp(img2,min=0,max=1)

    (_, channel, _, _) = img1.size()  # 取出img1的通道数
    window = create_window(window_size, channel)
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)  # 将window张量转换为给定img1类型的张量
    return _ssim(img1, img2, window, window_size, channel, size_average)


def psnr(pred, gt):
    pred=pred.clamp(0,1).cpu().numpy() # 将gpu上的数据类型转为cpu上的数据类型,然后转化为numpy()数组
    gt=gt.clamp(0,1).cpu().numpy()
    imdff = pred - gt
    rmse = math.sqrt(np.mean(imdff ** 2))
    if rmse == 0:
        return 100
    return 20 * math.log10( 1.0 / rmse)

if __name__ == "__main__":
    pass

FFA.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch.nn as nn
import torch

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)  # '//'整数除法,'/'浮点数除法

class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            # PA层的卷积核不应该是3x3么,为什么这里是1x1?
            # 这样的话PA层与CA层只差一个全局平均池化操作的区别,而且1x1是抓通道特征,并不能实现像素注意的功能
            # 论文中“实施细节”处写道只有CA模块的卷积核为1x1,怀疑此处代码失误
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),  # inplace 原位操作,即不经过复制操作,而是直接在原来的内存上改变它的值
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            # 第一个'1'表示输出的通道数为1,即实现CxHxW -> 1xHxW
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.pa(x)
        return x * y

class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 自适应平均池化,输出大小为: 1 x 1,即把一张图片(HxW)的所有的值加起来取平均,大小变为1x1
        self.ca = nn.Sequential(
            # 这里,'1'表示卷积核的大小为1x1,这是实现特征注意功能的关键:
            # 用channel个channel//8层的conv2D 1x1滤镜作逐点卷积,抓通道相关性
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class Block(nn.Module):
    def __init__(self, conv, dim, kernel_size, ):
        super(Block, self).__init__()
        self.conv1 = conv(dim, dim, kernel_size, bias=True)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = conv(dim, dim, kernel_size, bias=True)
        self.calayer = CALayer(dim)
        self.palayer = PALayer(dim)

    def forward(self, x):
        res = self.act1(self.conv1(x))
        res = res + x
        res = self.conv2(res)
        res = self.calayer(res)
        res = self.palayer(res)
        res += x
        return res

class Group(nn.Module):
    def __init__(self, conv, dim, kernel_size, blocks):
        super(Group, self).__init__()
        modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]
        # moduels列表里有n(=blocks)个Block块

        modules.append(conv(dim, dim, kernel_size))
        self.gp = nn.Sequential(*modules)
        # modules列表前加*号,表示将列表解开成独立的参数。
        # 转化为Sequential模型,网络为n个Block块线性堆叠。

    def forward(self, x):
        res = self.gp(x)
        res += x
        return res

class FFA(nn.Module):
    def __init__(self, gps, blocks, conv=default_conv):
        super(FFA, self).__init__()
        self.gps = gps
        self.dim = 64
        kernel_size = 3
        pre_process = [conv(3, self.dim, kernel_size)]
        assert self.gps == 3
        self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)
        self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)
        self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)
        self.ca = nn.Sequential(*[
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, padding=0, bias=True),
            nn.Sigmoid()
        ])
        self.palayer = PALayer(self.dim)

        post_precess = [
            conv(self.dim, self.dim, kernel_size),
            conv(self.dim, 3, kernel_size)]

        self.pre = nn.Sequential(*pre_process)
        self.post = nn.Sequential(*post_precess)

    def forward(self, x1):
        x = self.pre(x1)
        res1 = self.g1(x)
        res2 = self.g2(res1)
        res3 = self.g3(res2)
        w = self.ca(torch.cat([res1, res2, res3], dim=1))
        # 按序号为1的轴进行拼接,即按通道进行拼接,每个res大小为([1, 64, H, W]),
        # cat后大小为([1, 192, H, W]),w.size() = ([1, 192, 1, 1])

        w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]  # 添加两个轴(元素是None)
        # w.size() = ([1, 3, 64, 1, 1])

        out = w[:, 0, ::] * res1 + w[:, 1, ::] * res2 + w[:, 2, ::] * res3
        # w的三个通道分别与res1,2,3相乘再相加,out.size()=([1, 64, H, W])

        out = self.palayer(out)
        x = self.post(out)
        return x + x1

if __name__ == "__main__":
    # 当.py文件被直接运行时,if __name__ == '__main__'之下的代码块将被运行;
    # 当.py文件以模块形式被导入时, if __name__ == '__main__'之下的代码块不被运行
    net = FFA(gps=3, blocks=19)
    print(net)

main.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch, os, sys, torchvision, argparse
import torchvision.transforms as tfs

from net.models.FFA import FFA # FFA.py
from net.metrics import psnr, ssim # metrics.py
from net.models import *
import time, math
import numpy as np
from torch.backends import cudnn
from torch import optim
import torch, warnings
from torch import nn
# from tensorboardX import SummaryWriter
import torchvision.utils as vutils

warnings.filterwarnings('ignore')
from net.option import opt, model_name, log_dir # option.py
from net.data_utils import *  # data_utils.py
from torchvision.models import vgg16

print('log_dir :', log_dir)
print('model_name:', model_name)

models_ = {
    'ffa': FFA(gps=opt.gps, blocks=opt.blocks),
}

loaders_ = {
    'its_train': ITS_train_loader,
    'its_test': ITS_test_loader,
    'ots_train': OTS_train_loader,
    'ots_test': OTS_test_loader
}

start_time = time.time()  # 返回当前时间的时间戳
T = opt.steps  # default=100000


def lr_schedule_cosdecay(t, T, init_lr=opt.lr):
    # 文章中公式(9),采用cosine annealing strategy进行学习率衰减,直到0
    lr = 0.5 * (1 + math.cos(t * math.pi / T)) * init_lr
    return lr


def train(net, loader_train, loader_test, optim, criterion):
    losses = []
    start_step = 0
    max_ssim = 0
    max_psnr = 0
    ssims = []
    psnrs = []
    if opt.resume and os.path.exists(opt.model_dir):  # 如果已有训练好的模型,返回true
        print(f'resume from {opt.model_dir}')  # 带f的print可以执行表达式
        ckp = torch.load(opt.model_dir)  # 将对象文件反序列化为内存
        losses = ckp['losses']  # 取出已训练好的模型的loss
        net.load_state_dict(ckp['model'])
        # 使用反序列化状态字典加载model’s参数字典
        # state_dict是个简单的Python dictionary对象,它将每个层映射到它的参数张量

        start_step = ckp['step']
        max_ssim = ckp['max_ssim']
        max_psnr = ckp['max_psnr']
        psnrs = ckp['psnrs']
        ssims = ckp['ssims']
        print(f'start_step:{start_step} start training ---')
    else:
        print('train from scratch *** ')
    for step in range(start_step + 1, opt.steps + 1):  # opt.steps=10(default)
        net.train()  # 定义的网络进入训练模式
        lr = opt.lr
        if not opt.no_lr_sche:
            lr = lr_schedule_cosdecay(step, T)
            for param_group in optim.param_groups:  # 在训练中动态的调整学习率
                param_group["lr"] = lr
        x, y = next(iter(loader_train))
        # 读取一个读取一个batch的数据,batch size=16时实际对应16张图像
        # dataloader本质上是一个可迭代对象,使用iter()访问,不能使用next()访问;
        # 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问

        x = x.to(opt.device)  # 若opt.device=cuda,即转移到GPU运行
        y = y.to(opt.device)
        out = net(x)  # 把x输入网络训练
        loss = criterion[0](out, y)
        if opt.perloss:  # Perceptual loss为L1损失和L2损失的加权和
            loss2 = criterion[1](out, y)
            loss = loss + 0.04 * loss2

        loss.backward()  # 反向传播求梯度

        optim.step()  # 更新参数
        optim.zero_grad()  # 清除梯度,为下一个batch训练做准备
        losses.append(loss.item())  # loss是个标量,item表示取出这个标量,然后放入losses中
        print(
            f'\rtrain loss : {loss.item():.5f}| step :{step}/{opt.steps}|lr :{lr :.7f} |time_used :{(time.time() - start_time) / 60 :.1f}',
            end='', flush=True)

        # with SummaryWriter(logdir=log_dir,comment=log_dir) as writer:
        #  writer.add_scalar('data/loss',loss,step)

        if step % opt.eval_step == 0:  # default=5000
            with torch.no_grad():  # 切断梯度计算,不会进行反向传播,因为SSIM和PSNR的计算不需要
                ssim_eval, psnr_eval = test(net, loader_test, max_psnr, max_ssim, step)  # 计算SSIM,PSNR

            print(f'\nstep :{step} |ssim:{ssim_eval:.4f}| psnr:{psnr_eval:.4f}')

            # with SummaryWriter(logdir=log_dir,comment=log_dir) as writer:
            #  writer.add_scalar('data/ssim',ssim_eval,step)
            #  writer.add_scalar('data/psnr',psnr_eval,step)
            #  writer.add_scalars('group',{
            #     'ssim':ssim_eval,
            #     'psnr':psnr_eval,
            #     'loss':loss
            #  },step)
            ssims.append(ssim_eval)
            psnrs.append(psnr_eval)
            if ssim_eval > max_ssim and psnr_eval > max_psnr:
                max_ssim = max(max_ssim, ssim_eval)
                max_psnr = max(max_psnr, psnr_eval)
                torch.save({
                    'step': step,
                    'max_psnr': max_psnr,
                    'max_ssim': max_ssim,
                    'ssims': ssims,
                    'psnrs': psnrs,
                    'losses': losses,
                    'model': net.state_dict()
                }, opt.model_dir)  # 保存各项参数到model_dir中
                print(f'\n model saved at step :{step}| max_psnr:{max_psnr:.4f}|max_ssim:{max_ssim:.4f}')
               
    # 把参数保存为.npy文件
    np.save(f'./numpy_files/{model_name}_{opt.steps}_losses.npy', losses)    
    np.save(f'./numpy_files/{model_name}_{opt.steps}_ssims.npy', ssims)
    np.save(f'./numpy_files/{model_name}_{opt.steps}_psnrs.npy', psnrs)


def test(net, loader_test, max_psnr, max_ssim, step):
    net.eval()  # 网络参数会被固定,权值不会被更新
    torch.cuda.empty_cache()  # 清空显存
    ssims = []
    psnrs = []
    # s=True
    for i, (inputs, targets) in enumerate(loader_test):
        inputs = inputs.to(opt.device)
        targets = targets.to(opt.device)
        pred = net(inputs)
        # # print(pred)
        # tfs.ToPILImage()(torch.squeeze(targets.cpu())).save('111.png')
        # vutils.save_image(targets.cpu(),'target.png')
        # vutils.save_image(pred.cpu(),'pred.png')
        ssim1 = ssim(pred, targets).item()
        psnr1 = psnr(pred, targets)
        ssims.append(ssim1)
        psnrs.append(psnr1)
    # if (psnr1>max_psnr or ssim1 > max_ssim) and s :
    #     ts=vutils.make_grid([torch.squeeze(inputs.cpu()),torch.squeeze(targets.cpu()),torch.squeeze(pred.clamp(0,1).cpu())])
    #     vutils.save_image(ts,f'samples/{model_name}/{step}_{psnr1:.4}_{ssim1:.4}.png')
    #     s=False
    return np.mean(ssims), np.mean(psnrs)

if __name__ == "__main__":
    '''
    直接执行该模块(main.py),此时__name__=main.py,以下语句才会被执行;
    如果该模块 import 到其他模块中,此时__name__=main,以下语句不会被执行,。
    '''
    loader_train = loaders_[opt.trainset]
    loader_test = loaders_[opt.testset]
    net = models_[opt.net]
    net = net.to(opt.device)
    if opt.device == 'cuda':
        net = torch.nn.DataParallel(net)
        # 在多个GPU上并行计算,是将输入一个batch的数据均分成多份,分别送到对应的GPU进行计算,各个GPU得到的梯度累加。
        # cudnn.benchmark = True让内置的cuDNN 的auto-tuner自动寻找最适合当前配置的高效算法,来达到优化运行的效率
       
    criterion = []
    criterion.append(nn.L1Loss().to(opt.device))  # 采用L1损失,放入certerion[0]中
    if opt.perloss:
        vgg_model = vgg16(pretrained=True).features[:16]
        # 使用预训练的权重,只调用特征提取部分的前16层,分类部分已抛弃掉
        vgg_model = vgg_model.to(opt.device)
        for param in vgg_model.parameters():
            param.requires_grad = False  # vgg_model不进行梯度计算
        criterion.append(PerLoss(vgg_model).to(opt.device))  # 计算的Perceptual loss损失放入criterion[1]中

    optimizer = optim.Adam(params=filter(lambda x: x.requires_grad, net.parameters()), lr=opt.lr, betas=(0.9, 0.999),
                           eps=1e-08)
    # filter函数将net模型中属性requires_grad = True的参数筛选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新
    optimizer.zero_grad()

    train(net, loader_train, loader_test, optimizer, criterion)
n.L1Loss().to(opt.device))  # 采用L1损失,放入certerion[0]中
    if opt.perloss:
        vgg_model = vgg16(pretrained=True).features[:16]
        # 使用预训练的权重,只调用特征提取部分的前16层,分类部分已抛弃掉
        vgg_model = vgg_model.to(opt.device)
        for param in vgg_model.parameters():
            param.requires_grad = False  # vgg_model不进行梯度计算
        criterion.append(PerLoss(vgg_model).to(opt.device))  # 计算的Perceptual loss损失放入criterion[1]中

    optimizer = optim.Adam(params=filter(lambda x: x.requires_grad, net.parameters()), lr=opt.lr, betas=(0.9, 0.999),
                           eps=1e-08)
    # filter函数将net模型中属性requires_grad = True的参数筛选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新
    optimizer.zero_grad()

    train(net, loader_train, loader_test, optimizer, criterion)

附:为方便理解网络,将FFA.py的blocks改为1

1
2
3
if __name__ == "__main__":
    net = FFA(gps=3, blocks=1)  # blocks改为1
    print(net)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
FFA(
  (g1): Group(
    (gp): Sequential(
      (0): Block(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (calayer): CALayer(
          (avg_pool): AdaptiveAvgPool2d(output_size=1)
          (ca): Sequential(
            (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
            (3): Sigmoid()
          )
        )
        (palayer): PALayer(
          (pa): Sequential(
            (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
            (3): Sigmoid()
          )
        )
      )
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (g2): Group(
    (gp): Sequential(
      (0): Block(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (calayer): CALayer(
          (avg_pool): AdaptiveAvgPool2d(output_size=1)
          (ca): Sequential(
            (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
            (3): Sigmoid()
          )
        )
        (palayer): PALayer(
          (pa): Sequential(
            (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
            (3): Sigmoid()
          )
        )
      )
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (g3): Group(
    (gp): Sequential(
      (0): Block(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (calayer): CALayer(
          (avg_pool): AdaptiveAvgPool2d(output_size=1)
          (ca): Sequential(
            (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
            (3): Sigmoid()
          )
        )
        (palayer): PALayer(
          (pa): Sequential(
            (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
            (3): Sigmoid()
          )
        )
      )
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (ca): Sequential(
    (0): AdaptiveAvgPool2d(output_size=1)
    (1): Conv2d(192, 4, kernel_size=(1, 1), stride=(1, 1))
    (2): ReLU(inplace=True)
    (3): Conv2d(4, 192, kernel_size=(1, 1), stride=(1, 1))
    (4): Sigmoid()
  )
  (palayer): PALayer(
    (pa): Sequential(
      (0): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
      (3): Sigmoid()
    )
  )
  (pre): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (post): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

用summary()调出每层的输出大小和参数

1
pip install torchsummary

在FFA.py末添加:

1
2
from torchsummary import summary
summary(net, input_size=(3, 64, 64), batch_size=1)

结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [1, 64, 64, 64]           1,792
            Conv2d-2            [1, 64, 64, 64]          36,928
              ReLU-3            [1, 64, 64, 64]               0
            Conv2d-4            [1, 64, 64, 64]          36,928
 AdaptiveAvgPool2d-5              [1, 64, 1, 1]               0
            Conv2d-6               [1, 8, 1, 1]             520
              ReLU-7               [1, 8, 1, 1]               0
            Conv2d-8              [1, 64, 1, 1]             576
           Sigmoid-9              [1, 64, 1, 1]               0
          CALayer-10            [1, 64, 64, 64]               0
           Conv2d-11             [1, 8, 64, 64]             520
             ReLU-12             [1, 8, 64, 64]               0
           Conv2d-13             [1, 1, 64, 64]               9
          Sigmoid-14             [1, 1, 64, 64]               0
          PALayer-15            [1, 64, 64, 64]               0
            Block-16            [1, 64, 64, 64]               0
           Conv2d-17            [1, 64, 64, 64]          36,928
            Group-18            [1, 64, 64, 64]               0
           Conv2d-19            [1, 64, 64, 64]          36,928
             ReLU-20            [1, 64, 64, 64]               0
           Conv2d-21            [1, 64, 64, 64]          36,928
AdaptiveAvgPool2d-22              [1, 64, 1, 1]               0
           Conv2d-23               [1, 8, 1, 1]             520
             ReLU-24               [1, 8, 1, 1]               0
           Conv2d-25              [1, 64, 1, 1]             576
          Sigmoid-26              [1, 64, 1, 1]               0
          CALayer-27            [1, 64, 64, 64]               0
           Conv2d-28             [1, 8, 64, 64]             520
             ReLU-29             [1, 8, 64, 64]               0
           Conv2d-30             [1, 1, 64, 64]               9
          Sigmoid-31             [1, 1, 64, 64]               0
          PALayer-32            [1, 64, 64, 64]               0
            Block-33            [1, 64, 64, 64]               0
           Conv2d-34            [1, 64, 64, 64]          36,928
            Group-35            [1, 64, 64, 64]               0
           Conv2d-36            [1, 64, 64, 64]          36,928
             ReLU-37            [1, 64, 64, 64]               0
           Conv2d-38            [1, 64, 64, 64]          36,928
AdaptiveAvgPool2d-39              [1, 64, 1, 1]               0
           Conv2d-40               [1, 8, 1, 1]             520
             ReLU-41               [1, 8, 1, 1]               0
           Conv2d-42              [1, 64, 1, 1]             576
          Sigmoid-43              [1, 64, 1, 1]               0
          CALayer-44            [1, 64, 64, 64]               0
           Conv2d-45             [1, 8, 64, 64]             520
             ReLU-46             [1, 8, 64, 64]               0
           Conv2d-47             [1, 1, 64, 64]               9
          Sigmoid-48             [1, 1, 64, 64]               0
          PALayer-49            [1, 64, 64, 64]               0
            Block-50            [1, 64, 64, 64]               0
           Conv2d-51            [1, 64, 64, 64]          36,928
            Group-52            [1, 64, 64, 64]               0
AdaptiveAvgPool2d-53             [1, 192, 1, 1]               0
           Conv2d-54               [1, 4, 1, 1]             772
             ReLU-55               [1, 4, 1, 1]               0
           Conv2d-56             [1, 192, 1, 1]             960
          Sigmoid-57             [1, 192, 1, 1]               0
           Conv2d-58             [1, 8, 64, 64]             520
             ReLU-59             [1, 8, 64, 64]               0
           Conv2d-60             [1, 1, 64, 64]               9
          Sigmoid-61             [1, 1, 64, 64]               0
          PALayer-62            [1, 64, 64, 64]               0
           Conv2d-63            [1, 64, 64, 64]          36,928
           Conv2d-64             [1, 3, 64, 64]           1,731
================================================================
Total params: 379,939
Trainable params: 379,939
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 56.35
Params size (MB): 1.45
Estimated Total Size (MB): 57.85
----------------------------------------------------------------

注: CALayer-10即为残差连接部分,对应于Class CALayer 中最后一条语句 return x * y。假如summary()中不指定batch_size,那么Output Shape 的第一个轴将为-1。

总结:
整个网络由1个卷积层+3个群结构+Concatenate模块+1个CA模块+1个PA模块组成+2个卷积层组成,其中,每个群结构包含19个基础块结构,每个基础块结构又由1个卷积层+1个relu层+1个卷积层+1个CA模块+1个PA模块组成,CA和PA模块详细见“主要内容”部分,另外通过长跳和短跳残差连接绕过薄雾或低频区域等不太重要的信息,使得信息的流动更加容易。一般网络越深(如大于400层),网络训练将更加困难,使用残差连接能够让很深的网络训练更加容易。本文网络共704层,训练总参数:4455913。
疑问:作者在PA模块代码中使用1x1卷积核和论文描述不符。(见疑惑1)
疑问:在PA模块中,实现像素注意的原理。(见疑惑2)