Torch、Java、Milvus快速搭建以图搜图系统
1 原理概述
以图搜图大致原理(口水话版)
以图搜图,即通过一张图片去匹配数据库中的图片,找到最相似的N张图。在我们普通的搜索系统中,文字匹配的搜索单纯的MySQL数据库就能实现简单的搜索,但是图片就存在很多难点。
1、首先要解决的是图片怎么表达的问题,肯定不会是每个像素点去匹配,而是对图像提取特征。在传统的数字图像处理中,图像的特征有很多:颜色特征、纹理特征、关键点特征、几何特征,可以将具有代表性的特征提取处理归一化后形成一个多维向量去表示图片。在深度学习如火如荼的时代,卷积神经网络能更好的做到特征提取这个工作。
2、特征提取到了,自然而然的就是将每个图片的特征(即一个向量)存入数据库,要搜索一张图片时就去数据库匹配。第二个问题就是如何去匹配图片,两个向量相等?当然不是。我们用距离来表达两个向量的相似程度,距离越近就越相似。距离用得最多的就是欧式距离和余弦距离(简单来说区别就是欧氏距离体现数值上的差异、余弦距离体现方向上的相对差异)。
3、怎么判断两个图片是否相似解决了,通过距离!第三个问题:来一张图时去数据库查询怎么查?一个一个匹配,最后排个序?当然不是!MySQL可以建索引,这个好像建索引也无从下手。这里就需要借助向量搜索引擎了。目前开源的向量搜索引擎还是有很多的,这里采用Milvus这个开源项目实现向量搜索引擎,详细了解的去自行百度。
2、ResNet提取深度特征向量
环境:Pytorch1.1 python3.6 cuda9.0 采用pretrainedmodels库快速搭建ResNet(pip安装即可)
几行代码搭建出一个特征提取网络
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | from torch.autograd import Variable import torch import torch.nn as nn import torchvision.transforms as transforms import pretrainedmodels from PIL import Image TARGET_IMG_SIZE = 224 img_to_tensor = transforms.ToTensor() def get_seresnet50(): encoder = pretrainedmodels.se_resnet50() model = nn.Sequential(encoder.layer0, encoder.layer1, encoder.layer2, encoder.layer3, encoder.layer4, encoder.avg_pool # 平均池化,张成一个[batchSize,2048]的特征向量 ) for param in model.parameters(): param.requires_grad = False model.cuda() # 使用GPU,CPU版去掉 model.eval() return model # 特征提取 def extract_feature(model, imgpath): img = Image.open(imgpath) # 读取图片 img = img.resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE)) tensor = img_to_tensor(img) # 将图片矩阵转化成tensor tensor = tensor.cuda() # GPU tensor = torch.unsqueeze(tensor, 0) result = model(Variable(tensor)) result_npy = result.data.cpu().numpy()[0].ravel().tolist() return result_npy |
利用serverSocket搭建服务器端,Java端通信,调用python的特征提取。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import socket import threading import json from model import extract_feature, get_seresnet50 def main(): # 创建服务器套接字 serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 获取本地主机名称 host = socket.gethostname() # 设置一个端口 port = 12345 # 将套接字与本地主机和端口绑定 serversocket.bind((host, port)) # 设置监听最大连接数 serversocket.listen(10) # 模型创建 model = get_seresnet50() print("等待连接") while True: # 获取一个客户端连接 clientsocket, addr = serversocket.accept() print("连接地址:%s" % str(addr)) try: t = ServerThreading(model, clientsocket) # 为每一个请求开启一个处理线程 t.start() except Exception as identifier: print(identifier) pass serversocket.close() pass class ServerThreading(threading.Thread): def __init__(self, model, clientsocket, recvsize=1024 * 1024, encoding="utf-8"): threading.Thread.__init__(self) self.model = model self._socket = clientsocket self._recvsize = recvsize self._encoding = encoding pass def run(self): print("开启线程.....") try: # 接受数据 msg = '' while True: # 读取recvsize个字节 rec = self._socket.recv(self._recvsize) # 解码 msg += rec.decode(self._encoding) # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕, # 所以需要自定义协议标志数据接受完毕 if msg.strip().endswith('over'): msg = msg[:-4] break # 解析json格式的数据 # 调用神经网络模型处理请求 res = extract_feature(self.model, msg) sendmsg = json.dumps(res) print(sendmsg) # 发送数据 self._socket.send(("%s" % sendmsg).encode(self._encoding)) except Exception as identifier: self._socket.send("500".encode(self._encoding)) print(identifier) pass finally: self._socket.close() print("任务结束.....") if __name__ == "__main__": main() |
3、Java端调用Python函数提取特征
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | package top.maolaoe.imgsearch.service; import org.springframework.stereotype.Service; import java.io.*; import java.net.Socket; import java.util.ArrayList; import java.util.List; /** * 通过socket调用python获得图片的特征向量 */ @Service public class FeatureService { private String HOST = "192.168.1.103"; private final int PORT = 12345; public List<Float> remoteCall(String path){ // 访问服务进程的套接字 // System.out.println("调用远程接口:host=>"+HOST+",port=>"+PORT); try(Socket socket = new Socket(HOST, PORT)) { // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号 // 获取输出流对象 OutputStream os = socket.getOutputStream(); PrintStream out = new PrintStream(os); // 发送内容 out.print(path); // 告诉服务进程,内容发送完毕,可以开始处理 out.print("over"); // 获取服务进程的输入流 InputStream is = socket.getInputStream(); BufferedReader br = new BufferedReader(new InputStreamReader(is,"utf-8")); String tmp = null; StringBuilder sb = new StringBuilder(); // 读取内容 while((tmp=br.readLine())!=null) sb.append(tmp).append('\n'); // 解析结果 tmp = sb.toString().substring(1, sb.length()-2); String[] split = tmp.split(","); List<Float> list = new ArrayList<Float>(split.length); for (int i = 0; i < split.length; i++) { list.add(Float.valueOf(split[i])); split[i] = null; } // System.out.println(list); // System.out.println(list.size()); return list; } catch (IOException e) { e.printStackTrace(); } return null; } public static void main(String[] args) throws IOException { FeatureService featureService = new FeatureService(); featureService.remoteCall("E:\\data\\tx.jpg"); } } |
4、安装启动milvus向量搜索引擎
官方教程
注意修改配置文件中的内存大小以适应自己的机器,否则docker启动时报错。
5、编写milvus插入和搜索向量的方法
引入依赖:milvus中的guava容易与其他包冲突,单独引入
1 2 3 4 5 6 7 8 9 10 11 | <dependency> <groupId>io.milvus</groupId> <artifactId>milvus-sdk-java</artifactId> <exclusions> <exclusion> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> </exclusion> </exclusions> <version>0.8.2</version> </dependency> |
提供search和insert功能
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | package top.maolaoe.imgsearch.service; import com.google.gson.JsonObject; import io.milvus.client.*; import org.springframework.stereotype.Service; import java.util.List; @Service public class MilvusService { private MilvusClient client = new MilvusGrpcClient(); private String collectionName; MilvusService(String host, int port, String collectionName, int nlist){ this.collectionName = collectionName; //建立连接 ConnectParam connectParam = new ConnectParam.Builder().withHost(host).withPort(port).build(); try { client.connect(connectParam); //连接正常则创建collection Response responseCollect = createCollect(collectionName, 2048, 1024, MetricType.IP); System.out.println(responseCollect.getMessage()); //创建索引 if(createIndex(nlist)){ System.out.println("创建索引"); }else { System.out.println("索引创建失败"); } } catch (ConnectFailedException e) { System.out.println("连接失败!"); } } MilvusService(String host, int port, String collectionName){ this(host, port, collectionName, 1000); } MilvusService(){ this("localhost", 19530, "imgsearch"); } /** * 插入特征向量 * @param features * @return */ public List<Long> insertFeatures(List<List<Float>> features){ //先判断是否正常连接 boolean connected = client.isConnected(); if(!connected){ System.out.println("连接失败!!"); } //插入特征向量 InsertParam insertParam = new InsertParam.Builder(collectionName).withFloatVectors(features).build(); InsertResponse insertResponse = client.insert(insertParam); client.flush(collectionName); boolean flag = insertResponse.ok(); if(!flag){ System.out.println("插入失败"); return null; } List<Long> vectorIds = insertResponse.getVectorIds(); return vectorIds; } /** * 查询相似的特诊向量 * @param vectorsToSearch * @param topK * @return */ public SearchResponse searchFeature(List<List<Float>> vectorsToSearch, long topK, int nprobe){ //先判断是否正常连接 boolean connected = client.isConnected(); if(!connected){ System.out.println("连接失败!!"); } JsonObject indexParamsJson = new JsonObject(); indexParamsJson.addProperty("nprobe", nprobe); //nprobe代表选择最近的多少个聚类去比较。 SearchParam searchParam =new SearchParam.Builder(collectionName).withFloatVectors(vectorsToSearch) .withParamsInJson(indexParamsJson.toString()) .withTopK(topK) .build(); SearchResponse searchResponse = client.search(searchParam); return searchResponse; } public SearchResponse searchFeature(List<List<Float>> vectorsToSearch, long topK){ return searchFeature(vectorsToSearch, topK, 15); } /** * 创建数据库表 * @param collectionName 表的名称 * @param dimension 向量维度 * @param indexFileSize 单个文件的大小值 * @param metricType * @return */ private Response createCollect(String collectionName, int dimension, int indexFileSize,MetricType metricType){ CollectionMapping collectionMapping = new CollectionMapping.Builder(collectionName, dimension) .withIndexFileSize(indexFileSize) .withMetricType(metricType) .build(); Response response = client.createCollection(collectionMapping); return response; } // 创建索引,指定聚类数 private boolean createIndex(int nlist){ final IndexType indexType = IndexType.IVF_SQ8; JsonObject indexParamsJson = new JsonObject(); indexParamsJson.addProperty("nlist", nlist); //nlist代表聚类数,根据数据量多少设置 Index index = new Index.Builder(collectionName, indexType) .withParamsInJson(indexParamsJson.toString()) .build(); Response createIndexResponse = client.createIndex(index); return createIndexResponse.ok(); } } |
6、其他处理
接下来就是细枝末节上的处理,
我是前端上传图片时保存到本地,然后发送图片路径给python端提取特征返回Java端。
Java端有了特征向量就调用milvus的方法获取最近的topK条特征向量的ID,再根据ID查询数据库获取图片的路径,然后展示到前端。
导入VOC2007的数据集,5000张图片,传入一张猫的图片,搜索结果还是比较满意的。
当然,项目只是简单搭建完成,,精度上还有待优化,参数的调整还有待优化,,项目中还有大量bug没处理,Python端BIO的方式问题太多,batchSize只为1并发度太低,,
有时间再完整的搭建一下项目。