LightRAG

LightRAG 是针对 GraphRAG 构建索引速度慢、消耗 Token 量大而诞生的解决方案,更多详细的方法论请移步 wiki,这一篇主要聚焦与代码层面的实现。

注意事项

由于 Python 的异步模块 asyncio 不支持嵌套,.insert() 方法在执行的时候会报错:This event loop has already been running。我们需要用 nest_asyncio 打个补丁

先安装 nest_asyncio

1
pip install nest-asyncio

再导入,并 patch 一下

1
2
import nest_asyncio
nest_asyncio.apply()

然后就 ok 了

启动异步索引构建

.insert() 这个函数是插入文档的入口,是一个同步函数(然而 readme 里写成了异步函数,well)。它的工作实际上就是调用了异步函数进行插入文档

.insert()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def insert(
self,
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None,
) -> None:
"""
Sync Insert documents with checkpoint support

Args:
input:
文档字符串,或者包含很多字符串的列表。用于插入的文档
split_by_character:
是否按字符(而非 token)进行切割,如果把一个 token 切开了,把 token 放进来
split_by_character_only:
只以字符切割,忽略 token
ids:
文档的标识符(应该不同),不提供则用 hash 计算
file_paths:
文档的路径,只用于 citation 目的
"""
loop = always_get_an_event_loop()
loop.run_until_complete(
self.ainsert(
input, split_by_character, split_by_character_only, ids, file_paths
)
)

异步进行插入

.ainsert() 将处理分为两个主要阶段,真是太有异步了(

  1. 将文档加入队列 (.apipeline_enqueue_documents())
  2. 处理队列中的文档 (.apipeline_process_enqueue_documents())

完整代码

其实好像没必要开一个 Heading 4 (
.ainsert()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
async def ainsert(
self,
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None,
) -> None:
"""
Async Insert documents with checkpoint support

Args: (基本同上)
input:
文档字符串,或者包含很多字符串的列表。用于插入的文档
split_by_character:
是否按字符(而非 token)进行切割,如果把一个 token 切开了,把 token 放进来
split_by_character_only:
只以字符切割,忽略 token
ids:
文档的标识符(应该不同),不提供则用 hash 计算
file_paths:
文档的路径,只用于 citation 目的
"""
await self.apipeline_enqueue_documents(input, ids, file_paths)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)

对文档进行预处理

docstring 比较明晰地写明了 .apipeline_enqueue_documents() 这个函数在干什么。这里截取出代码详细解释一遍。

  1. 首先是预检查:把单个 str 放进列表方便后续统一的操作。然后检查 file_path 的数量是不是和 str 的数量对得上,毕竟file_path 的作用是引用
Pre-check
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
if isinstance(file_paths, str):
file_paths = [file_paths]

# If file_paths is provided, ensure it matches the number of documents
if file_paths is not None:
if isinstance(file_paths, str):
file_paths = [file_paths]
if len(file_paths) != len(input):
raise ValueError(
"Number of file paths must match the number of documents"
)
  1. 检查文档的 ID。
    1. 如果提供了 ID,则
      1. 检查数量是否与文档的数量一致
      2. 检查 ID 是否重复
    2. 否则就生成 MD5 作为 ID
    3. 这个阶段还把 file_path 作为引用和文档内容打包在一起,形成 id: { content: , file_path: } 的 Object 格式
pack up information
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# 1. Validate ids if provided or generate MD5 hash IDs
if ids is not None:
# Check if the number of IDs matches the number of documents
if len(ids) != len(input):
raise ValueError("Number of IDs must match the number of documents")

# Check if IDs are unique
if len(ids) != len(set(ids)):
raise ValueError("IDs must be unique")

# Generate contents dict of IDs provided by user and documents
contents = {
id_: {"content": doc, "file_path": path}
for id_, doc, path in zip(ids, input, file_paths)
}
else:
# Clean input text and remove duplicates
cleaned_input = [
(clean_text(doc), path) for doc, path in zip(input, file_paths)
]
unique_content_with_paths = {}

# Keep track of unique content and their paths
for content, path in cleaned_input:
if content not in unique_content_with_paths:
unique_content_with_paths[content] = path

# Generate contents dict of MD5 hash IDs and documents with paths
contents = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"file_path": path,
}
for content, path in unique_content_with_paths.items()
}
  1. 紧接着在输入的文档内部进行去重,意思是说,去除输入里的重复文档
Code
1
2
3
4
5
6
7
8
9
10
11
12
13
# 2. Remove duplicate contents
unique_contents = {}
for id_, content_data in contents.items():
content = content_data["content"]
file_path = content_data["file_path"]
if content not in unique_contents:
unique_contents[content] = (id_, file_path)

# Reconstruct contents with unique content
contents = {
id_: {"content": content, "file_path": file_path}
for content, (id_, file_path) in unique_contents.items()
}
  1. 为每一份文档建立一个状态,方便追踪(包括更新时间

这里的 content_summary 并非 LLM 的总结,仅仅只是做了截取。

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 3. Generate document initial status
new_docs: dict[str, Any] = {
id_: {
"status": DocStatus.PENDING,
"content": content_data["content"],
"content_summary": get_content_summary(content_data["content"]),
"content_length": len(content_data["content"]),
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"file_path": content_data[
"file_path"
], # Store file path in document status
}
for id_, content_data in contents.items()
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def get_content_summary(content: str, max_length: int = 250) -> str:
"""Get summary of document content

Args:
content:
Original document content
max_length:
Maximum length of summary

Returns:
Truncated content with ellipsis if needed
"""
content = content.strip()
if len(content) <= max_length:
return content
return content[:max_length] + "..."
  1. 紧接着是根据已有的数据库过滤掉已经添加过的文档。
Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 4. Filter out already processed documents
# Get docs ids
all_new_doc_ids = set(new_docs.keys())
# Exclude IDs of documents that are already in progress
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)

# Log ignored document IDs
ignored_ids = [
doc_id for doc_id in unique_new_doc_ids if doc_id not in new_docs
]
if ignored_ids:
logger.warning(
f"Ignoring {len(ignored_ids)} document IDs not found in new_docs"
)
for doc_id in ignored_ids:
logger.warning(f"Ignored document ID: {doc_id}")

# Filter new_docs to only include documents with unique IDs
new_docs = {
doc_id: new_docs[doc_id]
for doc_id in unique_new_doc_ids
if doc_id in new_docs
}

if not new_docs:
logger.info("No new unique documents were found.")
return
  1. 最后把过滤出来的文档插入文档数据库终于!(笑
Code
1
2
3
# 5. Store status document
await self.doc_status.upsert(new_docs)
logger.info(f"Stored {len(new_docs)} new unique documents")

正式处理文档

.apipeline_process_enqueue_documents() 大体的结构分成 async with 部分和 try ... finally 部分,分别对应“获取所有待处理文档”和“处理文档”的逻辑。

获取待处理文档的逻辑比较直观:获取数据库的锁之后,把数据库里的要处理的文档拿出来。但是写的比较奇怪,先不去深挖细节了(挖个坑先

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
async with pipeline_status_lock:
# Ensure only one worker is processing documents
if not pipeline_status.get("busy", False):
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)

to_process_docs: dict[str, DocProcessingStatus] = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)

if not to_process_docs:
logger.info("No documents to process")
return

pipeline_status.update(
{
"busy": True,
"job_name": "Default Job",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0, # Total number of files to be processed
"cur_batch": 0, # Number of files already processed
"request_pending": False, # Clear any previous request
"latest_message": "",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
else:
# Another process is busy, just set request flag and return
pipeline_status["request_pending"] = True
logger.info(
"Another process is already processing the document queue. Request queued."
)
return

注意

以下是 LightRAG 的核心部分,针对单篇文档提取 entity 和 relation,因此忽略了其他的一些操作,例如往 chunk database 里插入 chunks,插入 full doc 等等。包括错误处理、异步同步处理等等在内的很多细节也一并选择没有展开

那么肯定要考虑多篇文档的同时处理的。项目的处理也比较容易想到,也还是用 asyncio.create_task() 后用 asyncio.gather() 并行执行

把文档和 prompt 输入大模型

这一部分由 _process_single_content() 完成。首先 patch Prompt 输入,然后调用 use_llm_func_with_cache() 获得 LLM 输出并缓存下来。

接着开始解析输出,for ... in range(entity_extract_max_gleaning) 表示如果最多尝试提取关系 entity_extract_max_gleaning 次。

注意

以下的两个步骤是针对一块 chunk 做的。也就是说,如果文档太长而被切分成很多 chunk,那么以下两个步骤也会运行多次。

那么批量处理是如何进行的呢?

项目源码这里采用多线程的方式批量处理 chunk. 具体做法是定义了一个 semaphore,然后将所有任务都用 asyncio.create_task() 包装后,由 asyncio.wait() 统一执行并阻塞直到任务全部完成。

这里我们先忽略错误处理,先关注后面的流程。

收集完 chunk_results 后,直接用 listextend() 方法合并到 all_nodes, all_edges 里面

合并 chunk
1
2
3
4
5
6
7
8
9
10
11
12
13
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
all_edges = defaultdict(list)

for maybe_nodes, maybe_edges in chunk_results:
# Collect nodes
for entity_name, entities in maybe_nodes.items():
all_nodes[entity_name].extend(entities)

# Collect edges with sorted keys for undirected graph
for edge_key, edges in maybe_edges.items():
sorted_edge_key = tuple(sorted(edge_key))
all_edges[sorted_edge_key].extend(edges)

合并完这篇文档内的 entity 和 relation 之后,就要进入 graph insert 阶段了。我们把这部分放到后面再说。

解析大模型的输出

_process_extraction_result()extract_entities() 中定义的一个内部辅助函数,主要负责处理来自大语言模型 (LLM) 的提取结果,将非结构化的文本响应转换为结构化的实体和关系数据

首先将 LLM 返回的结果依照配置好的分隔符切开为 Record/Completion,每一条 Record/Completion 可能包含实体或者关系。

1
2
3
4
5
6
7
async def _process_extraction_result(
result: str, # 从 LLM 获取的提取结果文本字符串
chunk_key: str, # 文本块的唯一标识符,用于源跟踪
file_path: str = "unknown_source",
# 文件路径,用于引用来源(默认为"unknown_source")
):
# 返回一个元组 (maybe_nodes, maybe_edges),包含提取出的实体和关系

紧接着处理每一条记录(通过正则的匹配字符串可以发现每一条 Record 都包裹在一对圆括号内),在每一条 record 的内部,再使用 tuple delimiter 分割出 Entity 与 Attribute

1
2
3
4
5
6
7
8
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)

接着分别尝试将这条 record 解析为 Entity 或者 Relation. 根据 prompt.py 里记录的 prompt,可以看到我们要求大模型输出的格式为用 tuple delimiter 分割的元组。

1
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
1
2
3
4
5
6
7
# 提取为实体
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key, file_path
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
1
2
3
4
5
6
7
8
# 提取为关系
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key, file_path
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)

合并实体节点和关系边

这一块就是简单地提取的实体名称关系名称做合并

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Process gleaning result separately with file path
glean_nodes, glean_edges = await _process_extraction_result(
glean_result, chunk_key, file_path
)

# Merge results - only add entities and edges with new names
for entity_name, entities in glean_nodes.items():
if (
entity_name not in maybe_nodes
): # Only accetp entities with new name in gleaning stage
maybe_nodes[entity_name].extend(entities)
for edge_key, edges in glean_edges.items():
if (
edge_key not in maybe_edges
): # Only accetp edges with new name in gleaning stage
maybe_edges[edge_key].extend(edges)

为了避免 LLM 遗漏 Entity,我们再额外用 LLM 判断是否有遗漏,用 prompt.py 里的 if_loop_prompt 作为输入,直到没有遗漏了就退出循环。

最后返回从这个 chunk 提取出来的 nodes 和 edges

Code
1
2
3
4
5
6
7
8
9
10
if_loop_result: str = await use_llm_func_with_cache(
if_loop_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break

更新知识图谱

由于是异步执行,我们需要先获取锁确保数据的完整性

1
async with graph_db_lock:

接着是根据已有的图谱过滤掉已经添加过的点和边,这一部分比较偏工程实现,这里不做展开

根据已有知识库进行过滤
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Centralized processing of all nodes and edges
entities_data = []
relationships_data = []

# Use graph database lock to ensure atomic merges and updates
async with graph_db_lock:
# Process and update all entities at once
for entity_name, entities in all_nodes.items():
entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
entities_data.append(entity_data)

# Process and update all relationships at once
for edge_key, edges in all_edges.items():
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
edges,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
if edge_data is not None:
relationships_data.append(edge_data)

然后是更新节点数据库

更新节点数据库
1
2
3
4
5
6
7
8
9
10
11
12
13
# Update vector databases with all collected data
if entity_vdb is not None and entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in entities_data
}
await entity_vdb.upsert(data_for_vdb)

……和关系数据库

更新关系数据库
1
2
3
4
5
6
7
8
9
10
11
12
13
if relationships_vdb is not None and relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in relationships_data
}
await relationships_vdb.upsert(data_for_vdb)