1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | __author__ = 'Administrator' import numpy as np from operator import itemgetter import matplotlib.pyplot as plt import torch from torch.utils.data import DataLoader import torchvision.datasets as dsets import matplotlib.pyplot as plt import torchvision.transforms as transforms batch_size = 100 # MNIST dataset train_dataset = dsets.MNIST(root = '/ml/pymnist', #选择数据的根目录, train = True, #选择训练集 transform = None, #不考虑使用任何数据预处理 download = True) #从网络上下载图片 test_dataset = dsets.MNIST(root = '/ml/pymnist', #选择数据的根目录 train = False, #选择测试集 transform = None, #不考虑使用任何数据预处理 download = True) #从网络上下载图片 #加载数据 train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True) #将数据打乱 test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = True) print("train_data:", train_dataset.train_data.size()) print("train_labels:", train_dataset.train_labels.size()) print("test_data:", test_dataset.test_data.size()) print("test_labels:", test_dataset.test_labels.size()) digit = train_loader.dataset.train_data[0] #取第一个图片的 plt.imshow(digit,cmap=plt.cm.binary) plt.show() print(train_loader.dataset.train_labels[0]) #输出对应的标 def predict(self,k, dis, X_test): assert dis == 'E' or dis == 'M', 'dis must E or M' num_test = X_test.shape[0] #测试样本的数量 labellist = [] #使用欧式距离度量 if (dis == 'E'): for i in range(num_test): distances = np.sqrt(np.sum(((self.Xtr - np.tile(X_test[i], (self.Xtr.shape[0], 1))) ** 2), axis=1)) nearest_k = np.argsort(distances) topK = nearest_k[:k] classCount = {} for i in topK: classCount[self.ytr[i]] = classCount.get(self.ytr[i], 0) + 1 sortedClassCount = sorted(classCount.items(), key=itemgetter(1), reverse=True) labellist.append(sortedClassCount[0][0]) return np.array(labellist) def getXmean(X_train): X_train = np.reshape(X_train, (X_train.shape[0], -1)) #将图片从二维展开为一维 mean_image = np.mean(X_train, axis=0) #求出训练集中所有图片每个像素位置上的平均值 return mean_image def centralized(X_test,mean_image): X_test = np.reshape(X_test, (X_test.shape[0], -1)) #将图片从二维展开为一维 X_test = X_test.astype(np.float) X_test -= mean_image #减去均值图像,实现零均值化 return X_test def kNN_classify(k,dis,X_train,x_train,Y_test): assert dis == 'E' or dis == 'M', 'dis must E or M,E代表欧式距离,M代表曼哈顿距离' num_test = Y_test.shape[0] #测试样本的数量 labellist = [] if (dis == 'E'): for i in range(num_test): #np.tile:沿X轴复制或沿XY轴复制 #实现欧式距离公式 distances = np.sqrt(np.sum(((X_train - np.tile(Y_test[i], (X_train.shape[0], 1))) ** 2), axis=1)) nearest_k = np.argsort(distances) #距离由小到大进行排序,并返回index值 topK = nearest_k[:k] #选取前k个距离 classCount = {} for i in topK: #统计每个类别的个数 classCount[x_train[i]] = classCount.get(x_train[i],0) + 1 sortedClassCount = sorted(classCount.items(),key=itemgetter(1),reverse=True) labellist.append(sortedClassCount[0][0]) return np.array(labellist) if (dis=='M'): for i in range(num_test): #按照列的方向相加,其实就是行相加 distances = np.sum(np.abs(X_train - np.tile(Y_test[i], (X_train.shape[0], 1))), axis=1) nearest_k = np.argsort(distances) topK = nearest_k[:k] classCount = {} for i in topK: classCount[x_train[i]] = classCount.get(x_train[i], 0) + 1 % sortedClassCount = sorted(classCount.items(), key=itemgetter(1), reverse=True) labellist.append(sortedClassCount[0][0]) return np.array(labellist) if __name__=='__main__': X_train = train_loader.dataset.train_data.numpy() mean_image = getXmean(X_train) X_train = centralized(X_train,mean_image) y_train = train_loader.dataset.train_labels.numpy() X_test = test_loader.dataset.test_data[:1000].numpy() X_test = centralized(X_test,mean_image) y_test = test_loader.dataset.test_labels[:1000].numpy() num_test = y_test.shape[0] y_test_pred = kNN_classify(5, 'M', X_train, y_train, X_test) num_correct = np.sum(y_test_pred == y_test) accuracy = float(num_correct) / num_test print('Got %d / %d correct => accuracy: %f' % (num_correct, num_test,accuracy)) ## 准确率96% |