流程
代码
void IndexIVF::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);
const IVFSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IVFSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
}
const size_t nprobe =
std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);
// search function for a subset of queries
auto sub_search_func = [this, k, nprobe, params](
idx_t n,
const float* x,
float* distances,
idx_t* labels,
IndexIVFStats* ivf_stats) {
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
double t0 = getmillisecs();
quantizer->search(
n,
x,
nprobe,
coarse_dis.get(),
idx.get(),
params ? params->quantizer_params : nullptr);
double t1 = getmillisecs();
invlists->prefetch_lists(idx.get(), n * nprobe);
search_preassigned(
n,
x,
k,
idx.get(),
coarse_dis.get(),
distances,
labels,
false,
params,
ivf_stats);
double t2 = getmillisecs();
ivf_stats->quantization_time += t1 - t0;
ivf_stats->search_time += t2 - t0;
};
if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
int nt = std::min(omp_get_max_threads(), int(n));
std::vector<IndexIVFStats> stats(nt);
std::mutex exception_mutex;
std::string exception_string;
#pragma omp parallel for if (nt > 1)
for (idx_t slice = 0; slice < nt; slice++) {
IndexIVFStats local_stats;
idx_t i0 = n * slice / nt;
idx_t i1 = n * (slice + 1) / nt;
if (i1 > i0) {
try {
sub_search_func(
i1 - i0,
x + i0 * d,
distances + i0 * k,
labels + i0 * k,
&stats[slice]);
} catch (const std::exception& e) {
std::lock_guard<std::mutex> lock(exception_mutex);
exception_string = e.what();
}
}
}
if (!exception_string.empty()) {
FAISS_THROW_MSG(exception_string.c_str());
}
// collect stats
for (idx_t slice = 0; slice < nt; slice++) {
indexIVF_stats.add(stats[slice]);
}
} else {
// handle paralellization at level below (or don't run in parallel at
// all)
sub_search_func(n, x, distances, labels, &indexIVF_stats);
}
}
代码解析
IndexIVF::search 函数是 FAISS 的 IndexIVF 类中实现的一个核心函数,用于在倒排文件(Inverted File List, IVF)索引中执行搜索操作。以下是对函数的详细解析:
函数功能
在倒排文件索引中搜索最近的 k 个向量,返回它们的距离和对应的索引。
支持多线程并行化以提高查询性能。
参数说明
void IndexIVF::search(
idx_t n, // 查询向量的数量
const float* x, // 查询向量(每个向量有 d 个维度)
idx_t k, // 每个查询向量要找到的最近邻个数
float* distances, // 输出的距离数组,大小为 n*k
idx_t* labels, // 输出的索引数组,大小为 n*k
const SearchParameters* params_in // 搜索参数,可选
) const;
- n:查询向量的数量。
- x:指向查询向量的指针,形状为 (n, d)。
- k:每个查询向量需要返回的最近邻数量。
- distances:保存结果的距离数组。
- labels:保存结果的索引数组。
- params_in:可选的搜索参数对象,通常包括 nprobe(控制搜索的倒排列表数量)等。
函数实现解析
- 参数验证
FAISS_THROW_IF_NOT(k > 0);
确保 k > 0,即需要找到至少一个最近邻。
2. 处理搜索参数
const IVFSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IVFSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
}
const size_t nprobe = std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);
- 检查输入的搜索参数 params_in 是否是 IVFSearchParameters 类型。
- 从参数中提取 nprobe,即查询时访问的倒排列表数量:
- 如果参数提供了 nprobe,则使用参数中的值。
- 如果未提供,则使用索引默认的 nprobe。
- 确保 nprobe > 0。
- 定义子搜索函数
auto sub_search_func = [this, k, nprobe, params](
idx_t n,
const float* x,
float* distances,
idx_t* labels,
IndexIVFStats* ivf_stats) {
...
};
定义一个局部 lambda 函数 sub_search_func,处理子查询任务。参数包括当前的查询向量、结果存储位置和统计信息。
内部实现的步骤:
量化查询向量:
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get(), params ? params->quantizer_params : nullptr);
使用量化器将查询向量分配到 nprobe 个倒排列表中。idx 保存分配的倒排列表索引。coarse_dis 保存量化后的距离。
倒排列表的预取:
invlists->prefetch_lists(idx.get(), n * nprobe);
预取倒排列表数据以提高内存访问性能。
实际搜索:
search_preassigned(n, x, k, idx.get(), coarse_dis.get(), distances, labels, false, params, ivf_stats);
在分配好的倒排列表中执行搜索,返回最近邻结果的距离和索引。
更新统计信息:
ivf_stats->quantization_time += t1 - t0;
ivf_stats->search_time += t2 - t0;
记录量化时间和搜索时间。
4. 选择并行模式
if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
...
} else {
sub_search_func(n, x, distances, labels, &indexIVF_stats);
}
根据 parallel_mode 决定并行模式:如果启用了并行模式,则使用 OpenMP 进行多线程查询。否则直接调用 sub_search_func 处理整个查询。
5. 并行查询
int nt = std::min(omp_get_max_threads(), int(n));
std::vector<IndexIVFStats> stats(nt);
std::mutex exception_mutex;
std::string exception_string;
#pragma omp parallel for if (nt > 1)
for (idx_t slice = 0; slice < nt; slice++) {
...
}
- 线程数量:设置线程数量为查询向量数和最大线程数的较小值。
- 分片查询:将查询向量分配到多个线程进行并行处理。
- 异常处理:捕获并记录线程中的异常。
- 统计合并:将各线程的统计结果合并到全局统计对象。
关键步骤总结
- 查询向量量化:使用量化器将查询向量映射到倒排列表。
- 倒排列表预取:优化内存访问以提高效率。
- 倒排列表搜索:在分配的倒排列表中执行精确搜索。
- 支持并行化:利用 OpenMP 将查询任务分片并行化处理。
函数作用
- 高效搜索:支持通过 nprobe 调整查询范围,平衡搜索速度和准确率。
- 并行优化:通过多线程实现大规模查询的加速。
- 灵活性:支持自定义搜索参数(如量化器配置)以适应不同场景。
适用场景
海量数据的最近邻搜索,例如向量化文档、推荐系统和图像检索