前沿重器
栏目主要给大家分享各种大厂、顶会的论文和分享,从中抽取关键精华的部分和大家分享,和大家一起把握前沿技术。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。(算起来,专项启动已经是20年的事了!)
2024年文章合集最新发布!在这里:再添近20万字-CS的陋室2024年文章合集更新
往期回顾
- 前沿重器[63] | SIGIR25-个性化实时检索增强机器人主动对话能力
- 前沿重器[64] | 阿里妈妈URM大模型:基于LLM的通用推荐新方案
- 前沿重器[65] | 大模型评判能力综述
- 前沿重器[66] | 美团搜索广告召回迭代实践启示
- 前沿重器[67] | 北邮腾讯MemoryOS:分层记忆模式解决大模型长记忆问题(上)
书接上文(前沿重器[67] | 北邮腾讯MemoryOS:分层记忆模式解决大模型长记忆问题(上)),不寒暄了,直接开始,本期主要聚焦这篇论文的开源代码,讲一下整体实现思路和一些我们可以学习吸收的细节。依旧摆出关键链接。
- 论文:https://arxiv.org/abs/2506.06326
- 开源项目:https://github.com/BAI-LAB/MemoryOS
- 论文讲解:https://mp.weixin.qq.com/s/5sppdbUPzrxdKlDPzVFfkg
文章内容配合代码很长,目录安排上。
- 项目结构
- 内容梳理
- 短期记忆
- 中期记忆
- 长期信息
- 核心服务
- 核心流程
- Retrieve context
- Get short-term history
- Format retrieved mid-term pages
- Get user profile
- Format
- Prompt拼接和请求
- add_memory
- 记忆系统的查询和更新
- retriever.py
- updater.py
- 记忆系统
- short_term.py
- mid_term.py
- long_term.py
- 后记
长文预警,看看右侧的进度条,本文加上代码文章已经万字以上了,但是不好再分多篇文章写,这里面模块耦合很多,大家根据自己实际情况阅读吧。
这里再啰嗦一句,这个项目的内部流程代码都比较长,而且代码很多设计的位置也不是特别优秀,所以建议大家尽量在理解完原论文再回来看,否则容易看不懂。
项目结构
先从整个项目的结构开始入手,具体里面有这些关键文件(删除了一些类似pdf、视频之类和工程关系不大的内容)。
.
|-- README.md
|-- example.py
|-- memoryos-mcp
| |-- config.example.json
| |-- config.json
| |-- mcp.json
| |-- memoryos
| | |-- \_\_init\_\_.py
| | |-- long\_term.py
| | |-- memoryos.py
| | |-- mid\_term.py
| | |-- prompts.py
| | |-- retriever.py
| | |-- short\_term.py
| | |-- updater.py
| | `-- utils.py
| |-- requirements.txt
| |-- server\_new.py
| `-- test\_comprehensive.py
`-- requirements.txt
可以看到里面的内容并不是很多,结构比较简单。
- MCP应该是串联项目的模式吧,形成通用的协议。
- 短期、中期、长期内容似乎都比较简单,一个文件基本完成大体。
requirements
里面有这些依赖,依赖也比较简单。
openai
numpy
sentence-transformers
faiss-gpu
Flask
这里也可以看出整个项目的简单,openai是请求大模型的组件,numpy、sentence-transformers是做一些基础计算和向量模型,faiss-gpu是单机向量召回组件,flask则是服务化的工具(不过,在现在最新版的项目下,用的是MCP,flask在项目里目前是没用上的)。
PS:这里MCP没写上,但是在代码里引入了,这里提示一下。
内容梳理
在我看完代码后,我觉得有必要现在前面和大家先说清楚,论文所需要体现的3个核心模块,即短、中、长期记忆的存取,具体是怎么做的,核心函数是哪个,否则穿来穿去可能会搞晕。有这套理解,相信大家看代码会更清晰。
短期记忆
短期记忆就是记录对话历史,即代码里面的page
。
- 有长度限制,过长了会触发到mid的过程。具体的代码在这里。
process\_short\_term\_to\_mid\_term
。
这块内容的使用方式也比较简单,就是拼接后放入prompt中直接给大模型,没什么特别的工作。
page的一般格式是这样的。
processed\_page = {
**page\_data, # Carry over existing fields like user\_input, agent\_response, timestamp
"page\_id": page\_id,
"page\_embedding": inp\_vec,
"page\_keywords": page\_keywords,
"preloaded": page\_data.get("preloaded", False), # Preserve if passed
"analyzed": page\_data.get("analyzed", False), # Preserve if passed
# pre\_page, next\_page, meta\_info are handled by DynamicUpdater
}
中期记忆
中期记忆的更新触发点是在process\_short\_term\_to\_mid\_term
,新进来的page要看能不能和历史中的某个话题/对话片段(session_id)有关系,如果有则可以进行合并,没有的话就自己建立一个session。session的基础格式如下。
session\_obj = {
"id": session\_id,
"summary": summary,
"summary\_keywords": summary\_keywords,
"summary\_embedding": summary\_vec,
"details": processed\_details,
"L\_interaction": len(processed\_details),
"R\_recency": 1.0, # Initial recency
"N\_visit": 0,
"H\_segment": 0.0, # Initial heat, will be computed
"timestamp": current\_ts, # Creation timestamp
"last\_visit\_time": current\_ts, # Also initial last\_visit\_time for recency calc
"access\_count\_lfu": 0# For LFU eviction policy
}
session\_obj["H\_segment"] = compute\_segment\_heat(session\_obj)
至于取,代码在search\_sessions
里面。
- 用
query\_vec
查出与之最相关的sessions
。此处用的是summary\_embedding
。 - 从
sessions
找出和query\_vec
最接近的page
,这里用的是page\_embedding
。
然后用page来拼接prompt来作为中期信息。
长期信息
长期记忆的信息来源,来自中期记忆的热度提升,热门被反复提到的信息,则会被上升到长期记忆,触发在\_trigger\_profile\_and\_knowledge\_update\_if\_needed
。
- 如果中期记忆的
current\_heat
足够高大于阈值了,则会开始着手把session存下来。 - 这里会分3个方向进行分别记录。
User Profile
。强调个人偏好,prompt内的解释是user's personality preferences。User Private Knowledge
。个人信息,personal information, preferences, or private facts about the user,似乎也包括preferences)Assistant Knowledge
。我理解其实就是人设,记录助手曾经提过的一些信息,确保前后一致性,prompt原文是explicit statements about what the assistant did, provided, or demonstrated。
至于取,则是在long\_term.py
下处理的,knowledge
信息都指向一个通用的函数\_search\_knowledge\_deque
,说白了就是用query\_vec
来在用户和助手上进行分别搜索,至于Profile
的使用,则是写在了最外面get\_response
里,是直接根据用户id获取的,毕竟画像这是是贯穿对话全流程的,无脑直接用。
核心服务
服务的入口是server\_new.py
。首先,这里使用mcp统一内部各个组件的调用的。
try:
from mcp.server.fastmcp import FastMCP
except ImportError as e:
print(f"ERROR: Failed to import FastMCP. Exception: {e}", file=sys.stderr)
print("请安装最新版本的MCP: pip install --upgrade mcp", file=sys.stderr)
sys.exit(1)
下面的多个工具函数就可以用mcp来定义了。
在开始定义工具之前,得先把核心的memoryos实例给定义出来。
def init\_memoryos(config\_path: str) -> Memoryos:
"""初始化MemoryOS实例"""
ifnot os.path.exists(config\_path):
raise FileNotFoundError(f"配置文件不存在: {config\_path}")
with open(config\_path, 'r', encoding='utf-8') as f:
config = json.load(f)
required\_fields = ['user\_id', 'openai\_api\_key', 'data\_storage\_path']
for field in required\_fields:
if field notin config:
raise ValueError(f"配置文件缺少必需字段: {field}")
return Memoryos(
user\_id=config['user\_id'],
openai\_api\_key=config['openai\_api\_key'],
data\_storage\_path=config['data\_storage\_path'],
openai\_base\_url=config.get('openai\_base\_url'),
assistant\_id=config.get('assistant\_id', 'default\_assistant\_profile'),
short\_term\_capacity=config.get('short\_term\_capacity', 10),
mid\_term\_capacity=config.get('mid\_term\_capacity', 2000),
long\_term\_knowledge\_capacity=config.get('long\_term\_knowledge\_capacity', 100),
retrieval\_queue\_capacity=config.get('retrieval\_queue\_capacity', 7),
mid\_term\_heat\_threshold=config.get('mid\_term\_heat\_threshold', 5.0),
llm\_model=config.get('llm\_model', 'gpt-4o-mini')
定义代码比较多,基本就是加载和定义,这都还好,这种定义完直接返回的方式可以学习一下。
接下来就是各种工具的定义,这几个工具都大差不差。主要有这些。
@mcp.tool()
def add\_memory(user\_input: str, agent\_response: str, timestamp: Optional[str] = None, meta\_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
向MemoryOS系统添加新的记忆(用户输入和助手回应的对话对)
Args:
user\_input: 用户的输入或问题
agent\_response: 助手的回应
timestamp: 时间戳(可选,格式:YYYY-MM-DD HH:MM:SS)
meta\_data: 可选的元数据(JSON对象)
Returns:
包含操作结果的字典
"""
global memoryos\_instance
if memoryos\_instance isNone:
return {
"status": "error",
"message": "MemoryOS is not initialized. Please check the configuration file."
}
try:
ifnot user\_input ornot agent\_response:
return {
"status": "error",
"message": "user\_input and agent\_response are required"
}
memoryos\_instance.add\_memory(
user\_input=user\_input,
agent\_response=agent\_response,
timestamp=timestamp,
meta\_data=meta\_data
)
result = {
"status": "success",
"message": "Memory has been successfully added to MemoryOS",
"timestamp": timestamp or get\_timestamp(),
"details": {
"user\_input\_length": len(user\_input),
"agent\_response\_length": len(agent\_response),
"has\_meta\_data": meta\_data isnotNone
}
}
return result
except Exception as e:
return {
"status": "error",
"message": f"Error adding memory: {str(e)}"
}
@mcp.tool()
def retrieve\_memory(query: str, relationship\_with\_user: str = "friend", style\_hint: str = "", max\_results: int = 10) -> Dict[str, Any]:
"""
根据查询从MemoryOS检索相关的记忆和上下文信息,包括短期记忆、中期记忆和长期知识
Args:
query: 检索查询,描述要寻找的信息
relationship\_with\_user: 与用户的关系类型(如:friend, assistant, colleague等)
style\_hint: 回应风格提示
max\_results: 返回的最大结果数量
Returns:
包含检索结果的字典,包括:
- short\_term\_memory: 当前短期记忆中的所有QA对
- retrieved\_pages: 从中期记忆检索的相关页面
- retrieved\_user\_knowledge: 从用户长期知识库检索的相关条目
- retrieved\_assistant\_knowledge: 从助手知识库检索的相关条目
"""
pass
@mcp.tool()
def get\_user\_profile(include\_knowledge: bool = True, include\_assistant\_knowledge: bool = False) -> Dict[str, Any]:
"""
获取用户的画像信息,包括个性特征、偏好和相关知识
Args:
include\_knowledge: 是否包括用户相关的知识条目
include\_assistant\_knowledge: 是否包括助手知识库
Returns:
包含用户画像信息的字典
"""
pass
这里我把add\_memory
的具体代码写出来了,其他俩忽略了,毕竟都差不多。
- 这里可以看到,有3个工具,分别是
add\_memory
、retrieve\_memory
、get\_user\_profile
,看名字就知道是增加记忆、查询记忆和获取用户信息。 - 在这里,每个工具类的代码,格式都比较一致,合法性判断、执行、校验结果,并伴随健壮性的try-except代码。是比较标准的代码设计模式了。
- 执行部分,没把具体执行逻辑放这里,而是用封装好的函数来做,这部分只管定义工具,内部执行由具体的函数执行,例如增加记忆就是
memoryos\_instance.add\_memory
。 - 这里有个全局变量
memoryos\_instance
是MemoryOS的一个实例。在后面的main
函数里就能看得到。
这里的memoryos\_instance
需要强调。这个位置本质上是维护了一个memoryos\_instance
实例,贯穿所有部分来作为所有信息的储存器,作为开源项目做demo是完全够用的,但作为工程项目其实是比较粗糙的,很多内容这里都是维护在本地,健壮性、稳定性、安全性都是不足的,如果要优化,有如下建议。
- 首先,对于各种要用来存储的信息,最好还是存在存储类的中间件里,例如redis、ES、postgre甚至mysql都是可以的。
- 对于同一轮对话,可以用一个专门的类来维护过程,这个类在每一轮里各自初始化,需要快速更新的内容放在redis之类的存取快速的中间件非常合适。
最后就是main
函数启动服务了,这里应该没什么要展开说的东西了,都比较基础。
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="MemoryOS MCP Server")
parser.add\_argument(
"--config",
type=str,
default="config.json",
help="配置文件路径 (默认: config.json)"
)
args = parser.parse\_args()
global memoryos\_instance
try:
# 初始化MemoryOS
memoryos\_instance = init\_memoryos(args.config)
print(f"MemoryOS MCP Server 已启动,用户ID: {memoryos\_instance.user\_id}", file=sys.stderr)
print(f"配置文件: {args.config}", file=sys.stderr)
# 启动MCP服务器 - 使用stdio传输
mcp.run(transport="stdio")
except KeyboardInterrupt:
print("服务器被用户中断", file=sys.stderr)
except Exception as e:
print(f"启动服务器时发生错误: {e}", file=sys.stderr)
import traceback
traceback.print\_exc()
sys.exit(1)
核心流程
上面是外围的工程架构,现在来看一轮对话后,内部的流程是什么样的。核心流程在文件memoryos.py
里面,各个函数都写在里面了,为了更容易切入,我们首先从get\_response
开始说起,这个是给定一轮对话进行回复的函数。
def get\_response(self, query: str, relationship\_with\_user="friend", style\_hint="", user\_conversation\_meta\_data: dict = None) -> str:
"""
Generates a response to the user's query, incorporating memory and context.
"""
print(f"Memoryos: Generating response for query: '{query[:50]}...'")
# 1. Retrieve context
retrieval\_results = self.retriever.retrieve\_context(
user\_query=query,
user\_id=self.user\_id
# Using default thresholds from Retriever class for now
)
retrieved\_pages = retrieval\_results["retrieved\_pages"]
retrieved\_user\_knowledge = retrieval\_results["retrieved\_user\_knowledge"]
retrieved\_assistant\_knowledge = retrieval\_results["retrieved\_assistant\_knowledge"]
# 2. Get short-term history
short\_term\_history = self.short\_term\_memory.get\_all()
history\_text = "\n".join([
f"User: {qa.get('user\_input', '')}\nAssistant: {qa.get('agent\_response', '')} (Time: {qa.get('timestamp', '')})"
for qa in short\_term\_history
])
# 3. Format retrieved mid-term pages (retrieval\_queue equivalent)
retrieval\_text = "\n".join([
f"【Historical Memory】\nUser: {page.get('user\_input', '')}\nAssistant: {page.get('agent\_response', '')}\nTime: {page.get('timestamp', '')}\nConversation chain overview: {page.get('meta\_info','N/A')}"
for page in retrieved\_pages
])
# 4. Get user profile
user\_profile\_text = self.user\_long\_term\_memory.get\_raw\_user\_profile(self.user\_id)
ifnot user\_profile\_text or user\_profile\_text.lower() == "none":
user\_profile\_text = "No detailed profile available yet."
# 5. Format retrieved user knowledge for background
user\_knowledge\_background = ""
if retrieved\_user\_knowledge:
user\_knowledge\_background = "\n【Relevant User Knowledge Entries】\n"
for kn\_entry in retrieved\_user\_knowledge:
user\_knowledge\_background += f"- {kn\_entry['knowledge']} (Recorded: {kn\_entry['timestamp']})\n"
background\_context = f"【User Profile】\n{user\_profile\_text}\n{user\_knowledge\_background}"
# 6. Format retrieved Assistant Knowledge (from assistant's LTM)
# Use retrieved assistant knowledge instead of all assistant knowledge
assistant\_knowledge\_text\_for\_prompt = "【Assistant Knowledge Base】\n"
if retrieved\_assistant\_knowledge:
for ak\_entry in retrieved\_assistant\_knowledge:
assistant\_knowledge\_text\_for\_prompt += f"- {ak\_entry['knowledge']} (Recorded: {ak\_entry['timestamp']})\n"
else:
assistant\_knowledge\_text\_for\_prompt += "- No relevant assistant knowledge found for this query.\n"
# 7. Format user\_conversation\_meta\_data (if provided)
meta\_data\_text\_for\_prompt = "【Current Conversation Metadata】\n"
if user\_conversation\_meta\_data:
try:
meta\_data\_text\_for\_prompt += json.dumps(user\_conversation\_meta\_data, ensure\_ascii=False, indent=2)
except TypeError:
meta\_data\_text\_for\_prompt += str(user\_conversation\_meta\_data)
else:
meta\_data\_text\_for\_prompt += "None provided for this turn."
# 8. Construct Prompts
system\_prompt\_text = prompts.GENERATE\_SYSTEM\_RESPONSE\_SYSTEM\_PROMPT.format(
relationship=relationship\_with\_user,
assistant\_knowledge\_text=assistant\_knowledge\_text\_for\_prompt,
meta\_data\_text=meta\_data\_text\_for\_prompt # Using meta\_data\_text placeholder for user\_conversation\_meta\_data
)
user\_prompt\_text = prompts.GENERATE\_SYSTEM\_RESPONSE\_USER\_PROMPT.format(
history\_text=history\_text,
retrieval\_text=retrieval\_text,
background=background\_context,
relationship=relationship\_with\_user,
query=query
)
messages = [
{"role": "system", "content": system\_prompt\_text},
{"role": "user", "content": user\_prompt\_text}
]
# 9. Call LLM for response
print("Memoryos: Calling LLM for final response generation...")
# print("System Prompt:\n", system\_prompt\_text)
# print("User Prompt:\n", user\_prompt\_text)
response\_content = self.client.chat\_completion(
model=self.llm\_model,
messages=messages,
temperature=0.7,
max\_tokens=1500# As in original main
)
# 10. Add this interaction to memory
self.add\_memory(user\_input=query, agent\_response=response\_content, timestamp=get\_timestamp())
return response\_content
这块的代码的注释还是挺优秀的,每一步在做什么都写出来了,读起来非常好懂,核心就是10个过程,总结的更抽象,其实就是3个大模块。
- 抽取记忆。无论是搜索、短期记忆、中期记忆、用户画像,其实就是从多个记忆存储途径里查询所需内容。
- 内容整合与回复生成。这里的整合上述查询的内容,并加入当前轮的内容,生成大模型需要回复的内容。
- 加入记忆。把本轮的输入和输出都仍进记忆系统里。
然后我们分步逐步把各个步骤的内容都整合出来。
Retrieve context
这一步就是查询当前对话状态下和中期、长期相关的内容。
retrieval\_results = self.retriever.retrieve\_context(
user\_query=query,
user\_id=self.user\_id
# Using default thresholds from Retriever class for now
)
retrieved\_pages = retrieval\_results["retrieved\_pages"]
retrieved\_user\_knowledge = retrieval\_results["retrieved\_user\_knowledge"]
retrieved\_assistant\_knowledge = retrieval\_results["retrieved\_assistant\_knowledge"]
Get short-term history
短期记忆比较粗暴,是直接拿整个列表下来的。
def get\_all(self):
return list(self.memory)
然后把短期记忆内容进行整合(为后面的prompt做准备)。
history\_text = "\n".join([
f"User: {qa.get('user\_input', '')}\nAssistant: {qa.get('agent\_response', '')} (Time: {qa.get('timestamp', '')})"
for qa in short\_term\_history
])
Format retrieved mid-term pages
整合中期的信息(说实话这一步其实没必要单独拿出来,可以放Retrieve context
的后处理)。
retrieval\_text = "\n".join([
f"【Historical Memory】\nUser: {page.get('user\_input', '')}\nAssistant: {page.get('agent\_response', '')}\nTime: {page.get('timestamp', '')}\nConversation chain overview: {page.get('meta\_info','N/A')}"
for page in retrieved\_pages
])
Get user profile
然后是获取用户画像。
user\_profile\_text = self.user\_long\_term\_memory.get\_raw\_user\_profile(self.user\_id)
if not user\_profile\_text or user\_profile\_text.lower() == "none":
user\_profile\_text = "No detailed profile available yet."
看这个输入就知道这咋回事了。
Format
5/6/7步都是整合,就是对前面搜索结果的字符串拼接。
user\_knowledge\_background = ""
if retrieved\_user\_knowledge:
user\_knowledge\_background = "\n【Relevant User Knowledge Entries】\n"
for kn\_entry in retrieved\_user\_knowledge:
user\_knowledge\_background += f"- {kn\_entry['knowledge']} (Recorded: {kn\_entry['timestamp']})\n"
background\_context = f"【User Profile】\n{user\_profile\_text}\n{user\_knowledge\_background}"
assistant\_knowledge\_text\_for\_prompt = "【Assistant Knowledge Base】\n"
if retrieved\_assistant\_knowledge:
for ak\_entry in retrieved\_assistant\_knowledge:
assistant\_knowledge\_text\_for\_prompt += f"- {ak\_entry['knowledge']} (Recorded: {ak\_entry['timestamp']})\n"
else:
assistant\_knowledge\_text\_for\_prompt += "- No relevant assistant knowledge found for this query.\n"
meta\_data\_text\_for\_prompt = "【Current Conversation Metadata】\n"
if user\_conversation\_meta\_data:
try:
meta\_data\_text\_for\_prompt += json.dumps(user\_conversation\_meta\_data, ensure\_ascii=False, indent=2)
except TypeError:
meta\_data\_text\_for\_prompt += str(user\_conversation\_meta\_data)
else:
meta\_data\_text\_for\_prompt += "None provided for this turn."
Prompt拼接和请求
然后就是prompt拼接和请求了。
# 8. Construct Prompts
system\_prompt\_text = prompts.GENERATE\_SYSTEM\_RESPONSE\_SYSTEM\_PROMPT.format(
relationship=relationship\_with\_user,
assistant\_knowledge\_text=assistant\_knowledge\_text\_for\_prompt,
meta\_data\_text=meta\_data\_text\_for\_prompt # Using meta\_data\_text placeholder for user\_conversation\_meta\_data
)
user\_prompt\_text = prompts.GENERATE\_SYSTEM\_RESPONSE\_USER\_PROMPT.format(
history\_text=history\_text,
retrieval\_text=retrieval\_text,
background=background\_context,
relationship=relationship\_with\_user,
query=query
)
messages = [
{"role": "system", "content": system\_prompt\_text},
{"role": "user", "content": user\_prompt\_text}
]
# 9. Call LLM for response
print("Memoryos: Calling LLM for final response generation...")
# print("System Prompt:\n", system\_prompt\_text)
# print("User Prompt:\n", user\_prompt\_text)
response\_content = self.client.chat\_completion(
model=self.llm\_model,
messages=messages,
temperature=0.7,
max\_tokens=1500# As in original main
)
要说的东西并不多,比较常规的拼接和大模型请求过程。
- 此处用的是format格式的拼接,也是一种比较常见的方案了。
- 这里,使用的比较经典的messages格式来输入的,而且只有两句(注意看
len(messages)
),没有用到那种多轮对话交替的模式。
add_memory
流程还没完,别忘了还需要把本轮的所有信息都扔进记忆系统里面。
self.add\_memory(user\_input=query, agent\_response=response\_content, timestamp=get\_timestamp())
注意,这里只扔了两个东西,一个是用户query
,一个是刚大模型生成的回复response\_content
内部是这样的。
def add\_memory(self, user\_input: str, agent\_response: str, timestamp: str = None, meta\_data: dict = None):
"""
Adds a new QA pair (memory) to the system.
meta\_data is not used in the current refactoring but kept for future use.
"""
ifnot timestamp:
timestamp = get\_timestamp()
qa\_pair = {
"user\_input": user\_input,
"agent\_response": agent\_response,
"timestamp": timestamp
# meta\_data can be added here if it needs to be stored with the QA pair
}
self.short\_term\_memory.add\_qa\_pair(qa\_pair)
print(f"Memoryos: Added QA to short-term. User: {user\_input[:30]}...")
if self.short\_term\_memory.is\_full():
print("Memoryos: Short-term memory full. Processing to mid-term.")
self.updater.process\_short\_term\_to\_mid\_term()
# After any memory addition that might impact mid-term, check for profile updates
self.\_trigger\_profile\_and\_knowledge\_update\_if\_needed()
总结下来就是很朴实的3步,首先增加short\_term\_memory
短期记忆,在短期记忆满了以后,然后触发process\_short\_term\_to\_mid\_term
短期记忆转中期记忆,最后更新长期记忆\_trigger\_profile\_and\_knowledge\_update\_if\_needed
。最后这个函数\_trigger\_profile\_and\_knowledge\_update\_if\_needed
和前面更新short\_term\_memory
,我自己感觉写在这个位置有些奇怪,我理解同样也应该写在updater里面,毕竟同样都是更新,写在更新层里更加合适。
既然写在memoryos这一层了,那直接在这里看看\_trigger\_profile\_and\_knowledge\_update\_if\_needed
内部的事,这里的内容真挺多的。
def \_trigger\_profile\_and\_knowledge\_update\_if\_needed(self):
"""
Checks mid-term memory for hot segments and triggers profile/knowledge update if threshold is met.
Adapted from main\_memoybank.py's update\_user\_profile\_from\_top\_segment.
"""
ifnot self.mid\_term\_memory.heap:
return
# Peek at the top of the heap (hottest segment)
# MidTermMemory heap stores (-H\_segment, sid)
neg\_heat, sid = self.mid\_term\_memory.heap[0]
current\_heat = -neg\_heat
if current\_heat >= self.mid\_term\_heat\_threshold:
session = self.mid\_term\_memory.sessions.get(sid)
ifnot session:
self.mid\_term\_memory.rebuild\_heap() # Clean up if session is gone
return
# Get unanalyzed pages from this hot session
# A page is a dict: {"user\_input": ..., "agent\_response": ..., "timestamp": ..., "analyzed": False, ...}
unanalyzed\_pages = [p for p in session.get("details", []) ifnot p.get("analyzed", False)]
if unanalyzed\_pages:
print(f"Memoryos: Mid-term session {sid} heat ({current\_heat:.2f}) exceeded threshold. Analyzing {len(unanalyzed\_pages)} pages for profile/knowledge update.")
# Perform user profile analysis and knowledge extraction separately
# First call: User profile analysis
new\_user\_profile\_text = gpt\_user\_profile\_analysis(unanalyzed\_pages, self.client, model=self.llm\_model)
# Second call: Knowledge extraction (user private data and assistant knowledge)
knowledge\_result = gpt\_knowledge\_extraction(unanalyzed\_pages, self.client, model=self.llm\_model)
new\_user\_private\_knowledge = knowledge\_result.get("private")
new\_assistant\_knowledge = knowledge\_result.get("assistant\_knowledge")
# Update User Profile in user's LTM
if new\_user\_profile\_text and new\_user\_profile\_text.lower() != "none":
old\_profile = self.user\_long\_term\_memory.get\_raw\_user\_profile(self.user\_id)
if old\_profile and old\_profile.lower() != "none":
updated\_profile = gpt\_update\_profile(old\_profile, new\_user\_profile\_text, self.client, model=self.llm\_model)
else:
updated\_profile = new\_user\_profile\_text
self.user\_long\_term\_memory.update\_user\_profile(self.user\_id, updated\_profile, merge=False) # Don't merge, replace with latest
# Add User Private Knowledge to user's LTM
if new\_user\_private\_knowledge and new\_user\_private\_knowledge.lower() != "none":
for line in new\_user\_private\_knowledge.split('\n'):
if line.strip() and line.strip().lower() notin ["none", "- none", "- none."]:
self.user\_long\_term\_memory.add\_user\_knowledge(line.strip())
# Add Assistant Knowledge to assistant's LTM
if new\_assistant\_knowledge and new\_assistant\_knowledge.lower() != "none":
for line in new\_assistant\_knowledge.split('\n'):
if line.strip() and line.strip().lower() notin ["none", "- none", "- none."]:
self.assistant\_long\_term\_memory.add\_assistant\_knowledge(line.strip()) # Save to dedicated assistant LTM
# Mark pages as analyzed and reset session heat contributors
for p in session["details"]:
p["analyzed"] = True# Mark all pages in session, or just unanalyzed\_pages?
# Original code marked all pages in session
session["N\_visit"] = 0# Reset visits after analysis
session["L\_interaction"] = 0# Reset interaction length contribution
# session["R\_recency"] = 1.0 # Recency will re-calculate naturally
session["H\_segment"] = compute\_segment\_heat(session) # Recompute heat with reset factors
session["last\_visit\_time"] = get\_timestamp() # Update last visit time
self.mid\_term\_memory.rebuild\_heap() # Heap needs rebuild due to H\_segment change
self.mid\_term\_memory.save()
print(f"Memoryos: Profile/Knowledge update for session {sid} complete. Heat reset.")
else:
print(f"Memoryos: Hot session {sid} has no unanalyzed pages. Skipping profile update.")
else:
# print(f"Memoryos: Top session {sid} heat ({current\_heat:.2f}) below threshold. No profile update.")
pass# No action if below threshold
这个函数的任务,已经在注释里写了。
Checks mid-term memory for hot segments and triggers profile/knowledge update if threshold is met.
检测中期记忆里的热门片段,同时如果阈值达标则触发画像/知识更新。
- 首先检测当前的热度值
current\_heat
是否超过阈值,如果超过则要开始进行更新了。 - 更新的标记是
unanalyzed\_pages
,内部会对对应片段进行分析,gpt\_user\_profile\_analysis
和gpt\_knowledge\_extraction
。这两个函数本质都是请求大模型,代码里基本就是prompt拼接以及结果解析,没什么好说的。
def gpt\_user\_profile\_analysis(dialogs, client: OpenAIClient, model="gpt-4o-mini", known\_user\_traits="None"):
"""Analyze user personality profile from dialogs"""
conversation = "\n".join([f"User: {d.get('user\_input','')} (Timestamp: {d.get('timestamp', '')})\nAssistant: {d.get('agent\_response','')} (Timestamp: {d.get('timestamp', '')})"for d in dialogs])
messages = [
{"role": "system", "content": prompts.PERSONALITY\_ANALYSIS\_SYSTEM\_PROMPT},
{"role": "user", "content": prompts.PERSONALITY\_ANALYSIS\_USER\_PROMPT.format(
conversation=conversation,
known\_user\_traits=known\_user\_traits
)}
]
print("Calling LLM for user profile analysis...")
result\_text = client.chat\_completion(model=model, messages=messages)
return result\_text.strip() if result\_text else"None"
def gpt\_knowledge\_extraction(dialogs, client: OpenAIClient, model="gpt-4o-mini"):
"""Extract user private data and assistant knowledge from dialogs"""
conversation = "\n".join([f"User: {d.get('user\_input','')} (Timestamp: {d.get('timestamp', '')})\nAssistant: {d.get('agent\_response','')} (Timestamp: {d.get('timestamp', '')})"for d in dialogs])
messages = [
{"role": "system", "content": prompts.KNOWLEDGE\_EXTRACTION\_SYSTEM\_PROMPT},
{"role": "user", "content": prompts.KNOWLEDGE\_EXTRACTION\_USER\_PROMPT.format(
conversation=conversation
)}
]
print("Calling LLM for knowledge extraction...")
result\_text = client.chat\_completion(model=model, messages=messages)
private\_data = "None"
assistant\_knowledge = "None"
try:
if"【User Private Data】"in result\_text:
private\_data\_start = result\_text.find("【User Private Data】") + len("【User Private Data】")
if"【Assistant Knowledge】"in result\_text:
private\_data\_end = result\_text.find("【Assistant Knowledge】")
private\_data = result\_text[private\_data\_start:private\_data\_end].strip()
assistant\_knowledge\_start = result\_text.find("【Assistant Knowledge】") + len("【Assistant Knowledge】")
assistant\_knowledge = result\_text[assistant\_knowledge\_start:].strip()
else:
private\_data = result\_text[private\_data\_start:].strip()
elif"【Assistant Knowledge】"in result\_text:
assistant\_knowledge\_start = result\_text.find("【Assistant Knowledge】") + len("【Assistant Knowledge】")
assistant\_knowledge = result\_text[assistant\_knowledge\_start:].strip()
except Exception as e:
print(f"Error parsing knowledge extraction: {e}. Raw result: {result\_text}")
return {
"private": private\_data if private\_data else"None",
"assistant\_knowledge": assistant\_knowledge if assistant\_knowledge else"None"
}
- 然后是更新,这里对应画像和知识,
update\_user\_profile
,add\_user\_knowledge
,以及助手信息add\_assistant\_knowledge
。这3个内容的实现是放在long\_term
里面的(是不是就有些乱了,这块的设计感觉可以优化的)。这个3个函数内容都比较简单,我就从那边放过来,这里可以看到基本就是信息存储,有些必要的向量,也会进行推理后放到这个dict里面。
def update\_user\_profile(self, user\_id, new\_data, merge=True):
if merge and user\_id in self.user\_profiles and self.user\_profiles[user\_id].get("data"): # Check if data exists
current\_data = self.user\_profiles[user\_id]["data"]
if isinstance(current\_data, str) and isinstance(new\_data, str):
updated\_data = f"{current\_data}\n\n--- Updated on {get\_timestamp()} ---\n{new\_data}"
else: # Fallback to overwrite if types are not strings or for more complex merge
updated\_data = new\_data
else:
# If merge=False or no existing data, replace with new data
updated\_data = new\_data
self.user\_profiles[user\_id] = {
"data": updated\_data,
"last\_updated": get\_timestamp()
}
print(f"LongTermMemory: Updated user profile for {user\_id} (merge={merge}).")
self.save()
def add\_user\_knowledge(self, knowledge\_text):
self.add\_knowledge\_entry(knowledge\_text, self.knowledge\_base, "user knowledge")
def add\_assistant\_knowledge(self, knowledge\_text):
self.add\_knowledge\_entry(knowledge\_text, self.assistant\_knowledge, "assistant knowledge")
def add\_knowledge\_entry(self, knowledge\_text, knowledge\_deque: deque, type\_name="knowledge"):
ifnot knowledge\_text or knowledge\_text.strip().lower() in ["", "none", "- none", "- none."]:
print(f"LongTermMemory: Empty {type\_name} received, not saving.")
return
# If deque is full, the oldest item is automatically removed when appending.
vec = get\_embedding(knowledge\_text)
vec = normalize\_vector(vec).tolist()
entry = {
"knowledge": knowledge\_text,
"timestamp": get\_timestamp(),
"knowledge\_embedding": vec
}
knowledge\_deque.append(entry)
print(f"LongTermMemory: Added {type\_name}. Current count: {len(knowledge\_deque)}.")
self.save()
记忆系统的查询和更新
记忆系统的查询和更新,被作者用单独的模块给维护起来了,上层的流程重在触发的时机,下层的记忆系统则重在内容的存储结构,中层的记忆系统查询和更新,则关注“该搜什么”和“该存什么”。就在两个文件里,retriever.py
和updater.py
。
retriever.py
检索的功能其实在很多地方都有,但这里只提取了最主要的retrieve\_context
。
这里的核心便是retriever.retrieve\_context
,在retriever.py
里面,我们看看他具体搜的是什么。
def retrieve\_context(self, user\_query: str,
user\_id: str, # Needed for profile, can be used for context filtering if desired
segment\_similarity\_threshold=0.1, # From main\_memoybank example
page\_similarity\_threshold=0.1, # From main\_memoybank example
knowledge\_threshold=0.01, # From main\_memoybank example
top\_k\_sessions=5, # From MidTermMemory search default
top\_k\_knowledge=20 # Default for knowledge search
):
print(f"Retriever: Starting retrieval for query: '{user\_query[:50]}...'")
# 1. Retrieve from Mid-Term Memory
# MidTermMemory.search\_sessions now takes client for its internal keyword extraction
# It also returns a more structured result including scores.
matched\_sessions = self.mid\_term\_memory.search\_sessions(
query\_text=user\_query,
segment\_similarity\_threshold=segment\_similarity\_threshold,
page\_similarity\_threshold=page\_similarity\_threshold,
top\_k\_sessions=top\_k\_sessions
)
# Use a heap to get top N pages across all relevant sessions based on their scores
top\_pages\_heap = []
page\_counter = 0# Add counter to ensure unique comparison
for session\_match in matched\_sessions:
for page\_match in session\_match.get("matched\_pages", []):
page\_data = page\_match["page\_data"]
page\_score = page\_match["score"] # Using the page score directly
# Add session relevance score to page score or combine them?
# For now, using page\_score. Could be: page\_score * session\_match["session\_relevance\_score"]
combined\_score = page\_score # Potentially adjust with session\_relevance\_score
if len(top\_pages\_heap) < self.retrieval\_queue\_capacity:
heapq.heappush(top\_pages\_heap, (combined\_score, page\_counter, page\_data))
page\_counter += 1
elif combined\_score > top\_pages\_heap[0][0]: # If current page is better than the worst in heap
heapq.heappop(top\_pages\_heap)
heapq.heappush(top\_pages\_heap, (combined\_score, page\_counter, page\_data))
page\_counter += 1
# Extract pages from heap, already sorted by heapq property (smallest first)
# We want highest scores, so either use a max-heap or sort after popping from min-heap.
retrieved\_mid\_term\_pages = [item[2] for item in sorted(top\_pages\_heap, key=lambda x: x[0], reverse=True)]
print(f"Retriever: Mid-term memory recalled {len(retrieved\_mid\_term\_pages)} pages.")
# 2. Retrieve from Long-Term User Knowledge (specific to the user)
# Assuming LongTermMemory for a user stores their specific knowledge/private data.
# The main LongTermMemory class in `long\_term.py` has `search\_user\_knowledge` which doesn't need user\_id as it's implicit in the instance
# However, if a single LTM instance handles multiple users, it would need user\_id.
# For the Memoryos class, LTM will be user-specific or assistant-specific.
retrieved\_user\_knowledge = self.long\_term\_memory.search\_user\_knowledge(
user\_query, threshold=knowledge\_threshold, top\_k=top\_k\_knowledge
)
print(f"Retriever: Long-term user knowledge recalled {len(retrieved\_user\_knowledge)} items.")
# 3. Retrieve from Long-Term Assistant Knowledge (general for the assistant)
# This requires a separate LTM instance or a method in LTM that queries a different knowledge base.
# In our Memoryos structure, there will be a separate LTM for assistant.
# For now, assuming self.long\_term\_memory is the USER's LTM.
# The Memoryos class will handle passing the correct LTM instance for assistant knowledge.
# This function will just return what it can from the provided LTM.
# If assistant\_ltm is passed, it can be used: self.assistant\_long\_term\_memory.search\_assistant\_knowledge(...)
retrieved\_assistant\_knowledge = []
if self.assistant\_long\_term\_memory:
retrieved\_assistant\_knowledge = self.assistant\_long\_term\_memory.search\_assistant\_knowledge(
user\_query, threshold=knowledge\_threshold, top\_k=top\_k\_knowledge
)
print(f"Retriever: Long-term assistant knowledge recalled {len(retrieved\_assistant\_knowledge)} items.")
else:
print("Retriever: No assistant long-term memory provided, skipping assistant knowledge retrieval.")
return {
"retrieved\_pages": retrieved\_mid\_term\_pages, # List of page dicts
"retrieved\_user\_knowledge": retrieved\_user\_knowledge, # List of knowledge entry dicts
"retrieved\_assistant\_knowledge": retrieved\_assistant\_knowledge, # List of assistant knowledge entry dicts
"retrieved\_at": get\_timestamp()
}
这里的注释同样非常优秀,记录如下。
- 首先
mid\_term\_memory
,里面,找到中期记忆的内容。这里搜索后,用堆(heap)去做了一个TOPN的排序抽取。 - 然后是从
long\_term\_memory
抽取用户和助手的信息,这里是两个信息都一起搜了。
注意这里的long\_term\_memory
搜的用户信息,是和用户本身无关的,注意这句注释。
The main LongTermMemory class in
long\_term.py
hassearch\_user\_knowledge
which doesn't need user_id as it's implicit in the instance
这里考虑的是,与用户或者是和assistant相似的画像,并非用用户id自己去取的信息,这里没有输入id,而是query。
retrieved\_user\_knowledge = self.long\_term\_memory.search\_user\_knowledge(
user\_query, threshold=knowledge\_threshold, top\_k=top\_k\_knowledge
)
###############
retrieved\_assistant\_knowledge = self.assistant\_long\_term\_memory.search\_assistant\_knowledge(
user\_query, threshold=knowledge\_threshold, top\_k=top\_k\_knowledge
)
updater.py
updater实现了将短期记忆整合并更新到中期记忆的逻辑。还会对中期记忆的片段进行热度分析,当热度达到阈值时,进一步分析并更新长期记忆。
这句话是不是很熟,因为在add\_memory
这个函数里有提过,上面提到的关键函数process\_short\_term\_to\_mid\_term
就是在这里实现的。
def process\_short\_term\_to\_mid\_term(self):
evicted\_qas = []
while self.short\_term\_memory.is\_full():
qa = self.short\_term\_memory.pop\_oldest()
if qa and qa.get("user\_input") and qa.get("agent\_response"):
evicted\_qas.append(qa)
ifnot evicted\_qas:
print("Updater: No QAs evicted from short-term memory.")
return
print(f"Updater: Processing {len(evicted\_qas)} QAs from short-term to mid-term.")
# 1. Create page structures and handle continuity within the evicted batch
current\_batch\_pages = []
temp\_last\_page\_in\_batch = self.last\_evicted\_page\_for\_continuity # Carry over from previous batch if any
for qa\_pair in evicted\_qas:
current\_page\_obj = {
"page\_id": generate\_id("page"),
"user\_input": qa\_pair.get("user\_input", ""),
"agent\_response": qa\_pair.get("agent\_response", ""),
"timestamp": qa\_pair.get("timestamp", get\_timestamp()),
"preloaded": False, # Default for new pages from short-term
"analyzed": False, # Default for new pages from short-term
"pre\_page": None,
"next\_page": None,
"meta\_info": None
}
is\_continuous = check\_conversation\_continuity(temp\_last\_page\_in\_batch, current\_page\_obj, self.client, model=self.llm\_model)
if is\_continuous and temp\_last\_page\_in\_batch:
current\_page\_obj["pre\_page"] = temp\_last\_page\_in\_batch["page\_id"]
# The actual next\_page for temp\_last\_page\_in\_batch will be set when it's stored in mid-term
# or if it's already there, it needs an update. This linking is tricky.
# For now, we establish the link from current to previous.
# MidTermMemory's update\_page\_connections can fix the other side if pages are already there.
# Meta info generation based on continuity
last\_meta = temp\_last\_page\_in\_batch.get("meta\_info")
new\_meta = generate\_page\_meta\_info(last\_meta, current\_page\_obj, self.client, model=self.llm\_model)
current\_page\_obj["meta\_info"] = new\_meta
# If temp\_last\_page\_in\_batch was part of a chain, its meta\_info and subsequent ones should update.
# This implies that meta\_info should perhaps be updated more globally or propagated.
# For now, new\_meta applies to current\_page\_obj and potentially its chain.
# We can call \_update\_linked\_pages\_meta\_info if temp\_last\_page\_in\_batch is in mid-term already.
if temp\_last\_page\_in\_batch.get("page\_id") and self.mid\_term\_memory.get\_page\_by\_id(temp\_last\_page\_in\_batch["page\_id"]):
self.\_update\_linked\_pages\_meta\_info(temp\_last\_page\_in\_batch["page\_id"], new\_meta)
else:
# Start of a new chain or no previous page
current\_page\_obj["meta\_info"] = generate\_page\_meta\_info(None, current\_page\_obj, self.client, model=self.llm\_model)
current\_batch\_pages.append(current\_page\_obj)
temp\_last\_page\_in\_batch = current\_page\_obj # Update for the next iteration in this batch
# Update the global last evicted page for the next run of this method
if current\_batch\_pages:
self.last\_evicted\_page\_for\_continuity = current\_batch\_pages[-1]
# 2. Consolidate text from current\_batch\_pages for multi-summary
ifnot current\_batch\_pages:
return
input\_text\_for\_summary = "\n".join([
f"User: {p.get('user\_input','')}\nAssistant: {p.get('agent\_response','')}"
for p in current\_batch\_pages
])
print("Updater: Generating multi-topic summary for the evicted batch...")
multi\_summary\_result = gpt\_generate\_multi\_summary(input\_text\_for\_summary, self.client, model=self.llm\_model)
# 3. Insert pages into MidTermMemory based on summaries
if multi\_summary\_result and multi\_summary\_result.get("summaries"):
for summary\_item in multi\_summary\_result["summaries"]:
theme\_summary = summary\_item.get("content", "General summary of recent interactions.")
theme\_keywords = summary\_item.get("keywords", [])
print(f"Updater: Processing theme '{summary\_item.get('theme')}' for mid-term insertion.")
# Pass the already processed pages (with IDs, embeddings to be added by MidTermMemory if not present)
self.mid\_term\_memory.insert\_pages\_into\_session(
summary\_for\_new\_pages=theme\_summary,
keywords\_for\_new\_pages=theme\_keywords,
pages\_to\_insert=current\_batch\_pages, # These pages now have pre\_page, next\_page, meta\_info set up
similarity\_threshold=self.topic\_similarity\_threshold
)
else:
# Fallback: if no summaries, add as one session or handle as a single block
print("Updater: No specific themes from multi-summary. Adding batch as a general session.")
fallback\_summary = "General conversation segment from short-term memory."
fallback\_keywords = llm\_extract\_keywords(input\_text\_for\_summary, self.client, model=self.llm\_model) if input\_text\_for\_summary else []
self.mid\_term\_memory.insert\_pages\_into\_session(
summary\_for\_new\_pages=fallback\_summary,
keywords\_for\_new\_pages=list(fallback\_keywords),
pages\_to\_insert=current\_batch\_pages,
similarity\_threshold=self.topic\_similarity\_threshold
)
# After pages are in mid-term, ensure their connections are doubly linked if needed.
# MidTermMemory.insert\_pages\_into\_session should ideally handle this internally
# or we might need a separate pass to solidify connections after all insertions.
for page in current\_batch\_pages:
if page.get("pre\_page"):
self.mid\_term\_memory.update\_page\_connections(page["pre\_page"], page["page\_id"])
if page.get("next\_page"):
self.mid\_term\_memory.update\_page\_connections(page["page\_id"], page["next\_page"]) # This seems redundant if next is set by prior
if current\_batch\_pages: # Save if any pages were processed
self.mid\_term\_memory.save()
- 首先是判断短期记忆是否满,如果满了则要拿出来做处理。这部分都存在
evicted\_qas
里面。 - 然后会判断对话的连续性
check\_conversation\_continuity
,这个的判断也是通过大模型来实现的。(代码逻辑比较重复我就不放了) - 如果是连续的,则会和最近还在连续的
page
进行合并temp\_last\_page\_in\_batch
,合并时需要更新对话过程的meta信息generate\_page\_meta\_info
(也是大模型完成的)。如果不连续,则直接自己就生成meta信息完事generate\_page\_meta\_info
。 - 这里的信息需要整合,构造summary:
gpt\_generate\_multi\_summary
,这里也是用大模型做的。 - 而后便把这些整理好的信息加入到session里面,这个就是存入中期信息的过程:
insert\_pages\_into\_session
。 - 另外,更新page之间的状态转移关系。
update\_page\_connections
。
回到insert\_pages\_into\_session
,这个加入的工作是在mid\_term.py
里面做的,但只有这里提到,所以我就把代码放这里。(很长)
def insert\_pages\_into\_session(self, summary\_for\_new\_pages, keywords\_for\_new\_pages, pages\_to\_insert,
similarity\_threshold=0.6, keyword\_similarity\_alpha=1.0):
ifnot self.sessions: # If no existing sessions, just add as a new one
print("MidTermMemory: No existing sessions. Adding new session directly.")
return self.add\_session(summary\_for\_new\_pages, pages\_to\_insert)
new\_summary\_vec = get\_embedding(summary\_for\_new\_pages)
new\_summary\_vec = normalize\_vector(new\_summary\_vec)
best\_sid = None
best\_overall\_score = -1
for sid, existing\_session in self.sessions.items():
existing\_summary\_vec = np.array(existing\_session["summary\_embedding"], dtype=np.float32)
semantic\_sim = float(np.dot(existing\_summary\_vec, new\_summary\_vec))
# Keyword similarity (Jaccard index based)
existing\_keywords = set(existing\_session.get("summary\_keywords", []))
new\_keywords\_set = set(keywords\_for\_new\_pages)
s\_topic\_keywords = 0
if existing\_keywords and new\_keywords\_set:
intersection = len(existing\_keywords.intersection(new\_keywords\_set))
union = len(existing\_keywords.union(new\_keywords\_set))
if union > 0:
s\_topic\_keywords = intersection / union
overall\_score = semantic\_sim + keyword\_similarity\_alpha * s\_topic\_keywords
if overall\_score > best\_overall\_score:
best\_overall\_score = overall\_score
best\_sid = sid
if best\_sid and best\_overall\_score >= similarity\_threshold:
print(f"MidTermMemory: Merging pages into session {best\_sid}. Score: {best\_overall\_score:.2f} (Threshold: {similarity\_threshold})")
target\_session = self.sessions[best\_sid]
processed\_new\_pages = []
for page\_data in pages\_to\_insert:
page\_id = page\_data.get("page\_id", generate\_id("page")) # Use existing or generate new ID
full\_text = f"User: {page\_data.get('user\_input','')} Assistant: {page\_data.get('agent\_response','')}"
inp\_vec = get\_embedding(full\_text)
inp\_vec = normalize\_vector(inp\_vec).tolist()
page\_keywords\_current = list(llm\_extract\_keywords(full\_text, client=self.client))
processed\_page = {
**page\_data, # Carry over existing fields
"page\_id": page\_id,
"page\_embedding": inp\_vec,
"page\_keywords": page\_keywords\_current,
# analyzed, preloaded flags should be part of page\_data if set
}
target\_session["details"].append(processed\_page)
processed\_new\_pages.append(processed\_page)
target\_session["L\_interaction"] += len(pages\_to\_insert)
target\_session["last\_visit\_time"] = get\_timestamp() # Update last visit time on modification
target\_session["H\_segment"] = compute\_segment\_heat(target\_session)
self.rebuild\_heap() # Rebuild heap as heat has changed
self.save()
return best\_sid
else:
print(f"MidTermMemory: No suitable session to merge (best score {best\_overall\_score:.2f} < threshold {similarity\_threshold}). Creating new session.")
return self.add\_session(summary\_for\_new\_pages, pages\_to\_insert)
简单说说这块的逻辑。
- 计算进来的句子和所有已知的session的距离,找到最接近且大于阈值的部分,把
page\_data
逐个塞到session里面,并更新对应值。 - 如果没找到大于阈值的session,则自己建立新的session。
记忆系统
在对具体推理流程有了解后,再来讲记忆系统里面的工作细节会更好理解。
提示一下,所谓记忆,重点关注两个内容——存、取,新的信息进来,各个信息模块内是如何存储这些信息的,以及后面是怎么抽取出来的。
short_term.py
短期记忆是这里最简单的,主要就是记录每轮的具体信息罢了。
class ShortTermMemory:
def \_\_init\_\_(self, file\_path, max\_capacity=10):
self.max\_capacity = max\_capacity
self.file\_path = file\_path
ensure\_directory\_exists(self.file\_path)
self.memory = deque(maxlen=max\_capacity)
self.load()
def add\_qa\_pair(self, qa\_pair):
# Ensure timestamp exists, add if not
if'timestamp'notin qa\_pair ornot qa\_pair['timestamp']:
qa\_pair["timestamp"] = get\_timestamp()
self.memory.append(qa\_pair)
print(f"ShortTermMemory: Added QA. User: {qa\_pair.get('user\_input','')[:30]}...")
self.save()
def get\_all(self):
return list(self.memory)
def is\_full(self):
return len(self.memory) >= self.max\_capacity # Use >= to be safe
def pop\_oldest(self):
if self.memory:
msg = self.memory.popleft()
print("ShortTermMemory: Evicted oldest QA pair.")
self.save()
return msg
returnNone
def save(self):
try:
with open(self.file\_path, "w", encoding="utf-8") as f:
json.dump(list(self.memory), f, ensure\_ascii=False, indent=2)
except IOError as e:
print(f"Error saving ShortTermMemory to {self.file\_path}: {e}")
def load(self):
try:
with open(self.file\_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Ensure items are loaded correctly, especially if file was empty or malformed
if isinstance(data, list):
self.memory = deque(data, maxlen=self.max\_capacity)
else:
self.memory = deque(maxlen=self.max\_capacity)
print(f"ShortTermMemory: Loaded from {self.file\_path}.")
except FileNotFoundError:
self.memory = deque(maxlen=self.max\_capacity)
print(f"ShortTermMemory: No history file found at {self.file\_path}. Initializing new memory.")
except json.JSONDecodeError:
self.memory = deque(maxlen=self.max\_capacity)
print(f"ShortTermMemory: Error decoding JSON from {self.file\_path}. Initializing new memory.")
except Exception as e:
self.memory = deque(maxlen=self.max\_capacity)
print(f"ShortTermMemory: An unexpected error occurred during load from {self.file\_path}: {e}. Initializing new memory.")
注意点总结一下。
- 前面就有提到过,这里用户的信息,本质是静态存在本地文件里,用json保存的,这里的
load
和save
都是对文件的处理。这个方式适合做demo,但不适合工程,工程的话,建议还是用redis之类的中间件的方式来做,尤其是多并发的情况,存本地文件容易出现错乱。 - 记忆存储的格式是
list<dict>
的格式,从add\_qa\_pair
里面就能看出来。但这个list并不那么简单,是一个队列deque
,遵循“先进先出”原则的,有max\_capacity
,要把前文内容给干掉。
mid_term.py
中期记忆就开始比较复杂了,这里还是从关键的函数入手吧。
这里比较关键的是search\_sessions
,这个函数是在Retriever
里面的retrieve\_context
调用的。这里是基于query,找到比较相似的session。
def search\_sessions(self, query\_text, segment\_similarity\_threshold=0.1, page\_similarity\_threshold=0.1,
top\_k\_sessions=5, keyword\_alpha=1.0, recency\_tau\_search=3600):
ifnot self.sessions:
return []
query\_vec = get\_embedding(query\_text)
query\_vec = normalize\_vector(query\_vec)
query\_keywords = set(llm\_extract\_keywords(query\_text, client=self.client))
candidate\_sessions = []
session\_ids = list(self.sessions.keys())
ifnot session\_ids: return []
summary\_embeddings\_list = [self.sessions[s]["summary\_embedding"] for s in session\_ids]
summary\_embeddings\_np = np.array(summary\_embeddings\_list, dtype=np.float32)
dim = summary\_embeddings\_np.shape[1]
index = faiss.IndexFlatIP(dim) # Inner product for similarity
index.add(summary\_embeddings\_np)
query\_arr\_np = np.array([query\_vec], dtype=np.float32)
distances, indices = index.search(query\_arr\_np, min(top\_k\_sessions, len(session\_ids)))
results = []
current\_time\_str = get\_timestamp()
for i, idx in enumerate(indices[0]):
if idx == -1: continue
session\_id = session\_ids[idx]
session = self.sessions[session\_id]
semantic\_sim\_score = float(distances[0][i]) # This is the dot product
# Keyword similarity for session summary
session\_keywords = set(session.get("summary\_keywords", []))
s\_topic\_keywords = 0
if query\_keywords and session\_keywords:
intersection = len(query\_keywords.intersection(session\_keywords))
union = len(query\_keywords.union(session\_keywords))
if union > 0: s\_topic\_keywords = intersection / union
# Time decay for session recency in search scoring
# time\_decay\_factor = compute\_time\_decay(session["timestamp"], current\_time\_str, tau\_hours=recency\_tau\_search)
# Combined score for session relevance
session\_relevance\_score = (semantic\_sim\_score + keyword\_alpha * s\_topic\_keywords)
if session\_relevance\_score >= segment\_similarity\_threshold:
matched\_pages\_in\_session = []
for page in session.get("details", []):
page\_embedding = np.array(page["page\_embedding"], dtype=np.float32)
# page\_keywords = set(page.get("page\_keywords", []))
page\_sim\_score = float(np.dot(page\_embedding, query\_vec))
# Can also add keyword sim for pages if needed, but keeping it simpler for now
if page\_sim\_score >= page\_similarity\_threshold:
matched\_pages\_in\_session.append({"page\_data": page, "score": page\_sim\_score})
if matched\_pages\_in\_session:
# Update session access stats
session["N\_visit"] += 1
session["last\_visit\_time"] = current\_time\_str
session["access\_count\_lfu"] = session.get("access\_count\_lfu", 0) + 1
self.access\_frequency[session\_id] = session["access\_count\_lfu"]
session["H\_segment"] = compute\_segment\_heat(session)
self.rebuild\_heap() # Heat changed
results.append({
"session\_id": session\_id,
"session\_summary": session["summary"],
"session\_relevance\_score": session\_relevance\_score,
"matched\_pages": sorted(matched\_pages\_in\_session, key=lambda x: x["score"], reverse=True) # Sort pages by score
})
self.save() # Save changes from access updates
# Sort final results by session\_relevance\_score
return sorted(results, key=lambda x: x["session\_relevance\_score"], reverse=True)
get\_embedding
是获取向量的函数,向量模型使用的是sentence\_transformer
。这个函数下还有个设计的小细节。这里是懒加载,只有第一次使用这个模型,才会被加载。另外这个地方的实现并没使用常见的方案,即建立一个类维护模型,再来一个encode
函数做,而是弄了个内部全局变量来维护。
\_model\_cache = {}
def get\_embedding(text, model\_name="all-MiniLM-L6-v2"):
if model\_name not in \_model\_cache:
print(f"Loading sentence transformer model: {model\_name}")
\_model\_cache[model\_name] = SentenceTransformer(model\_name)
model = \_model\_cache[model\_name]
embedding = model.encode([text], convert\_to\_numpy=True)[0]
return embedding
llm\_extract\_keywords
用于让大模型抽取关键词,当然如果有这个关键词是用在后面做关键词相似度计算的。
# 关键词提取(大模型)
def llm\_extract\_keywords(text, client: OpenAIClient, model="gpt-4o-mini"):
messages = [
{"role": "system", "content": prompts.EXTRACT\_KEYWORDS\_SYSTEM\_PROMPT},
{"role": "user", "content": prompts.EXTRACT\_KEYWORDS\_USER\_PROMPT.format(text=text)}
]
print("Calling LLM to extract keywords...")
response = client.chat\_completion(model=model, messages=messages)
return [kw.strip() for kw in response.split(',') if kw.strip()]
# 相似度计算
s\_topic\_keywords = 0
if query\_keywords and session\_keywords:
intersection = len(query\_keywords.intersection(session\_keywords))
union = len(query\_keywords.union(session\_keywords))
if union > 0: s\_topic\_keywords = intersection / union
- session的相似度召回,是对所有session做一个faiss的索引,然后进行召回,这里的索引是临时的索引。这里我自己其实感觉有些尴尬,若是
summary\_embeddings\_list
比较长,搜起来确实很快,但是add
的时间也不短,反过来若是比较短,那搜起来和逐个匹配其实也差不多,add
也没啥意义了。而且从代码里self.sessions
来看,是全局存下来的所有session,这么搞情况更像前者,每次都构造一个新的索引用完就扔,好像有点浪费。
summary\_embeddings\_list = [self.sessions[s]["summary\_embedding"] for s in session\_ids]
summary\_embeddings\_np = np.array(summary\_embeddings\_list, dtype=np.float32)
dim = summary\_embeddings\_np.shape[1]
index = faiss.IndexFlatIP(dim) # Inner product for similarity
index.add(summary\_embeddings\_np)
query\_arr\_np = np.array([query\_vec], dtype=np.float32)
distances, indices = index.search(query\_arr\_np, min(top\_k\_sessions, len(session\_ids)))
- 相似度的计算,使用的是字面和向量相似度的加权求和。
session\_relevance\_score = (semantic\_sim\_score + keyword\_alpha * s\_topic\_keywords)
- 除了和
summary\_embedding
算相似度,还需要和page\_embedding
计算相似度。这里的page\_embedding
就是论文里提的page,强调话题片段。
page\_embedding = np.array(page["page\_embedding"], dtype=np.float32)
- 然后需要更新一系列的内容信息,包括访问次数、最后一次访问时间、LFU(Least Frequently Used),另外还有热度。
session["N\_visit"] += 1
session["last\_visit\_time"] = current\_time\_str
session["access\_count\_lfu"] = session.get("access\_count\_lfu", 0) + 1
self.access\_frequency[session\_id] = session["access\_count\_lfu"]
session["H\_segment"] = compute\_segment\_heat(session)
self.rebuild\_heap() # Heat changed
# compute\_segment\_heat
def compute\_segment\_heat(session, alpha=HEAT\_ALPHA, beta=HEAT\_BETA, gamma=HEAT\_GAMMA, tau\_hours=RECENCY\_TAU\_HOURS):
N\_visit = session.get("N\_visit", 0)
L\_interaction = session.get("L\_interaction", 0)
# Calculate recency based on last\_visit\_time
R\_recency = 1.0# Default if no last\_visit\_time
if session.get("last\_visit\_time"):
R\_recency = compute\_time\_decay(session["last\_visit\_time"], get\_timestamp(), tau\_hours)
session["R\_recency"] = R\_recency # Update session's recency factor
return alpha * N\_visit + beta * L\_interaction + gamma * R\_recency
- 这里有个
rebuild\_heap
,是用来基于热度H\_segment
做末位淘汰的。
def rebuild\_heap(self):
self.heap = []
for sid, session\_data in self.sessions.items():
# Ensure H\_segment is up-to-date before rebuilding heap if necessary
# session\_data["H\_segment"] = compute\_segment\_heat(session\_data)
heapq.heappush(self.heap, (-session\_data["H\_segment"], sid))
# heapq.heapify(self.heap) # Not needed if pushing one by one
# No save here, it's an internal operation often followed by other ops that save
说完了查询,就该说存储了,上面估计很多变量怎么来的不太明白,从这里基本可以看懂了,具体定位在add\_session
函数。
def add\_session(self, summary, details):
session\_id = generate\_id("session")
summary\_vec = get\_embedding(summary)
summary\_vec = normalize\_vector(summary\_vec).tolist()
summary\_keywords = list(llm\_extract\_keywords(summary, client=self.client))
processed\_details = []
for page\_data in details:
page\_id = page\_data.get("page\_id", generate\_id("page"))
full\_text = f"User: {page\_data.get('user\_input','')} Assistant: {page\_data.get('agent\_response','')}"
inp\_vec = get\_embedding(full\_text)
inp\_vec = normalize\_vector(inp\_vec).tolist()
page\_keywords = list(llm\_extract\_keywords(full\_text, client=self.client))
processed\_page = {
**page\_data, # Carry over existing fields like user\_input, agent\_response, timestamp
"page\_id": page\_id,
"page\_embedding": inp\_vec,
"page\_keywords": page\_keywords,
"preloaded": page\_data.get("preloaded", False), # Preserve if passed
"analyzed": page\_data.get("analyzed", False), # Preserve if passed
# pre\_page, next\_page, meta\_info are handled by DynamicUpdater
}
processed\_details.append(processed\_page)
current\_ts = get\_timestamp()
session\_obj = {
"id": session\_id,
"summary": summary,
"summary\_keywords": summary\_keywords,
"summary\_embedding": summary\_vec,
"details": processed\_details,
"L\_interaction": len(processed\_details),
"R\_recency": 1.0, # Initial recency
"N\_visit": 0,
"H\_segment": 0.0, # Initial heat, will be computed
"timestamp": current\_ts, # Creation timestamp
"last\_visit\_time": current\_ts, # Also initial last\_visit\_time for recency calc
"access\_count\_lfu": 0# For LFU eviction policy
}
session\_obj["H\_segment"] = compute\_segment\_heat(session\_obj)
self.sessions[session\_id] = session\_obj
self.access\_frequency[session\_id] = 0# Initialize for LFU
heapq.heappush(self.heap, (-session\_obj["H\_segment"], session\_id)) # Use negative heat for max-heap behavior
print(f"MidTermMemory: Added new session {session\_id}. Initial heat: {session\_obj['H\_segment']:.2f}.")
if len(self.sessions) > self.max\_capacity:
self.evict\_lfu()
self.save()
return session\_id
这里重点看两个类的的内容,processed\_page
和session\_obj
。
processed\_page = {
**page\_data, # Carry over existing fields like user\_input, agent\_response, timestamp
"page\_id": page\_id,
"page\_embedding": inp\_vec,
"page\_keywords": page\_keywords,
"preloaded": page\_data.get("preloaded", False), # Preserve if passed
"analyzed": page\_data.get("analyzed", False), # Preserve if passed
# pre\_page, next\_page, meta\_info are handled by DynamicUpdater
}
session\_obj = {
"id": session\_id,
"summary": summary,
"summary\_keywords": summary\_keywords,
"summary\_embedding": summary\_vec,
"details": processed\_details,
"L\_interaction": len(processed\_details),
"R\_recency": 1.0, # Initial recency
"N\_visit": 0,
"H\_segment": 0.0, # Initial heat, will be computed
"timestamp": current\_ts, # Creation timestamp
"last\_visit\_time": current\_ts, # Also initial last\_visit\_time for recency calc
"access\_count\_lfu": 0# For LFU eviction policy
}
session\_obj["H\_segment"] = compute\_segment\_heat(session\_obj)
这些东西基本就是提前算好并且初始化的。
另外还有一个这个函数evict\_lfu
,用于处理论文里提到的lfu,是一种基于访问频率的缓存淘汰算法,广泛应用于操作系统和数据库管理中。
def evict\_lfu(self):
ifnot self.access\_frequency ornot self.sessions:
return
lfu\_sid = min(self.access\_frequency, key=self.access\_frequency.get)
print(f"MidTermMemory: LFU eviction. Session {lfu\_sid} has lowest access frequency.")
if lfu\_sid notin self.sessions:
del self.access\_frequency[lfu\_sid] # Clean up access frequency if session already gone
self.rebuild\_heap()
return
session\_to\_delete = self.sessions.pop(lfu\_sid) # Remove from sessions
del self.access\_frequency[lfu\_sid] # Remove from LFU tracking
# Clean up page connections if this session's pages were linked
for page in session\_to\_delete.get("details", []):
prev\_page\_id = page.get("pre\_page")
next\_page\_id = page.get("next\_page")
# If a page from this session was linked to an external page, nullify the external link
if prev\_page\_id andnot self.get\_page\_by\_id(prev\_page\_id): # Check if prev page is still in memory
# This case should ideally not happen if connections are within sessions or handled carefully
pass
if next\_page\_id andnot self.get\_page\_by\_id(next\_page\_id):
pass
# More robustly, one might need to search all other sessions if inter-session linking was allowed
# For now, assuming internal consistency or that MemoryOS class manages higher-level links
self.rebuild\_heap()
self.save()
print(f"MidTermMemory: Evicted session {lfu\_sid}.")
long_term.py
首先还是关注查询,上文中有提到用户信息,也有提到助手信息,都需要查询,本质都是由同一个函数维护的,具体可以看这里,就是这个\_search\_knowledge\_deque
函数。
def search\_user\_knowledge(self, query, threshold=0.1, top\_k=5):
results = self.\_search\_knowledge\_deque(query, self.knowledge\_base, threshold, top\_k)
print(f"LongTermMemory: Searched user knowledge for '{query[:30]}...'. Found {len(results)} matches.")
return results
def search\_assistant\_knowledge(self, query, threshold=0.1, top\_k=5):
results = self.\_search\_knowledge\_deque(query, self.assistant\_knowledge, threshold, top\_k)
print(f"LongTermMemory: Searched assistant knowledge for '{query[:30]}...'. Found {len(results)} matches.")
return results
具体函数内部流程如下。
def \_search\_knowledge\_deque(self, query, knowledge\_deque: deque, threshold=0.1, top\_k=5):
ifnot knowledge\_deque:
return []
query\_vec = get\_embedding(query)
query\_vec = normalize\_vector(query\_vec)
embeddings = []
valid\_entries = []
for entry in knowledge\_deque:
if"knowledge\_embedding"in entry and entry["knowledge\_embedding"]:
embeddings.append(np.array(entry["knowledge\_embedding"], dtype=np.float32))
valid\_entries.append(entry)
else:
print(f"Warning: Entry without embedding found in knowledge\_deque: {entry.get('knowledge','N/A')[:50]}")
ifnot embeddings:
return []
embeddings\_np = np.array(embeddings, dtype=np.float32)
if embeddings\_np.ndim == 1: # Single item case
if embeddings\_np.shape[0] == 0: return [] # Empty embeddings
embeddings\_np = embeddings\_np.reshape(1, -1)
if embeddings\_np.shape[0] == 0: # No valid embeddings
return []
dim = embeddings\_np.shape[1]
index = faiss.IndexFlatIP(dim) # Using Inner Product for similarity
index.add(embeddings\_np)
query\_arr = np.array([query\_vec], dtype=np.float32)
distances, indices = index.search(query\_arr, min(top\_k, len(valid\_entries))) # Search at most k or length of valid\_entries
results = []
for i, idx in enumerate(indices[0]):
if idx != -1: # faiss returns -1 for no valid index
similarity\_score = float(distances[0][i]) # For IndexFlatIP, distance is the dot product (similarity)
if similarity\_score >= threshold:
results.append(valid\_entries[idx]) # Add the original entry dict
# Sort by similarity score descending before returning, as faiss might not guarantee order for IP
results.sort(key=lambda x: float(np.dot(np.array(x["knowledge\_embedding"], dtype=np.float32), query\_vec)), reverse=True)
return results
- 看下来其实和min_term的很相似,就是计算相似度然后召回,区别就在于,和谁做的相似度。这里是和知识的相似度来做的相似度
knowledge\_embedding
。
没什么好说的,重看怎么存。
def add\_user\_knowledge(self, knowledge\_text):
self.add\_knowledge\_entry(knowledge\_text, self.knowledge\_base, "user knowledge")
def add\_assistant\_knowledge(self, knowledge\_text):
self.add\_knowledge\_entry(knowledge\_text, self.assistant\_knowledge, "assistant knowledge")
和上面的查询类似,add\_knowledge\_entry
都是用同一个函数,但是这里有区分的token——user knowledge
和assistant knowledge
。
def add\_knowledge\_entry(self, knowledge\_text, knowledge\_deque: deque, type\_name="knowledge"):
ifnot knowledge\_text or knowledge\_text.strip().lower() in ["", "none", "- none", "- none."]:
print(f"LongTermMemory: Empty {type\_name} received, not saving.")
return
# If deque is full, the oldest item is automatically removed when appending.
vec = get\_embedding(knowledge\_text)
vec = normalize\_vector(vec).tolist()
entry = {
"knowledge": knowledge\_text,
"timestamp": get\_timestamp(),
"knowledge\_embedding": vec
}
knowledge\_deque.append(entry)
print(f"LongTermMemory: Added {type\_name}. Current count: {len(knowledge\_deque)}.")
self.save()
这里也比较容易理解,主要是把对应的knowledge\_text
存入。
后记
整个代码看完,能很明显地感受到作者希望给我们呈现的内容,即关键的短中长期分层动态记忆,关键内容还是比较清晰的,让我们知道内部具体做了哪些内容,有哪些细节,方便我们理解论文。
另外还有几个细节,我这里总结一下吧。
-
热度和相似度的做法值得借鉴,尤其是这个热度,在生产上很有意义。
-
向量的匹配和召回,用sentence_transformer和faiss,确实是比较常见的方案,需要学习起来。
-
这种分层的模式,值得吸收学习,很多场景量变到质变,很多信息确实是需要提炼存储。
-
代码的注释还是比较清晰的,通过注释能了解作者思路,大家日常写代码要注意。
-
这里很多的提取任务,都是交给大模型来做的,可以结合实际情况尝试,里面的prompt格式和写法也可以吸收学习。(本文没列举具体的prompt,大家可以在项目内自行查看)
-
本文的各种存储都是用python原生的dict类型来维护的,会比较简单,在生产环境还是不那么建议,可以多考虑各种中间件综合使用。