diff --git a/dist/app.py b/dist/app.py index d3ee072..62ee945 100644 --- a/dist/app.py +++ b/dist/app.py @@ -304,7 +304,8 @@ class IncidentState(Session): import json from datetime import datetime, timezone -from typing import Generic, Optional, Type, TypeVar +import logging +from typing import Any, Generic, Optional, Type, TypeVar from pydantic import BaseModel from sqlalchemy import desc, select @@ -312,11 +313,6 @@ class IncidentState(Session): -# The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat -# validation against on-disk rows minted before the ``Session.id_format`` -# hook existed. New rows are validated by ``_SESSION_ID_RE`` which -# accepts any ``PREFIX-YYYYMMDD-NNN`` shape the app's ``id_format`` may -# emit (e.g. ``CR-...`` for code-review). # ----- imports for runtime/storage/event_log.py ----- """Append-only session event log. @@ -398,7 +394,6 @@ class IncidentState(Session): """ -import logging @@ -937,6 +932,7 @@ class IncidentState(Session): from datetime import datetime, timezone, timedelta from sqlalchemy import DateTime, String, delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Mapped, Session as SqlaSession, mapped_column @@ -1427,7 +1423,7 @@ async def _poll(self, registry): ``config/config.yaml``) and returns a fresh app. """ -from typing import AsyncIterator, Literal +from typing import Any, AsyncIterator, Literal from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -1531,6 +1527,7 @@ async def _poll(self, registry): from fastapi.responses import StreamingResponse + # ----- imports for runtime/api_static.py ----- """StaticFiles mount + SPA fallback for the React UI bundle. @@ -2405,7 +2402,13 @@ def resolve_framework_app_config( class ApiConfig(BaseModel): - """API surface knobs surfaced to the React frontend.""" + """API surface knobs surfaced to the React frontend. + + Values come from the loaded runtime config. Environment variables + choose the config file (``ASR_CONFIG``) but do not override these + CORS fields directly; deployments that need different origins + should set them in YAML. + """ # CORS origins allowed by the FastAPI CORSMiddleware. Default # covers the two common React dev-server URLs (Vite, CRA/Next). @@ -2549,6 +2552,16 @@ def load_config(path: str | Path) -> AppConfig: # Per-call audit metadata for the risk-rated tool gateway. +SessionStatus = Literal[ + "new", + "in_progress", + "awaiting_input", + "resolved", + "escalated", + "stopped", + "error", + "duplicate", +] ToolRisk = Literal["low", "medium", "high"] ToolStatus = Literal[ "executed", # auto / legacy default @@ -4270,6 +4283,13 @@ def _ef(i, key, default: Any = ""): # ====== module: runtime/storage/session_store.py ====== +_log = logging.getLogger(__name__) + +# The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat +# validation against on-disk rows minted before the ``Session.id_format`` +# hook existed. New rows are validated by ``_SESSION_ID_RE`` which +# accepts any ``PREFIX-YYYYMMDD-NNN`` shape the app's ``id_format`` may +# emit (e.g. ``CR-...`` for code-review). _INC_ID_RE = re.compile(r"^INC-\d{8}-\d{3}$") _SESSION_ID_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_-]*-\d{8}-\d{3}$") @@ -4437,7 +4457,12 @@ def _next_id(self, session: SqlSession) -> str: # ---------- public API ---------- def create(self, *, query: str, environment: str, reporter_id: str = "user-mock", - reporter_team: str = "platform") -> StateT: + reporter_team: str = "platform", + state_overrides: dict[str, Any] | None = None) -> StateT: + extra_fields = { + k: v for k, v in (state_overrides or {}).items() + if k not in self._ROW_TYPED_DOMAIN_COLUMNS + } with SqlSession(self.engine) as session: now = _now() inc_id = self._next_id(session) @@ -4456,6 +4481,7 @@ def create(self, *, query: str, environment: str, tool_calls=[], findings={}, user_inputs=[], + extra_fields=extra_fields, ) session.add(row) session.commit() @@ -4507,7 +4533,10 @@ def save(self, session: StateT) -> None: for k, v in data.items(): setattr(existing, k, v) db_session.commit() - self._refresh_vector(session, prior_text=prior_text) + try: + self._refresh_vector(session, prior_text=prior_text) + except Exception as exc: # noqa: BLE001 — SQL commit already succeeded + self._log_vector_warning("refresh", session.id, exc) def delete(self, incident_id: str) -> StateT: with SqlSession(self.engine) as session: @@ -4659,7 +4688,10 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: return if prior_text == text: return - self.vector_store.delete(ids=[inc.id]) + try: + self.vector_store.delete(ids=[inc.id]) + except Exception as exc: # noqa: BLE001 — vector delete is idempotent + self._log_vector_warning("delete", inc.id, exc) from langchain_core.documents import Document self.vector_store.add_documents( [Document(page_content=text, metadata={"id": inc.id})], @@ -4667,6 +4699,17 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: ) self._persist_vector() + def _log_vector_warning(self, operation: str, sid: str, exc: Exception) -> None: + backend = type(self.vector_store).__name__ if self.vector_store is not None else "None" + _log.warning( + "session vector %s failed sid=%s backend=%s exc=%s", + operation, + sid, + backend, + type(exc).__name__, + exc_info=True, + ) + # ---------- mapping helpers ---------- # # Round-trip is driven by ``state_cls.model_fields`` so any @@ -6330,7 +6373,6 @@ def start_session( ) sub_id = (resolved_submitter or {}).get("id", "user-mock") sub_team = (resolved_submitter or {}).get("team", "platform") - env = (resolved_overrides or {}).get("environment", "") async def _scheduler() -> str: # Enforce the concurrency cap on the loop thread so the @@ -6340,6 +6382,13 @@ async def _scheduler() -> str: if len(self._registry) >= self.max_concurrent_sessions: raise SessionCapExceeded(self.max_concurrent_sessions) orch = await self._ensure_orchestrator() + overrides_for_create = resolved_overrides + state_overrides_cls = getattr(orch, "_state_overrides_cls", None) + if state_overrides_cls is not None and overrides_for_create is not None: + overrides_for_create = state_overrides_cls.model_validate( + overrides_for_create + ).model_dump(exclude_none=True) + env_for_create = (overrides_for_create or {}).get("environment", "") # Allocate the row (and its id) synchronously on the loop # so the caller gets a stable id back. The graph then runs # in a separate task — registration happens here, before @@ -6347,9 +6396,10 @@ async def _scheduler() -> str: # entry immediately. inc = orch.store.create( query=query, - environment=env, + environment=env_for_create, reporter_id=sub_id, reporter_team=sub_team, + state_overrides=overrides_for_create, ) session_id = inc.id # Emit session.created on the cross-session SSE stream so @@ -6397,8 +6447,8 @@ async def _run() -> None: raise SessionBusy(session_id) # Hold the per-session lock for the full graph turn, # including any HITL interrupt() pause (D-01). - async with orch._locks.acquire(session_id): - try: + try: + async with orch._locks.acquire(session_id): await orch.graph.ainvoke( GraphState( session=inc, @@ -6408,34 +6458,36 @@ async def _run() -> None: ), config=orch._thread_config(session_id), ) - except asyncio.CancelledError: - raise - except Exception as exc: # noqa: BLE001 - # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a - # pending-approval pause, not a failure. Don't stamp - # status='error' on the registry entry -- let - # LangGraph's checkpointer hold the paused state - # and let the UI's Approve/Reject action drive - # resume. - try: - from langgraph.errors import GraphInterrupt - if isinstance(exc, GraphInterrupt): - # Propagate so the underlying Task - # observer (stop_session etc.) still - # sees the exception, but skip the - # status='error' write. - raise - except ImportError: # pragma: no cover - pass - # Mark the registry entry so any concurrent snapshot - # observes the failure before the done-callback - # evicts it. The exception itself is preserved on - # the task object for ``stop_session`` and any - # other observer that holds a Task reference. - e = self._registry.get(session_id) - if e is not None: - e.status = "error" - raise + if not await orch._is_graph_paused(session_id): + await orch._finalize_session_status_async(session_id) + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 + # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a + # pending-approval pause, not a failure. Don't stamp + # status='error' on the registry entry -- let + # LangGraph's checkpointer hold the paused state + # and let the UI's Approve/Reject action drive + # resume. + try: + from langgraph.errors import GraphInterrupt + if isinstance(exc, GraphInterrupt): + # Propagate so the underlying Task + # observer (stop_session etc.) still + # sees the exception, but skip the + # status='error' write. + raise + except ImportError: # pragma: no cover + pass + # Mark the registry entry so any concurrent snapshot + # observes the failure before the done-callback + # evicts it. The exception itself is preserved on + # the task object for ``stop_session`` and any + # other observer that holds a Task reference. + e = self._registry.get(session_id) + if e is not None: + e.status = "error" + raise task = asyncio.create_task(_run(), name=f"session:{session_id}") entry.task = task @@ -10292,6 +10344,12 @@ async def node(state: GraphState) -> dict: inc_id, "agent_started", agent=skill.name, started_at=started_at, ) + event_log.record( + inc_id, + "session.agent_running", + id=inc_id, + agent=skill.name, + ) except Exception: # noqa: BLE001 — telemetry must not break the agent logger.debug( "event_log.record(agent_started) failed", exc_info=True, @@ -11481,7 +11539,7 @@ def get(self, trigger_name: str, key: str) -> str | None: if cache is not None and key in cache: # Bump recency cache.move_to_end(key) - return cache[key] + return cache[key] or None # SQLite fall-through (outside the threading lock — sqlite3 has # its own locking, and this path is rare). with SqlaSession(self._engine) as s: @@ -11509,6 +11567,8 @@ def get(self, trigger_name: str, key: str) -> str | None: s.commit() return None session_id = row.session_id + if not session_id: + return None # Refill LRU. with self._lock: cache = self._lru.setdefault(trigger_name, OrderedDict()) @@ -11557,6 +11617,51 @@ def put( # doesn't accumulate dead records. Cheap (range-bounded delete). self.purge_expired() + def reserve( + self, + trigger_name: str, + key: str, + *, + ttl_hours: int = 24, + ) -> bool: + """Atomically reserve a fresh idempotency key. + + Returns ``True`` only for the caller that inserted the row. A + ``False`` result means another request has either already + completed and stored ``session_id`` or is still in flight. + """ + now = _utc_now() + expires_at = now + timedelta(hours=ttl_hours) + with SqlaSession(self._engine) as s: + existing = s.get(IdempotencyRow, (trigger_name, key)) + if existing is not None: + existing_expires = existing.expires_at + if existing_expires.tzinfo is None: + existing_expires = existing_expires.replace(tzinfo=timezone.utc) + if existing_expires <= now: + s.delete(existing) + s.commit() + else: + return False + s.add(IdempotencyRow( + trigger_name=trigger_name, + key=key, + session_id="", + created_at=now, + expires_at=expires_at, + )) + try: + s.commit() + except IntegrityError: + s.rollback() + return False + with self._lock: + cache = self._lru.setdefault(trigger_name, OrderedDict()) + cache[key] = "" + cache.move_to_end(key) + self._evict_if_needed(cache) + return True + def purge_expired(self) -> int: """Delete all rows whose ``expires_at`` is in the past. Returns the number of rows removed.""" @@ -12044,9 +12149,14 @@ async def dispatch( if spec is None: raise KeyError(f"unknown trigger: {name!r}") - # Idempotency hit: return cached session id without invoking - # transform / orchestrator. Per R3 in the plan, transform errors - # are NOT cached — only successful dispatches. + ttl = ( + spec.config.idempotency_ttl_hours + if isinstance(spec.config, WebhookTriggerConfig) + else 24 + ) + + # Fast idempotency hit: return cached session id without + # invoking transform / orchestrator. if idempotency_key and self._idempotency is not None: cached = self._idempotency.get(name, idempotency_key) if cached is not None: @@ -12071,15 +12181,25 @@ async def dispatch( received_at=datetime.now(timezone.utc), ) + # Idempotency reservation: only the caller that atomically + # inserts the key may start a session. Followers wait briefly + # for the first caller to publish the resulting session id. + if idempotency_key and self._idempotency is not None: + if not self._idempotency.reserve(name, idempotency_key, ttl_hours=ttl): + for _ in range(100): + cached = self._idempotency.get(name, idempotency_key) + if cached is not None: + return cached + await asyncio.sleep(0.01) + raise RuntimeError( + f"idempotency key {idempotency_key!r} for trigger " + f"{name!r} is still in flight" + ) + session_id = await self._start_session_fn(trigger=info, **kwargs) # Record successful dispatch for idempotency. if idempotency_key and self._idempotency is not None: - ttl = ( - spec.config.idempotency_ttl_hours - if isinstance(spec.config, WebhookTriggerConfig) - else 24 - ) self._idempotency.put(name, idempotency_key, session_id, ttl_hours=ttl) return session_id @@ -13976,6 +14096,19 @@ def _emit_status_changed_event( return status_def = statuses.statuses.get(to_status) if status_def is not None and status_def.terminal: + if event_log is not None: + try: + event_log.record( + inc.id, + "session.agent_running", + id=inc.id, + agent=None, + ) + except Exception: # noqa: BLE001 — telemetry must not break finalize + _log.debug( + "event_log.record(session.agent_running) failed", + exc_info=True, + ) _extract_lesson_on_terminal(orch=orch, inc=inc) @@ -14906,13 +15039,20 @@ async def start_session(self, *, query: str, # ``__new__`` (bypassing ``__init__``) working. state_overrides_cls = getattr(self, "_state_overrides_cls", None) if state_overrides_cls is not None and state_overrides is not None: - state_overrides_cls.model_validate(state_overrides) + state_overrides = state_overrides_cls.model_validate( + state_overrides + ).model_dump(exclude_none=True) submitter = _coerce_submitter(submitter, reporter_id, reporter_team) sub_id = (submitter or {}).get("id", "user-mock") sub_team = (submitter or {}).get("team", "platform") env = (state_overrides or {}).get("environment", "") - inc = self.store.create(query=query, environment=env, - reporter_id=sub_id, reporter_team=sub_team) + inc = self.store.create( + query=query, + environment=env, + reporter_id=sub_id, + reporter_team=sub_team, + state_overrides=state_overrides, + ) # Emit session.created on the cross-session SSE stream so the # React UI's Other Sessions monitor lights up the new tile in # real time. ``session_id`` already lands on the row; the @@ -14942,6 +15082,8 @@ async def start_session(self, *, query: str, last_agent=None, error=None), config=self._thread_config(inc.id), ) + if not await self._is_graph_paused(inc.id): + await self._finalize_session_status_async(inc.id) return inc.id async def start_investigation(self, *, query: str, environment: str, @@ -15474,6 +15616,7 @@ class ResumeRequest(BaseModel): class SessionStartBody(BaseModel): query: str environment: str + state_overrides: dict[str, Any] | None = None # Generic submitter dict — the framework projects ``id``/``team`` # onto the row's reporter columns; apps interpret the rest. The # legacy ``reporter_id`` / ``reporter_team`` fields were removed @@ -15713,22 +15856,13 @@ def build_app(cfg: AppConfig) -> FastAPI: # at root for monitor / load-balancer health-check conventions. api_v1 = APIRouter(prefix="/api/v1") - # CORS: env-driven so the React dev server (Vite at :5173) can call - # every endpoint, SSE included. Override via ``ASR_CORS_ORIGINS`` - # (comma-separated) — production deployments lock the origin list - # down by setting the env var to the narrower allow-list. - # ``allow_credentials=False`` matches the bearer-token auth pattern - # (no cookies); methods are explicit so OPTIONS preflights are - # handled the same way for every route. - _cors_origins_raw = os.environ.get( - "ASR_CORS_ORIGINS", - "http://localhost:5173,http://127.0.0.1:5173", # Vite dev defaults - ) - _cors_origins = [o.strip() for o in _cors_origins_raw.split(",") if o.strip()] + # CORS is config-driven via ``cfg.api``. ``ASR_CONFIG`` selects the + # config file; deployments set origins/credentials in YAML rather + # than via a second env-var override path. fastapi_app.add_middleware( CORSMiddleware, - allow_origins=_cors_origins, - allow_credentials=False, + allow_origins=cfg.api.cors_origins, + allow_credentials=cfg.api.cors_allow_credentials, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) @@ -15883,9 +16017,34 @@ class is matched by name so this handler does not depend on a """ svc = request.app.state.service try: + if ( + body.state_overrides is not None + and "environment" in body.state_overrides + and body.environment != body.state_overrides["environment"] + ): + raise HTTPException( + status_code=422, + detail={ + "error": { + "code": "conflicting_environment", + "message": ( + "environment and state_overrides.environment " + "must match when both are supplied" + ), + "details": { + "environment": body.environment, + "state_overrides_environment": ( + body.state_overrides["environment"] + ), + }, + } + }, + ) + state_overrides = dict(body.state_overrides or {}) + state_overrides.setdefault("environment", body.environment) sid = svc.start_session( query=body.query, - state_overrides={"environment": body.environment}, + state_overrides=state_overrides, submitter=body.submitter, ) except Exception as e: # noqa: BLE001 @@ -16633,6 +16792,29 @@ async def list_app_views(app_name: str, request: Request) -> list[dict]: }) +def _event_payload(orch, ev: SessionEvent) -> dict: + payload = dict(ev.payload or {}) + payload.setdefault("id", ev.session_id) + if ev.kind == "session.status_changed": + payload.setdefault("status", payload.get("to")) + if ev.kind == "session.created": + try: + session = orch.store.load(ev.session_id) + except Exception: # noqa: BLE001 — SSE enrichment is best effort + session = None + if session is not None: + payload.setdefault("status", session.status) + payload.setdefault("created_at", session.created_at) + payload.setdefault("updated_at", session.updated_at) + label = ( + getattr(session, "query", None) + or (session.extra_fields or {}).get("query") + or session.id + ) + payload.setdefault("label", label) + return payload + + def add_recent_events_routes(api_v1: APIRouter) -> None: """Mount the /sessions/recent/events SSE handler on the api_v1 router. @@ -16657,7 +16839,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" # Tail: poll for new rows; exit on client disconnect @@ -16667,7 +16850,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" diff --git a/dist/apps/code-review.py b/dist/apps/code-review.py index 1ad48ac..067613d 100644 --- a/dist/apps/code-review.py +++ b/dist/apps/code-review.py @@ -304,7 +304,8 @@ class IncidentState(Session): import json from datetime import datetime, timezone -from typing import Generic, Optional, Type, TypeVar +import logging +from typing import Any, Generic, Optional, Type, TypeVar from pydantic import BaseModel from sqlalchemy import desc, select @@ -312,11 +313,6 @@ class IncidentState(Session): -# The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat -# validation against on-disk rows minted before the ``Session.id_format`` -# hook existed. New rows are validated by ``_SESSION_ID_RE`` which -# accepts any ``PREFIX-YYYYMMDD-NNN`` shape the app's ``id_format`` may -# emit (e.g. ``CR-...`` for code-review). # ----- imports for runtime/storage/event_log.py ----- """Append-only session event log. @@ -398,7 +394,6 @@ class IncidentState(Session): """ -import logging @@ -937,6 +932,7 @@ class IncidentState(Session): from datetime import datetime, timezone, timedelta from sqlalchemy import DateTime, String, delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Mapped, Session as SqlaSession, mapped_column @@ -1427,7 +1423,7 @@ async def _poll(self, registry): ``config/config.yaml``) and returns a fresh app. """ -from typing import AsyncIterator, Literal +from typing import Any, AsyncIterator, Literal from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -1531,6 +1527,7 @@ async def _poll(self, registry): from fastapi.responses import StreamingResponse + # ----- imports for runtime/api_static.py ----- """StaticFiles mount + SPA fallback for the React UI bundle. @@ -2458,7 +2455,13 @@ def resolve_framework_app_config( class ApiConfig(BaseModel): - """API surface knobs surfaced to the React frontend.""" + """API surface knobs surfaced to the React frontend. + + Values come from the loaded runtime config. Environment variables + choose the config file (``ASR_CONFIG``) but do not override these + CORS fields directly; deployments that need different origins + should set them in YAML. + """ # CORS origins allowed by the FastAPI CORSMiddleware. Default # covers the two common React dev-server URLs (Vite, CRA/Next). @@ -2602,6 +2605,16 @@ def load_config(path: str | Path) -> AppConfig: # Per-call audit metadata for the risk-rated tool gateway. +SessionStatus = Literal[ + "new", + "in_progress", + "awaiting_input", + "resolved", + "escalated", + "stopped", + "error", + "duplicate", +] ToolRisk = Literal["low", "medium", "high"] ToolStatus = Literal[ "executed", # auto / legacy default @@ -4323,6 +4336,13 @@ def _ef(i, key, default: Any = ""): # ====== module: runtime/storage/session_store.py ====== +_log = logging.getLogger(__name__) + +# The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat +# validation against on-disk rows minted before the ``Session.id_format`` +# hook existed. New rows are validated by ``_SESSION_ID_RE`` which +# accepts any ``PREFIX-YYYYMMDD-NNN`` shape the app's ``id_format`` may +# emit (e.g. ``CR-...`` for code-review). _INC_ID_RE = re.compile(r"^INC-\d{8}-\d{3}$") _SESSION_ID_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_-]*-\d{8}-\d{3}$") @@ -4490,7 +4510,12 @@ def _next_id(self, session: SqlSession) -> str: # ---------- public API ---------- def create(self, *, query: str, environment: str, reporter_id: str = "user-mock", - reporter_team: str = "platform") -> StateT: + reporter_team: str = "platform", + state_overrides: dict[str, Any] | None = None) -> StateT: + extra_fields = { + k: v for k, v in (state_overrides or {}).items() + if k not in self._ROW_TYPED_DOMAIN_COLUMNS + } with SqlSession(self.engine) as session: now = _now() inc_id = self._next_id(session) @@ -4509,6 +4534,7 @@ def create(self, *, query: str, environment: str, tool_calls=[], findings={}, user_inputs=[], + extra_fields=extra_fields, ) session.add(row) session.commit() @@ -4560,7 +4586,10 @@ def save(self, session: StateT) -> None: for k, v in data.items(): setattr(existing, k, v) db_session.commit() - self._refresh_vector(session, prior_text=prior_text) + try: + self._refresh_vector(session, prior_text=prior_text) + except Exception as exc: # noqa: BLE001 — SQL commit already succeeded + self._log_vector_warning("refresh", session.id, exc) def delete(self, incident_id: str) -> StateT: with SqlSession(self.engine) as session: @@ -4712,7 +4741,10 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: return if prior_text == text: return - self.vector_store.delete(ids=[inc.id]) + try: + self.vector_store.delete(ids=[inc.id]) + except Exception as exc: # noqa: BLE001 — vector delete is idempotent + self._log_vector_warning("delete", inc.id, exc) from langchain_core.documents import Document self.vector_store.add_documents( [Document(page_content=text, metadata={"id": inc.id})], @@ -4720,6 +4752,17 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: ) self._persist_vector() + def _log_vector_warning(self, operation: str, sid: str, exc: Exception) -> None: + backend = type(self.vector_store).__name__ if self.vector_store is not None else "None" + _log.warning( + "session vector %s failed sid=%s backend=%s exc=%s", + operation, + sid, + backend, + type(exc).__name__, + exc_info=True, + ) + # ---------- mapping helpers ---------- # # Round-trip is driven by ``state_cls.model_fields`` so any @@ -6383,7 +6426,6 @@ def start_session( ) sub_id = (resolved_submitter or {}).get("id", "user-mock") sub_team = (resolved_submitter or {}).get("team", "platform") - env = (resolved_overrides or {}).get("environment", "") async def _scheduler() -> str: # Enforce the concurrency cap on the loop thread so the @@ -6393,6 +6435,13 @@ async def _scheduler() -> str: if len(self._registry) >= self.max_concurrent_sessions: raise SessionCapExceeded(self.max_concurrent_sessions) orch = await self._ensure_orchestrator() + overrides_for_create = resolved_overrides + state_overrides_cls = getattr(orch, "_state_overrides_cls", None) + if state_overrides_cls is not None and overrides_for_create is not None: + overrides_for_create = state_overrides_cls.model_validate( + overrides_for_create + ).model_dump(exclude_none=True) + env_for_create = (overrides_for_create or {}).get("environment", "") # Allocate the row (and its id) synchronously on the loop # so the caller gets a stable id back. The graph then runs # in a separate task — registration happens here, before @@ -6400,9 +6449,10 @@ async def _scheduler() -> str: # entry immediately. inc = orch.store.create( query=query, - environment=env, + environment=env_for_create, reporter_id=sub_id, reporter_team=sub_team, + state_overrides=overrides_for_create, ) session_id = inc.id # Emit session.created on the cross-session SSE stream so @@ -6450,8 +6500,8 @@ async def _run() -> None: raise SessionBusy(session_id) # Hold the per-session lock for the full graph turn, # including any HITL interrupt() pause (D-01). - async with orch._locks.acquire(session_id): - try: + try: + async with orch._locks.acquire(session_id): await orch.graph.ainvoke( GraphState( session=inc, @@ -6461,34 +6511,36 @@ async def _run() -> None: ), config=orch._thread_config(session_id), ) - except asyncio.CancelledError: - raise - except Exception as exc: # noqa: BLE001 - # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a - # pending-approval pause, not a failure. Don't stamp - # status='error' on the registry entry -- let - # LangGraph's checkpointer hold the paused state - # and let the UI's Approve/Reject action drive - # resume. - try: - from langgraph.errors import GraphInterrupt - if isinstance(exc, GraphInterrupt): - # Propagate so the underlying Task - # observer (stop_session etc.) still - # sees the exception, but skip the - # status='error' write. - raise - except ImportError: # pragma: no cover - pass - # Mark the registry entry so any concurrent snapshot - # observes the failure before the done-callback - # evicts it. The exception itself is preserved on - # the task object for ``stop_session`` and any - # other observer that holds a Task reference. - e = self._registry.get(session_id) - if e is not None: - e.status = "error" - raise + if not await orch._is_graph_paused(session_id): + await orch._finalize_session_status_async(session_id) + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 + # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a + # pending-approval pause, not a failure. Don't stamp + # status='error' on the registry entry -- let + # LangGraph's checkpointer hold the paused state + # and let the UI's Approve/Reject action drive + # resume. + try: + from langgraph.errors import GraphInterrupt + if isinstance(exc, GraphInterrupt): + # Propagate so the underlying Task + # observer (stop_session etc.) still + # sees the exception, but skip the + # status='error' write. + raise + except ImportError: # pragma: no cover + pass + # Mark the registry entry so any concurrent snapshot + # observes the failure before the done-callback + # evicts it. The exception itself is preserved on + # the task object for ``stop_session`` and any + # other observer that holds a Task reference. + e = self._registry.get(session_id) + if e is not None: + e.status = "error" + raise task = asyncio.create_task(_run(), name=f"session:{session_id}") entry.task = task @@ -10345,6 +10397,12 @@ async def node(state: GraphState) -> dict: inc_id, "agent_started", agent=skill.name, started_at=started_at, ) + event_log.record( + inc_id, + "session.agent_running", + id=inc_id, + agent=skill.name, + ) except Exception: # noqa: BLE001 — telemetry must not break the agent logger.debug( "event_log.record(agent_started) failed", exc_info=True, @@ -11534,7 +11592,7 @@ def get(self, trigger_name: str, key: str) -> str | None: if cache is not None and key in cache: # Bump recency cache.move_to_end(key) - return cache[key] + return cache[key] or None # SQLite fall-through (outside the threading lock — sqlite3 has # its own locking, and this path is rare). with SqlaSession(self._engine) as s: @@ -11562,6 +11620,8 @@ def get(self, trigger_name: str, key: str) -> str | None: s.commit() return None session_id = row.session_id + if not session_id: + return None # Refill LRU. with self._lock: cache = self._lru.setdefault(trigger_name, OrderedDict()) @@ -11610,6 +11670,51 @@ def put( # doesn't accumulate dead records. Cheap (range-bounded delete). self.purge_expired() + def reserve( + self, + trigger_name: str, + key: str, + *, + ttl_hours: int = 24, + ) -> bool: + """Atomically reserve a fresh idempotency key. + + Returns ``True`` only for the caller that inserted the row. A + ``False`` result means another request has either already + completed and stored ``session_id`` or is still in flight. + """ + now = _utc_now() + expires_at = now + timedelta(hours=ttl_hours) + with SqlaSession(self._engine) as s: + existing = s.get(IdempotencyRow, (trigger_name, key)) + if existing is not None: + existing_expires = existing.expires_at + if existing_expires.tzinfo is None: + existing_expires = existing_expires.replace(tzinfo=timezone.utc) + if existing_expires <= now: + s.delete(existing) + s.commit() + else: + return False + s.add(IdempotencyRow( + trigger_name=trigger_name, + key=key, + session_id="", + created_at=now, + expires_at=expires_at, + )) + try: + s.commit() + except IntegrityError: + s.rollback() + return False + with self._lock: + cache = self._lru.setdefault(trigger_name, OrderedDict()) + cache[key] = "" + cache.move_to_end(key) + self._evict_if_needed(cache) + return True + def purge_expired(self) -> int: """Delete all rows whose ``expires_at`` is in the past. Returns the number of rows removed.""" @@ -12097,9 +12202,14 @@ async def dispatch( if spec is None: raise KeyError(f"unknown trigger: {name!r}") - # Idempotency hit: return cached session id without invoking - # transform / orchestrator. Per R3 in the plan, transform errors - # are NOT cached — only successful dispatches. + ttl = ( + spec.config.idempotency_ttl_hours + if isinstance(spec.config, WebhookTriggerConfig) + else 24 + ) + + # Fast idempotency hit: return cached session id without + # invoking transform / orchestrator. if idempotency_key and self._idempotency is not None: cached = self._idempotency.get(name, idempotency_key) if cached is not None: @@ -12124,15 +12234,25 @@ async def dispatch( received_at=datetime.now(timezone.utc), ) + # Idempotency reservation: only the caller that atomically + # inserts the key may start a session. Followers wait briefly + # for the first caller to publish the resulting session id. + if idempotency_key and self._idempotency is not None: + if not self._idempotency.reserve(name, idempotency_key, ttl_hours=ttl): + for _ in range(100): + cached = self._idempotency.get(name, idempotency_key) + if cached is not None: + return cached + await asyncio.sleep(0.01) + raise RuntimeError( + f"idempotency key {idempotency_key!r} for trigger " + f"{name!r} is still in flight" + ) + session_id = await self._start_session_fn(trigger=info, **kwargs) # Record successful dispatch for idempotency. if idempotency_key and self._idempotency is not None: - ttl = ( - spec.config.idempotency_ttl_hours - if isinstance(spec.config, WebhookTriggerConfig) - else 24 - ) self._idempotency.put(name, idempotency_key, session_id, ttl_hours=ttl) return session_id @@ -14029,6 +14149,19 @@ def _emit_status_changed_event( return status_def = statuses.statuses.get(to_status) if status_def is not None and status_def.terminal: + if event_log is not None: + try: + event_log.record( + inc.id, + "session.agent_running", + id=inc.id, + agent=None, + ) + except Exception: # noqa: BLE001 — telemetry must not break finalize + _log.debug( + "event_log.record(session.agent_running) failed", + exc_info=True, + ) _extract_lesson_on_terminal(orch=orch, inc=inc) @@ -14959,13 +15092,20 @@ async def start_session(self, *, query: str, # ``__new__`` (bypassing ``__init__``) working. state_overrides_cls = getattr(self, "_state_overrides_cls", None) if state_overrides_cls is not None and state_overrides is not None: - state_overrides_cls.model_validate(state_overrides) + state_overrides = state_overrides_cls.model_validate( + state_overrides + ).model_dump(exclude_none=True) submitter = _coerce_submitter(submitter, reporter_id, reporter_team) sub_id = (submitter or {}).get("id", "user-mock") sub_team = (submitter or {}).get("team", "platform") env = (state_overrides or {}).get("environment", "") - inc = self.store.create(query=query, environment=env, - reporter_id=sub_id, reporter_team=sub_team) + inc = self.store.create( + query=query, + environment=env, + reporter_id=sub_id, + reporter_team=sub_team, + state_overrides=state_overrides, + ) # Emit session.created on the cross-session SSE stream so the # React UI's Other Sessions monitor lights up the new tile in # real time. ``session_id`` already lands on the row; the @@ -14995,6 +15135,8 @@ async def start_session(self, *, query: str, last_agent=None, error=None), config=self._thread_config(inc.id), ) + if not await self._is_graph_paused(inc.id): + await self._finalize_session_status_async(inc.id) return inc.id async def start_investigation(self, *, query: str, environment: str, @@ -15527,6 +15669,7 @@ class ResumeRequest(BaseModel): class SessionStartBody(BaseModel): query: str environment: str + state_overrides: dict[str, Any] | None = None # Generic submitter dict — the framework projects ``id``/``team`` # onto the row's reporter columns; apps interpret the rest. The # legacy ``reporter_id`` / ``reporter_team`` fields were removed @@ -15766,22 +15909,13 @@ def build_app(cfg: AppConfig) -> FastAPI: # at root for monitor / load-balancer health-check conventions. api_v1 = APIRouter(prefix="/api/v1") - # CORS: env-driven so the React dev server (Vite at :5173) can call - # every endpoint, SSE included. Override via ``ASR_CORS_ORIGINS`` - # (comma-separated) — production deployments lock the origin list - # down by setting the env var to the narrower allow-list. - # ``allow_credentials=False`` matches the bearer-token auth pattern - # (no cookies); methods are explicit so OPTIONS preflights are - # handled the same way for every route. - _cors_origins_raw = os.environ.get( - "ASR_CORS_ORIGINS", - "http://localhost:5173,http://127.0.0.1:5173", # Vite dev defaults - ) - _cors_origins = [o.strip() for o in _cors_origins_raw.split(",") if o.strip()] + # CORS is config-driven via ``cfg.api``. ``ASR_CONFIG`` selects the + # config file; deployments set origins/credentials in YAML rather + # than via a second env-var override path. fastapi_app.add_middleware( CORSMiddleware, - allow_origins=_cors_origins, - allow_credentials=False, + allow_origins=cfg.api.cors_origins, + allow_credentials=cfg.api.cors_allow_credentials, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) @@ -15936,9 +16070,34 @@ class is matched by name so this handler does not depend on a """ svc = request.app.state.service try: + if ( + body.state_overrides is not None + and "environment" in body.state_overrides + and body.environment != body.state_overrides["environment"] + ): + raise HTTPException( + status_code=422, + detail={ + "error": { + "code": "conflicting_environment", + "message": ( + "environment and state_overrides.environment " + "must match when both are supplied" + ), + "details": { + "environment": body.environment, + "state_overrides_environment": ( + body.state_overrides["environment"] + ), + }, + } + }, + ) + state_overrides = dict(body.state_overrides or {}) + state_overrides.setdefault("environment", body.environment) sid = svc.start_session( query=body.query, - state_overrides={"environment": body.environment}, + state_overrides=state_overrides, submitter=body.submitter, ) except Exception as e: # noqa: BLE001 @@ -16686,6 +16845,29 @@ async def list_app_views(app_name: str, request: Request) -> list[dict]: }) +def _event_payload(orch, ev: SessionEvent) -> dict: + payload = dict(ev.payload or {}) + payload.setdefault("id", ev.session_id) + if ev.kind == "session.status_changed": + payload.setdefault("status", payload.get("to")) + if ev.kind == "session.created": + try: + session = orch.store.load(ev.session_id) + except Exception: # noqa: BLE001 — SSE enrichment is best effort + session = None + if session is not None: + payload.setdefault("status", session.status) + payload.setdefault("created_at", session.created_at) + payload.setdefault("updated_at", session.updated_at) + label = ( + getattr(session, "query", None) + or (session.extra_fields or {}).get("query") + or session.id + ) + payload.setdefault("label", label) + return payload + + def add_recent_events_routes(api_v1: APIRouter) -> None: """Mount the /sessions/recent/events SSE handler on the api_v1 router. @@ -16710,7 +16892,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" # Tail: poll for new rows; exit on client disconnect @@ -16720,7 +16903,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" @@ -16803,6 +16987,7 @@ class CodeReviewStateOverrides(BaseModel): pr_url: str | None = None repo: str | None = None + environment: str | None = None base_branch: str | None = None pr_number: int | None = None diff --git a/dist/apps/incident-management.py b/dist/apps/incident-management.py index 96ef7bb..a619331 100644 --- a/dist/apps/incident-management.py +++ b/dist/apps/incident-management.py @@ -304,7 +304,8 @@ class IncidentState(Session): import json from datetime import datetime, timezone -from typing import Generic, Optional, Type, TypeVar +import logging +from typing import Any, Generic, Optional, Type, TypeVar from pydantic import BaseModel from sqlalchemy import desc, select @@ -312,11 +313,6 @@ class IncidentState(Session): -# The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat -# validation against on-disk rows minted before the ``Session.id_format`` -# hook existed. New rows are validated by ``_SESSION_ID_RE`` which -# accepts any ``PREFIX-YYYYMMDD-NNN`` shape the app's ``id_format`` may -# emit (e.g. ``CR-...`` for code-review). # ----- imports for runtime/storage/event_log.py ----- """Append-only session event log. @@ -398,7 +394,6 @@ class IncidentState(Session): """ -import logging @@ -937,6 +932,7 @@ class IncidentState(Session): from datetime import datetime, timezone, timedelta from sqlalchemy import DateTime, String, delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Mapped, Session as SqlaSession, mapped_column @@ -1427,7 +1423,7 @@ async def _poll(self, registry): ``config/config.yaml``) and returns a fresh app. """ -from typing import AsyncIterator, Literal +from typing import Any, AsyncIterator, Literal from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -1531,6 +1527,7 @@ async def _poll(self, registry): from fastapi.responses import StreamingResponse + # ----- imports for runtime/api_static.py ----- """StaticFiles mount + SPA fallback for the React UI bundle. @@ -2470,7 +2467,13 @@ def resolve_framework_app_config( class ApiConfig(BaseModel): - """API surface knobs surfaced to the React frontend.""" + """API surface knobs surfaced to the React frontend. + + Values come from the loaded runtime config. Environment variables + choose the config file (``ASR_CONFIG``) but do not override these + CORS fields directly; deployments that need different origins + should set them in YAML. + """ # CORS origins allowed by the FastAPI CORSMiddleware. Default # covers the two common React dev-server URLs (Vite, CRA/Next). @@ -2614,6 +2617,16 @@ def load_config(path: str | Path) -> AppConfig: # Per-call audit metadata for the risk-rated tool gateway. +SessionStatus = Literal[ + "new", + "in_progress", + "awaiting_input", + "resolved", + "escalated", + "stopped", + "error", + "duplicate", +] ToolRisk = Literal["low", "medium", "high"] ToolStatus = Literal[ "executed", # auto / legacy default @@ -4335,6 +4348,13 @@ def _ef(i, key, default: Any = ""): # ====== module: runtime/storage/session_store.py ====== +_log = logging.getLogger(__name__) + +# The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat +# validation against on-disk rows minted before the ``Session.id_format`` +# hook existed. New rows are validated by ``_SESSION_ID_RE`` which +# accepts any ``PREFIX-YYYYMMDD-NNN`` shape the app's ``id_format`` may +# emit (e.g. ``CR-...`` for code-review). _INC_ID_RE = re.compile(r"^INC-\d{8}-\d{3}$") _SESSION_ID_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_-]*-\d{8}-\d{3}$") @@ -4502,7 +4522,12 @@ def _next_id(self, session: SqlSession) -> str: # ---------- public API ---------- def create(self, *, query: str, environment: str, reporter_id: str = "user-mock", - reporter_team: str = "platform") -> StateT: + reporter_team: str = "platform", + state_overrides: dict[str, Any] | None = None) -> StateT: + extra_fields = { + k: v for k, v in (state_overrides or {}).items() + if k not in self._ROW_TYPED_DOMAIN_COLUMNS + } with SqlSession(self.engine) as session: now = _now() inc_id = self._next_id(session) @@ -4521,6 +4546,7 @@ def create(self, *, query: str, environment: str, tool_calls=[], findings={}, user_inputs=[], + extra_fields=extra_fields, ) session.add(row) session.commit() @@ -4572,7 +4598,10 @@ def save(self, session: StateT) -> None: for k, v in data.items(): setattr(existing, k, v) db_session.commit() - self._refresh_vector(session, prior_text=prior_text) + try: + self._refresh_vector(session, prior_text=prior_text) + except Exception as exc: # noqa: BLE001 — SQL commit already succeeded + self._log_vector_warning("refresh", session.id, exc) def delete(self, incident_id: str) -> StateT: with SqlSession(self.engine) as session: @@ -4724,7 +4753,10 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: return if prior_text == text: return - self.vector_store.delete(ids=[inc.id]) + try: + self.vector_store.delete(ids=[inc.id]) + except Exception as exc: # noqa: BLE001 — vector delete is idempotent + self._log_vector_warning("delete", inc.id, exc) from langchain_core.documents import Document self.vector_store.add_documents( [Document(page_content=text, metadata={"id": inc.id})], @@ -4732,6 +4764,17 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: ) self._persist_vector() + def _log_vector_warning(self, operation: str, sid: str, exc: Exception) -> None: + backend = type(self.vector_store).__name__ if self.vector_store is not None else "None" + _log.warning( + "session vector %s failed sid=%s backend=%s exc=%s", + operation, + sid, + backend, + type(exc).__name__, + exc_info=True, + ) + # ---------- mapping helpers ---------- # # Round-trip is driven by ``state_cls.model_fields`` so any @@ -6395,7 +6438,6 @@ def start_session( ) sub_id = (resolved_submitter or {}).get("id", "user-mock") sub_team = (resolved_submitter or {}).get("team", "platform") - env = (resolved_overrides or {}).get("environment", "") async def _scheduler() -> str: # Enforce the concurrency cap on the loop thread so the @@ -6405,6 +6447,13 @@ async def _scheduler() -> str: if len(self._registry) >= self.max_concurrent_sessions: raise SessionCapExceeded(self.max_concurrent_sessions) orch = await self._ensure_orchestrator() + overrides_for_create = resolved_overrides + state_overrides_cls = getattr(orch, "_state_overrides_cls", None) + if state_overrides_cls is not None and overrides_for_create is not None: + overrides_for_create = state_overrides_cls.model_validate( + overrides_for_create + ).model_dump(exclude_none=True) + env_for_create = (overrides_for_create or {}).get("environment", "") # Allocate the row (and its id) synchronously on the loop # so the caller gets a stable id back. The graph then runs # in a separate task — registration happens here, before @@ -6412,9 +6461,10 @@ async def _scheduler() -> str: # entry immediately. inc = orch.store.create( query=query, - environment=env, + environment=env_for_create, reporter_id=sub_id, reporter_team=sub_team, + state_overrides=overrides_for_create, ) session_id = inc.id # Emit session.created on the cross-session SSE stream so @@ -6462,8 +6512,8 @@ async def _run() -> None: raise SessionBusy(session_id) # Hold the per-session lock for the full graph turn, # including any HITL interrupt() pause (D-01). - async with orch._locks.acquire(session_id): - try: + try: + async with orch._locks.acquire(session_id): await orch.graph.ainvoke( GraphState( session=inc, @@ -6473,34 +6523,36 @@ async def _run() -> None: ), config=orch._thread_config(session_id), ) - except asyncio.CancelledError: - raise - except Exception as exc: # noqa: BLE001 - # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a - # pending-approval pause, not a failure. Don't stamp - # status='error' on the registry entry -- let - # LangGraph's checkpointer hold the paused state - # and let the UI's Approve/Reject action drive - # resume. - try: - from langgraph.errors import GraphInterrupt - if isinstance(exc, GraphInterrupt): - # Propagate so the underlying Task - # observer (stop_session etc.) still - # sees the exception, but skip the - # status='error' write. - raise - except ImportError: # pragma: no cover - pass - # Mark the registry entry so any concurrent snapshot - # observes the failure before the done-callback - # evicts it. The exception itself is preserved on - # the task object for ``stop_session`` and any - # other observer that holds a Task reference. - e = self._registry.get(session_id) - if e is not None: - e.status = "error" - raise + if not await orch._is_graph_paused(session_id): + await orch._finalize_session_status_async(session_id) + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 + # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a + # pending-approval pause, not a failure. Don't stamp + # status='error' on the registry entry -- let + # LangGraph's checkpointer hold the paused state + # and let the UI's Approve/Reject action drive + # resume. + try: + from langgraph.errors import GraphInterrupt + if isinstance(exc, GraphInterrupt): + # Propagate so the underlying Task + # observer (stop_session etc.) still + # sees the exception, but skip the + # status='error' write. + raise + except ImportError: # pragma: no cover + pass + # Mark the registry entry so any concurrent snapshot + # observes the failure before the done-callback + # evicts it. The exception itself is preserved on + # the task object for ``stop_session`` and any + # other observer that holds a Task reference. + e = self._registry.get(session_id) + if e is not None: + e.status = "error" + raise task = asyncio.create_task(_run(), name=f"session:{session_id}") entry.task = task @@ -10357,6 +10409,12 @@ async def node(state: GraphState) -> dict: inc_id, "agent_started", agent=skill.name, started_at=started_at, ) + event_log.record( + inc_id, + "session.agent_running", + id=inc_id, + agent=skill.name, + ) except Exception: # noqa: BLE001 — telemetry must not break the agent logger.debug( "event_log.record(agent_started) failed", exc_info=True, @@ -11546,7 +11604,7 @@ def get(self, trigger_name: str, key: str) -> str | None: if cache is not None and key in cache: # Bump recency cache.move_to_end(key) - return cache[key] + return cache[key] or None # SQLite fall-through (outside the threading lock — sqlite3 has # its own locking, and this path is rare). with SqlaSession(self._engine) as s: @@ -11574,6 +11632,8 @@ def get(self, trigger_name: str, key: str) -> str | None: s.commit() return None session_id = row.session_id + if not session_id: + return None # Refill LRU. with self._lock: cache = self._lru.setdefault(trigger_name, OrderedDict()) @@ -11622,6 +11682,51 @@ def put( # doesn't accumulate dead records. Cheap (range-bounded delete). self.purge_expired() + def reserve( + self, + trigger_name: str, + key: str, + *, + ttl_hours: int = 24, + ) -> bool: + """Atomically reserve a fresh idempotency key. + + Returns ``True`` only for the caller that inserted the row. A + ``False`` result means another request has either already + completed and stored ``session_id`` or is still in flight. + """ + now = _utc_now() + expires_at = now + timedelta(hours=ttl_hours) + with SqlaSession(self._engine) as s: + existing = s.get(IdempotencyRow, (trigger_name, key)) + if existing is not None: + existing_expires = existing.expires_at + if existing_expires.tzinfo is None: + existing_expires = existing_expires.replace(tzinfo=timezone.utc) + if existing_expires <= now: + s.delete(existing) + s.commit() + else: + return False + s.add(IdempotencyRow( + trigger_name=trigger_name, + key=key, + session_id="", + created_at=now, + expires_at=expires_at, + )) + try: + s.commit() + except IntegrityError: + s.rollback() + return False + with self._lock: + cache = self._lru.setdefault(trigger_name, OrderedDict()) + cache[key] = "" + cache.move_to_end(key) + self._evict_if_needed(cache) + return True + def purge_expired(self) -> int: """Delete all rows whose ``expires_at`` is in the past. Returns the number of rows removed.""" @@ -12109,9 +12214,14 @@ async def dispatch( if spec is None: raise KeyError(f"unknown trigger: {name!r}") - # Idempotency hit: return cached session id without invoking - # transform / orchestrator. Per R3 in the plan, transform errors - # are NOT cached — only successful dispatches. + ttl = ( + spec.config.idempotency_ttl_hours + if isinstance(spec.config, WebhookTriggerConfig) + else 24 + ) + + # Fast idempotency hit: return cached session id without + # invoking transform / orchestrator. if idempotency_key and self._idempotency is not None: cached = self._idempotency.get(name, idempotency_key) if cached is not None: @@ -12136,15 +12246,25 @@ async def dispatch( received_at=datetime.now(timezone.utc), ) + # Idempotency reservation: only the caller that atomically + # inserts the key may start a session. Followers wait briefly + # for the first caller to publish the resulting session id. + if idempotency_key and self._idempotency is not None: + if not self._idempotency.reserve(name, idempotency_key, ttl_hours=ttl): + for _ in range(100): + cached = self._idempotency.get(name, idempotency_key) + if cached is not None: + return cached + await asyncio.sleep(0.01) + raise RuntimeError( + f"idempotency key {idempotency_key!r} for trigger " + f"{name!r} is still in flight" + ) + session_id = await self._start_session_fn(trigger=info, **kwargs) # Record successful dispatch for idempotency. if idempotency_key and self._idempotency is not None: - ttl = ( - spec.config.idempotency_ttl_hours - if isinstance(spec.config, WebhookTriggerConfig) - else 24 - ) self._idempotency.put(name, idempotency_key, session_id, ttl_hours=ttl) return session_id @@ -14041,6 +14161,19 @@ def _emit_status_changed_event( return status_def = statuses.statuses.get(to_status) if status_def is not None and status_def.terminal: + if event_log is not None: + try: + event_log.record( + inc.id, + "session.agent_running", + id=inc.id, + agent=None, + ) + except Exception: # noqa: BLE001 — telemetry must not break finalize + _log.debug( + "event_log.record(session.agent_running) failed", + exc_info=True, + ) _extract_lesson_on_terminal(orch=orch, inc=inc) @@ -14971,13 +15104,20 @@ async def start_session(self, *, query: str, # ``__new__`` (bypassing ``__init__``) working. state_overrides_cls = getattr(self, "_state_overrides_cls", None) if state_overrides_cls is not None and state_overrides is not None: - state_overrides_cls.model_validate(state_overrides) + state_overrides = state_overrides_cls.model_validate( + state_overrides + ).model_dump(exclude_none=True) submitter = _coerce_submitter(submitter, reporter_id, reporter_team) sub_id = (submitter or {}).get("id", "user-mock") sub_team = (submitter or {}).get("team", "platform") env = (state_overrides or {}).get("environment", "") - inc = self.store.create(query=query, environment=env, - reporter_id=sub_id, reporter_team=sub_team) + inc = self.store.create( + query=query, + environment=env, + reporter_id=sub_id, + reporter_team=sub_team, + state_overrides=state_overrides, + ) # Emit session.created on the cross-session SSE stream so the # React UI's Other Sessions monitor lights up the new tile in # real time. ``session_id`` already lands on the row; the @@ -15007,6 +15147,8 @@ async def start_session(self, *, query: str, last_agent=None, error=None), config=self._thread_config(inc.id), ) + if not await self._is_graph_paused(inc.id): + await self._finalize_session_status_async(inc.id) return inc.id async def start_investigation(self, *, query: str, environment: str, @@ -15539,6 +15681,7 @@ class ResumeRequest(BaseModel): class SessionStartBody(BaseModel): query: str environment: str + state_overrides: dict[str, Any] | None = None # Generic submitter dict — the framework projects ``id``/``team`` # onto the row's reporter columns; apps interpret the rest. The # legacy ``reporter_id`` / ``reporter_team`` fields were removed @@ -15778,22 +15921,13 @@ def build_app(cfg: AppConfig) -> FastAPI: # at root for monitor / load-balancer health-check conventions. api_v1 = APIRouter(prefix="/api/v1") - # CORS: env-driven so the React dev server (Vite at :5173) can call - # every endpoint, SSE included. Override via ``ASR_CORS_ORIGINS`` - # (comma-separated) — production deployments lock the origin list - # down by setting the env var to the narrower allow-list. - # ``allow_credentials=False`` matches the bearer-token auth pattern - # (no cookies); methods are explicit so OPTIONS preflights are - # handled the same way for every route. - _cors_origins_raw = os.environ.get( - "ASR_CORS_ORIGINS", - "http://localhost:5173,http://127.0.0.1:5173", # Vite dev defaults - ) - _cors_origins = [o.strip() for o in _cors_origins_raw.split(",") if o.strip()] + # CORS is config-driven via ``cfg.api``. ``ASR_CONFIG`` selects the + # config file; deployments set origins/credentials in YAML rather + # than via a second env-var override path. fastapi_app.add_middleware( CORSMiddleware, - allow_origins=_cors_origins, - allow_credentials=False, + allow_origins=cfg.api.cors_origins, + allow_credentials=cfg.api.cors_allow_credentials, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) @@ -15948,9 +16082,34 @@ class is matched by name so this handler does not depend on a """ svc = request.app.state.service try: + if ( + body.state_overrides is not None + and "environment" in body.state_overrides + and body.environment != body.state_overrides["environment"] + ): + raise HTTPException( + status_code=422, + detail={ + "error": { + "code": "conflicting_environment", + "message": ( + "environment and state_overrides.environment " + "must match when both are supplied" + ), + "details": { + "environment": body.environment, + "state_overrides_environment": ( + body.state_overrides["environment"] + ), + }, + } + }, + ) + state_overrides = dict(body.state_overrides or {}) + state_overrides.setdefault("environment", body.environment) sid = svc.start_session( query=body.query, - state_overrides={"environment": body.environment}, + state_overrides=state_overrides, submitter=body.submitter, ) except Exception as e: # noqa: BLE001 @@ -16698,6 +16857,29 @@ async def list_app_views(app_name: str, request: Request) -> list[dict]: }) +def _event_payload(orch, ev: SessionEvent) -> dict: + payload = dict(ev.payload or {}) + payload.setdefault("id", ev.session_id) + if ev.kind == "session.status_changed": + payload.setdefault("status", payload.get("to")) + if ev.kind == "session.created": + try: + session = orch.store.load(ev.session_id) + except Exception: # noqa: BLE001 — SSE enrichment is best effort + session = None + if session is not None: + payload.setdefault("status", session.status) + payload.setdefault("created_at", session.created_at) + payload.setdefault("updated_at", session.updated_at) + label = ( + getattr(session, "query", None) + or (session.extra_fields or {}).get("query") + or session.id + ) + payload.setdefault("label", label) + return payload + + def add_recent_events_routes(api_v1: APIRouter) -> None: """Mount the /sessions/recent/events SSE handler on the api_v1 router. @@ -16722,7 +16904,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" # Tail: poll for new rows; exit on client disconnect @@ -16732,7 +16915,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" diff --git a/examples/code_review/state.py b/examples/code_review/state.py index 4b937b9..8f370cb 100644 --- a/examples/code_review/state.py +++ b/examples/code_review/state.py @@ -23,5 +23,6 @@ class CodeReviewStateOverrides(BaseModel): pr_url: str | None = None repo: str | None = None + environment: str | None = None base_branch: str | None = None pr_number: int | None = None diff --git a/src/runtime/api.py b/src/runtime/api.py index 344c443..0cf4fdb 100644 --- a/src/runtime/api.py +++ b/src/runtime/api.py @@ -26,7 +26,7 @@ import os from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncIterator, Literal +from typing import Any, AsyncIterator, Literal from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -120,6 +120,7 @@ class ResumeRequest(BaseModel): class SessionStartBody(BaseModel): query: str environment: str + state_overrides: dict[str, Any] | None = None # Generic submitter dict — the framework projects ``id``/``team`` # onto the row's reporter columns; apps interpret the rest. The # legacy ``reporter_id`` / ``reporter_team`` fields were removed @@ -359,22 +360,13 @@ def build_app(cfg: AppConfig) -> FastAPI: # at root for monitor / load-balancer health-check conventions. api_v1 = APIRouter(prefix="/api/v1") - # CORS: env-driven so the React dev server (Vite at :5173) can call - # every endpoint, SSE included. Override via ``ASR_CORS_ORIGINS`` - # (comma-separated) — production deployments lock the origin list - # down by setting the env var to the narrower allow-list. - # ``allow_credentials=False`` matches the bearer-token auth pattern - # (no cookies); methods are explicit so OPTIONS preflights are - # handled the same way for every route. - _cors_origins_raw = os.environ.get( - "ASR_CORS_ORIGINS", - "http://localhost:5173,http://127.0.0.1:5173", # Vite dev defaults - ) - _cors_origins = [o.strip() for o in _cors_origins_raw.split(",") if o.strip()] + # CORS is config-driven via ``cfg.api``. ``ASR_CONFIG`` selects the + # config file; deployments set origins/credentials in YAML rather + # than via a second env-var override path. fastapi_app.add_middleware( CORSMiddleware, - allow_origins=_cors_origins, - allow_credentials=False, + allow_origins=cfg.api.cors_origins, + allow_credentials=cfg.api.cors_allow_credentials, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) @@ -529,9 +521,34 @@ class is matched by name so this handler does not depend on a """ svc = request.app.state.service try: + if ( + body.state_overrides is not None + and "environment" in body.state_overrides + and body.environment != body.state_overrides["environment"] + ): + raise HTTPException( + status_code=422, + detail={ + "error": { + "code": "conflicting_environment", + "message": ( + "environment and state_overrides.environment " + "must match when both are supplied" + ), + "details": { + "environment": body.environment, + "state_overrides_environment": ( + body.state_overrides["environment"] + ), + }, + } + }, + ) + state_overrides = dict(body.state_overrides or {}) + state_overrides.setdefault("environment", body.environment) sid = svc.start_session( query=body.query, - state_overrides={"environment": body.environment}, + state_overrides=state_overrides, submitter=body.submitter, ) except Exception as e: # noqa: BLE001 diff --git a/src/runtime/api_recent_events.py b/src/runtime/api_recent_events.py index 7432c28..3984e0d 100644 --- a/src/runtime/api_recent_events.py +++ b/src/runtime/api_recent_events.py @@ -14,6 +14,7 @@ from typing import AsyncIterator from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse +from runtime.storage.event_log import SessionEvent _SSE_MEDIA_TYPE = "text/event-stream" _SESSION_KINDS = frozenset({ @@ -23,6 +24,29 @@ }) +def _event_payload(orch, ev: SessionEvent) -> dict: + payload = dict(ev.payload or {}) + payload.setdefault("id", ev.session_id) + if ev.kind == "session.status_changed": + payload.setdefault("status", payload.get("to")) + if ev.kind == "session.created": + try: + session = orch.store.load(ev.session_id) + except Exception: # noqa: BLE001 — SSE enrichment is best effort + session = None + if session is not None: + payload.setdefault("status", session.status) + payload.setdefault("created_at", session.created_at) + payload.setdefault("updated_at", session.updated_at) + label = ( + getattr(session, "query", None) + or (session.extra_fields or {}).get("query") + or session.id + ) + payload.setdefault("label", label) + return payload + + def add_recent_events_routes(api_v1: APIRouter) -> None: """Mount the /sessions/recent/events SSE handler on the api_v1 router. @@ -47,7 +71,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" # Tail: poll for new rows; exit on client disconnect @@ -57,7 +82,8 @@ async def _stream() -> AsyncIterator[str]: if ev.kind in _SESSION_KINDS: payload = {"seq": ev.seq, "kind": ev.kind, "session_id": ev.session_id, - "payload": ev.payload, "ts": ev.ts} + "payload": _event_payload(orch, ev), + "ts": ev.ts} last_seq = ev.seq yield f"data: {json.dumps(payload)}\n\n" diff --git a/src/runtime/config.py b/src/runtime/config.py index dcebcc5..7d13a68 100644 --- a/src/runtime/config.py +++ b/src/runtime/config.py @@ -761,7 +761,13 @@ def resolve_framework_app_config( class ApiConfig(BaseModel): - """API surface knobs surfaced to the React frontend.""" + """API surface knobs surfaced to the React frontend. + + Values come from the loaded runtime config. Environment variables + choose the config file (``ASR_CONFIG``) but do not override these + CORS fields directly; deployments that need different origins + should set them in YAML. + """ # CORS origins allowed by the FastAPI CORSMiddleware. Default # covers the two common React dev-server URLs (Vite, CRA/Next). diff --git a/src/runtime/graph.py b/src/runtime/graph.py index 77f8771..56248c7 100644 --- a/src/runtime/graph.py +++ b/src/runtime/graph.py @@ -743,6 +743,12 @@ async def node(state: GraphState) -> dict: inc_id, "agent_started", agent=skill.name, started_at=started_at, ) + event_log.record( + inc_id, + "session.agent_running", + id=inc_id, + agent=skill.name, + ) except Exception: # noqa: BLE001 — telemetry must not break the agent logger.debug( "event_log.record(agent_started) failed", exc_info=True, diff --git a/src/runtime/orchestrator.py b/src/runtime/orchestrator.py index 1fac9bd..2be5e03 100644 --- a/src/runtime/orchestrator.py +++ b/src/runtime/orchestrator.py @@ -300,6 +300,19 @@ def _emit_status_changed_event( return status_def = statuses.statuses.get(to_status) if status_def is not None and status_def.terminal: + if event_log is not None: + try: + event_log.record( + inc.id, + "session.agent_running", + id=inc.id, + agent=None, + ) + except Exception: # noqa: BLE001 — telemetry must not break finalize + _log.debug( + "event_log.record(session.agent_running) failed", + exc_info=True, + ) _extract_lesson_on_terminal(orch=orch, inc=inc) @@ -1235,13 +1248,20 @@ async def start_session(self, *, query: str, # ``__new__`` (bypassing ``__init__``) working. state_overrides_cls = getattr(self, "_state_overrides_cls", None) if state_overrides_cls is not None and state_overrides is not None: - state_overrides_cls.model_validate(state_overrides) + state_overrides = state_overrides_cls.model_validate( + state_overrides + ).model_dump(exclude_none=True) submitter = _coerce_submitter(submitter, reporter_id, reporter_team) sub_id = (submitter or {}).get("id", "user-mock") sub_team = (submitter or {}).get("team", "platform") env = (state_overrides or {}).get("environment", "") - inc = self.store.create(query=query, environment=env, - reporter_id=sub_id, reporter_team=sub_team) + inc = self.store.create( + query=query, + environment=env, + reporter_id=sub_id, + reporter_team=sub_team, + state_overrides=state_overrides, + ) # Emit session.created on the cross-session SSE stream so the # React UI's Other Sessions monitor lights up the new tile in # real time. ``session_id`` already lands on the row; the @@ -1271,6 +1291,8 @@ async def start_session(self, *, query: str, last_agent=None, error=None), config=self._thread_config(inc.id), ) + if not await self._is_graph_paused(inc.id): + await self._finalize_session_status_async(inc.id) return inc.id async def start_investigation(self, *, query: str, environment: str, diff --git a/src/runtime/service.py b/src/runtime/service.py index f52ef6b..afaad7b 100644 --- a/src/runtime/service.py +++ b/src/runtime/service.py @@ -474,7 +474,6 @@ def start_session( ) sub_id = (resolved_submitter or {}).get("id", "user-mock") sub_team = (resolved_submitter or {}).get("team", "platform") - env = (resolved_overrides or {}).get("environment", "") async def _scheduler() -> str: # Enforce the concurrency cap on the loop thread so the @@ -484,6 +483,13 @@ async def _scheduler() -> str: if len(self._registry) >= self.max_concurrent_sessions: raise SessionCapExceeded(self.max_concurrent_sessions) orch = await self._ensure_orchestrator() + overrides_for_create = resolved_overrides + state_overrides_cls = getattr(orch, "_state_overrides_cls", None) + if state_overrides_cls is not None and overrides_for_create is not None: + overrides_for_create = state_overrides_cls.model_validate( + overrides_for_create + ).model_dump(exclude_none=True) + env_for_create = (overrides_for_create or {}).get("environment", "") # Allocate the row (and its id) synchronously on the loop # so the caller gets a stable id back. The graph then runs # in a separate task — registration happens here, before @@ -491,9 +497,10 @@ async def _scheduler() -> str: # entry immediately. inc = orch.store.create( query=query, - environment=env, + environment=env_for_create, reporter_id=sub_id, reporter_team=sub_team, + state_overrides=overrides_for_create, ) session_id = inc.id # Emit session.created on the cross-session SSE stream so @@ -541,8 +548,8 @@ async def _run() -> None: raise SessionBusy(session_id) # Hold the per-session lock for the full graph turn, # including any HITL interrupt() pause (D-01). - async with orch._locks.acquire(session_id): - try: + try: + async with orch._locks.acquire(session_id): await orch.graph.ainvoke( GraphState( session=inc, @@ -552,34 +559,36 @@ async def _run() -> None: ), config=orch._thread_config(session_id), ) - except asyncio.CancelledError: - raise - except Exception as exc: # noqa: BLE001 - # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a - # pending-approval pause, not a failure. Don't stamp - # status='error' on the registry entry -- let - # LangGraph's checkpointer hold the paused state - # and let the UI's Approve/Reject action drive - # resume. - try: - from langgraph.errors import GraphInterrupt - if isinstance(exc, GraphInterrupt): - # Propagate so the underlying Task - # observer (stop_session etc.) still - # sees the exception, but skip the - # status='error' write. - raise - except ImportError: # pragma: no cover - pass - # Mark the registry entry so any concurrent snapshot - # observes the failure before the done-callback - # evicts it. The exception itself is preserved on - # the task object for ``stop_session`` and any - # other observer that holds a Task reference. - e = self._registry.get(session_id) - if e is not None: - e.status = "error" - raise + if not await orch._is_graph_paused(session_id): + await orch._finalize_session_status_async(session_id) + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 + # Phase 11 (FOC-04 / D-11-04): GraphInterrupt is a + # pending-approval pause, not a failure. Don't stamp + # status='error' on the registry entry -- let + # LangGraph's checkpointer hold the paused state + # and let the UI's Approve/Reject action drive + # resume. + try: + from langgraph.errors import GraphInterrupt + if isinstance(exc, GraphInterrupt): + # Propagate so the underlying Task + # observer (stop_session etc.) still + # sees the exception, but skip the + # status='error' write. + raise + except ImportError: # pragma: no cover + pass + # Mark the registry entry so any concurrent snapshot + # observes the failure before the done-callback + # evicts it. The exception itself is preserved on + # the task object for ``stop_session`` and any + # other observer that holds a Task reference. + e = self._registry.get(session_id) + if e is not None: + e.status = "error" + raise task = asyncio.create_task(_run(), name=f"session:{session_id}") entry.task = task diff --git a/src/runtime/state.py b/src/runtime/state.py index a7f16b4..9679e90 100644 --- a/src/runtime/state.py +++ b/src/runtime/state.py @@ -22,6 +22,16 @@ class IncidentState(Session): # Per-call audit metadata for the risk-rated tool gateway. +SessionStatus = Literal[ + "new", + "in_progress", + "awaiting_input", + "resolved", + "escalated", + "stopped", + "error", + "duplicate", +] ToolRisk = Literal["low", "medium", "high"] ToolStatus = Literal[ "executed", # auto / legacy default diff --git a/src/runtime/storage/session_store.py b/src/runtime/storage/session_store.py index 4d0db9e..bad5c1f 100644 --- a/src/runtime/storage/session_store.py +++ b/src/runtime/storage/session_store.py @@ -17,7 +17,8 @@ import json import re from datetime import datetime, timezone -from typing import Generic, Optional, Type, TypeVar +import logging +from typing import Any, Generic, Optional, Type, TypeVar from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore @@ -29,6 +30,8 @@ from runtime.state import AgentRun, Session, TokenUsage, ToolCall from runtime.storage.models import IncidentRow +_log = logging.getLogger(__name__) + # The legacy ``INC-YYYYMMDD-NNN`` pattern stays here for back-compat # validation against on-disk rows minted before the ``Session.id_format`` # hook existed. New rows are validated by ``_SESSION_ID_RE`` which @@ -201,7 +204,12 @@ def _next_id(self, session: SqlSession) -> str: # ---------- public API ---------- def create(self, *, query: str, environment: str, reporter_id: str = "user-mock", - reporter_team: str = "platform") -> StateT: + reporter_team: str = "platform", + state_overrides: dict[str, Any] | None = None) -> StateT: + extra_fields = { + k: v for k, v in (state_overrides or {}).items() + if k not in self._ROW_TYPED_DOMAIN_COLUMNS + } with SqlSession(self.engine) as session: now = _now() inc_id = self._next_id(session) @@ -220,6 +228,7 @@ def create(self, *, query: str, environment: str, tool_calls=[], findings={}, user_inputs=[], + extra_fields=extra_fields, ) session.add(row) session.commit() @@ -271,7 +280,10 @@ def save(self, session: StateT) -> None: for k, v in data.items(): setattr(existing, k, v) db_session.commit() - self._refresh_vector(session, prior_text=prior_text) + try: + self._refresh_vector(session, prior_text=prior_text) + except Exception as exc: # noqa: BLE001 — SQL commit already succeeded + self._log_vector_warning("refresh", session.id, exc) def delete(self, incident_id: str) -> StateT: with SqlSession(self.engine) as session: @@ -423,7 +435,10 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: return if prior_text == text: return - self.vector_store.delete(ids=[inc.id]) + try: + self.vector_store.delete(ids=[inc.id]) + except Exception as exc: # noqa: BLE001 — vector delete is idempotent + self._log_vector_warning("delete", inc.id, exc) from langchain_core.documents import Document self.vector_store.add_documents( [Document(page_content=text, metadata={"id": inc.id})], @@ -431,6 +446,17 @@ def _refresh_vector(self, inc: Session, *, prior_text: str) -> None: ) self._persist_vector() + def _log_vector_warning(self, operation: str, sid: str, exc: Exception) -> None: + backend = type(self.vector_store).__name__ if self.vector_store is not None else "None" + _log.warning( + "session vector %s failed sid=%s backend=%s exc=%s", + operation, + sid, + backend, + type(exc).__name__, + exc_info=True, + ) + # ---------- mapping helpers ---------- # # Round-trip is driven by ``state_cls.model_fields`` so any diff --git a/src/runtime/triggers/idempotency.py b/src/runtime/triggers/idempotency.py index 65b0ade..5731a4a 100644 --- a/src/runtime/triggers/idempotency.py +++ b/src/runtime/triggers/idempotency.py @@ -28,6 +28,7 @@ from datetime import datetime, timezone, timedelta from sqlalchemy import DateTime, String, delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.engine import Engine, create_engine from sqlalchemy.orm import Mapped, Session as SqlaSession, mapped_column from sqlalchemy.pool import NullPool @@ -106,7 +107,7 @@ def get(self, trigger_name: str, key: str) -> str | None: if cache is not None and key in cache: # Bump recency cache.move_to_end(key) - return cache[key] + return cache[key] or None # SQLite fall-through (outside the threading lock — sqlite3 has # its own locking, and this path is rare). with SqlaSession(self._engine) as s: @@ -134,6 +135,8 @@ def get(self, trigger_name: str, key: str) -> str | None: s.commit() return None session_id = row.session_id + if not session_id: + return None # Refill LRU. with self._lock: cache = self._lru.setdefault(trigger_name, OrderedDict()) @@ -182,6 +185,51 @@ def put( # doesn't accumulate dead records. Cheap (range-bounded delete). self.purge_expired() + def reserve( + self, + trigger_name: str, + key: str, + *, + ttl_hours: int = 24, + ) -> bool: + """Atomically reserve a fresh idempotency key. + + Returns ``True`` only for the caller that inserted the row. A + ``False`` result means another request has either already + completed and stored ``session_id`` or is still in flight. + """ + now = _utc_now() + expires_at = now + timedelta(hours=ttl_hours) + with SqlaSession(self._engine) as s: + existing = s.get(IdempotencyRow, (trigger_name, key)) + if existing is not None: + existing_expires = existing.expires_at + if existing_expires.tzinfo is None: + existing_expires = existing_expires.replace(tzinfo=timezone.utc) + if existing_expires <= now: + s.delete(existing) + s.commit() + else: + return False + s.add(IdempotencyRow( + trigger_name=trigger_name, + key=key, + session_id="", + created_at=now, + expires_at=expires_at, + )) + try: + s.commit() + except IntegrityError: + s.rollback() + return False + with self._lock: + cache = self._lru.setdefault(trigger_name, OrderedDict()) + cache[key] = "" + cache.move_to_end(key) + self._evict_if_needed(cache) + return True + def purge_expired(self) -> int: """Delete all rows whose ``expires_at`` is in the past. Returns the number of rows removed.""" diff --git a/src/runtime/triggers/registry.py b/src/runtime/triggers/registry.py index 82b5927..eba5a24 100644 --- a/src/runtime/triggers/registry.py +++ b/src/runtime/triggers/registry.py @@ -23,6 +23,7 @@ """ from __future__ import annotations +import asyncio import importlib.metadata import logging from datetime import datetime, timezone @@ -277,9 +278,14 @@ async def dispatch( if spec is None: raise KeyError(f"unknown trigger: {name!r}") - # Idempotency hit: return cached session id without invoking - # transform / orchestrator. Per R3 in the plan, transform errors - # are NOT cached — only successful dispatches. + ttl = ( + spec.config.idempotency_ttl_hours + if isinstance(spec.config, WebhookTriggerConfig) + else 24 + ) + + # Fast idempotency hit: return cached session id without + # invoking transform / orchestrator. if idempotency_key and self._idempotency is not None: cached = self._idempotency.get(name, idempotency_key) if cached is not None: @@ -304,15 +310,25 @@ async def dispatch( received_at=datetime.now(timezone.utc), ) + # Idempotency reservation: only the caller that atomically + # inserts the key may start a session. Followers wait briefly + # for the first caller to publish the resulting session id. + if idempotency_key and self._idempotency is not None: + if not self._idempotency.reserve(name, idempotency_key, ttl_hours=ttl): + for _ in range(100): + cached = self._idempotency.get(name, idempotency_key) + if cached is not None: + return cached + await asyncio.sleep(0.01) + raise RuntimeError( + f"idempotency key {idempotency_key!r} for trigger " + f"{name!r} is still in flight" + ) + session_id = await self._start_session_fn(trigger=info, **kwargs) # Record successful dispatch for idempotency. if idempotency_key and self._idempotency is not None: - ttl = ( - spec.config.idempotency_ttl_hours - if isinstance(spec.config, WebhookTriggerConfig) - else 24 - ) self._idempotency.put(name, idempotency_key, session_id, ttl_hours=ttl) return session_id diff --git a/tests/test_agent_node.py b/tests/test_agent_node.py index f425747..249328c 100644 --- a/tests/test_agent_node.py +++ b/tests/test_agent_node.py @@ -9,6 +9,7 @@ from runtime.llm import StubChatModel from runtime.storage.embeddings import build_embedder from runtime.storage.engine import build_engine +from runtime.storage.event_log import EventLog from runtime.storage.models import Base from runtime.storage.session_store import SessionStore @@ -51,6 +52,7 @@ async def test_agent_node_runs_llm_records_agent_run_and_routes(incident): skill=skill, llm=llm, tools=[], decide_route=lambda inc: "default", store=store, + event_log=EventLog(engine=store.engine), terminal_tool_names=_TEST_TERMINAL_NAMES, patch_tool_names=_TEST_PATCH_NAMES, ) @@ -74,6 +76,10 @@ async def test_agent_node_runs_llm_records_agent_run_and_routes(incident): # The runner stamps these onto the AgentRun. assert intake_runs[0].confidence == approx(0.85) assert intake_runs[0].confidence_rationale == "stub envelope rationale" + events = list(EventLog(engine=store.engine).iter_for(inc.id)) + running = [e for e in events if e.kind == "session.agent_running"] + assert running + assert running[0].payload == {"id": inc.id, "agent": "intake"} @pytest.mark.asyncio diff --git a/tests/test_api.py b/tests/test_api.py index 46dce05..4817f83 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -218,6 +218,44 @@ async def test_post_sessions_omitting_submitter_uses_defaults(cfg): assert res.json()["session_id"] +@pytest.mark.asyncio +async def test_post_sessions_accepts_state_overrides(cfg): + app = build_app(cfg) + async with _client_with_lifespan(app) as client: + res = await client.post( + "/api/v1/sessions", + json={ + "query": "review PR", + "environment": "staging", + "state_overrides": { + "pr_url": "https://github.com/foo/bar/pull/1", + "repo": "foo/bar", + }, + }, + ) + assert res.status_code == 201 + sid = res.json()["session_id"] + inc = app.state.orchestrator.store.load(sid) + assert inc.extra_fields["pr_url"] == "https://github.com/foo/bar/pull/1" + assert inc.extra_fields["repo"] == "foo/bar" + + +@pytest.mark.asyncio +async def test_post_sessions_rejects_conflicting_environment(cfg): + app = build_app(cfg) + async with _client_with_lifespan(app) as client: + res = await client.post( + "/api/v1/sessions", + json={ + "query": "conflict", + "environment": "staging", + "state_overrides": {"environment": "production"}, + }, + ) + assert res.status_code == 422 + assert res.json()["error"]["code"] == "conflicting_environment" + + @pytest.mark.asyncio async def test_get_sessions_returns_list(cfg): """GET /sessions returns a JSON list (possibly empty); shape matches diff --git a/tests/test_api_react_surface.py b/tests/test_api_react_surface.py index d367b3d..b484921 100644 --- a/tests/test_api_react_surface.py +++ b/tests/test_api_react_surface.py @@ -490,7 +490,7 @@ async def test_unknown_endpoint_returns_404_envelope(cfg): automatic 404 raised by Starlette for unknown routes).""" app = build_app(cfg) async with _client_with_lifespan(app) as client: - res = await client.get("/this-route-does-not-exist") + res = await client.get("/api/v1/this-route-does-not-exist") assert res.status_code == 404 body = res.json() # Starlette's default 404 body is ``{"detail": "Not Found"}``; the diff --git a/tests/test_api_recent_events.py b/tests/test_api_recent_events.py index e06bcb2..b79d48a 100644 --- a/tests/test_api_recent_events.py +++ b/tests/test_api_recent_events.py @@ -1,55 +1,41 @@ -"""Cross-session SSE — emits session-level events across all sessions.""" -import asyncio -import json -from fastapi.testclient import TestClient -from runtime.api import build_app -from tests.test_api_v1_url_move import _cfg +"""Cross-session SSE payload enrichment.""" +from __future__ import annotations +import json +from types import SimpleNamespace -def test_recent_events_replays_session_creates(tmp_path): - """Two POST /sessions calls each emit a session.created event; - the recent SSE stream replays them on connect.""" - app = build_app(_cfg(tmp_path)) - with TestClient(app) as client: - # Create two sessions - for q in ["alpha", "beta"]: - r = client.post("/api/v1/sessions", json={ - "query": q, "environment": "dev", - "submitter": {"id": "u1", "team": "p"}, - }) - assert r.status_code == 201 +from sqlalchemy import create_engine - # Direct-call the SSE handler with a forced-disconnect to drain - # the backlog (avoids long-poll deadlock in TestClient). - from starlette.requests import Request as StarletteRequest +from runtime.api_recent_events import _event_payload +from runtime.storage.event_log import EventLog +from runtime.storage.models import Base +from runtime.storage.session_store import SessionStore - async def _disc(): - return True # exit tail loop after backlog drain - sse_route = next( - r for r in app.router.routes - if getattr(r, "path", "") == "/api/v1/sessions/recent/events" - ) - scope = { - "type": "http", "method": "GET", - "path": "/api/v1/sessions/recent/events", - "query_string": b"since=0", "headers": [], "app": app, - } - request = StarletteRequest(scope) - request.is_disconnected = _disc # type: ignore[method-assign] - response = asyncio.run( - sse_route.endpoint(request=request, since=0) # type: ignore[attr-defined] - ) +def test_recent_events_session_created_payload_matches_session_summary(tmp_path): + engine = create_engine(f"sqlite:///{tmp_path / 't.db'}") + Base.metadata.create_all(engine) + store = SessionStore(engine=engine) + log = EventLog(engine=engine) + inc = store.create(query="alpha", environment="dev") + log.record(inc.id, "session.created") + ev = list(log.iter_recent())[0] - async def _drain(): - frames = [] - async for chunk in response.body_iterator: - text = chunk.decode() if isinstance(chunk, bytes) else chunk - for line in text.splitlines(): - if line.startswith("data: "): - frames.append(json.loads(line[len("data: "):])) - return frames + payload = { + "seq": ev.seq, + "kind": ev.kind, + "session_id": ev.session_id, + "payload": _event_payload(SimpleNamespace(store=store), ev), + "ts": ev.ts, + } + frame = f"data: {json.dumps(payload)}\n\n" + decoded = json.loads(frame.split("data: ", 1)[1]) - frames = asyncio.run(_drain()) - kinds = [f["kind"] for f in frames] - assert kinds.count("session.created") == 2 + assert decoded["kind"] == "session.created" + assert decoded["payload"]["id"] == inc.id + assert { + "id", + "status", + "created_at", + "updated_at", + } <= decoded["payload"].keys() diff --git a/tests/test_code_review_e2e.py b/tests/test_code_review_e2e.py index 1485611..7574d0e 100644 --- a/tests/test_code_review_e2e.py +++ b/tests/test_code_review_e2e.py @@ -103,6 +103,7 @@ async def test_code_review_e2e_approves_clean_pr(tmp_path): inc.tool_calls.append(_exec_set_recommendation( recommendation="approve", summary="LGTM", )) + inc.status = "in_progress" inc.extra_fields["overall_recommendation"] = "approve" orch.store.save(inc) @@ -155,6 +156,7 @@ async def test_code_review_e2e_request_changes_on_critical(tmp_path): recommendation="request_changes", summary="Critical bug found", )) + inc.status = "in_progress" orch.store.save(inc) new_status = await orch._finalize_session_status_async(sid) @@ -183,6 +185,7 @@ async def test_code_review_e2e_comment_on_warnings(tmp_path): recommendation="comment", summary="Minor nits", )) + inc.status = "in_progress" orch.store.save(inc) new_status = await orch._finalize_session_status_async(sid) @@ -256,6 +259,7 @@ async def test_code_review_e2e_no_incident_imports(tmp_path): inc.tool_calls.append(_exec_set_recommendation( recommendation="approve", )) + inc.status = "in_progress" orch.store.save(inc) await orch._finalize_session_status_async(sid) finally: diff --git a/tests/test_finalizer_paths.py b/tests/test_finalizer_paths.py new file mode 100644 index 0000000..04801f8 --- /dev/null +++ b/tests/test_finalizer_paths.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest + +from runtime.config import ( + AppConfig, + LLMConfig, + MCPConfig, + MCPServerConfig, + MetadataConfig, + OrchestratorConfig, + Paths, + StorageConfig, +) +from runtime.orchestrator import Orchestrator +from runtime.service import OrchestratorService +from runtime.state import ToolCall +from runtime.terminal_tools import StatusDef, TerminalToolRule + + +def _cfg(tmp_path) -> AppConfig: + return AppConfig( + llm=LLMConfig.stub(), + mcp=MCPConfig(servers=[ + MCPServerConfig( + name="local_inc", + transport="in_process", + module="examples.incident_management.mcp_server", + category="incident_management", + ), + MCPServerConfig( + name="local_obs", + transport="in_process", + module="examples.incident_management.mcp_servers.observability", + category="observability", + ), + MCPServerConfig( + name="local_rem", + transport="in_process", + module="examples.incident_management.mcp_servers.remediation", + category="remediation", + ), + MCPServerConfig( + name="local_user", + transport="in_process", + module="examples.incident_management.mcp_servers.user_context", + category="user_context", + ), + ]), + storage=StorageConfig( + metadata=MetadataConfig(url=f"sqlite:///{tmp_path}/test.db"), + ), + paths=Paths( + skills_dir="examples/incident_management/skills", + incidents_dir=str(tmp_path), + ), + orchestrator=OrchestratorConfig( + statuses={ + "in_progress": StatusDef( + name="in_progress", terminal=False, kind="pending", + ), + "resolved": StatusDef( + name="resolved", terminal=True, kind="success", + ), + "needs_review": StatusDef( + name="needs_review", terminal=True, kind="needs_review", + ), + }, + terminal_tools=[ + TerminalToolRule(tool_name="mark_resolved", status="resolved"), + ], + default_terminal_status="needs_review", + ), + ) + + +class _CompletingGraph: + def __init__(self, store, *, service=None) -> None: + self.store = store + self.service = service + self.captured_entries = [] + + async def ainvoke(self, state, *, config): + inc = state["session"] + if self.service is not None: + self.captured_entries.append(self.service._registry.get(inc.id)) + inc.status = "in_progress" + inc.tool_calls.append(ToolCall( + agent="resolution", + tool="mark_resolved", + args={}, + result={"status": "resolved"}, + ts="2026-01-01T00:00:00Z", + status="executed", + )) + self.store.save(inc) + return {} + + async def aget_state(self, config): + return SimpleNamespace(next=()) + + +@pytest.mark.asyncio +async def test_orchestrator_start_session_finalizes_completed_non_streaming_run( + tmp_path, +): + orch = await Orchestrator.create(_cfg(tmp_path)) + try: + orch.graph = _CompletingGraph(orch.store) + sid = await orch.start_session(query="db pool exhausted") + assert orch.store.load(sid).status == "resolved" + finally: + await orch.aclose() + + +def test_service_background_start_session_finalizes_completed_run(tmp_path): + service = OrchestratorService.get_or_create(_cfg(tmp_path)) + service.start() + try: + async def _install_graph(): + orch = await service._ensure_orchestrator() + graph = _CompletingGraph(orch.store, service=service) + orch.graph = graph + return graph + + graph = service.submit_and_wait(_install_graph(), timeout=10.0) + sid = service.start_session(query="db pool exhausted") + + async def _await_background_task(): + while not graph.captured_entries: + await asyncio.sleep(0.01) + entry = graph.captured_entries[0] + if entry is not None and entry.task is not None: + await entry.task + + service.submit_and_wait(_await_background_task(), timeout=10.0) + orch = service.submit_and_wait(service._ensure_orchestrator(), timeout=10.0) + assert orch.store.load(sid).status == "resolved" + finally: + service.shutdown() diff --git a/tests/test_state_overrides_schema.py b/tests/test_state_overrides_schema.py index 59b2bf9..f8cceb0 100644 --- a/tests/test_state_overrides_schema.py +++ b/tests/test_state_overrides_schema.py @@ -23,8 +23,11 @@ """ from __future__ import annotations +from types import SimpleNamespace + import pytest from pydantic import ValidationError +from sqlalchemy.orm import Session as SqlSession from runtime.config import ( AppConfig, @@ -39,6 +42,7 @@ ) from runtime.orchestrator import Orchestrator from runtime.state import ToolCall +from runtime.storage.models import IncidentRow # --------------------------------------------------------------------------- @@ -78,6 +82,14 @@ ] +class _NoopGraph: + async def ainvoke(self, state, *, config): + return {} + + async def aget_state(self, config): + return SimpleNamespace(next=()) + + def _base_cfg( tmp_path, *, @@ -318,6 +330,44 @@ async def test_cross_app_rejection_code_review_rejects_incident_shape(tmp_path): await orch.aclose() +@pytest.mark.asyncio +async def test_code_review_state_overrides_persist_extra_fields_and_environment( + tmp_path, +): + cfg = _base_cfg( + tmp_path, + state_overrides_schema=( + "examples.code_review.state.CodeReviewStateOverrides" + ), + skills_dir="examples/code_review/skills", + ) + orch = await Orchestrator.create(cfg) + try: + orch.graph = _NoopGraph() + sid = await orch.start_session( + query="Review PR", + state_overrides={ + "pr_url": "https://github.com/foo/bar/pull/1", + "repo": "foo/bar", + "environment": "staging", + }, + ) + inc = orch.store.load(sid) + assert inc.extra_fields["pr_url"] == "https://github.com/foo/bar/pull/1" + assert inc.extra_fields["repo"] == "foo/bar" + assert inc.extra_fields.get("environment") == "staging" + with SqlSession(orch.store.engine) as db: + row = db.get(IncidentRow, sid) + assert row is not None + assert row.environment == "staging" + assert row.extra_fields == { + "pr_url": "https://github.com/foo/bar/pull/1", + "repo": "foo/bar", + } + finally: + await orch.aclose() + + # --------------------------------------------------------------------------- # YAML round-trip. # --------------------------------------------------------------------------- diff --git a/tests/test_status_change_telemetry.py b/tests/test_status_change_telemetry.py index 6396fb7..f7fe52b 100644 --- a/tests/test_status_change_telemetry.py +++ b/tests/test_status_change_telemetry.py @@ -98,6 +98,11 @@ async def test_finalize_with_mark_resolved_emits_status_changed(tmp_path): assert e.payload["from"] == "in_progress" assert e.payload["to"] == "resolved" assert e.payload["cause"] == "mark_resolved" + clear_events = [ + e for e in events + if e.kind == "session.agent_running" + ] + assert clear_events[-1].payload == {"id": inc.id, "agent": None} finally: await orch.aclose() diff --git a/tests/test_storage_vector.py b/tests/test_storage_vector.py index 3ffd7e3..3925386 100644 --- a/tests/test_storage_vector.py +++ b/tests/test_storage_vector.py @@ -2,9 +2,12 @@ from __future__ import annotations from pathlib import Path import pytest +from sqlalchemy import create_engine from runtime.config import EmbeddingConfig, ProviderConfig, VectorConfig from runtime.storage.embeddings import build_embedder +from runtime.storage.models import Base +from runtime.storage.session_store import SessionStore def _stub_embedder(dim: int = 8): @@ -82,3 +85,31 @@ def test_distance_to_similarity_unknown_raises(): from runtime.storage.vector import distance_to_similarity with pytest.raises(ValueError, match="unknown distance strategy"): distance_to_similarity(0.5, "manhattan") + + +class _FailingVectorStore: + def delete(self, *, ids): + raise KeyError(ids[0]) + + def add_documents(self, documents, *, ids): + raise RuntimeError("vector write failed") + + +def test_session_store_save_persists_sql_when_vector_refresh_fails(tmp_path, caplog): + engine = create_engine(f"sqlite:///{tmp_path / 't.db'}") + Base.metadata.create_all(engine) + store = SessionStore( + engine=engine, + embedder=_stub_embedder(), + ) + inc = store.create(query="first", environment="dev") + store.vector_store = _FailingVectorStore() # type: ignore[assignment] + inc.extra_fields["query"] = "changed" + + with caplog.at_level("WARNING", logger="runtime.storage.session_store"): + store.save(inc) + + loaded = store.load(inc.id) + assert loaded.extra_fields["query"] == "changed" + assert any("vector delete failed" in rec.getMessage() for rec in caplog.records) + assert any("vector refresh failed" in rec.getMessage() for rec in caplog.records) diff --git a/tests/test_triggers/test_orchestrator_trigger_kwarg.py b/tests/test_triggers/test_orchestrator_trigger_kwarg.py index e6bacb1..be4d76c 100644 --- a/tests/test_triggers/test_orchestrator_trigger_kwarg.py +++ b/tests/test_triggers/test_orchestrator_trigger_kwarg.py @@ -8,6 +8,10 @@ from runtime.triggers.base import TriggerInfo +async def _always_paused() -> bool: + return True + + @pytest.mark.asyncio async def test_orchestrator_start_session_records_trigger(tmp_path, monkeypatch): """``Orchestrator.start_session(trigger=...)`` stamps provenance on @@ -46,6 +50,7 @@ async def ainvoke(self, state, config): orch.store = _FakeStore() orch.graph = _FakeGraph() orch._thread_config = lambda sid: {"configurable": {"thread_id": sid}} + orch._is_graph_paused = lambda sid: _always_paused() # Tests that bypass __init__ must set the dedup pipeline # attribute to ``None`` so the dedup-check shortcut returns False # without touching the (uninitialised) attribute. @@ -90,6 +95,7 @@ async def ainvoke(self, state, config): orch.store = _FakeStore() orch.graph = _FakeGraph() orch._thread_config = lambda sid: {"configurable": {"thread_id": sid}} + orch._is_graph_paused = lambda sid: _always_paused() orch.dedup_pipeline = None sid = await orch.start_session(query="q", environment="dev") diff --git a/tests/test_triggers/test_registry.py b/tests/test_triggers/test_registry.py index c656247..1a89e78 100644 --- a/tests/test_triggers/test_registry.py +++ b/tests/test_triggers/test_registry.py @@ -1,6 +1,8 @@ """TriggerRegistry tests — resolution, lifecycle, dispatch, idempotency.""" from __future__ import annotations +import asyncio + import pytest from runtime.triggers import TriggerRegistry @@ -128,6 +130,41 @@ async def fake_start(*, trigger=None, **kw): assert len(calls) == 1 # only one call to start_session +@pytest.mark.asyncio +async def test_idempotency_reservation_allows_one_concurrent_start(tmp_path): + """Concurrent duplicate keys must not both start sessions.""" + calls = [] + started = asyncio.Event() + release = asyncio.Event() + + async def fake_start(*, trigger=None, **kw): + calls.append(kw) + started.set() + await release.wait() + return f"INC-{len(calls)}" + + store = IdempotencyStore.connect(f"sqlite:///{tmp_path / 'idem.db'}") + reg = TriggerRegistry.create( + [_webhook_cfg()], start_session_fn=fake_start, idempotency=store + ) + from tests.test_triggers.conftest import PagerDutyPayload + + payload = PagerDutyPayload(incident_id="P-1", summary="boom") + first = asyncio.create_task( + reg.dispatch("pd", payload, idempotency_key="K-race") + ) + await started.wait() + second = asyncio.create_task( + reg.dispatch("pd", payload, idempotency_key="K-race") + ) + await asyncio.sleep(0) + release.set() + + s1, s2 = await asyncio.gather(first, second) + assert s1 == s2 + assert len(calls) == 1 + + @pytest.mark.asyncio async def test_lifecycle_start_stop_idempotent(): _StubTransport.instances.clear() diff --git a/tests/test_type_literal_parity.py b/tests/test_type_literal_parity.py new file mode 100644 index 0000000..fc35aaf --- /dev/null +++ b/tests/test_type_literal_parity.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import ast +import re +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] + + +def _python_literal_values(alias: str) -> set[str]: + tree = ast.parse((ROOT / "src/runtime/state.py").read_text()) + for node in tree.body: + if isinstance(node, ast.Assign): + if not any(isinstance(t, ast.Name) and t.id == alias for t in node.targets): + continue + value = node.value + if ( + isinstance(value, ast.Subscript) + and isinstance(value.value, ast.Name) + and value.value.id == "Literal" + ): + items = value.slice.elts if isinstance(value.slice, ast.Tuple) else [value.slice] + return { + item.value + for item in items + if isinstance(item, ast.Constant) and isinstance(item.value, str) + } + raise AssertionError(f"literal alias {alias!r} not found") + + +def _ts_union_values(interface: str, field: str) -> set[str]: + text = (ROOT / "web/src/api/types.ts").read_text() + match = re.search( + rf"export interface {interface} \{{(?P.*?)\n\}}", + text, + flags=re.DOTALL, + ) + assert match is not None, f"interface {interface!r} not found" + field_match = re.search(rf"\b{field}:\s*(?P[^;]+);", match.group("body")) + assert field_match is not None, f"field {interface}.{field} not found" + return set(re.findall(r"'([^']+)'", field_match.group("union"))) + + +def test_session_status_literals_match_typescript(): + assert _ts_union_values("Session", "status") == _python_literal_values( + "SessionStatus" + ) + + +def test_tool_call_status_literals_match_typescript(): + assert _ts_union_values("ToolCall", "status") == _python_literal_values( + "ToolStatus" + ) diff --git a/web/src/api/types.ts b/web/src/api/types.ts index 100825e..d8339b0 100644 --- a/web/src/api/types.ts +++ b/web/src/api/types.ts @@ -5,7 +5,7 @@ export type SessionId = string; // SES-YYYYMMDD-NNN export interface Session { id: SessionId; - status: 'in_progress' | 'awaiting_input' | 'matched' | 'resolved' | 'escalated' | 'stopped' | 'error' | 'new'; + status: 'new' | 'in_progress' | 'awaiting_input' | 'resolved' | 'escalated' | 'stopped' | 'error' | 'duplicate'; created_at: string; // ISO UTC updated_at: string; deleted_at: string | null; @@ -38,7 +38,7 @@ export interface ToolCall { result: unknown; ts: string; risk: 'low' | 'medium' | 'high' | null; - status: 'executed' | 'executed_with_notify' | 'pending_approval' | 'approved' | 'rejected' | 'timeout' | 'auto_rejected'; + status: 'executed' | 'executed_with_notify' | 'pending_approval' | 'approved' | 'rejected' | 'timeout'; approver: string | null; approved_at: string | null; approval_rationale: string | null; @@ -70,6 +70,13 @@ export interface SessionFullBundle { vm_seq: number; } +export interface SessionStartBody { + query: string; + environment: string; + submitter?: Record | null; + state_overrides?: Record; +} + export interface UiHints { brand_name: string; brand_logo_url: string | null; diff --git a/web/src/modals/NewSessionModal.tsx b/web/src/modals/NewSessionModal.tsx index f6b4feb..8a03dd6 100644 --- a/web/src/modals/NewSessionModal.tsx +++ b/web/src/modals/NewSessionModal.tsx @@ -2,6 +2,7 @@ import { useState, useEffect } from 'react'; import type { CSSProperties } from 'react'; import { Modal } from '@/components/Modal'; import { apiFetch, ApiClientError } from '@/api/client'; +import type { SessionStartBody } from '@/api/types'; interface NewSessionModalProps { open: boolean; @@ -80,7 +81,11 @@ export function NewSessionModal({ open, onOpenChange, environments, onCreated }: try { const res = await apiFetch('/sessions', { method: 'POST', - json: { query: query.trim(), environment, submitter: { id: 'operator' } }, + json: { + query: query.trim(), + environment, + submitter: { id: 'operator' }, + } satisfies SessionStartBody, }); onCreated(res.session_id); onOpenChange(false); diff --git a/web/src/state/sessionReducer.ts b/web/src/state/sessionReducer.ts index 9ffd1ab..7705058 100644 --- a/web/src/state/sessionReducer.ts +++ b/web/src/state/sessionReducer.ts @@ -53,6 +53,11 @@ export function sessionReducer(state: SessionState, action: Action): SessionStat started_at: p.started_at ?? '', ended_at: p.ended_at ?? ev.ts, summary: p.summary ?? '', + token_usage: p.token_usage ?? { + input_tokens: Number(ev.payload.input_tokens ?? 0), + output_tokens: Number(ev.payload.output_tokens ?? 0), + total_tokens: Number(ev.payload.total_tokens ?? 0), + }, confidence: p.confidence ?? null, confidence_rationale: p.confidence_rationale ?? null, signal: p.signal ?? null, @@ -75,7 +80,8 @@ export function sessionReducer(state: SessionState, action: Action): SessionStat }]; break; } - case 'approval_pending': { + case 'approval_pending': + case 'gate_fired': { const p = ev.payload as Partial; toolCalls = [...toolCalls, { agent: p.agent ?? '', @@ -92,8 +98,9 @@ export function sessionReducer(state: SessionState, action: Action): SessionStat break; } case 'status_changed': { - const p = ev.payload as { status?: Session['status'] }; - if (session && p.status) session = { ...session, status: p.status }; + const p = ev.payload as { status?: Session['status']; to?: Session['status'] }; + const status = p.to ?? p.status; + if (session && status) session = { ...session, status }; break; } } diff --git a/web/tests/unit/sessionReducer.test.ts b/web/tests/unit/sessionReducer.test.ts index 973a1cd..306cab4 100644 --- a/web/tests/unit/sessionReducer.test.ts +++ b/web/tests/unit/sessionReducer.test.ts @@ -47,11 +47,13 @@ describe('sessionReducer', () => { const state = sessionReducer(initialSessionState, { type: 'bootstrap', bundle: baseBundle }); const finished: SessionEvent = { seq: 1, kind: 'agent_finished', ts: 'x', - payload: { agent: 'intake', summary: 'done', confidence: 0.9 }, + payload: { agent: 'intake', input_tokens: 3, output_tokens: 4, total_tokens: 7 }, }; const next = sessionReducer(state, { type: 'event', event: finished }); expect(next.agentsRun).toHaveLength(1); expect(next.agentsRun[0]?.agent).toBe('intake'); + expect(next.agentsRun[0]?.summary).toBe(''); + expect(next.agentsRun[0]?.token_usage?.total_tokens).toBe(7); }); it('event "tool_invoked" appends a ToolCall with status="executed"', () => { @@ -76,14 +78,35 @@ describe('sessionReducer', () => { expect(next.toolCalls[0]?.risk).toBe('high'); }); + it('event "gate_fired" inserts a pending tool call', () => { + const state = sessionReducer(initialSessionState, { type: 'bootstrap', bundle: baseBundle }); + const ev: SessionEvent = { + seq: 1, kind: 'gate_fired', ts: 'x', + payload: { agent: 'investigate', tool: 'rem:propose_fix', reason: 'high_risk' }, + }; + const next = sessionReducer(state, { type: 'event', event: ev }); + expect(next.toolCalls[0]?.status).toBe('pending_approval'); + expect(next.toolCalls[0]?.tool).toBe('rem:propose_fix'); + }); + it('event "status_changed" updates session.status', () => { const state = sessionReducer(initialSessionState, { type: 'bootstrap', bundle: baseBundle }); const ev: SessionEvent = { seq: 1, kind: 'status_changed', ts: 'x', - payload: { status: 'resolved' }, + payload: { from: 'in_progress', to: 'resolved', cause: 'mark_resolved' }, }; const next = sessionReducer(state, { type: 'event', event: ev }); expect(next.session?.status).toBe('resolved'); }); + + it('event "status_changed" keeps backwards-compatible status payload', () => { + const state = sessionReducer(initialSessionState, { type: 'bootstrap', bundle: baseBundle }); + const ev: SessionEvent = { + seq: 1, kind: 'status_changed', ts: 'x', + payload: { status: 'error' }, + }; + const next = sessionReducer(state, { type: 'event', event: ev }); + expect(next.session?.status).toBe('error'); + }); }); });