MUVERA:让RAG系统中的多向量检索像单向量一样高效

向量数据库机器学习推荐算法

在向量数据库和信息检索领域,多向量嵌入模型(如 ColBERT、ColPali)凭借其强大的语义捕获能力正在成为主流选择。这类模型能够保留文本的词元级别含义,或是识别图像不同部分的信息特征。然而,它们也带来了显著的性能挑战:庞大的内存占用和较慢的检索速度。Weaviate 在 1.31 版本中引入的 MUVERA 编码算法,正是为了解决这些问题而生。

多向量模型的优势与困境

多向量嵌入的核心优势在于其细粒度的语义表达能力。相比单向量模型将整个文档压缩成一个固定长度的向量,多向量模型为文档的每个词元或图像块生成独立的向量表示。这种设计使得模型能够捕捉更丰富的语义信息,在检索任务中展现出更高的准确性。

picture.image

单向量与多向量对比

但这种精细化表示的代价同样明显。假设要索引一百万个文档,每个文档平均包含 100 个词元。使用传统的单向量模型(768 维,32 位浮点数),大约需要 3.1GB 内存。而多向量模型(96 维)的内存消耗可能高达 40GB,超过十倍的差距。这种内存压力在大规模部署场景下会转化为实实在在的成本负担。

picture.image

多向量嵌入内存对比

性能瓶颈不仅体现在存储层面。在检索阶段,多向量模型需要使用 MaxSim 运算符计算相似度。这个过程需要遍历查询的每个词元,找出它与文档所有词元中的最佳匹配,然后累加所有匹配得分。数学表达式如下:

这种非线性计算相比简单的点积运算复杂得多,直接影响了查询响应速度和数据导入效率。

picture.image

单向量和多向量内存使用情况

MUVERA 的核心思想

MUVERA(Multi-Vector Retrieval via Fixed Dimensional Encodings)的设计哲学是将复杂的多向量检索问题转化为单向量最大内积搜索。算法的关键在于构建固定维度编码(FDE),将一组长度不定的向量集合压缩成单个固定长度的向量表示。

整个转换过程可以用一个简洁的映射函数表示:

这里的核心目标是让编码后的单向量点积能够很好地近似原始多向量的 MaxSim 相似度:

picture.image

MUVERA 高层概览

这种转换带来的效率提升是显著的。对于包含 100 万个文档、每个文档 100 个向量的数据集,传统方案需要索引 1 亿个向量,而 MUVERA 只需处理 100 万个 FDE 向量,将 HNSW 图的规模缩减到原来的 1%。

算法实现细节

MUVERA 通过四个精心设计的步骤完成编码转换:空间划分、降维、重复增强和最终投影。每个步骤都有明确的数学基础和实际考量。

空间划分策略

第一步是将高维向量空间划分成若干个桶。算法采用 SimHash 技术实现这一过程,这是一种基于局部敏感哈希的方法。具体来说,算法会采样

个高斯向量,然后通过计算输入向量与这些高斯向量的点积符号来确定桶编号:

这种划分方式的优势在于其与数据分布无关,不需要预先训练,也不会因为数据漂移而失效。划分完成后,属于同一个桶的向量会被聚合成一个代表性向量。

picture.image

MUVERA 步骤 1 - 空间划分

picture.image

MUVERA 步骤 2 - 填充空簇

对于文档编码,每个桶的子向量计算方式为:

而查询编码则直接对属于同一桶的向量求和:

这种不对称处理恰好对应了 MaxSim 运算的特性。

降维与重复

空间划分后得到的向量维度是

,其中

是桶的数量,

是原始向量维度。为了进一步压缩表示,MUVERA 使用随机投影矩阵进行降维:

这里的随机矩阵

元素取值为

,遵循 Johnson-Lindenstrauss 引理,能够在降维的同时保持向量间点积的相对关系。

picture.image

MUVERA 步骤 3 - 降维

为了提高编码的鲁棒性,算法会重复执行空间划分和降维步骤

次,将得到的多个编码向量拼接起来。最终的 FDE 维度为

性能评测与实际效果

Weaviate 团队使用 LoTTE 基准测试数据集进行了详细的性能评估。该数据集包含约 11.9 万个文档,使用 ColBERT v2.0 编码后生成了 1500 万个 128 维向量,总内存占用约 8GB。

启用 MUVERA 后(参数设置为

),每个文档被编码为 2560 维的单一向量。这使得总浮点数存储量从 19 亿降至 3.04 亿,内存节省接近 80%。更重要的是,HNSW 图的节点数从 1500 万降至 11.9 万,这对于图遍历效率的提升是质的飞跃。

picture.image

未使用 MUVERA + SQ 与 MUVERA + SQ 时的堆内存分配

数据导入速度的改善同样显著。基准场景下,导入 11 万个对象需要 20 多分钟,相当于每秒只能处理约 100 个对象。而使用 MUVERA 后,这个时间缩短到 3-6 分钟。对于需要频繁更新索引的生产环境,这种效率提升意义重大。

性能权衡考量

技术方案从来不是完美的,MUVERA 也有其代价。最主要的妥协体现在召回率上。测试数据显示,在相同的搜索参数下,启用 MUVERA 会导致召回率下降。不过,这个问题可以通过调整 HNSW 的 ef 参数来缓解。

ef 值设置在 512 以上时,召回率可以恢复到 80% 以上;而在 2048 时甚至能超过 90%。但提高 ef 值意味着要检索更多的候选集,这会降低查询吞吐量。因此,实际应用中需要在召回质量和查询速度之间找到平衡点。

picture.image

MUVERA 对比

Google Research 团队的实验结果进一步验证了 MUVERA 的效果。在 BEIR 基准测试中,相比基于单向量启发式的 PLAID 系统,MUVERA 在召回率平均提升 10% 的同时,将延迟降低了 90%。这种性能提升在大规模部署中的价值不言而喻。

适用场景分析

MUVERA 并非万能方案,它最适合以下几类应用场景。首先是内存成本敏感的大规模部署。当数据集规模达到千万甚至亿级时,内存占用的降低可以直接转化为每年数万甚至数十万美元的成本节约。其次是对索引速度有较高要求的场景,比如需要频繁更新的实时系统。

另一个重要考量是对召回质量的容忍度。如果应用场景对检索精度有极致要求,那么需要仔细权衡 MUVERA 带来的召回率下降是否可以接受。不过对于许多实际应用来说,轻微的召回损失往往是可以承受的,特别是考虑到可以通过调整搜索参数来部分恢复性能。

从实现角度看,Weaviate 的集成使得启用 MUVERA 变得非常简单,只需要几行配置代码。用户可以设置的主要参数包括 k\_sim(空间划分的细粒度)、d\_proj(降维后的维度)和 r\_reps(重复次数)。Weaviate 团队为这些参数提供了合理的默认值,大多数场景下可以直接使用。

值得注意的是,MUVERA 的固定维度编码还可以结合标量量化(Scalar Quantization)等技术进一步压缩。Google 的研究表明,通过乘积量化可以在几乎不影响检索质量的前提下,将内存占用再减少 32 倍。这为超大规模应用提供了更多优化空间。

实现

picture.image

https://github.com/sionic-ai/muvera-py/tree/master

我在github上面找到一个MUVERA的python实现,大家可以尝试一下


 
 
 
 
   
import logging  
import time  
  
import numpy as np  
from dataclasses import dataclass, replace  
from enum import Enum  
from typing import Optional, List  
  
  
class EncodingType(Enum):  
    DEFAULT\_SUM = 0  
    AVERAGE = 1  
  
  
class ProjectionType(Enum):  
    DEFAULT\_IDENTITY = 0  
    AMS\_SKETCH = 1  
  
  
@dataclass  
class FixedDimensionalEncodingConfig:  
    dimension: int = 128  
    num\_repetitions: int = 10  
    num\_simhash\_projections: int = 6  
    seed: int = 42  
    encoding\_type: EncodingType = EncodingType.DEFAULT\_SUM  
    projection\_type: ProjectionType = ProjectionType.DEFAULT\_IDENTITY  
    projection\_dimension: Optional[int] = None  
    fill\_empty\_partitions: bool = False  
    final\_projection\_dimension: Optional[int] = None  
  
  
def \_append\_to\_gray\_code(gray\_code: int, bit: bool) -> int:  
    return (gray\_code << 1) + (int(bit) ^ (gray\_code & 1))  
  
  
def \_gray\_code\_to\_binary(num: int) -> int:  
    mask = num >> 1  
    while mask != 0:  
        num = num ^ mask  
        mask >>= 1  
    return num  
  
  
def \_simhash\_matrix\_from\_seed(  
    dimension: int, num\_projections: int, seed: int  
) -> np.ndarray:  
    rng = np.random.default\_rng(seed)  
    return rng.normal(loc=0.0, scale=1.0, size=(dimension, num\_projections)).astype(  
        np.float32  
    )  
  
  
def \_ams\_projection\_matrix\_from\_seed(  
    dimension: int, projection\_dim: int, seed: int  
) -> np.ndarray:  
    rng = np.random.default\_rng(seed)  
    out = np.zeros((dimension, projection\_dim), dtype=np.float32)  
    indices = rng.integers(0, projection\_dim, size=dimension)  
    signs = rng.choice([-1.0, 1.0], size=dimension)  
    out[np.arange(dimension), indices] = signs  
    return out  
  
  
def \_apply\_count\_sketch\_to\_vector(  
    input\_vector: np.ndarray, final\_dimension: int, seed: int  
) -> np.ndarray:  
    rng = np.random.default\_rng(seed)  
    out = np.zeros(final\_dimension, dtype=np.float32)  
    indices = rng.integers(0, final\_dimension, size=input\_vector.shape[0])  
    signs = rng.choice([-1.0, 1.0], size=input\_vector.shape[0])  
    np.add.at(out, indices, signs * input\_vector)  
    return out  
  
  
def \_simhash\_partition\_index\_gray(sketch\_vector: np.ndarray) -> int:  
    partition\_index = 0  
    for val in sketch\_vector:  
        partition\_index = \_append\_to\_gray\_code(partition\_index, val > 0)  
    return partition\_index  
  
  
def \_distance\_to\_simhash\_partition(  
    sketch\_vector: np.ndarray, partition\_index: int  
) -> int:  
    num\_projections = sketch\_vector.size  
    binary\_representation = \_gray\_code\_to\_binary(partition\_index)  
    sketch\_bits = (sketch\_vector > 0).astype(int)  
    binary\_array = (binary\_representation >> np.arange(num\_projections - 1, -1, -1)) & 1  
    return int(np.sum(sketch\_bits != binary\_array))  
  
  
def \_generate\_fde\_internal(  
    point\_cloud: np.ndarray, config: FixedDimensionalEncodingConfig  
) -> np.ndarray:  
    if point\_cloud.ndim != 2 or point\_cloud.shape[1] != config.dimension:  
        raise ValueError(  
            f"Input data shape {point\_cloud.shape} is inconsistent with config dimension {config.dimension}."  
        )  
    if not (0 <= config.num\_simhash\_projections < 32):  
        raise ValueError(  
            f"num\_simhash\_projections must be in [0, 31]: {config.num\_simhash\_projections}"  
        )  
  
    num\_points, original\_dim = point\_cloud.shape  
    num\_partitions = 2**config.num\_simhash\_projections  
  
    use\_identity\_proj = config.projection\_type == ProjectionType.DEFAULT\_IDENTITY  
    projection\_dim = original\_dim if use\_identity\_proj else config.projection\_dimension  
    if not use\_identity\_proj and (not projection\_dim or projection\_dim <= 0):  
        raise ValueError(  
            "A positive projection\_dimension is required for non-identity projections."  
        )  
  
    final\_fde\_dim = config.num\_repetitions * num\_partitions * projection\_dim  
    out\_fde = np.zeros(final\_fde\_dim, dtype=np.float32)  
  
    for rep\_num in range(config.num\_repetitions):  
        current\_seed = config.seed + rep\_num  
  
        sketches = point\_cloud @ \_simhash\_matrix\_from\_seed(  
            original\_dim, config.num\_simhash\_projections, current\_seed  
        )  
  
        if use\_identity\_proj:  
            projected\_matrix = point\_cloud  
        elif config.projection\_type == ProjectionType.AMS\_SKETCH:  
            ams\_matrix = \_ams\_projection\_matrix\_from\_seed(  
                original\_dim, projection\_dim, current\_seed  
            )  
            projected\_matrix = point\_cloud @ ams\_matrix  
  
        rep\_fde\_sum = np.zeros(num\_partitions * projection\_dim, dtype=np.float32)  
        partition\_counts = np.zeros(num\_partitions, dtype=np.int32)  
        partition\_indices = np.array(  
            [\_simhash\_partition\_index\_gray(sketches[i]) for i in range(num\_points)]  
        )  
  
        for i in range(num\_points):  
            start\_idx = partition\_indices[i] * projection\_dim  
            rep\_fde\_sum[start\_idx : start\_idx + projection\_dim] += projected\_matrix[i]  
            partition\_counts[partition\_indices[i]] += 1  
  
        if config.encoding\_type == EncodingType.AVERAGE:  
            for i in range(num\_partitions):  
                start\_idx = i * projection\_dim  
                if partition\_counts[i] > 0:  
                    rep\_fde\_sum[start\_idx : start\_idx + projection\_dim] /= (  
                        partition\_counts[i]  
                    )  
                elif config.fill\_empty\_partitions and num\_points > 0:  
                    distances = [  
                        \_distance\_to\_simhash\_partition(sketches[j], i)  
                        for j in range(num\_points)  
                    ]  
                    nearest\_point\_idx = np.argmin(distances)  
                    rep\_fde\_sum[start\_idx : start\_idx + projection\_dim] = (  
                        projected\_matrix[nearest\_point\_idx]  
                    )  
  
        rep\_start\_index = rep\_num * num\_partitions * projection\_dim  
        out\_fde[rep\_start\_index : rep\_start\_index + rep\_fde\_sum.size] = rep\_fde\_sum  
  
    if config.final\_projection\_dimension and config.final\_projection\_dimension > 0:  
        return \_apply\_count\_sketch\_to\_vector(  
            out\_fde, config.final\_projection\_dimension, config.seed  
        )  
  
    return out\_fde  
  
  
def generate\_query\_fde(  
    point\_cloud: np.ndarray, config: FixedDimensionalEncodingConfig  
) -> np.ndarray:  
    """Generates a Fixed Dimensional Encoding for a query point cloud (using SUM)."""  
    if config.fill\_empty\_partitions:  
        raise ValueError(  
            "Query FDE generation does not support 'fill\_empty\_partitions'."  
        )  
    query\_config = replace(config, encoding\_type=EncodingType.DEFAULT\_SUM)  
    return \_generate\_fde\_internal(point\_cloud, query\_config)  
  
  
def generate\_document\_fde(  
    point\_cloud: np.ndarray, config: FixedDimensionalEncodingConfig  
) -> np.ndarray:  
    """Generates a Fixed Dimensional Encoding for a document point cloud (using AVERAGE)."""  
    doc\_config = replace(config, encoding\_type=EncodingType.AVERAGE)  
    return \_generate\_fde\_internal(point\_cloud, doc\_config)  
  
  
def generate\_fde(  
    point\_cloud: np.ndarray, config: FixedDimensionalEncodingConfig  
) -> np.ndarray:  
    if config.encoding\_type == EncodingType.DEFAULT\_SUM:  
        return generate\_query\_fde(point\_cloud, config)  
    elif config.encoding\_type == EncodingType.AVERAGE:  
        return generate\_document\_fde(point\_cloud, config)  
    else:  
        raise ValueError(f"Unsupported encoding type in config: {config.encoding\_type}")  
  
  
def generate\_document\_fde\_batch(  
    doc\_embeddings\_list: List[np.ndarray], config: FixedDimensionalEncodingConfig  
) -> np.ndarray:  
    """  
    Generates FDEs for a batch of documents using highly optimized NumPy vectorization.  
    Fully compliant with C++ implementation including all projection types.  
    """  
    batch\_start\_time = time.perf\_counter()  
    num\_docs = len(doc\_embeddings\_list)  
  
    if num\_docs == 0:  
        logging.warning("[FDE Batch] Empty document list provided")  
        return np.array([])  
  
    logging.info(f"[FDE Batch] Starting batch FDE generation for {num\_docs} documents")  
  
    # Input validation  
    valid\_docs = []  
    for i, doc in enumerate(doc\_embeddings\_list):  
        if doc.ndim != 2:  
            logging.warning(  
                f"[FDE Batch] Document {i} has invalid shape (ndim={doc.ndim}), skipping"  
            )  
            continue  
        if doc.shape[1] != config.dimension:  
            raise ValueError(  
                f"Document {i} has incorrect dimension: expected {config.dimension}, got {doc.shape[1]}"  
            )  
        if doc.shape[0] == 0:  
            logging.warning(f"[FDE Batch] Document {i} has no vectors, skipping")  
            continue  
        valid\_docs.append(doc)  
  
    if len(valid\_docs) == 0:  
        logging.warning("[FDE Batch] No valid documents after filtering")  
        return np.array([])  
  
    num\_docs = len(valid\_docs)  
    doc\_embeddings\_list = valid\_docs  
  
    # Determine projection dimension (matching C++ logic)  
    use\_identity\_proj = config.projection\_type == ProjectionType.DEFAULT\_IDENTITY  
    if use\_identity\_proj:  
        projection\_dim = config.dimension  
        logging.info(f"[FDE Batch] Using identity projection (dim={projection\_dim})")  
    else:  
        if not config.projection\_dimension or config.projection\_dimension <= 0:  
            raise ValueError(  
                "A positive projection\_dimension must be specified for non-identity projections"  
            )  
        projection\_dim = config.projection\_dimension  
        logging.info(  
            f"[FDE Batch] Using {config.projection\_type.name} projection: "  
            f"{config.dimension} -> {projection\_dim}"  
        )  
  
    # Configuration summary  
    num\_partitions = 2**config.num\_simhash\_projections  
    logging.info(  
        f"[FDE Batch] Configuration: {config.num\_repetitions} repetitions, "  
        f"{num\_partitions} partitions, projection\_dim={projection\_dim}"  
    )  
  
    # Document tracking  
    doc\_lengths = np.array([len(doc) for doc in doc\_embeddings\_list], dtype=np.int32)  
    total\_vectors = np.sum(doc\_lengths)  
    doc\_boundaries = np.insert(np.cumsum(doc\_lengths), 0, 0)  
    doc\_indices = np.repeat(np.arange(num\_docs), doc\_lengths)  
  
    logging.info(  
        f"[FDE Batch] Total vectors: {total\_vectors}, avg per doc: {total\_vectors / num\_docs:.1f}"  
    )  
  
    # Concatenate all embeddings  
    concat\_start = time.perf\_counter()  
    all\_points = np.vstack(doc\_embeddings\_list).astype(np.float32)  
    concat\_time = time.perf\_counter() - concat\_start  
    logging.info(f"[FDE Batch] Concatenation completed in {concat\_time:.3f}s")  
  
    # Pre-allocate output  
    final\_fde\_dim = config.num\_repetitions * num\_partitions * projection\_dim  
    out\_fdes = np.zeros((num\_docs, final\_fde\_dim), dtype=np.float32)  
    logging.info(f"[FDE Batch] Output FDE dimension: {final\_fde\_dim}")  
  
    # Process each repetition  
    for rep\_num in range(config.num\_repetitions):  
        # rep\_start\_time = time.perf\_counter()  
        current\_seed = config.seed + rep\_num  
  
        if rep\_num % 5 == 0:  # Log every 5 repetitions  
            logging.info(  
                f"[FDE Batch] Processing repetition {rep\_num + 1}/{config.num\_repetitions}"  
            )  
  
        # Step 1: SimHash projection  
        simhash\_start = time.perf\_counter()  
        simhash\_matrix = \_simhash\_matrix\_from\_seed(  
            config.dimension, config.num\_simhash\_projections, current\_seed  
        )  
        all\_sketches = all\_points @ simhash\_matrix  
        simhash\_time = time.perf\_counter() - simhash\_start  
  
        # Step 2: Apply dimensionality reduction if configured  
        proj\_start = time.perf\_counter()  
        if use\_identity\_proj:  
            projected\_points = all\_points  
        elif config.projection\_type == ProjectionType.AMS\_SKETCH:  
            ams\_matrix = \_ams\_projection\_matrix\_from\_seed(  
                config.dimension, projection\_dim, current\_seed  
            )  
            projected\_points = all\_points @ ams\_matrix  
        else:  
            raise ValueError(f"Unsupported projection type: {config.projection\_type}")  
        proj\_time = time.perf\_counter() - proj\_start  
  
        # Step 3: Vectorized partition index calculation  
        partition\_start = time.perf\_counter()  
        bits = (all\_sketches > 0).astype(np.uint32)  
        partition\_indices = np.zeros(total\_vectors, dtype=np.uint32)  
  
        # Vectorized Gray Code computation  
        for bit\_idx in range(config.num\_simhash\_projections):  
            partition\_indices = (partition\_indices << 1) + (  
                bits[:, bit\_idx] ^ (partition\_indices & 1)  
            )  
  
        partition\_time = time.perf\_counter() - partition\_start  
  
        # Step 4: Vectorized aggregation  
        agg\_start = time.perf\_counter()  
  
        # Initialize storage for this repetition  
        rep\_fde\_sum = np.zeros(  
            (num\_docs * num\_partitions * projection\_dim,), dtype=np.float32  
        )  
        partition\_counts = np.zeros((num\_docs, num\_partitions), dtype=np.int32)  
  
        # Count vectors per partition per document  
        np.add.at(partition\_counts, (doc\_indices, partition\_indices), 1)  
  
        # Aggregate vectors using flattened indexing for efficiency  
        doc\_part\_indices = doc\_indices * num\_partitions + partition\_indices  
        base\_indices = doc\_part\_indices * projection\_dim  
  
        for d in range(projection\_dim):  
            flat\_indices = base\_indices + d  
            np.add.at(rep\_fde\_sum, flat\_indices, projected\_points[:, d])  
  
        # Reshape for easier manipulation  
        rep\_fde\_sum = rep\_fde\_sum.reshape(num\_docs, num\_partitions, projection\_dim)  
  
        agg\_time = time.perf\_counter() - agg\_start  
  
        # Step 5: Convert sums to averages (for document FDE)  
        avg\_start = time.perf\_counter()  
  
        # Vectorized division where counts > 0  
        non\_zero\_mask = partition\_counts > 0  
        counts\_3d = partition\_counts[:, :, np.newaxis]  # Broadcasting for division  
  
        # Safe division (avoid divide by zero)  
        np.divide(rep\_fde\_sum, counts\_3d, out=rep\_fde\_sum, where=counts\_3d > 0)  
  
        # Fill empty partitions if configured  
        empty\_filled = 0  
        if config.fill\_empty\_partitions:  
            empty\_mask = ~non\_zero\_mask  
            empty\_docs, empty\_parts = np.where(empty\_mask)  
  
            for doc\_idx, part\_idx in zip(empty\_docs, empty\_parts):  
                if doc\_lengths[doc\_idx] == 0:  
                    continue  
  
                # Get sketches for this document  
                doc\_start = doc\_boundaries[doc\_idx]  
                doc\_end = doc\_boundaries[doc\_idx + 1]  
                doc\_sketches = all\_sketches[doc\_start:doc\_end]  
  
                # Vectorized distance calculation  
                binary\_rep = \_gray\_code\_to\_binary(part\_idx)  
                target\_bits = (  
                    binary\_rep >> np.arange(config.num\_simhash\_projections - 1, -1, -1)  
                ) & 1  
                distances = np.sum(  
                    (doc\_sketches > 0).astype(int) != target\_bits, axis=1  
                )  
  
                nearest\_local\_idx = np.argmin(distances)  
                nearest\_global\_idx = doc\_start + nearest\_local\_idx  
  
                rep\_fde\_sum[doc\_idx, part\_idx, :] = projected\_points[nearest\_global\_idx]  
                empty\_filled += 1  
  
        avg\_time = time.perf\_counter() - avg\_start  
  
        # Step 6: Copy results to output array  
        rep\_output\_start = rep\_num * num\_partitions * projection\_dim  
        out\_fdes[  
            :, rep\_output\_start : rep\_output\_start + num\_partitions * projection\_dim  
        ] = rep\_fde\_sum.reshape(num\_docs, -1)  
  
        # Log timing for first repetition  
        if rep\_num == 0:  
            logging.info("[FDE Batch] Repetition timing breakdown:")  
            logging.info(f"  - SimHash: {simhash\_time:.3f}s")  
            logging.info(f"  - Projection: {proj\_time:.3f}s")  
            logging.info(f"  - Partition indices: {partition\_time:.3f}s")  
            logging.info(f"  - Aggregation: {agg\_time:.3f}s")  
            logging.info(f"  - Averaging: {avg\_time:.3f}s")  
            if config.fill\_empty\_partitions:  
                logging.info(f"  - Filled {empty\_filled} empty partitions")  
  
    # Step 7: Apply final projection if configured  
    if config.final\_projection\_dimension and config.final\_projection\_dimension > 0:  
        logging.info(  
            f"[FDE Batch] Applying final projection: {final\_fde\_dim} -> "  
            f"{config.final\_projection\_dimension}"  
        )  
        final\_proj\_start = time.perf\_counter()  
  
        # Process in chunks to avoid memory issues  
        chunk\_size = min(100, num\_docs)  
        final\_fdes = []  
  
        for i in range(0, num\_docs, chunk\_size):  
            chunk\_end = min(i + chunk\_size, num\_docs)  
            chunk\_fdes = np.array(  
                [  
                    \_apply\_count\_sketch\_to\_vector(  
                        out\_fdes[j], config.final\_projection\_dimension, config.seed  
                    )  
                    for j in range(i, chunk\_end)  
                ]  
            )  
            final\_fdes.append(chunk\_fdes)  
  
        out\_fdes = np.vstack(final\_fdes)  
        final\_proj\_time = time.perf\_counter() - final\_proj\_start  
        logging.info(  
            f"[FDE Batch] Final projection completed in {final\_proj\_time:.3f}s"  
        )  
  
    # Final statistics and validation  
    total\_time = time.perf\_counter() - batch\_start\_time  
    logging.info(f"[FDE Batch] Batch generation completed in {total\_time:.3f}s")  
    logging.info(  
        f"[FDE Batch] Average time per document: {total\_time / num\_docs * 1000:.2f}ms"  
    )  
    logging.info(f"[FDE Batch] Throughput: {num\_docs / total\_time:.1f} docs/sec")  
    logging.info(f"[FDE Batch] Output shape: {out\_fdes.shape}")  
  
    # Validate output dimensions  
    expected\_dim = (  
        final\_fde\_dim  
        if not config.final\_projection\_dimension  
        else config.final\_projection\_dimension  
    )  
    assert out\_fdes.shape == (num\_docs, expected\_dim), (  
        f"Output shape mismatch: {out\_fdes.shape} != ({num\_docs}, {expected\_dim})"  
    )  
  
    # doc\_config = replace(config, encoding\_type=EncodingType.AVERAGE)  
  
    return out\_fdes  
  
  
if \_\_name\_\_ == "\_\_main\_\_":  
    print(f"\n{'=' * 20} SCENARIO 1: Basic FDE Generation {'=' * 20}")  
  
    base\_config = FixedDimensionalEncodingConfig(  
        dimension=128, num\_repetitions=2, num\_simhash\_projections=4, seed=42  
    )  
    query\_data = np.random.randn(32, base\_config.dimension).astype(np.float32)  
    doc\_data = np.random.randn(80, base\_config.dimension).astype(np.float32)  
  
    query\_fde = generate\_query\_fde(query\_data, base\_config)  
    doc\_fde = generate\_document\_fde(  
        doc\_data, replace(base\_config, fill\_empty\_partitions=True)  
    )  
  
    expected\_dim = (  
        base\_config.num\_repetitions  
        * (2**base\_config.num\_simhash\_projections)  
        * base\_config.dimension  
    )  
    print(f"Query FDE Shape: {query\_fde.shape} (Expected: {expected\_dim})")  
    print(f"Document FDE Shape: {doc\_fde.shape} (Expected: {expected\_dim})")  
    print(f"Similarity Score: {np.dot(query\_fde, doc\_fde):.4f}")  
    assert query\_fde.shape[0] == expected\_dim  
  
    print(f"\n{'=' * 20} SCENARIO 2: Inner Projection (AMS Sketch) {'=' * 20}")  
  
    ams\_config = replace(  
        base\_config, projection\_type=ProjectionType.AMS\_SKETCH, projection\_dimension=16  
    )  
    query\_fde\_ams = generate\_query\_fde(query\_data, ams\_config)  
    expected\_dim\_ams = (  
        ams\_config.num\_repetitions  
        * (2**ams\_config.num\_simhash\_projections)  
        * ams\_config.projection\_dimension  
    )  
    print(f"AMS Sketch FDE Shape: {query\_fde\_ams.shape} (Expected: {expected\_dim\_ams})")  
    assert query\_fde\_ams.shape[0] == expected\_dim\_ams  
  
    print(f"\n{'=' * 20} SCENARIO 3: Final Projection (Count Sketch) {'=' * 20}")  
  
    final\_proj\_config = replace(base\_config, final\_projection\_dimension=1024)  
    query\_fde\_final = generate\_query\_fde(query\_data, final\_proj\_config)  
    print(  
        f"Final Projection FDE Shape: {query\_fde\_final.shape} (Expected: {final\_proj\_config.final\_projection\_dimension})"  
    )  
    assert query\_fde\_final.shape[0] == final\_proj\_config.final\_projection\_dimension  
  
    print(f"\n{'=' * 20} SCENARIO 4: Top-level `generate\_fde` wrapper {'=' * 20}")  
  
    query\_fde\_2 = generate\_fde(  
        query\_data, replace(base\_config, encoding\_type=EncodingType.DEFAULT\_SUM)  
    )  
    doc\_fde\_2 = generate\_fde(  
        doc\_data, replace(base\_config, encoding\_type=EncodingType.AVERAGE)  
    )  
    print(  
        f"Wrapper-generated Query FDE is identical: {np.allclose(query\_fde, query\_fde\_2)}"  
    )  
    print(  
        f"Wrapper-generated Document FDE is identical: {np.allclose(doc\_fde, doc\_fde\_2)}"  
    )  
  
    print("\nAll test scenarios completed successfully.")

结语

随着 ColBERT、ColPali 等多向量模型的进一步发展,以及 MUVERA 这类优化算法的不断演进,多向量检索的效率瓶颈正在逐步被克服。未来,在推荐系统、搜索引擎、文档检索等场景中,多向量技术很可能成为标准配置。而 MUVERA 所展示的将复杂问题简化为经典问题的思路,也为其他领域的算法优化提供了有价值的参考。

picture.image

添加微信,备注” LLM “进入大模型技术交流群

picture.image

picture.image

如果你觉得这篇文章对你有帮助,别忘了点个赞、送个喜欢

/ 作者:致Great

/ 作者:欢迎转载,标注来源即可

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论