diff --git a/atomicmemory/providers/atomicmemory/async_handle_impl.py b/atomicmemory/providers/atomicmemory/async_handle_impl.py index 9384235..a3af9ca 100644 --- a/atomicmemory/providers/atomicmemory/async_handle_impl.py +++ b/atomicmemory/providers/atomicmemory/async_handle_impl.py @@ -47,6 +47,7 @@ scope_to_fields, scope_to_query_pairs, strip_agent_scope, + strip_read_filters, ) Route = Callable[[str], str] @@ -93,7 +94,7 @@ async def expand(self, refs: list[str], scope: MemoryScope) -> list[AtomicMemory method="POST", json=body, ) - echoed = strip_agent_scope(scope) + echoed = strip_read_filters(scope) return [_to_atomic_memory(m, echoed) for m in raw.get("memories", [])] async def list( @@ -103,7 +104,7 @@ async def list( ) -> AtomicMemoryListResultPage: opts = _coerce_list_options(options) _assert_list_options_scope_compat(scope, opts) - pairs: list[tuple[str, str]] = scope_to_query_pairs(scope) + pairs: list[tuple[str, str]] = scope_to_query_pairs(scope, include_thread=True) if opts.limit is not None: pairs.append(("limit", str(opts.limit))) if opts.offset is not None: @@ -127,14 +128,16 @@ async def list( ) async def get(self, id: str, scope: MemoryScope) -> AtomicMemoryMemory | None: - path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}") + unfiltered_scope = strip_read_filters(scope) + path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}") raw = await afetch_json_or_none(self._client, self._http, path) if raw is None: return None - return _to_atomic_memory(raw, strip_agent_scope(scope)) + return _to_atomic_memory(raw, unfiltered_scope) async def delete(self, id: str, scope: MemoryScope) -> None: - path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}") + unfiltered_scope = strip_read_filters(scope) + path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}") try: await afetch_void(self._client, self._http, path, method="DELETE") except ProviderError as exc: @@ -152,7 +155,7 @@ async def _post_ingest( ) -> AtomicMemoryIngestResult: assert_scope_allows_visibility(scope, input.visibility) body: dict[str, Any] = { - **scope_to_fields(scope), + **scope_to_fields(scope, include_thread=True), "conversation": input.conversation, "source_site": input.source_site, "source_url": input.source_url or "", @@ -173,7 +176,7 @@ async def _post_search( scope: MemoryScope, ) -> AtomicMemorySearchResultPage: body: dict[str, Any] = { - **scope_to_fields(scope, include_agent_scope=True), + **scope_to_fields(scope, include_agent_scope=True, include_thread=True), "query": request.query, } if request.limit is not None: diff --git a/atomicmemory/providers/atomicmemory/async_provider.py b/atomicmemory/providers/atomicmemory/async_provider.py index 72fd76f..026b10c 100644 --- a/atomicmemory/providers/atomicmemory/async_provider.py +++ b/atomicmemory/providers/atomicmemory/async_provider.py @@ -55,6 +55,7 @@ from atomicmemory.providers.atomicmemory.path import normalize_api_version from atomicmemory.providers.atomicmemory.provider import ( _build_ingest_body, + _build_list_path, _build_package_body, _build_search_body, _qs, @@ -137,7 +138,7 @@ async def do_delete(self, ref: MemoryRef) -> None: async def do_list(self, request: ListRequest) -> ListResultPage: offset = int(request.cursor) if request.cursor else 0 limit = request.limit if request.limit is not None else 20 - path = self._route(f"/memories/list?user_id={_qs(request.scope.user)}&limit={limit}&offset={offset}") + path = self._route(_build_list_path(request.scope, limit, offset)) raw = await afetch_json(self._require_client(), self._http_options, path) memories = [to_memory(m, request.scope) for m in raw.get("memories", [])] next_offset = offset + len(memories) diff --git a/atomicmemory/providers/atomicmemory/handle.py b/atomicmemory/providers/atomicmemory/handle.py index 749da21..50fdd08 100644 --- a/atomicmemory/providers/atomicmemory/handle.py +++ b/atomicmemory/providers/atomicmemory/handle.py @@ -35,6 +35,7 @@ class UserScope(BaseModel): model_config = ConfigDict(extra="forbid", populate_by_name=True) kind: Literal["user"] = "user" user_id: str = Field(alias="userId") + thread: str | None = None class WorkspaceScope(BaseModel): @@ -45,6 +46,7 @@ class WorkspaceScope(BaseModel): user_id: str = Field(alias="userId") workspace_id: str = Field(alias="workspaceId") agent_id: str = Field(alias="agentId") + thread: str | None = None agent_scope: AgentScope | None = Field(default=None, alias="agentScope") diff --git a/atomicmemory/providers/atomicmemory/handle_impl.py b/atomicmemory/providers/atomicmemory/handle_impl.py index 638d235..c253c2a 100644 --- a/atomicmemory/providers/atomicmemory/handle_impl.py +++ b/atomicmemory/providers/atomicmemory/handle_impl.py @@ -41,6 +41,7 @@ scope_to_fields, scope_to_query_pairs, strip_agent_scope, + strip_read_filters, ) Route = Callable[[str], str] @@ -97,7 +98,7 @@ def expand(self, refs: list[str], scope: MemoryScope) -> list[AtomicMemoryMemory method="POST", json=body, ) - echoed = strip_agent_scope(scope) + echoed = strip_read_filters(scope) return [_to_atomic_memory(m, echoed) for m in raw.get("memories", [])] def list( @@ -107,7 +108,7 @@ def list( ) -> AtomicMemoryListResultPage: opts = _coerce_list_options(options) _assert_list_options_scope_compat(scope, opts) - pairs: list[tuple[str, str]] = scope_to_query_pairs(scope) + pairs: list[tuple[str, str]] = scope_to_query_pairs(scope, include_thread=True) if opts.limit is not None: pairs.append(("limit", str(opts.limit))) if opts.offset is not None: @@ -131,14 +132,16 @@ def list( ) def get(self, id: str, scope: MemoryScope) -> AtomicMemoryMemory | None: - path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}") + unfiltered_scope = strip_read_filters(scope) + path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}") raw = fetch_json_or_none(self._client, self._http, path) if raw is None: return None - return _to_atomic_memory(raw, strip_agent_scope(scope)) + return _to_atomic_memory(raw, unfiltered_scope) def delete(self, id: str, scope: MemoryScope) -> None: - path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(scope))}") + unfiltered_scope = strip_read_filters(scope) + path = self._route(f"/memories/{quote(id, safe='')}?{urlencode(scope_to_query_pairs(unfiltered_scope))}") try: fetch_void(self._client, self._http, path, method="DELETE") except ProviderError as exc: @@ -160,7 +163,7 @@ def _post_ingest( ) -> AtomicMemoryIngestResult: assert_scope_allows_visibility(scope, input.visibility) body: dict[str, Any] = { - **scope_to_fields(scope), + **scope_to_fields(scope, include_thread=True), "conversation": input.conversation, "source_site": input.source_site, "source_url": input.source_url or "", @@ -181,7 +184,7 @@ def _post_search( scope: MemoryScope, ) -> AtomicMemorySearchResultPage: body: dict[str, Any] = { - **scope_to_fields(scope, include_agent_scope=True), + **scope_to_fields(scope, include_agent_scope=True, include_thread=True), "query": request.query, } if request.limit is not None: @@ -263,7 +266,7 @@ def _to_atomic_memory(raw: dict[str, Any], scope: MemoryScope) -> AtomicMemoryMe payload: dict[str, Any] = { "id": raw["id"], "content": raw.get("content") or "", - "scope": scope, + "scope": _build_memory_scope(raw, scope), "created_at": _parse_iso(raw.get("created_at")) or _now_utc(), } if raw.get("updated_at"): @@ -274,6 +277,23 @@ def _to_atomic_memory(raw: dict[str, Any], scope: MemoryScope) -> AtomicMemoryMe return AtomicMemoryMemory.model_validate(payload) +def _build_memory_scope(raw: dict[str, Any], requested_scope: MemoryScope) -> MemoryScope: + """Validate and project Core ``session_id`` back into namespace scope.""" + session_id = raw.get("session_id") + if requested_scope.thread is not None: + if not session_id: + raise ValueError( + "atomicmemory provider: backend response missing required `session_id` for thread-scoped request" + ) + if session_id != requested_scope.thread: + raise ValueError( + "atomicmemory provider: backend response `session_id` did not match requested thread scope" + ) + if not session_id: + return requested_scope + return requested_scope.model_copy(update={"thread": session_id}) + + def _to_atomic_search_result(raw: dict[str, Any], scope: MemoryScope) -> AtomicMemorySearchResult: similarity = _coalesce(raw.get("semantic_similarity"), raw.get("similarity")) ranking_score = _coalesce(raw.get("ranking_score"), raw.get("score")) diff --git a/atomicmemory/providers/atomicmemory/mappers.py b/atomicmemory/providers/atomicmemory/mappers.py index a49c2aa..bef5a4a 100644 --- a/atomicmemory/providers/atomicmemory/mappers.py +++ b/atomicmemory/providers/atomicmemory/mappers.py @@ -51,13 +51,37 @@ def to_memory(raw: dict[str, Any], scope: Scope) -> Memory: return Memory( id=raw["id"], content=raw["content"], - scope=scope, + scope=_build_scope(raw, scope), created_at=created_at, provenance=_build_provenance(raw), metadata=_build_metadata(raw), ) +def _build_scope(raw: dict[str, Any], scope: Scope) -> Scope: + """Merge backend-projected scope fields and validate scoped reads.""" + namespace = raw.get("namespace") + session_id = raw.get("session_id") + if scope.namespace is not None and namespace is not None and namespace != scope.namespace: + raise ValueError("atomicmemory provider: backend response `namespace` did not match requested namespace scope") + if scope.thread is not None: + if not session_id: + raise ValueError( + "atomicmemory provider: backend response missing required `session_id` for thread-scoped request" + ) + if session_id != scope.thread: + raise ValueError( + "atomicmemory provider: backend response `session_id` did not match requested thread scope" + ) + + updates: dict[str, Any] = {} + if namespace: + updates["namespace"] = namespace + if session_id: + updates["thread"] = session_id + return scope.model_copy(update=updates) + + def _build_provenance(raw: dict[str, Any]) -> Provenance | None: fields: dict[str, Any] = {} if "source_site" in raw and raw["source_site"] is not None: diff --git a/atomicmemory/providers/atomicmemory/provider.py b/atomicmemory/providers/atomicmemory/provider.py index 6b3f758..a8d43e6 100644 --- a/atomicmemory/providers/atomicmemory/provider.py +++ b/atomicmemory/providers/atomicmemory/provider.py @@ -9,7 +9,7 @@ from datetime import datetime from typing import Any -from urllib.parse import quote +from urllib.parse import quote, urlencode import httpx @@ -31,6 +31,7 @@ MemoryVersion, PackageFormat, PackageRequest, + Scope, SearchRequest, SearchResult, SearchResultPage, @@ -129,7 +130,7 @@ def do_delete(self, ref: MemoryRef) -> None: def do_list(self, request: ListRequest) -> ListResultPage: offset = int(request.cursor) if request.cursor else 0 limit = request.limit if request.limit is not None else 20 - path = self._route(f"/memories/list?user_id={_qs(request.scope.user)}&limit={limit}&offset={offset}") + path = self._route(_build_list_path(request.scope, limit, offset)) raw = fetch_json(self._require_client(), self._http_options, path) memories = [to_memory(m, request.scope) for m in raw.get("memories", [])] next_offset = offset + len(memories) @@ -264,6 +265,8 @@ def _build_ingest_body(input: IngestInput) -> dict[str, Any]: "source_site": input.provenance.source if input.provenance and input.provenance.source else "sdk", "source_url": input.provenance.source_url if input.provenance and input.provenance.source_url else "", } + if input.scope.thread is not None: + body["session_id"] = input.scope.thread if input.mode == "verbatim": body["skip_extraction"] = True if input.metadata: @@ -282,6 +285,8 @@ def _build_search_body(request: SearchRequest) -> dict[str, Any]: body["threshold"] = request.threshold if request.scope.namespace is not None: body["namespace_scope"] = request.scope.namespace + if request.scope.thread is not None: + body["session_id"] = request.scope.thread return body @@ -298,3 +303,15 @@ def _build_package_body(request: PackageRequest) -> dict[str, Any]: def _qs(value: str | None) -> str: """URL-encode a query-string value; empty string when falsy.""" return quote(value, safe="") if value else "" + + +def _build_list_path(scope: Scope, limit: int, offset: int) -> str: + """Build the Core list path, including optional thread scope.""" + pairs = [ + ("user_id", scope.user or ""), + ("limit", str(limit)), + ("offset", str(offset)), + ] + if scope.thread is not None: + pairs.append(("session_id", scope.thread)) + return f"/memories/list?{urlencode(pairs)}" diff --git a/atomicmemory/providers/atomicmemory/scope_mapper.py b/atomicmemory/providers/atomicmemory/scope_mapper.py index af9373a..865fa99 100644 --- a/atomicmemory/providers/atomicmemory/scope_mapper.py +++ b/atomicmemory/providers/atomicmemory/scope_mapper.py @@ -11,13 +11,14 @@ from typing import Any from atomicmemory.core.errors import ValidationError -from atomicmemory.providers.atomicmemory.handle import MemoryScope, WorkspaceScope +from atomicmemory.providers.atomicmemory.handle import MemoryScope, UserScope, WorkspaceScope def scope_to_fields( scope: MemoryScope, *, include_agent_scope: bool = False, + include_thread: bool = False, ) -> dict[str, Any]: """Translate a `MemoryScope` to wire-format request fields. @@ -26,6 +27,8 @@ def scope_to_fields( include_agent_scope: Emit ``agent_scope`` on the wire. Defaults to ``False``; only the search routes opt in (core ignores ``agent_scope`` on expand/list/get/delete). + include_thread: Emit ``session_id`` on routes Core honors: + ingest, search, and list. Returns: A dict with ``user_id`` always set, plus ``workspace_id`` / @@ -33,21 +36,27 @@ def scope_to_fields( scopes. """ if not isinstance(scope, WorkspaceScope): - return {"user_id": scope.user_id} - fields: dict[str, Any] = { + user_fields: dict[str, Any] = {"user_id": scope.user_id} + if include_thread and scope.thread is not None: + user_fields["session_id"] = scope.thread + return user_fields + workspace_fields: dict[str, Any] = { "user_id": scope.user_id, "workspace_id": scope.workspace_id, "agent_id": scope.agent_id, } if include_agent_scope and scope.agent_scope is not None: - fields["agent_scope"] = scope.agent_scope - return fields + workspace_fields["agent_scope"] = scope.agent_scope + if include_thread and scope.thread is not None: + workspace_fields["session_id"] = scope.thread + return workspace_fields def scope_to_query_pairs( scope: MemoryScope, *, include_agent_scope: bool = False, + include_thread: bool = False, ) -> list[tuple[str, str]]: """Translate a scope to ``[(key, value)]`` pairs for query strings. @@ -66,6 +75,8 @@ def scope_to_query_pairs( pairs.extend(("agent_scope", v) for v in value) else: pairs.append(("agent_scope", value)) + if include_thread and scope.thread is not None: + pairs.append(("session_id", scope.thread)) return pairs @@ -91,6 +102,18 @@ def strip_agent_scope(scope: MemoryScope) -> MemoryScope: """ if not isinstance(scope, WorkspaceScope): return scope + return WorkspaceScope( + user_id=scope.user_id, + workspace_id=scope.workspace_id, + agent_id=scope.agent_id, + thread=scope.thread, + ) + + +def strip_read_filters(scope: MemoryScope) -> MemoryScope: + """Drop filters the target route did not apply before echoing scope.""" + if not isinstance(scope, WorkspaceScope): + return UserScope(user_id=scope.user_id) return WorkspaceScope( user_id=scope.user_id, workspace_id=scope.workspace_id, diff --git a/tests/providers/atomicmemory/test_async_provider.py b/tests/providers/atomicmemory/test_async_provider.py index 3e6480c..972c876 100644 --- a/tests/providers/atomicmemory/test_async_provider.py +++ b/tests/providers/atomicmemory/test_async_provider.py @@ -2,11 +2,15 @@ from __future__ import annotations +import json +from datetime import datetime, timezone + import httpx import pytest import pytest_asyncio import respx +from atomicmemory.core.errors import ProviderError from atomicmemory.memory.types import ( ListRequest, MemoryRef, @@ -52,6 +56,33 @@ async def test_async_ingest_text(provider: AsyncAtomicMemoryProvider) -> None: assert result.created == ["m-1"] +@pytest.mark.asyncio +@respx.mock +async def test_async_ingest_maps_thread_to_session_id(provider: AsyncAtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/ingest").mock( + return_value=httpx.Response( + 200, + json={ + "episode_id": "ep-1", + "facts_extracted": 1, + "memories_stored": 1, + "memories_updated": 0, + "memories_deleted": 0, + "memories_skipped": 0, + "stored_memory_ids": ["m-1"], + "updated_memory_ids": [], + "links_created": 0, + "composites_created": 0, + }, + ) + ) + + await provider.ingest(TextIngest(content="hi", scope=Scope(user="u1", thread="thread-1"))) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + @pytest.mark.asyncio @respx.mock async def test_async_search_returns_typed_page( @@ -77,6 +108,32 @@ async def test_async_search_returns_typed_page( assert page.results[0].score == 0.42 +@pytest.mark.asyncio +@respx.mock +async def test_async_search_maps_thread_to_session_id(provider: AsyncAtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search/fast").mock( + return_value=httpx.Response(200, json={"memories": [], "count": 0}), + ) + + await provider.search(SearchRequest(query="q", scope=Scope(user="u1", thread="thread-1"))) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + +@pytest.mark.asyncio +@respx.mock +async def test_async_search_rejects_thread_scoped_rows_without_session_id( + provider: AsyncAtomicMemoryProvider, +) -> None: + respx.post("http://core.test/v1/memories/search/fast").mock( + return_value=httpx.Response(200, json={"memories": [{"id": "m-1", "content": "x"}], "count": 1}), + ) + + with pytest.raises(ProviderError, match="session_id"): + await provider.search(SearchRequest(query="q", scope=Scope(user="u1", thread="thread-1"))) + + @pytest.mark.asyncio @respx.mock async def test_async_get_returns_none_on_404( @@ -105,6 +162,18 @@ async def test_async_list_paginates(provider: AsyncAtomicMemoryProvider) -> None assert page.cursor == "2" +@pytest.mark.asyncio +@respx.mock +async def test_async_list_maps_thread_to_session_id(provider: AsyncAtomicMemoryProvider) -> None: + route = respx.get("http://core.test/v1/memories/list").mock( + return_value=httpx.Response(200, json={"memories": [], "count": 0}), + ) + + await provider.list(ListRequest(scope=Scope(user="u1", thread="thread-1"), limit=10)) + + assert route.calls[0].request.url.params["session_id"] == "thread-1" + + @pytest.mark.asyncio @respx.mock async def test_async_health(provider: AsyncAtomicMemoryProvider) -> None: @@ -135,6 +204,38 @@ async def test_async_package(provider: AsyncAtomicMemoryProvider) -> None: assert pkg.budget_constrained is False +@pytest.mark.asyncio +@respx.mock +async def test_async_package_maps_thread_to_session_id(provider: AsyncAtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search").mock( + return_value=httpx.Response( + 200, + json={"memories": [], "injection_text": "", "estimated_context_tokens": 0, "budget_constrained": False}, + ) + ) + + await provider.package(PackageRequest(query="q", scope=Scope(user="u1", thread="thread-1"))) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + +@pytest.mark.asyncio +@respx.mock +async def test_async_search_as_of_maps_thread_to_session_id(provider: AsyncAtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search").mock( + return_value=httpx.Response(200, json={"memories": []}), + ) + + await provider.search_as_of( + SearchRequest(query="q", scope=Scope(user="u1", thread="thread-1")), + datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + @pytest.mark.asyncio @respx.mock async def test_async_package_propagates_budget_constrained_true( diff --git a/tests/providers/atomicmemory/test_handle_base.py b/tests/providers/atomicmemory/test_handle_base.py index 9f5421c..8490857 100644 --- a/tests/providers/atomicmemory/test_handle_base.py +++ b/tests/providers/atomicmemory/test_handle_base.py @@ -61,6 +61,22 @@ def test_ingest_full_workspace_includes_visibility(provider: AtomicMemoryProvide assert body["visibility"] == "restricted" +@respx.mock +def test_ingest_full_forwards_thread_as_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/ingest").mock( + return_value=httpx.Response(200, json=_ingest_response()) + ) + handle = provider.get_extension("atomicmemory.base") + assert handle is not None + handle.ingest_full( + AtomicMemoryIngestInput(conversation="hi", source_site="chat"), + UserScope(user_id="u1", thread="thread-1"), + ) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + def test_ingest_user_scope_with_visibility_raises(provider: AtomicMemoryProvider) -> None: handle = provider.get_extension("atomicmemory.base") assert handle is not None @@ -87,6 +103,43 @@ def test_search_includes_agent_scope_in_body(provider: AtomicMemoryProvider) -> assert body["agent_scope"] == "self" +@respx.mock +def test_search_forwards_thread_and_maps_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search").mock( + return_value=httpx.Response( + 200, + json={ + "count": 1, + "retrieval_mode": "flat", + "memories": [{"id": "m-1", "content": "x", "session_id": "thread-1"}], + }, + ) + ) + handle = provider.get_extension("atomicmemory.base") + assert handle is not None + + page = handle.search(AtomicMemorySearchRequest(query="q"), UserScope(user_id="u1", thread="thread-1")) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + assert page.results[0].memory.scope.thread == "thread-1" + + +@respx.mock +def test_search_rejects_thread_scoped_rows_without_session_id(provider: AtomicMemoryProvider) -> None: + respx.post("http://core.test/v1/memories/search").mock( + return_value=httpx.Response( + 200, + json={"count": 1, "retrieval_mode": "flat", "memories": [{"id": "m-1", "content": "x"}]}, + ) + ) + handle = provider.get_extension("atomicmemory.base") + assert handle is not None + + with pytest.raises(ValueError, match="session_id"): + handle.search(AtomicMemorySearchRequest(query="q"), UserScope(user_id="u1", thread="thread-1")) + + @respx.mock def test_search_returns_namespace_typed_page(provider: AtomicMemoryProvider) -> None: respx.post("http://core.test/v1/memories/search").mock( @@ -145,6 +198,21 @@ def test_expand_strips_agent_scope_on_returned_memories( assert memories[0].scope.agent_scope is None +@respx.mock +def test_expand_strips_thread_on_returned_memories(provider: AtomicMemoryProvider) -> None: + respx.post("http://core.test/v1/memories/expand").mock( + return_value=httpx.Response( + 200, + json={"memories": [{"id": "m-1", "content": "a", "created_at": "2024-01-01T00:00:00Z"}]}, + ) + ) + handle = provider.get_extension("atomicmemory.base") + assert handle is not None + memories = handle.expand(["m-1"], UserScope(user_id="u1", thread="thread-1")) + + assert memories[0].scope.thread is None + + def test_list_rejects_workspace_with_source_site(provider: AtomicMemoryProvider) -> None: handle = provider.get_extension("atomicmemory.base") assert handle is not None @@ -155,6 +223,22 @@ def test_list_rejects_workspace_with_source_site(provider: AtomicMemoryProvider) ) +@respx.mock +def test_list_forwards_thread_and_maps_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.get("http://core.test/v1/memories/list").mock( + return_value=httpx.Response( + 200, + json={"memories": [{"id": "m-1", "content": "a", "session_id": "thread-1"}], "count": 1}, + ) + ) + handle = provider.get_extension("atomicmemory.base") + assert handle is not None + page = handle.list(UserScope(user_id="u1", thread="thread-1")) + + assert route.calls[0].request.url.params["session_id"] == "thread-1" + assert page.memories[0].scope.thread == "thread-1" + + @respx.mock def test_get_returns_none_on_404_with_full_scope_echo( provider: AtomicMemoryProvider, @@ -166,6 +250,20 @@ def test_get_returns_none_on_404_with_full_scope_echo( assert result is None +@respx.mock +def test_get_omits_thread_filter(provider: AtomicMemoryProvider) -> None: + route = respx.get(url__regex=r"http://core.test/v1/memories/m-1.*").mock( + return_value=httpx.Response(200, json={"id": "m-1", "content": "a"}) + ) + handle = provider.get_extension("atomicmemory.base") + assert handle is not None + memory = handle.get("m-1", UserScope(user_id="u1", thread="thread-1")) + + assert memory is not None + assert "session_id" not in route.calls[0].request.url.params + assert memory.scope.thread is None + + @respx.mock def test_delete_swallows_404(provider: AtomicMemoryProvider) -> None: respx.delete(url__regex=r"http://core.test/v1/memories/m-x.*").mock(return_value=httpx.Response(404)) diff --git a/tests/providers/atomicmemory/test_mappers.py b/tests/providers/atomicmemory/test_mappers.py index 3831d13..7a127a2 100644 --- a/tests/providers/atomicmemory/test_mappers.py +++ b/tests/providers/atomicmemory/test_mappers.py @@ -53,6 +53,21 @@ def test_to_memory_includes_episode_id_in_metadata() -> None: assert memory.metadata == {"episode_id": "ep-1"} +def test_to_memory_maps_session_id_to_thread_scope() -> None: + memory = to_memory({"id": "m1", "content": "hi", "session_id": "thread-1"}, _SCOPE) + assert memory.scope.thread == "thread-1" + + +def test_to_memory_rejects_thread_scoped_row_without_session_id() -> None: + with pytest.raises(ValueError, match="session_id"): + to_memory({"id": "m1", "content": "hi"}, Scope(user="u1", thread="thread-1")) + + +def test_to_memory_rejects_thread_scoped_row_with_mismatched_session_id() -> None: + with pytest.raises(ValueError, match="session_id"): + to_memory({"id": "m1", "content": "hi", "session_id": "thread-2"}, Scope(user="u1", thread="thread-1")) + + def test_to_search_result_prefers_semantic_similarity() -> None: raw = {"id": "m1", "content": "hi", "semantic_similarity": 0.9, "similarity": 0.5, "score": 0.7} diff --git a/tests/providers/atomicmemory/test_provider.py b/tests/providers/atomicmemory/test_provider.py index 334d6a4..2bbb978 100644 --- a/tests/providers/atomicmemory/test_provider.py +++ b/tests/providers/atomicmemory/test_provider.py @@ -9,6 +9,7 @@ import pytest import respx +from atomicmemory.core.errors import ProviderError from atomicmemory.memory.types import ( ListRequest, MemoryRef, @@ -91,6 +92,32 @@ def test_ingest_verbatim_posts_quick_path_with_skip_extraction( assert body["metadata"] == {"k": "v"} +@respx.mock +def test_ingest_maps_thread_to_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/ingest").mock( + return_value=httpx.Response( + 200, + json={ + "episode_id": "ep-1", + "facts_extracted": 1, + "memories_stored": 1, + "memories_updated": 0, + "memories_deleted": 0, + "memories_skipped": 0, + "stored_memory_ids": ["m-1"], + "updated_memory_ids": [], + "links_created": 0, + "composites_created": 0, + }, + ) + ) + + provider.ingest(TextIngest(content="thread note", scope=Scope(user="u1", thread="thread-1"))) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + @respx.mock def test_search_posts_fast_path_and_maps_scores(provider: AtomicMemoryProvider) -> None: respx.post("http://core.test/v1/memories/search/fast").mock( @@ -122,6 +149,28 @@ def test_search_posts_fast_path_and_maps_scores(provider: AtomicMemoryProvider) assert hit.relevance == 0.75 +@respx.mock +def test_search_maps_thread_to_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search/fast").mock( + return_value=httpx.Response(200, json={"memories": [], "count": 0}), + ) + + provider.search(SearchRequest(query="q", scope=Scope(user="u1", thread="thread-1"))) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + +@respx.mock +def test_search_rejects_thread_scoped_rows_without_session_id(provider: AtomicMemoryProvider) -> None: + respx.post("http://core.test/v1/memories/search/fast").mock( + return_value=httpx.Response(200, json={"memories": [{"id": "m-1", "content": "x"}], "count": 1}), + ) + + with pytest.raises(ProviderError, match="session_id"): + provider.search(SearchRequest(query="q", scope=Scope(user="u1", thread="thread-1"))) + + @respx.mock def test_get_returns_none_on_404(provider: AtomicMemoryProvider) -> None: respx.get("http://core.test/v1/memories/m-x").mock(return_value=httpx.Response(404)) @@ -155,6 +204,17 @@ def test_list_paginates_with_cursor(provider: AtomicMemoryProvider) -> None: assert page.cursor == "2" +@respx.mock +def test_list_maps_thread_to_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.get("http://core.test/v1/memories/list").mock( + return_value=httpx.Response(200, json={"memories": [], "count": 0}), + ) + + provider.list(ListRequest(scope=Scope(user="u1", thread="thread-1"), limit=10)) + + assert route.calls[0].request.url.params["session_id"] == "thread-1" + + @respx.mock def test_search_as_of_serializes_iso_datetime(provider: AtomicMemoryProvider) -> None: route = respx.post("http://core.test/v1/memories/search").mock( @@ -168,6 +228,21 @@ def test_search_as_of_serializes_iso_datetime(provider: AtomicMemoryProvider) -> assert body["as_of"] == "2024-06-01T00:00:00+00:00" +@respx.mock +def test_search_as_of_maps_thread_to_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search").mock( + return_value=httpx.Response(200, json={"memories": []}), + ) + + provider.search_as_of( + SearchRequest(query="q", scope=Scope(user="u1", thread="thread-1")), + datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + @respx.mock def test_package_returns_text_and_tokens(provider: AtomicMemoryProvider) -> None: respx.post("http://core.test/v1/memories/search").mock( @@ -193,6 +268,21 @@ def test_package_returns_text_and_tokens(provider: AtomicMemoryProvider) -> None assert pkg.budget_constrained is False +@respx.mock +def test_package_maps_thread_to_session_id(provider: AtomicMemoryProvider) -> None: + route = respx.post("http://core.test/v1/memories/search").mock( + return_value=httpx.Response( + 200, + json={"memories": [], "injection_text": "", "estimated_context_tokens": 0, "budget_constrained": False}, + ) + ) + + provider.package(PackageRequest(query="q", scope=Scope(user="u1", thread="thread-1"))) + + body = json.loads(route.calls[0].request.content) + assert body["session_id"] == "thread-1" + + @respx.mock def test_package_propagates_budget_constrained_true(provider: AtomicMemoryProvider) -> None: respx.post("http://core.test/v1/memories/search").mock( diff --git a/tests/providers/atomicmemory/test_scope_mapper.py b/tests/providers/atomicmemory/test_scope_mapper.py index e31a1ef..69d5cbc 100644 --- a/tests/providers/atomicmemory/test_scope_mapper.py +++ b/tests/providers/atomicmemory/test_scope_mapper.py @@ -11,6 +11,7 @@ scope_to_fields, scope_to_query_pairs, strip_agent_scope, + strip_read_filters, ) @@ -37,6 +38,17 @@ def test_workspace_scope_emits_agent_scope_when_opt_in() -> None: assert fields["agent_scope"] == "self" +def test_user_scope_emits_session_id_when_thread_opted_in() -> None: + fields = scope_to_fields(UserScope(user_id="u1", thread="thread-1"), include_thread=True) + assert fields == {"user_id": "u1", "session_id": "thread-1"} + + +def test_workspace_scope_emits_session_id_when_thread_opted_in() -> None: + scope = WorkspaceScope(user_id="u1", workspace_id="w1", agent_id="a1", thread="thread-1") + fields = scope_to_fields(scope, include_thread=True) + assert fields["session_id"] == "thread-1" + + def test_query_pairs_repeats_agent_scope_for_lists() -> None: scope = WorkspaceScope(user_id="u1", workspace_id="w1", agent_id="a1", agent_scope=["a2", "a3"]) pairs = scope_to_query_pairs(scope, include_agent_scope=True) @@ -44,6 +56,11 @@ def test_query_pairs_repeats_agent_scope_for_lists() -> None: assert agent_pairs == ["a2", "a3"] +def test_query_pairs_emit_session_id_when_thread_opted_in() -> None: + pairs = scope_to_query_pairs(UserScope(user_id="u1", thread="thread-1"), include_thread=True) + assert ("session_id", "thread-1") in pairs + + def test_visibility_rejected_on_user_scope() -> None: with pytest.raises(ValidationError): assert_scope_allows_visibility(UserScope(user_id="u1"), "workspace") @@ -64,3 +81,19 @@ def test_strip_agent_scope_clears_workspace_filter() -> None: def test_strip_agent_scope_leaves_user_scope_unchanged() -> None: scope = UserScope(user_id="u1") assert strip_agent_scope(scope) is scope + + +def test_strip_agent_scope_preserves_thread() -> None: + scope = WorkspaceScope(user_id="u1", workspace_id="w1", agent_id="a1", thread="thread-1", agent_scope="self") + stripped = strip_agent_scope(scope) + assert isinstance(stripped, WorkspaceScope) + assert stripped.thread == "thread-1" + assert stripped.agent_scope is None + + +def test_strip_read_filters_drops_thread_and_agent_scope() -> None: + scope = WorkspaceScope(user_id="u1", workspace_id="w1", agent_id="a1", thread="thread-1", agent_scope="self") + stripped = strip_read_filters(scope) + assert isinstance(stripped, WorkspaceScope) + assert stripped.thread is None + assert stripped.agent_scope is None