-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add SQLAlchemy session backend for conversation history management #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The new
The Temporary Solution As a temporary measure to allow the feature to be merged without breaking the build, the following changes have been made:
Next Steps This solution is a trade-off. It prevents the CI from breaking, but it means that the I think the correct long-term fix is to update Let me know what you think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At a glance, this looks good to me; @rm-openai any thoughts?
@habema could you resolve the conflict? |
@seratch you got it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this enhancement looks good to me; great work 👍
@habema Thanks for working on this. Overall, this PR looks good to me, but I don't have the bandwidth to do thorough testing with it before merging. I will take a look this week or early next week, but until then, please feel free to make further adjustments and/or adding unit tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Overall, it already looks great, but I have a few minor suggestions. Let me know what you think!
diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py
index 038f845..20f0555 100644
--- a/src/agents/extensions/memory/sqlalchemy_session.py
+++ b/src/agents/extensions/memory/sqlalchemy_session.py
@@ -31,6 +31,7 @@ from sqlalchemy import (
TIMESTAMP,
Column,
ForeignKey,
+ Index,
Integer,
MetaData,
String,
@@ -122,18 +123,10 @@ class SQLAlchemySession(SessionABC):
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
+ Index(f"idx_{messages_table}_session_time", "session_id", "created_at"),
sqlite_autoincrement=True,
)
- # Index for efficient retrieval of messages per session ordered by time
- from sqlalchemy import Index
-
- Index(
- f"idx_{messages_table}_session_time",
- self._messages.c.session_id,
- self._messages.c.created_at,
- )
-
# Async session factory
self._session_factory = async_sessionmaker(
self._engine, expire_on_commit=False
@@ -180,21 +173,28 @@ class SQLAlchemySession(SessionABC):
await conn.run_sync(self._metadata.create_all)
self._create_tables = False # Only create once
+ async def _serialize_message_data(self, item: TResponseInputItem) -> str:
+ return json.dumps(item, separators=(",", ":"))
+
+ async def _deserialize_message_data(self, item: str) -> TResponseInputItem:
+ return json.loads(item)
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
await self._ensure_tables()
async with self._session_factory() as sess:
- if limit is None:
+ if limit is not None:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
- .order_by(self._messages.c.created_at.asc())
+ # pick up the top-n new messages here, then reverse the order below
+ .order_by(self._messages.c.created_at.desc())
+ .limit(limit)
)
else:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
- .order_by(self._messages.c.created_at.desc())
- .limit(limit)
+ .order_by(self._messages.c.created_at.asc())
)
result = await sess.execute(stmt)
@@ -206,7 +206,7 @@ class SQLAlchemySession(SessionABC):
items: list[TResponseInputItem] = []
for raw in rows:
try:
- items.append(json.loads(raw))
+ items.append(await self._deserialize_message_data(raw))
except json.JSONDecodeError:
# Skip corrupted rows
continue
@@ -220,7 +220,7 @@ class SQLAlchemySession(SessionABC):
payload = [
{
"session_id": self.session_id,
- "message_data": json.dumps(item, separators=(",", ":")),
+ "message_data": await self._serialize_message_data(item),
}
for item in items
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you check my last two requests?
When we get limited amount of rows, we can get to the point when we get function_call_output row, but truncate function_call which will result in Error getting response: |
@russ-safeturf from my understanding what you mean isnt truncation, but rather retrieval of coupled items that might be separated due to the limit. If thats what you mean, this problem isnt inherent to the SQLAlchemy session implementation. |
Documentation for SQLAlchemy-powered sessions, to be merged after merging and releasing #1357
This PR introduces
SQLAlchemySession
, providing a production-grade session storage backend that can connect to any database supported by SQLAlchemy (e.g., PostgreSQL, MySQL).This implementation is based on the discussion in Issue #1328 and follows the agreed-upon architectural pattern:
src/agents/extensions/memory/
.[sqlalchemy]
extra.A test has been added.
Resolves #1328.