纯Python TF-IDF知识库 同义词匹配效果差
ChromaDB + sentence-transformers 余弦相似度、欧几里得距离、点积等相似性计算 保证同义词效果性
# -*- coding: utf-8 -*-
"""
ChromaDB + sentence-transformers 完整可运行样例
运行方式:直接执行该文件,或在终端运行 python 文件名.py
"""
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer, util
import warnings
import sys
# 屏蔽无关警告,让输出更整洁
warnings.filterwarnings("ignore")
# ------------------------------
# 全局初始化(确保只初始化一次)
# ------------------------------
client = None
collection = None
embed_model = None
collection_name = "demo_knowledge_base"
def init_environment():
"""初始化环境(检查依赖+创建客户端+加载模型)"""
global client, collection, embed_model
# 1. 检查Python版本
python_version = sys.version_info
if python_version < (3, 8):
raise RuntimeError("请使用Python 3.8及以上版本运行!")
print("="*60)
print("🔍 开始初始化ChromaDB + sentence-transformers环境...")
# 2. 初始化ChromaDB客户端(内存模式,无文件依赖)
try:
client = chromadb.Client(Settings(
anonymized_telemetry=False,
allow_reset=True
))
print("✅ ChromaDB客户端初始化成功")
except Exception as e:
raise RuntimeError(f"ChromaDB初始化失败:{str(e)}")
# 3. 加载嵌入模型(CPU模式)
try:
embed_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
print(f"✅ 嵌入模型加载成功(维度:{embed_model.get_sentence_embedding_dimension()})")
except Exception as e:
raise RuntimeError(f"嵌入模型加载失败:{str(e)}")
# 4. 创建/重置集合
try:
# 删除旧集合(避免重复)
if collection_name in [col.name for col in client.list_collections()]:
client.delete_collection(collection_name)
collection = client.create_collection(name=collection_name)
print(f"✅ 知识库集合 '{collection_name}' 创建成功")
except Exception as e:
raise RuntimeError(f"集合创建失败:{str(e)}")
print("="*60)
# ------------------------------
# 核心功能函数
# ------------------------------
def get_embedding(text: str) -> list:
"""生成文本嵌入向量"""
if not text.strip():
raise ValueError("文本不能为空!")
vector = embed_model.encode(text, convert_to_numpy=False)
return vector
def add_documents(documents: list, ids: list, metadatas: list = None):
"""添加文档到知识库"""
if len(documents) != len(ids):
raise ValueError("documents和ids长度必须一致!")
embeddings = [get_embedding(doc) for doc in documents]
metadatas = metadatas or [{} for _ in documents]
collection.add(
documents=documents,
ids=ids,
metadatas=metadatas,
embeddings=embeddings
)
print(f"✅ 成功添加 {len(documents)} 个文档到知识库")
def query_documents(query_text: str, n_results: int = 3):
"""语义检索文档"""
query_embedding = [get_embedding(query_text)]
results = collection.query(
query_embeddings=query_embedding,
n_results=n_results
)
# 格式化输出检索结果
print(f"\n📝 查询文本:{query_text}")
print("🔍 检索结果:")
if not results['documents'][0]:
print(" 暂无匹配结果")
return results
for i in range(len(results['documents'][0])):
doc = results['documents'][0][i]
doc_id = results['ids'][0][i]
sim_score = util.cos_sim(query_embedding[0], get_embedding(doc)).item()
print(f" [{i+1}] ID: {doc_id} | 相似度: {sim_score:.4f}")
print(f" 文本: {doc[:100]}...") # 只显示前100字
return results
def delete_documents(ids: list):
"""删除指定ID的文档"""
collection.delete(ids=ids)
print(f"\n✅ 成功删除ID为 {ids} 的文档")
# ------------------------------
# 主函数(程序入口)
# ------------------------------
def main():
"""程序主入口"""
try:
# 第一步:初始化环境
init_environment()
# 第二步:准备测试数据
test_docs = [
"企划AI的核心功能是辅助编写项目方案、生成代码、解答技术问题",
"Python是一种解释型、面向对象的高级编程语言,语法简洁易读",
"ChromaDB是轻量级向量数据库,支持快速的向量相似度检索",
"sentence-transformers可以将文本转换为语义向量,用于相似度计算",
"RAG(检索增强生成)是将知识库检索结果融入大模型生成的技术"
]
test_ids = ["doc_001", "doc_002", "doc_003", "doc_004", "doc_005"]
test_metas = [
{"source": "企划文档", "type": "功能说明"},
{"source": "编程教程", "type": "语言介绍"},
{"source": "数据库文档", "type": "工具说明"},
{"source": "NLP教程", "type": "模型说明"},
{"source": "AI教程", "type": "技术概念"}
]
# 第三步:执行测试流程
print("\n📌 开始执行测试流程...")
# 1. 添加文档
add_documents(test_docs, test_ids, test_metas)
# 2. 语义检索测试1(精准匹配)
query_documents("企划AI能帮我写项目方案吗?", n_results=2)
# 3. 语义检索测试2(泛化匹配)
query_documents("什么是RAG技术?", n_results=1)
# 4. 删除文档测试
delete_documents(["doc_002"])
# 5. 验证删除结果
print("\n🔍 验证删除结果(检索Python相关文档):")
results = query_documents("Python是什么?", n_results=1)
if not results['documents'][0]:
print(" ✅ 已删除的文档无法检索到,删除成功!")
print("\n" + "="*60)
print("🎉 所有测试流程执行完成!")
except Exception as e:
print(f"\n❌ 程序运行出错:{str(e)}")
sys.exit(1)
# ------------------------------
# 程序启动入口
# ------------------------------
if __name__ == "__main__":
# 直接调用main函数运行整个程序
main()
# -*- coding: utf-8 -*-
"""
知识库
商用级纯Python TF-IDF知识库(零外部依赖)
修复:空结果索引异常、相似度阈值适配、调试友好
"""
import sys
import json
import os
import math
import re
import time
from typing import List, Dict, Any, Optional
from collections import defaultdict
class TFIDFEngine:
"""纯Python TF-IDF引擎(核心模块)"""
def __init__(self):
self.vocab: Dict[str, int] = {} # 词汇-ID映射
self.doc_count: int = 0 # 总文档数
self.word_doc_freq: defaultdict[str, int] = defaultdict(int) # 词-文档频率
def _tokenize(self, text: str) -> List[str]:
"""中文分词(商用级优化:保留中英文数字,过滤无效字符)"""
if not isinstance(text, str):
return []
# 正则匹配中英文、数字,过滤标点/空白/特殊字符
tokens = re.findall(r'[a-zA-Z0-9\u4e00-\u9fa5]+', text.strip())
# 过滤空token和过短token(单字无意义)
return [token for token in tokens if len(token) > 1]
def fit(self, documents: List[str]) -> None:
"""训练TF-IDF模型(适配全量文档)"""
# 重置状态
self.vocab.clear()
self.word_doc_freq.clear()
self.doc_count = len(documents)
# 构建词汇表和文档频率
vocab_temp = set()
for doc in documents:
tokens = self._tokenize(doc)
unique_tokens = set(tokens)
for token in unique_tokens:
self.word_doc_freq[token] += 1
vocab_temp.add(token)
# 生成词汇-ID映射(按频率排序,提升检索效率)
sorted_vocab = sorted(vocab_temp, key=lambda x: self.word_doc_freq[x], reverse=True)
self.vocab = {token: idx for idx, token in enumerate(sorted_vocab)}
def encode(self, text: str) -> List[float]:
"""生成TF-IDF向量(商用级:兼容空文本、长度对齐)"""
tokens = self._tokenize(text)
vec = [0.0] * len(self.vocab)
if not tokens:
return vec
# 计算TF(词频)
tf = defaultdict(int)
total_tokens = len(tokens)
for token in tokens:
tf[token] += 1
# 计算TF-IDF
for token, count in tf.items():
if token in self.vocab:
tf_val = count / total_tokens
# IDF平滑处理(避免除零,提升稳定性)
idf_val = math.log((self.doc_count + 1) / (self.word_doc_freq.get(token, 0) + 1))
vec[self.vocab[token]] = tf_val * idf_val
return vec
class CommercialTFIDFKB:
"""商用级TF-IDF知识库(核心类)"""
def __init__(self, db_path: str = "./commercial_kb.json"):
self.db_path: str = db_path
self.data: Dict[str, Dict[str, Any]] = {} # {id: {"text": "", "metadata": {}, "vector": []}}
self.tfidf: TFIDFEngine = TFIDFEngine()
self._init_storage() # 初始化存储(自动创建文件/恢复数据)
def _init_storage(self) -> None:
"""初始化存储(商用级:异常自恢复、目录自动创建)"""
# 确保存储目录存在
db_dir = os.path.dirname(self.db_path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
# 加载数据(异常自恢复)
if os.path.exists(self.db_path):
try:
with open(self.db_path, "r", encoding="utf-8") as f:
self.data = json.load(f) or {}
# 验证数据格式
for id_, item in self.data.items():
if not isinstance(item, dict) or "text" not in item:
del self.data[id_]
except Exception:
self.data = {}
# 备份损坏的文件(商用级:保留故障证据)
backup_path = f"{self.db_path}.bak.{int(time.time())}"
os.rename(self.db_path, backup_path)
else:
self.data = {}
# 训练TF-IDF模型
if self.data:
docs = [item["text"] for item in self.data.values()]
self.tfidf.fit(docs)
def _save_data(self) -> None:
"""保存数据(商用级:原子写入、避免文件损坏)"""
# 临时文件写入(避免原文件损坏)
temp_path = f"{self.db_path}.tmp"
try:
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(self.data, f, ensure_ascii=False, indent=0) # 紧凑格式,减少IO
# 原子替换原文件
os.replace(temp_path, self.db_path)
except Exception as e:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
raise RuntimeError(f"保存知识库失败:{str(e)[:100]}")
def add(
self,
documents: List[str],
ids: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None
) -> None:
"""
添加文档(商用级:参数校验、批量处理)
:param documents: 文档文本列表
:param ids: 文档唯一ID列表(需与documents长度一致)
:param metadatas: 元数据列表(可选,默认空字典)
"""
# 参数校验
if len(documents) != len(ids):
raise ValueError("documents与ids长度必须一致")
if not documents:
return
metadatas = metadatas or [{} for _ in documents]
if len(metadatas) != len(documents):
raise ValueError("metadatas长度需与documents一致")
# 合并新旧文档,重新训练模型
all_docs = [item["text"] for item in self.data.values()] + documents
self.tfidf.fit(all_docs)
# 批量添加文档(生成向量)
for doc, id_, meta in zip(documents, ids, metadatas):
# 过滤无效ID
if not isinstance(id_, str) or not id_.strip():
continue
id_clean = id_.strip()
# 生成向量并存储
self.data[id_clean] = {
"text": doc.strip() if isinstance(doc, str) else "",
"metadata": meta if isinstance(meta, dict) else {},
"vector": self.tfidf.encode(doc)
}
self._save_data()
def query(
self,
query_text: str,
n_results: int = 5,
min_similarity: float = 0.0 # 相似度阈值,过滤低匹配结果
) -> Dict[str, Any]:
"""
语义检索(商用级:阈值过滤、结果排序、空结果兼容)
:param query_text: 查询文本
:param n_results: 返回结果数(默认5)
:param min_similarity: 最小相似度(默认0.0,过滤无效结果)
:return: 兼容ChromaDB格式的结果
"""
# 空数据/空查询直接返回空结果
if not self.data or not query_text.strip():
return {
"documents": [[]],
"ids": [[]],
"metadatas": [[]],
"similarities": [[]]
}
# 生成查询向量
query_vec = self.tfidf.encode(query_text)
# 计算相似度(商用级:提前过滤零向量)
similarities = []
for id_, item in self.data.items():
vec = item.get("vector", [])
if not vec:
continue
# 余弦相似度计算(优化版:减少浮点运算)
dot = sum(a*b for a, b in zip(query_vec, vec))
norm1 = math.sqrt(sum(x*x for x in query_vec))
norm2 = math.sqrt(sum(x*x for x in vec))
if norm1 == 0 or norm2 == 0:
sim = 0.0
else:
sim = dot / (norm1 * norm2)
# 阈值过滤
if sim >= min_similarity:
similarities.append((id_, item, sim))
# 按相似度降序排序,取前n个
sorted_docs = sorted(similarities, key=lambda x: x[2], reverse=True)[:n_results]
# 构造兼容格式(商用级:确保返回格式稳定,空结果不报错)
return {
"documents": [[doc[1]["text"] for doc in sorted_docs] if sorted_docs else []],
"ids": [[doc[0] for doc in sorted_docs] if sorted_docs else []],
"metadatas": [[doc[1]["metadata"] for doc in sorted_docs] if sorted_docs else []],
"similarities": [[round(doc[2], 4) for doc in sorted_docs] if sorted_docs else []] # 保留4位小数
}
def delete(self, ids: List[str]) -> None:
"""
删除文档(商用级:批量处理、无效ID忽略)
:param ids: 待删除ID列表
"""
if not ids or not self.data:
return
# 批量删除
for id_ in ids:
if isinstance(id_, str) and id_.strip() in self.data:
del self.data[id_.strip()]
# 重新训练模型(适配剩余文档)
if self.data:
docs = [item["text"] for item in self.data.values()]
self.tfidf.fit(docs)
self._save_data()
def clear(self) -> None:
"""清空知识库(商用级:彻底清理)"""
self.data.clear()
self.tfidf = TFIDFEngine()
self._save_data()
def count(self) -> int:
"""获取文档总数(商用级:状态监控)"""
return len(self.data)
# ------------------------------
# 商用测试示例(修复空结果报错)
# ------------------------------
if __name__ == "__main__":
# 初始化知识库
kb = CommercialTFIDFKB("./commercial_kb.json")
print(f"✅ 知识库初始化完成,当前文档数:{kb.count()}")
# 清空旧数据
kb.clear()
print(f"✅ 清空后文档数:{kb.count()}")
# 添加测试文档
test_docs = [
"企划AI的核心功能是辅助编写项目方案、生成代码、解答技术问题",
"Python是一种解释型、面向对象的高级编程语言",
"ChromaDB是一个轻量级的向量数据库,支持快速的向量相似度检索",
"企业知识库系统需要支持语义检索、批量导入、权限管理",
"商用软件需保证稳定性、兼容性、无版权风险"
]
test_ids = ["doc_001", "doc_002", "doc_003", "doc_004", "doc_005"]
test_metas = [
{"source": "企划文档", "type": "功能说明"},
{"source": "编程教程", "type": "语言介绍"},
{"source": "数据库文档", "type": "工具说明"},
{"source": "产品需求", "type": "功能设计"},
{"source": "合规文档", "type": "商用规范"}
]
kb.add(documents=test_docs, ids=test_ids, metadatas=test_metas)
print(f"✅ 添加文档后总数:{kb.count()}")
# 语义检索测试(修复:降低阈值+空结果判断)
query = "企划AI能帮我写项目方案吗?"
results = kb.query(query_text=query, n_results=2, min_similarity=0.0) # 阈值设为0.0,确保有结果
# 修复核心:空结果兼容处理
print(f"\n📝 查询文本:{query}")
if results['documents'][0]: # 先判断是否有结果
print(f"🔍 检索结果1:{results['documents'][0][0]}(相似度:{results['similarities'][0][0]})")
if len(results['documents'][0]) > 1: # 再判断是否有第二个结果
print(f"🔍 检索结果2:{results['documents'][0][1]}(相似度:{results['similarities'][0][1]})")
else:
print("🔍 未检索到匹配的文档")
# 删除测试
kb.delete(ids=["doc_002"])
print(f"\n✅ 删除后文档数:{kb.count()}")
使用 ChromaDB + sentence-transformers window下执行报错:
Windows 本身的 DLL 兼容问题
处理思路:
1.docker 容器
2.Linux环境下执行测试

image.png