文章目录
- 一、理论
- 1.1 为什么要用centerloss
- 1.2 centerloss的理论基础
- 二、损失
- 2.1 公式
- 三、代码
- 3.1 centerloss.py
- 3.2 nets.py
- 3.3 Train.py
- 3.4 效果展示(优化后)
- 四、优化总结
- 五、完整代码
高能警告:文末有所有完整代码(内含权重)~~
一、理论
网络效果的提升除了改变网络结构外,还有一群人在研究损失层的改进,本篇博文要介绍的就是较为新颖的center loss。
对于常见的图像分类问题,我们常常用softmax loss来求损失。以MNIST数据集为例,如果你的损失采用softmax loss,那么最后各个类别学出来的特征分布大概如下图:
一共10个类别(代表0-9十个数字),用不同的颜色表示。从上图可以看出不管是训练数据集还是测试数据集,都能看出比较清晰的类别界限。
1.1 为什么要用centerloss
在图像识别中,一个很关键的要素就是图像中提取出来的特征,它关乎着图像识别的精准度。而通常用的softmax输出函数提取到的特征之间往往接的很紧,无太大的明显界限。在根据这些特征做识别的时候会出现模棱两可的情况,那么怎么让提取到的特征之间差异性更大从而提高识别的正确率就成了图像识别的一个重大问题。
因此:centerloss应运而生
1.2 centerloss的理论基础
它的目的是给每个类别的特征加一个中心点,然后使这一类别的特征点与它的中心的距离总和作为一个损失,然后去优化这个损失,使他们彼此无限靠近。从理论层面上讲,当学习到一定程度后,每个类别的特征会集中为一个点上,但从实际上说,这几乎是不太可能的,只能说接近于重叠在一个点。
回到刚才的问题,在MNIST数据集中,如果你是采用softmax loss加上center loss的损失,那么最后各个类别的特征分布大概如下图。与上图相比,类间距离变大了,类内距离减少了(主要变化在于类内距离),这就是直观的结果。因此:我们得到结论:centerloss的主要功能:缩小类内聚
二、损失
2.1 公式
softmax函数公式:
softmaxloss公式:
centerloss公式:
总损失:(softmaxloss + centerloss)
关于centerloss的公式:
centerloss中关于Lc的梯度和cyi的更新公式如下:
因此上面关于cyi的更新的公式中,当yi(表示yi类别)和cj的类别j不一样的时候,cj是不需要更新的,只有当yi和j一样才需要更新。
三、代码
代码目录
3.1 centerloss.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import torch import torch.nn as nn class Centerloss(nn.Module): def __init__(self): super().__init__() self.center = nn.Parameter(torch.randn(10,2),requires_grad=True) def forward(self,features,ys,lambdas=2): center_exp = self.center.index_select(dim=0,index=ys.long()) count = torch.histc(ys,bins=int(max(ys).item()+1),min=0,max=int(max(ys).item())) count_exp = count.index_select(dim=0,index=ys.long()) loss = lambdas/2*torch.mean(torch.div(torch.sum(torch.pow(features-center_exp,2),dim=1),count_exp)) return loss if __name__ == '__main__':#测试 a = Centerloss() feature = torch.randn(5, 2, dtype=torch.float32) ys = torch.tensor([0, 0, 1, 0, 1,], dtype=torch.float32) b = a(feature, ys) |
3.2 nets.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch import torch.nn as nn import matplotlib.pyplot as plt class ConvolutionalLayer(nn.Module): def __init__(self,in_channels,out_channels,kernel_size,stride,padding,bias=False): super().__init__() self.cnn_layer = nn.Sequential( nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), nn.BatchNorm2d(out_channels), nn.PReLU() ) def forward(self,x): return self.cnn_layer(x) class MyNet(nn.Module): def __init__(self): super().__init__() self.conv_layer = nn.Sequential( ConvolutionalLayer(1, 32, 5, 1, 2), # 28-5+4+1=28 ConvolutionalLayer(32, 32, 5, 1, 2), # 28 nn.MaxPool2d(2, 2), ConvolutionalLayer(32, 64, 5, 1, 2), # 14 ConvolutionalLayer(64, 64, 5, 1, 2), # 14 nn.MaxPool2d(2, 2), ConvolutionalLayer(64, 128, 5, 1, 2), # 7 ConvolutionalLayer(128, 128, 5, 1, 2), # 7 nn.MaxPool2d(2, 2) # 3 ) self.features = nn.Linear(128*3*3,2) self.output = nn.Linear(2,10) def forward(self,x): y_conv = self.conv_layer(x) y_conv = torch.reshape(y_conv,[-1,128*3*3]) y_feature = self.features(y_conv) y_output = torch.log_softmax(self.output(y_feature),dim=1) return y_feature,y_output def visualize(self,feat,labels,epoch): color = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', '#ff00ff', '#990000', '#999900', '#009900', '#009999'] plt.clf() for i in range(10): plt.plot(feat[labels==i,0],feat[labels==i,1],".",c=color[i]) plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],loc="upper right") plt.title("epoch=%d" % epoch) plt.savefig("./images/epoch=%d.jpg" % epoch) |
3.3 Train.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import torch import torch.nn as nn import os from torchvision import transforms from nets import MyNet from centerloss import Centerloss import torch.optim.lr_scheduler as lr_scheduler from torch.utils import data from torchvision.datasets import MNIST class Trainer: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.loss_fn_cls = nn.NLLLoss() self.trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1309,),(0.3084,)) ]) def train(self): BATCH_SIZE = 100 save_path = "models/net_center.pth" if not os.path.exists("models"): os.mkdir("models") train_data = MNIST(root="./MNIST",train=True,download=False,transform=self.trans) train_loader = data.DataLoader(dataset=train_data,shuffle=True,batch_size=BATCH_SIZE) net = MyNet().to(self.device) c_net = Centerloss().to(self.device) if os.path.exists(save_path): net.load_state_dict(torch.load(save_path)) else: print("No Param") net_opt = torch.optim.SGD(net.parameters(),lr=0.001, momentum=0.9, weight_decay=0.0005) # net_opt = torch.optim.Adam(net.parameters()) scheduler = lr_scheduler.StepLR(net_opt, 20, gamma=0.8) c_net_opt = torch.optim.SGD(c_net.parameters(),lr=0.5) EPOCHS = 0 while True: feat_loader = [] label_loader = [] for i,(x,y) in enumerate(train_loader): x = x.to(self.device) y = y.to(self.device) feature,output = net(x) loss_cls = self.loss_fn_cls(output,y) y = y.float() loss_center = c_net(feature,y) loss = loss_cls+loss_center net_opt.zero_grad() c_net_opt.zero_grad() loss.backward() net_opt.step() c_net_opt.step() feat_loader.append(feature) label_loader.append(y) if i % 600 == 0: print("epoch:", EPOCHS, "i:", i, "total_loss:", (loss_cls.item() + loss_center.item()), "softmax_loss:", loss_cls.item(), "center_loss:", loss_center.item()) feat = torch.cat(feat_loader, 0) labels = torch.cat(label_loader, 0) net.visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), EPOCHS) EPOCHS += 1 torch.save(net.state_dict(), save_path) scheduler.step() # 150轮停止 if EPOCHS == 150: break if __name__ == '__main__': t = Trainer() t.train() |
3.4 效果展示(优化后)
四、优化总结
1 2 3 4 5 6 | 1.选择NLLloss效果比CrossEntropyLoss效果好,nllloss=log()+nllloss() 2.center loss 和网络分开优化,效果会更好,速度也更快(center loss learning rate=0.5) 3.使用SGD优化时,如果没有添加动量,则会在三十轮左右开始出现无法(难以)收敛的情况,如果仅4.仅增加动量,而没有人为更新学习率,则收敛速度超慢; 5.使用Adam优化时,速度比SGD更快,但效果欠佳; 6.最终搭配:NLLLOSS+SGD optmizer(momentum+lr updata) 7.关于网络,卷积比全连接效果略好,网络设计大一点效果会更好。 |
五、完整代码
完整代码请点这里链接:完整代码内含权重
提取码:5a04