faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-7

news/2024/12/26 4:01:22 标签: faiss, android

流程

在这里插入图片描述

代码

void IndexIVF::search(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels,
        const SearchParameters* params_in) const {
    FAISS_THROW_IF_NOT(k > 0);
    const IVFSearchParameters* params = nullptr;
    if (params_in) {
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
    }
    const size_t nprobe =
            std::min(nlist, params ? params->nprobe : this->nprobe);
    FAISS_THROW_IF_NOT(nprobe > 0);

    // search function for a subset of queries
    auto sub_search_func = [this, k, nprobe, params](
                                   idx_t n,
                                   const float* x,
                                   float* distances,
                                   idx_t* labels,
                                   IndexIVFStats* ivf_stats) {
        std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
        std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

        double t0 = getmillisecs();
        quantizer->search(
                n,
                x,
                nprobe,
                coarse_dis.get(),
                idx.get(),
                params ? params->quantizer_params : nullptr);

        double t1 = getmillisecs();
        invlists->prefetch_lists(idx.get(), n * nprobe);

        search_preassigned(
                n,
                x,
                k,
                idx.get(),
                coarse_dis.get(),
                distances,
                labels,
                false,
                params,
                ivf_stats);
        double t2 = getmillisecs();
        ivf_stats->quantization_time += t1 - t0;
        ivf_stats->search_time += t2 - t0;
    };

    if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
        int nt = std::min(omp_get_max_threads(), int(n));
        std::vector<IndexIVFStats> stats(nt);
        std::mutex exception_mutex;
        std::string exception_string;

#pragma omp parallel for if (nt > 1)
        for (idx_t slice = 0; slice < nt; slice++) {
            IndexIVFStats local_stats;
            idx_t i0 = n * slice / nt;
            idx_t i1 = n * (slice + 1) / nt;
            if (i1 > i0) {
                try {
                    sub_search_func(
                            i1 - i0,
                            x + i0 * d,
                            distances + i0 * k,
                            labels + i0 * k,
                            &stats[slice]);
                } catch (const std::exception& e) {
                    std::lock_guard<std::mutex> lock(exception_mutex);
                    exception_string = e.what();
                }
            }
        }

        if (!exception_string.empty()) {
            FAISS_THROW_MSG(exception_string.c_str());
        }

        // collect stats
        for (idx_t slice = 0; slice < nt; slice++) {
            indexIVF_stats.add(stats[slice]);
        }
    } else {
        // handle paralellization at level below (or don't run in parallel at
        // all)
        sub_search_func(n, x, distances, labels, &indexIVF_stats);
    }
}

代码解析

IndexIVF::search 函数是 FAISS 的 IndexIVF 类中实现的一个核心函数,用于在倒排文件(Inverted File List, IVF)索引中执行搜索操作。以下是对函数的详细解析:

函数功能

在倒排文件索引中搜索最近的 k 个向量,返回它们的距离和对应的索引。
支持多线程并行化以提高查询性能。

参数说明

void IndexIVF::search(
    idx_t n,                    // 查询向量的数量
    const float* x,             // 查询向量(每个向量有 d 个维度)
    idx_t k,                    // 每个查询向量要找到的最近邻个数
    float* distances,           // 输出的距离数组,大小为 n*k
    idx_t* labels,              // 输出的索引数组,大小为 n*k
    const SearchParameters* params_in // 搜索参数,可选
) const;
  • n:查询向量的数量。
  • x:指向查询向量的指针,形状为 (n, d)。
  • k:每个查询向量需要返回的最近邻数量。
  • distances:保存结果的距离数组。
  • labels:保存结果的索引数组。
  • params_in:可选的搜索参数对象,通常包括 nprobe(控制搜索的倒排列表数量)等。
函数实现解析
  1. 参数验证
FAISS_THROW_IF_NOT(k > 0);

确保 k > 0,即需要找到至少一个最近邻。
2. 处理搜索参数

const IVFSearchParameters* params = nullptr;
if (params_in) {
    params = dynamic_cast<const IVFSearchParameters*>(params_in);
    FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
}
const size_t nprobe = std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);
  • 检查输入的搜索参数 params_in 是否是 IVFSearchParameters 类型。
  • 从参数中提取 nprobe,即查询时访问的倒排列表数量:
    • 如果参数提供了 nprobe,则使用参数中的值。
    • 如果未提供,则使用索引默认的 nprobe。
  • 确保 nprobe > 0。
  1. 定义子搜索函数
auto sub_search_func = [this, k, nprobe, params](
    idx_t n,
    const float* x,
    float* distances,
    idx_t* labels,
    IndexIVFStats* ivf_stats) {
    ...
};

定义一个局部 lambda 函数 sub_search_func,处理子查询任务。参数包括当前的查询向量、结果存储位置和统计信息。
内部实现的步骤:
量化查询向量:

quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get(), params ? params->quantizer_params : nullptr);

使用量化器将查询向量分配到 nprobe 个倒排列表中。idx 保存分配的倒排列表索引。coarse_dis 保存量化后的距离。
倒排列表的预取:

invlists->prefetch_lists(idx.get(), n * nprobe);

预取倒排列表数据以提高内存访问性能。
实际搜索:

search_preassigned(n, x, k, idx.get(), coarse_dis.get(), distances, labels, false, params, ivf_stats);

在分配好的倒排列表中执行搜索,返回最近邻结果的距离和索引。
更新统计信息:

ivf_stats->quantization_time += t1 - t0;
ivf_stats->search_time += t2 - t0;

记录量化时间和搜索时间。
4. 选择并行模式

if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
    ...
} else {
    sub_search_func(n, x, distances, labels, &indexIVF_stats);
}

根据 parallel_mode 决定并行模式:如果启用了并行模式,则使用 OpenMP 进行多线程查询。否则直接调用 sub_search_func 处理整个查询。
5. 并行查询

int nt = std::min(omp_get_max_threads(), int(n));
std::vector<IndexIVFStats> stats(nt);
std::mutex exception_mutex;
std::string exception_string;

#pragma omp parallel for if (nt > 1)
for (idx_t slice = 0; slice < nt; slice++) {
    ...
}
  • 线程数量:设置线程数量为查询向量数和最大线程数的较小值。
  • 分片查询:将查询向量分配到多个线程进行并行处理。
  • 异常处理:捕获并记录线程中的异常。
  • 统计合并:将各线程的统计结果合并到全局统计对象。
关键步骤总结
  • 查询向量量化:使用量化器将查询向量映射到倒排列表。
  • 倒排列表预取:优化内存访问以提高效率。
  • 倒排列表搜索:在分配的倒排列表中执行精确搜索。
  • 支持并行化:利用 OpenMP 将查询任务分片并行化处理。
函数作用
  • 高效搜索:支持通过 nprobe 调整查询范围,平衡搜索速度和准确率。
  • 并行优化:通过多线程实现大规模查询的加速。
  • 灵活性:支持自定义搜索参数(如量化器配置)以适应不同场景。
适用场景

海量数据的最近邻搜索,例如向量化文档、推荐系统和图像检索


http://www.niftyadmin.cn/n/5799771.html

相关文章

Redis--通用命令学习

目录 一、引言 二、基础命令 1.set 2.get 3.keys 3.1 keys &#xff1f; 3.2 keys * 3.3 keys [abe] 3.4 keys [^] 3.5 keys [a-b] 4.exists 5.delete 6.expire 7.ttl 8.type 三、Redis中的过期策略&#xff08;面试题&#xff09; 1.惰性删除 2.定期删除 …

C++的侵入式链表

非侵入式链表 非侵入式链表是一种链表数据结构&#xff0c;其中每个元素&#xff08;节点&#xff09;并不需要自己包含指向前后节点的指针。链表的结构和节点的存储是分开的&#xff0c;链表容器会单独管理这些指针。 常见的非侵入式链表节点可以由以下所示&#xff0c;即&a…

C语言的预处理器

C语言的预处理器是C语言编译器的一个组成部分&#xff0c;它在编译代码之前对源代码进行文本替换和条件编译等操作。预处理器指令以#字符开头&#xff0c;它们不是C语言的正式语法部分&#xff0c;但它们在代码生成过程中起着非常重要的作用。 C语言预处理器的功能 宏定义&am…

云手机与Temu矩阵:跨境电商运营新引擎

云手机与 Temu 矩阵结合的基础 云手机技术原理 云手机基于先进的 ARM 虚拟化技术&#xff0c;在服务器端运行 APP。通过在服务器上利用容器虚拟化软件技术&#xff0c;能够虚拟出多个独立的手机操作系统实例&#xff0c;每个实例等同于一部单独的手机&#xff0c;可独立运行各…

【Linux】虚拟机扩展磁盘

文章目录 1、扩展虚拟机磁盘2、扩展卷3、扩展文件系统磁盘存储知识扩展在 CentOS 上扩展磁盘空间并将其应用到根目录(/)的过程通常包括以下几个步骤: 1、扩展虚拟机磁盘 扩展前使用lsblk命令 sda 8:0 0 4G 0 disk ├─sda1 8:1 0 30…

【201】进销存管理系统

--基于springboot的图书进销存管理系统 本图书进销存管理系统管理员功能有: 个人中心&#xff0c;用户管理&#xff0c;图书类型管理&#xff0c;进货订单管理&#xff0c;商品退货管理&#xff0c;批销订单管理&#xff0c;图书信息管理&#xff0c;客户信息管理&#xff0c;供…

VMware虚拟机三种网络工作模式

vmware为我们提供了三种网络工作模式,它们分别是:Bridged(桥接模式)、NAT(网络地址转换模式)、Host-Only(仅主机模式)。 打开vmware虚拟机,我们可以在选项栏的“编辑”下的“虚拟网络编辑器”中看到VMnet0(桥接模式)、VMnet1(仅主机模式)、VMnet8(NAT模式),那…

WEB:如何在 Vue 中同步数据的技术指南

1、简述 在 Vue 应用中&#xff0c;数据同步是保持用户界面和数据状态一致的关键。当多个组件或页面共享同一个状态时&#xff0c;如何保证它们的同步更新变得尤为重要。本篇博客将介绍 Vue 中几种常见的数据同步技术&#xff0c;包括 v-model 双向绑定、事件总线、Vuex 状态管…