BERT与C++结合:打造强大的语义搜索系统

云信安装大师
90
AI 质量分
10 5 月, 2025
5 分钟阅读
0 阅读

BERT与C++结合:打造强大的语义搜索系统

引言

在当今信息爆炸的时代,传统的基于关键词的搜索已经无法满足用户对精准语义匹配的需求。BERT作为Google推出的革命性自然语言处理模型,能够理解词语的上下文含义。本文将带你了解如何将BERT的强大语义理解能力与C++的高效性能相结合,构建一个强大的语义搜索系统。

准备工作

环境要求

  • C++17或更高版本
  • Python 3.7+ (用于BERT模型处理)
  • CMake 3.12+
  • PyTorch 1.8+
  • Transformers库 (Hugging Face)

前置知识

  • 基本C++编程能力
  • Python基础
  • 对神经网络有基本了解

系统架构设计

我们的语义搜索系统将分为两个主要部分:

  1. Python服务:负责加载BERT模型并生成文本的嵌入向量
  2. C++核心:处理向量索引、快速相似度计算和搜索结果排序
代码片段
[用户查询] -> [Python BERT服务] -> [向量嵌入] -> [C++搜索核心] -> [排序结果]

详细实现步骤

步骤1:设置Python BERT服务

首先我们创建一个简单的Flask服务来提供BERT嵌入功能:

代码片段
# bert_server.py
from flask import Flask, request, jsonify
from transformers import BertModel, BertTokenizer
import torch
import numpy as np

app = Flask(__name__)

# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

@app.route('/embed', methods=['POST'])
def embed_text():
    text = request.json['text']

    # Tokenize输入文本
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)

    # 获取BERT嵌入
    with torch.no_grad():
        outputs = model(**inputs)

    # 使用[CLS]标记的隐藏状态作为句子表示
    embedding = outputs.last_hidden_state[:, 0, :].numpy()

    return jsonify({'embedding': embedding.tolist()[0]})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

代码说明
1. 使用Hugging Face的Transformers库加载预训练的BERT模型
2. /embed端点接收文本并返回其768维的嵌入向量
3. [CLS]标记的隐藏状态被用作整个句子的表示

启动服务

代码片段
python bert_server.py

步骤2:C++客户端实现

我们将使用libcurl与Python服务通信,并实现一个简单的最近邻搜索:

代码片段
// search_engine.h
#pragma once

#include <vector>
#include <string>
#include <map>

class SemanticSearchEngine {
public:
    // 添加文档到搜索索引
    void addDocument(const std::string& doc_id, const std::vector<float>& embedding);

    // 查询最相似的文档
    std::vector<std::pair<std::string, float>> query(const std::vector<float>& query_embedding, int top_k=5);

private:
    std::map<std::string, std::vector<float>> document_embeddings;

    // 计算余弦相似度
    float cosineSimilarity(const std::vector<float>& a, const std::vector<float>& b);
};
代码片段
// search_engine.cpp
#include "search_engine.h"
#include <cmath>
#include <algorithm>

void SemanticSearchEngine::addDocument(const std::string& doc_id, const std::vector<float>& embedding) {
    document_embeddings[doc_id] = embedding;
}

std::vector<std::pair<std::string, float>> SemanticSearchEngine::query(
    const std::vector<float>& query_embedding, int top_k) {

    std::vector<std::pair<std::string, float>> results;

    for (const auto& [doc_id, doc_embedding] : document_embeddings) {
        float similarity = cosineSimilarity(query_embedding, doc_embedding);
        results.emplace_back(doc_id, similarity);
    }

    // 按相似度降序排序
    std::sort(results.begin(), results.end(), 
        [](const auto& a, const auto& b) { return a.second > b.second; });

    // 返回top_k结果
    if (results.size() > top_k) {
        results.resize(top_k);
    }

    return results;
}

float SemanticSearchEngine::cosineSimilarity(const std::vector<float>& a, const std::vector<float>& b) {
    if (a.size() != b.size()) return -1.0f;

    float dot_product = 0.0f;
    float norm_a = 0.0f;
    float norm_b = 0.0f;

    for (size_t i = 0; i < a.size(); ++i) {
        dot_product += a[i] * b[i];
        norm_a += a[i] * a[i];
        norm_b += b[i] * b[i];
    }

    norm_a = sqrt(norm_a);
    norm_b = sqrt(norm_b);

    if (norm_a == 0 || norm_b == 0) return -1.0f;

    return dot_product / (norm_a * norm_b);
}

步骤3:集成HTTP客户端

使用libcurl与Python BERT服务通信:

代码片段
// http_client.h
#pragma once

#include <string>
#include <vector>

class HttpClient {
public:
    static std::vector<float> getEmbedding(const std::string& text);
};
代码片段
// http_client.cpp
#include "http_client.h"
#include <curl/curl.h>
#include <nlohmann/json.hpp>
#include <sstream>

using json = nlohmann::json;

static size_t WriteCallback(void* contents, size_t size, size_t nmemb, void* userp) {
    ((std::string*)userp)->append((char*)contents, size * nmemb);
    return size * nmemb;
}

std::vector<float> HttpClient::getEmbedding(const std::string& text) {
    CURL* curl;
    CURLcode res;
    std::string readBuffer;

    curl = curl_easy_init();
    if(curl) {
        json request_body;
        request_body["text"] = text;

        struct curl_slist* headers = NULL;
        headers = curl_slist_append(headers, "Content-Type: application/json");

        curl_easy_setopt(curl, CURLOPT_URL, "http://localhost:5000/embed");
        curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_body.dump().c_str());
        curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
        curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
        curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);

        res = curl_easy_perform(curl);

        if(res != CURLE_OK) {
            fprintf(stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
            return {};
        }

        auto response_json = json::parse(readBuffer);

        std::vector<float> embedding;
        for (auto& val : response_json["embedding"]) {
            embedding.push_back(val.get<float>());
        }

        return embedding;

    }
    return {};
}

CMake配置

代码片段
cmake_minimum_required(VERSION 3.12)
project(SemanticSearch)

set(CMAKE_CXX_STANDARD 17)

find_package(CURL REQUIRED)
find_package(nlohmann_json REQUIRED)

add_executable(semantic_search 
    src/main.cpp 
    src/search_engine.cpp 
    src/http_client.cpp)

target_link_libraries(semantic_search PRIVATE 
    CURL::libcurl 
    nlohmann_json)

main函数示例

代码片段
// main.cpp示例用法:
#include "search_engine.h"
#include "http_client.h"
#include <iostream>

int main() {    
    std::cout << "Building semantic search index..." << std:endl;

    std:SemanticSearchEngine engine;

    std:cout << "Adding documents..." << endl;

    std:document1_text ="The quick brown fox jumps over the lazy dog";
    std:document2_text ="Artificial intelligence is transforming industries";
    std:document3_text ="Deep learning models require large amounts of data";

    auto embed_doc1 HttpClient.getEmbedding(document1_text); 
    auto embed_doc2 HttpClient.getEmbedding(document2_text); 
    auto embed_doc3 HttpClient.getEmbedding(document3_text);

    if(embed_doc1.empty() || embed_doc2.empty() || embed_doc3.empty()) { 
        cerr <<"Failed to get embeddings from BERT service"<< endl; 
        return -1; }

   engine.addDocument("doc1", embed_doc1);  
   engine.addDocument("doc2", embed_doc2);  
   engine.addDocument("doc3", embed_doc3);

   while(true){  
       cout<<"Enter your search query(or 'exit' to quit):";  
       string query;  
       getline(cin query);

       if(query=="exit") break;

       auto query_embed=HttpClient.getEmbedding(query);  
       if(query_embed.empty()){  
           cerr<<"Failed to get query embedding"<<endl;  
           continue;}  

       auto results=engine.query(query_embed);  

       cout<<"Top "<<results.size()<<" results:"<<endl;  
       for(const auto&[doc_id score]:results){  
           cout<<doc_id<<": score="<<score<<endl;} }  

   return EXIT_SUCCESS;}

性能优化建议(实践经验)

1.批处理请求:修改Python服务以支持一次处理多个文本,减少HTTP往返次数。

2.近似最近邻搜索:对于大规模数据集,考虑使用FAISS或Annoy等库替代暴力搜索。

3.缓存机制:缓存常见查询的结果以提高响应速度。

4.多线程:在C++端使用多线程处理并发查询。

5.量化:考虑对BERT模型进行量化以减少内存占用和提高推理速度。

常见问题解决(注意事项)

Q:BERT服务响应慢怎么办?
A:-尝试更小的BERT变体如DistilBERT或TinyBERT。
-使用GPU加速推理。
-启用动态批处理。

Q:如何处理长文本?
A:-分段处理然后合并嵌入。
-使用专门处理长文本的模型如Longformer。

Q:如何扩展到大规规模文档集?
A:-实现分片索引。
-考虑专门的向量数据库如Milvus或Weaviate。

总结

通过本文我们实现了:
1.Python端的BERT嵌入服务搭建。
2.C++端的语义搜索引擎核心功能。
3.HTTP通信集成方案。
4.完整的示例代码和构建配置。

这种架构结合了Python在深度学习方面的便利性和C++的高效性,非常适合需要高性能语义搜索的场景。你可以在此基础上进一步扩展功能,如添加更复杂的排序算法、支持多语言或集成到现有系统中。

原创 高质量