Chroma实现本地知识库

纯Python TF-IDF知识库 同义词匹配效果差
ChromaDB + sentence-transformers 余弦相似度、欧几里得距离、点积等相似性计算 保证同义词效果性

# -*- coding: utf-8 -*-
"""
ChromaDB + sentence-transformers 完整可运行样例
运行方式:直接执行该文件,或在终端运行 python 文件名.py
"""
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer, util
import warnings
import sys

# 屏蔽无关警告,让输出更整洁
warnings.filterwarnings("ignore")

# ------------------------------
# 全局初始化(确保只初始化一次)
# ------------------------------
client = None
collection = None
embed_model = None
collection_name = "demo_knowledge_base"

def init_environment():
    """初始化环境(检查依赖+创建客户端+加载模型)"""
    global client, collection, embed_model

    # 1. 检查Python版本
    python_version = sys.version_info
    if python_version < (3, 8):
        raise RuntimeError("请使用Python 3.8及以上版本运行!")

    print("="*60)
    print("🔍 开始初始化ChromaDB + sentence-transformers环境...")

    # 2. 初始化ChromaDB客户端(内存模式,无文件依赖)
    try:
        client = chromadb.Client(Settings(
            anonymized_telemetry=False,
            allow_reset=True
        ))
        print("✅ ChromaDB客户端初始化成功")
    except Exception as e:
        raise RuntimeError(f"ChromaDB初始化失败:{str(e)}")

    # 3. 加载嵌入模型(CPU模式)
    try:
        embed_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
        print(f"✅ 嵌入模型加载成功(维度:{embed_model.get_sentence_embedding_dimension()})")
    except Exception as e:
        raise RuntimeError(f"嵌入模型加载失败:{str(e)}")

    # 4. 创建/重置集合
    try:
        # 删除旧集合(避免重复)
        if collection_name in [col.name for col in client.list_collections()]:
            client.delete_collection(collection_name)
        collection = client.create_collection(name=collection_name)
        print(f"✅ 知识库集合 '{collection_name}' 创建成功")
    except Exception as e:
        raise RuntimeError(f"集合创建失败:{str(e)}")

    print("="*60)

# ------------------------------
# 核心功能函数
# ------------------------------
def get_embedding(text: str) -> list:
    """生成文本嵌入向量"""
    if not text.strip():
        raise ValueError("文本不能为空!")
    vector = embed_model.encode(text, convert_to_numpy=False)
    return vector

def add_documents(documents: list, ids: list, metadatas: list = None):
    """添加文档到知识库"""
    if len(documents) != len(ids):
        raise ValueError("documents和ids长度必须一致!")

    embeddings = [get_embedding(doc) for doc in documents]
    metadatas = metadatas or [{} for _ in documents]

    collection.add(
        documents=documents,
        ids=ids,
        metadatas=metadatas,
        embeddings=embeddings
    )
    print(f"✅ 成功添加 {len(documents)} 个文档到知识库")

def query_documents(query_text: str, n_results: int = 3):
    """语义检索文档"""
    query_embedding = [get_embedding(query_text)]
    results = collection.query(
        query_embeddings=query_embedding,
        n_results=n_results
    )

    # 格式化输出检索结果
    print(f"\n📝 查询文本:{query_text}")
    print("🔍 检索结果:")
    if not results['documents'][0]:
        print("   暂无匹配结果")
        return results

    for i in range(len(results['documents'][0])):
        doc = results['documents'][0][i]
        doc_id = results['ids'][0][i]
        sim_score = util.cos_sim(query_embedding[0], get_embedding(doc)).item()
        print(f"  [{i+1}] ID: {doc_id} | 相似度: {sim_score:.4f}")
        print(f"     文本: {doc[:100]}...")  # 只显示前100字
    return results

def delete_documents(ids: list):
    """删除指定ID的文档"""
    collection.delete(ids=ids)
    print(f"\n✅ 成功删除ID为 {ids} 的文档")

# ------------------------------
# 主函数(程序入口)
# ------------------------------
def main():
    """程序主入口"""
    try:
        # 第一步:初始化环境
        init_environment()

        # 第二步:准备测试数据
        test_docs = [
            "企划AI的核心功能是辅助编写项目方案、生成代码、解答技术问题",
            "Python是一种解释型、面向对象的高级编程语言,语法简洁易读",
            "ChromaDB是轻量级向量数据库,支持快速的向量相似度检索",
            "sentence-transformers可以将文本转换为语义向量,用于相似度计算",
            "RAG(检索增强生成)是将知识库检索结果融入大模型生成的技术"
        ]
        test_ids = ["doc_001", "doc_002", "doc_003", "doc_004", "doc_005"]
        test_metas = [
            {"source": "企划文档", "type": "功能说明"},
            {"source": "编程教程", "type": "语言介绍"},
            {"source": "数据库文档", "type": "工具说明"},
            {"source": "NLP教程", "type": "模型说明"},
            {"source": "AI教程", "type": "技术概念"}
        ]

        # 第三步:执行测试流程
        print("\n📌 开始执行测试流程...")

        # 1. 添加文档
        add_documents(test_docs, test_ids, test_metas)

        # 2. 语义检索测试1(精准匹配)
        query_documents("企划AI能帮我写项目方案吗?", n_results=2)

        # 3. 语义检索测试2(泛化匹配)
        query_documents("什么是RAG技术?", n_results=1)

        # 4. 删除文档测试
        delete_documents(["doc_002"])

        # 5. 验证删除结果
        print("\n🔍 验证删除结果(检索Python相关文档):")
        results = query_documents("Python是什么?", n_results=1)
        if not results['documents'][0]:
            print("   ✅ 已删除的文档无法检索到,删除成功!")

        print("\n" + "="*60)
        print("🎉 所有测试流程执行完成!")

    except Exception as e:
        print(f"\n❌ 程序运行出错:{str(e)}")
        sys.exit(1)

# ------------------------------
# 程序启动入口
# ------------------------------
if __name__ == "__main__":
    # 直接调用main函数运行整个程序
    main()
# -*- coding: utf-8 -*-
"""
知识库
商用级纯Python TF-IDF知识库(零外部依赖)
修复:空结果索引异常、相似度阈值适配、调试友好
"""
import sys
import json
import os
import math
import re
import time
from typing import List, Dict, Any, Optional
from collections import defaultdict


class TFIDFEngine:
    """纯Python TF-IDF引擎(核心模块)"""
    def __init__(self):
        self.vocab: Dict[str, int] = {}  # 词汇-ID映射
        self.doc_count: int = 0  # 总文档数
        self.word_doc_freq: defaultdict[str, int] = defaultdict(int)  # 词-文档频率

    def _tokenize(self, text: str) -> List[str]:
        """中文分词(商用级优化:保留中英文数字,过滤无效字符)"""
        if not isinstance(text, str):
            return []
        # 正则匹配中英文、数字,过滤标点/空白/特殊字符
        tokens = re.findall(r'[a-zA-Z0-9\u4e00-\u9fa5]+', text.strip())
        # 过滤空token和过短token(单字无意义)
        return [token for token in tokens if len(token) > 1]

    def fit(self, documents: List[str]) -> None:
        """训练TF-IDF模型(适配全量文档)"""
        # 重置状态
        self.vocab.clear()
        self.word_doc_freq.clear()
        self.doc_count = len(documents)

        # 构建词汇表和文档频率
        vocab_temp = set()
        for doc in documents:
            tokens = self._tokenize(doc)
            unique_tokens = set(tokens)
            for token in unique_tokens:
                self.word_doc_freq[token] += 1
                vocab_temp.add(token)

        # 生成词汇-ID映射(按频率排序,提升检索效率)
        sorted_vocab = sorted(vocab_temp, key=lambda x: self.word_doc_freq[x], reverse=True)
        self.vocab = {token: idx for idx, token in enumerate(sorted_vocab)}

    def encode(self, text: str) -> List[float]:
        """生成TF-IDF向量(商用级:兼容空文本、长度对齐)"""
        tokens = self._tokenize(text)
        vec = [0.0] * len(self.vocab)
        if not tokens:
            return vec

        # 计算TF(词频)
        tf = defaultdict(int)
        total_tokens = len(tokens)
        for token in tokens:
            tf[token] += 1

        # 计算TF-IDF
        for token, count in tf.items():
            if token in self.vocab:
                tf_val = count / total_tokens
                # IDF平滑处理(避免除零,提升稳定性)
                idf_val = math.log((self.doc_count + 1) / (self.word_doc_freq.get(token, 0) + 1))
                vec[self.vocab[token]] = tf_val * idf_val

        return vec


class CommercialTFIDFKB:
    """商用级TF-IDF知识库(核心类)"""
    def __init__(self, db_path: str = "./commercial_kb.json"):
        self.db_path: str = db_path
        self.data: Dict[str, Dict[str, Any]] = {}  # {id: {"text": "", "metadata": {}, "vector": []}}
        self.tfidf: TFIDFEngine = TFIDFEngine()
        self._init_storage()  # 初始化存储(自动创建文件/恢复数据)

    def _init_storage(self) -> None:
        """初始化存储(商用级:异常自恢复、目录自动创建)"""
        # 确保存储目录存在
        db_dir = os.path.dirname(self.db_path)
        if db_dir and not os.path.exists(db_dir):
            os.makedirs(db_dir, exist_ok=True)

        # 加载数据(异常自恢复)
        if os.path.exists(self.db_path):
            try:
                with open(self.db_path, "r", encoding="utf-8") as f:
                    self.data = json.load(f) or {}
                # 验证数据格式
                for id_, item in self.data.items():
                    if not isinstance(item, dict) or "text" not in item:
                        del self.data[id_]
            except Exception:
                self.data = {}
                # 备份损坏的文件(商用级:保留故障证据)
                backup_path = f"{self.db_path}.bak.{int(time.time())}"
                os.rename(self.db_path, backup_path)
        else:
            self.data = {}

        # 训练TF-IDF模型
        if self.data:
            docs = [item["text"] for item in self.data.values()]
            self.tfidf.fit(docs)

    def _save_data(self) -> None:
        """保存数据(商用级:原子写入、避免文件损坏)"""
        # 临时文件写入(避免原文件损坏)
        temp_path = f"{self.db_path}.tmp"
        try:
            with open(temp_path, "w", encoding="utf-8") as f:
                json.dump(self.data, f, ensure_ascii=False, indent=0)  # 紧凑格式,减少IO
            # 原子替换原文件
            os.replace(temp_path, self.db_path)
        except Exception as e:
            # 清理临时文件
            if os.path.exists(temp_path):
                os.remove(temp_path)
            raise RuntimeError(f"保存知识库失败:{str(e)[:100]}")

    def add(
            self,
            documents: List[str],
            ids: List[str],
            metadatas: Optional[List[Dict[str, Any]]] = None
    ) -> None:
        """
        添加文档(商用级:参数校验、批量处理)
        :param documents: 文档文本列表
        :param ids: 文档唯一ID列表(需与documents长度一致)
        :param metadatas: 元数据列表(可选,默认空字典)
        """
        # 参数校验
        if len(documents) != len(ids):
            raise ValueError("documents与ids长度必须一致")
        if not documents:
            return

        metadatas = metadatas or [{} for _ in documents]
        if len(metadatas) != len(documents):
            raise ValueError("metadatas长度需与documents一致")

        # 合并新旧文档,重新训练模型
        all_docs = [item["text"] for item in self.data.values()] + documents
        self.tfidf.fit(all_docs)

        # 批量添加文档(生成向量)
        for doc, id_, meta in zip(documents, ids, metadatas):
            # 过滤无效ID
            if not isinstance(id_, str) or not id_.strip():
                continue
            id_clean = id_.strip()
            # 生成向量并存储
            self.data[id_clean] = {
                "text": doc.strip() if isinstance(doc, str) else "",
                "metadata": meta if isinstance(meta, dict) else {},
                "vector": self.tfidf.encode(doc)
            }

        self._save_data()

    def query(
            self,
            query_text: str,
            n_results: int = 5,
            min_similarity: float = 0.0  # 相似度阈值,过滤低匹配结果
    ) -> Dict[str, Any]:
        """
        语义检索(商用级:阈值过滤、结果排序、空结果兼容)
        :param query_text: 查询文本
        :param n_results: 返回结果数(默认5)
        :param min_similarity: 最小相似度(默认0.0,过滤无效结果)
        :return: 兼容ChromaDB格式的结果
        """
        # 空数据/空查询直接返回空结果
        if not self.data or not query_text.strip():
            return {
                "documents": [[]],
                "ids": [[]],
                "metadatas": [[]],
                "similarities": [[]]
            }

        # 生成查询向量
        query_vec = self.tfidf.encode(query_text)

        # 计算相似度(商用级:提前过滤零向量)
        similarities = []
        for id_, item in self.data.items():
            vec = item.get("vector", [])
            if not vec:
                continue
            # 余弦相似度计算(优化版:减少浮点运算)
            dot = sum(a*b for a, b in zip(query_vec, vec))
            norm1 = math.sqrt(sum(x*x for x in query_vec))
            norm2 = math.sqrt(sum(x*x for x in vec))
            if norm1 == 0 or norm2 == 0:
                sim = 0.0
            else:
                sim = dot / (norm1 * norm2)

            # 阈值过滤
            if sim >= min_similarity:
                similarities.append((id_, item, sim))

        # 按相似度降序排序,取前n个
        sorted_docs = sorted(similarities, key=lambda x: x[2], reverse=True)[:n_results]

        # 构造兼容格式(商用级:确保返回格式稳定,空结果不报错)
        return {
            "documents": [[doc[1]["text"] for doc in sorted_docs] if sorted_docs else []],
            "ids": [[doc[0] for doc in sorted_docs] if sorted_docs else []],
            "metadatas": [[doc[1]["metadata"] for doc in sorted_docs] if sorted_docs else []],
            "similarities": [[round(doc[2], 4) for doc in sorted_docs] if sorted_docs else []]  # 保留4位小数
        }

    def delete(self, ids: List[str]) -> None:
        """
        删除文档(商用级:批量处理、无效ID忽略)
        :param ids: 待删除ID列表
        """
        if not ids or not self.data:
            return

        # 批量删除
        for id_ in ids:
            if isinstance(id_, str) and id_.strip() in self.data:
                del self.data[id_.strip()]

        # 重新训练模型(适配剩余文档)
        if self.data:
            docs = [item["text"] for item in self.data.values()]
            self.tfidf.fit(docs)

        self._save_data()

    def clear(self) -> None:
        """清空知识库(商用级:彻底清理)"""
        self.data.clear()
        self.tfidf = TFIDFEngine()
        self._save_data()

    def count(self) -> int:
        """获取文档总数(商用级:状态监控)"""
        return len(self.data)


# ------------------------------
# 商用测试示例(修复空结果报错)
# ------------------------------
if __name__ == "__main__":
    # 初始化知识库
    kb = CommercialTFIDFKB("./commercial_kb.json")
    print(f"✅ 知识库初始化完成,当前文档数:{kb.count()}")

    # 清空旧数据
    kb.clear()
    print(f"✅ 清空后文档数:{kb.count()}")

    # 添加测试文档
    test_docs = [
        "企划AI的核心功能是辅助编写项目方案、生成代码、解答技术问题",
        "Python是一种解释型、面向对象的高级编程语言",
        "ChromaDB是一个轻量级的向量数据库,支持快速的向量相似度检索",
        "企业知识库系统需要支持语义检索、批量导入、权限管理",
        "商用软件需保证稳定性、兼容性、无版权风险"
    ]
    test_ids = ["doc_001", "doc_002", "doc_003", "doc_004", "doc_005"]
    test_metas = [
        {"source": "企划文档", "type": "功能说明"},
        {"source": "编程教程", "type": "语言介绍"},
        {"source": "数据库文档", "type": "工具说明"},
        {"source": "产品需求", "type": "功能设计"},
        {"source": "合规文档", "type": "商用规范"}
    ]
    kb.add(documents=test_docs, ids=test_ids, metadatas=test_metas)
    print(f"✅ 添加文档后总数:{kb.count()}")

    # 语义检索测试(修复:降低阈值+空结果判断)
    query = "企划AI能帮我写项目方案吗?"
    results = kb.query(query_text=query, n_results=2, min_similarity=0.0)  # 阈值设为0.0,确保有结果

    # 修复核心:空结果兼容处理
    print(f"\n📝 查询文本:{query}")
    if results['documents'][0]:  # 先判断是否有结果
        print(f"🔍 检索结果1:{results['documents'][0][0]}(相似度:{results['similarities'][0][0]})")
        if len(results['documents'][0]) > 1:  # 再判断是否有第二个结果
            print(f"🔍 检索结果2:{results['documents'][0][1]}(相似度:{results['similarities'][0][1]})")
    else:
        print("🔍 未检索到匹配的文档")

    # 删除测试
    kb.delete(ids=["doc_002"])
    print(f"\n✅ 删除后文档数:{kb.count()}")

使用 ChromaDB + sentence-transformers window下执行报错:
Windows 本身的 DLL 兼容问题

处理思路:
1.docker 容器
2.Linux环境下执行测试


image.png
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容