Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .fern/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
]
},
"originGitCommit": "efe71642022d9d3303fd78c648e5b2539192230e",
"sdkVersion": "1.2.1"
"sdkVersion": "1.2.2"
}
79 changes: 70 additions & 9 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dynamic = ["version"]

[tool.poetry]
name = "schematichq"
version = "1.2.1"
version = "1.2.2"
description = ""
readme = "README.md"
authors = []
Expand Down
4 changes: 2 additions & 2 deletions src/schematic/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def get_headers(self) -> typing.Dict[str, str]:
import platform

headers: typing.Dict[str, str] = {
"User-Agent": "schematichq/1.2.1",
"User-Agent": "schematichq/1.2.2",
"X-Fern-Language": "Python",
"X-Fern-Runtime": f"python/{platform.python_version()}",
"X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}",
"X-Fern-SDK-Name": "schematichq",
"X-Fern-SDK-Version": "1.2.1",
"X-Fern-SDK-Version": "1.2.2",
**(self.get_custom_headers() or {}),
}
headers["X-Schematic-Api-Key"] = self.api_key
Expand Down
144 changes: 92 additions & 52 deletions src/schematic/datastream/datastream_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def __init__(self, options: DataStreamClientOptions) -> None:
self._pending_user: Dict[str, List[asyncio.Future[RulesengineUser]]] = {}
self._pending_flags: Optional[asyncio.Future[bool]] = None

# Per-entity locks serialize read-modify-write on the cache so that the
# WS handler and external callers (e.g. update_company_metrics) can't
# lose each other's updates when they interleave at await points.
self._company_locks: Dict[str, asyncio.Lock] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to consider bounding these

self._user_locks: Dict[str, asyncio.Lock] = {}

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
Expand Down Expand Up @@ -434,16 +440,23 @@ async def update_company_metrics(self, keys: Dict[str, str], event: str, quantit
if company is None:
return

updated = company.model_copy(deep=True)
if updated.metrics:
new_metrics = [
metric.model_copy(update={"value": (metric.value or 0) + quantity})
if metric.event_subtype == event else metric
for metric in updated.metrics
]
updated = updated.model_copy(update={"metrics": new_metrics})
async with self._get_company_lock(company.id):
# Re-fetch under lock — a concurrent partial may have changed state
# between the unlocked lookup above and lock acquisition.
company = await self._get_company_from_cache(keys)
if company is None:
return

updated = company.model_copy(deep=True)
if updated.metrics:
new_metrics = [
metric.model_copy(update={"value": (metric.value or 0) + quantity})
if metric.event_subtype == event else metric
for metric in updated.metrics
]
updated = updated.model_copy(update={"metrics": new_metrics})

await self._cache_company(updated)
await self._cache_company(updated)

async def close(self) -> None:
"""Gracefully close the datastream client."""
Expand Down Expand Up @@ -518,75 +531,88 @@ async def _handle_company_message(self, message: DataStreamResp) -> None:
return

# For partial updates, look up the cached entity by envelope entity_id
# and merge the wrapped data payload into it.
# and merge the wrapped data payload into it. The read-merge-write must
# run under a per-id lock so concurrent calls (e.g. update_company_metrics)
# can't read stale state and overwrite our merge.
if message.message_type == MessageType.PARTIAL.value:
entity_id = message.entity_id
if not entity_id:
self._logger.warning("Partial company message missing entity_id")
return

rk = self._resource_id_cache_key(_PREFIX_COMPANY, entity_id)
raw_existing = await self._company_cache.get(rk)
if raw_existing is None:
self._logger.warning("Partial company update for unknown entity: %s", entity_id)
return
async with self._get_company_lock(entity_id):
rk = self._resource_id_cache_key(_PREFIX_COMPANY, entity_id)
raw_existing = await self._company_cache.get(rk)
if raw_existing is None:
self._logger.warning("Partial company update for unknown entity: %s", entity_id)
return

existing = _validate(RulesengineCompany, raw_existing)
partial_data = raw if isinstance(raw, dict) else raw.model_dump()
try:
company = partial_company(existing, partial_data)
except Exception as exc:
self._logger.error("Failed to merge partial company: %s", exc)
return
else:
company = _validate(RulesengineCompany, raw)
existing = _validate(RulesengineCompany, raw_existing)
partial_data = raw if isinstance(raw, dict) else raw.model_dump()
try:
company = partial_company(existing, partial_data)
except Exception as exc:
self._logger.error("Failed to merge partial company: %s", exc)
return

if message.message_type == MessageType.DELETE.value:
await self._delete_entity(
company.id, company.keys, _PREFIX_COMPANY, self._company_cache, self._company_key_cache,
)
await self._cache_company(company)
self._notify_pending_company(company.keys or {}, company)
return

await self._cache_company(company)
self._notify_pending_company(company.keys or {}, company)
company = _validate(RulesengineCompany, raw)

async with self._get_company_lock(company.id):
if message.message_type == MessageType.DELETE.value:
await self._delete_entity(
company.id, company.keys, _PREFIX_COMPANY, self._company_cache, self._company_key_cache,
)
return

await self._cache_company(company)
self._notify_pending_company(company.keys or {}, company)

async def _handle_user_message(self, message: DataStreamResp) -> None:
raw = message.data
if not raw:
return

# For partial updates, look up the cached entity by envelope entity_id
# and merge the wrapped data payload into it.
# See _handle_company_message — same per-id locking rationale.
if message.message_type == MessageType.PARTIAL.value:
entity_id = message.entity_id
if not entity_id:
self._logger.warning("Partial user message missing entity_id")
return

rk = self._resource_id_cache_key(_PREFIX_USER, entity_id)
raw_existing = await self._user_cache.get(rk)
if raw_existing is None:
self._logger.warning("Partial user update for unknown entity: %s", entity_id)
return
async with self._get_user_lock(entity_id):
rk = self._resource_id_cache_key(_PREFIX_USER, entity_id)
raw_existing = await self._user_cache.get(rk)
if raw_existing is None:
self._logger.warning("Partial user update for unknown entity: %s", entity_id)
return

existing = _validate(RulesengineUser, raw_existing)
partial_data = raw if isinstance(raw, dict) else raw.model_dump()
try:
user = partial_user(existing, partial_data)
except Exception as exc:
self._logger.error("Failed to merge partial user: %s", exc)
return
else:
user = _validate(RulesengineUser, raw)
existing = _validate(RulesengineUser, raw_existing)
partial_data = raw if isinstance(raw, dict) else raw.model_dump()
try:
user = partial_user(existing, partial_data)
except Exception as exc:
self._logger.error("Failed to merge partial user: %s", exc)
return

if message.message_type == MessageType.DELETE.value:
await self._delete_entity(
user.id, user.keys, _PREFIX_USER, self._user_cache, self._user_key_cache,
)
await self._cache_user(user)
self._notify_pending_user(user.keys or {}, user)
return

await self._cache_user(user)
self._notify_pending_user(user.keys or {}, user)
user = _validate(RulesengineUser, raw)

async with self._get_user_lock(user.id):
if message.message_type == MessageType.DELETE.value:
await self._delete_entity(
user.id, user.keys, _PREFIX_USER, self._user_cache, self._user_key_cache,
)
return

await self._cache_user(user)
self._notify_pending_user(user.keys or {}, user)

async def _handle_flags_message(self, message: DataStreamResp) -> None:
raw_flags = message.data
Expand Down Expand Up @@ -936,6 +962,20 @@ def _cleanup_pending_user(self, cache_keys: List[str], future: asyncio.Future[An
if not futures:
del self._pending_user[ck]

def _get_company_lock(self, company_id: str) -> asyncio.Lock:
lock = self._company_locks.get(company_id)
if lock is None:
lock = asyncio.Lock()
self._company_locks[company_id] = lock
return lock

def _get_user_lock(self, user_id: str) -> asyncio.Lock:
lock = self._user_locks.get(user_id)
if lock is None:
lock = asyncio.Lock()
self._user_locks[user_id] = lock
return lock

def _clear_pending_requests(self) -> None:
for futures in self._pending_company.values():
for fut in futures:
Expand Down
38 changes: 38 additions & 0 deletions src/schematic/datastream/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@ def partial_company(existing: RulesengineCompany, partial: Dict[str, Any]) -> Ru
Only fields present in `partial` are applied. Maps (keys, credit_balances)
merge additively. Metrics are upserted by (event_subtype, period, month_reset).
All other fields replace the existing value. The original is not mutated.

Partials don't carry refreshed entitlements, so when their derived fields
change in another part of the company we sync them here to match server
behavior:
- credit_remaining ← credit_balances[credit_id]
- usage ← metric value matching (event_name, metric_period, month_reset)
Both are skipped when the partial also sends entitlements wholesale.
"""
updates: Dict[str, Any] = {}
updated_balances: Optional[Dict[str, float]] = None
metrics_updated = False

for key, value in partial.items():
if key == "keys":
Expand All @@ -24,13 +33,42 @@ def partial_company(existing: RulesengineCompany, partial: Dict[str, Any]) -> Ru
merged_cb = dict(existing.credit_balances) if existing.credit_balances else {}
merged_cb.update(value or {})
updates["credit_balances"] = merged_cb
updated_balances = value or {}
elif key == "metrics":
incoming = _parse_metrics(value)
existing_metrics = [m.model_dump() for m in (existing.metrics or [])]
updates["metrics"] = _upsert_metrics(existing_metrics, incoming)
metrics_updated = True
else:
updates[key] = value

if (updated_balances or metrics_updated) and "entitlements" not in updates:
existing_ents = existing.entitlements or []
if existing_ents:
metrics_lookup: Dict[Tuple[str, str, str], int] = {}
if metrics_updated:
for m in updates["metrics"]:
if isinstance(m, dict):
metrics_lookup[(
m.get("event_subtype", ""),
m.get("period", "") or "",
m.get("month_reset", "") or "",
)] = m.get("value", 0)

new_ents = []
for ent in existing_ents:
ent_dict = ent.model_dump()
if updated_balances and ent.credit_id and ent.credit_id in updated_balances:
ent_dict["credit_remaining"] = updated_balances[ent.credit_id]
if metrics_lookup and ent.event_name:
period = ent.metric_period or "all_time"
month_reset = ent.month_reset or "first_of_month"
matched = metrics_lookup.get((ent.event_name, period, month_reset))
if matched is not None:
ent_dict["usage"] = matched
new_ents.append(ent_dict)
updates["entitlements"] = new_ents

base = existing.model_dump()
base.update(updates)
return RulesengineCompany.model_validate(base)
Expand Down
97 changes: 97 additions & 0 deletions tests/datastream/test_datastream_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,103 @@ async def test_partial_user_merges_keys(
assert user.keys == {"email": "orig@test.com", "slack_id": "U123"}


class TestDataStreamClientConcurrentUpdates:
"""Per-company locks must keep concurrent RMW callers from losing each
other's writes. The classic scenario: a host app calls
update_company_metrics in response to a track event, and the server
sends a partial with new credit_balances at the same moment."""

@pytest.fixture
def cache_with_delay(self) -> "_DelayedCache":
return _DelayedCache(get_delay=0.01)

async def _seed_company(self, client: DataStreamClient, credit: float = 100.0) -> None:
await client._handle_message(DataStreamResp(
data={
"id": "co_race",
"keys": {"slug": "race"},
"account_id": "acc_1",
"environment_id": "env_1",
"billing_product_ids": [],
"credit_balances": {"credit-1": credit},
"metrics": [
{"event_subtype": "credits_used", "period": "all_time", "month_reset": "first_of_month",
"value": 0, "account_id": "acc_1", "company_id": "co_race", "environment_id": "env_1",
"created_at": "2026-01-01T00:00:00Z"},
],
"plan_ids": [],
"plan_version_ids": [],
"rules": [],
"traits": [],
},
entity_type=EntityType.COMPANY.value,
message_type=MessageType.FULL.value,
))

async def test_concurrent_partial_and_metric_update_both_applied(
self, logger: logging.Logger, cache_with_delay: "_DelayedCache",
) -> None:
client = DataStreamClient(DataStreamClientOptions(
api_key="test-key",
logger=logger,
replicator_mode=True,
company_cache=cache_with_delay,
company_lookup_cache=cache_with_delay,
user_cache=cache_with_delay,
user_lookup_cache=cache_with_delay,
flag_cache=cache_with_delay,
))

await self._seed_company(client, credit=100.0)

# Concurrent: server-side partial updates credit_balances, host-side
# track event updates metrics. Both run interleaved at await points.
partial_msg = client._handle_message(DataStreamResp(
data={"credit_balances": {"credit-1": 25.0}},
entity_id="co_race",
entity_type=EntityType.COMPANY.value,
message_type=MessageType.PARTIAL.value,
))
metric_update = client.update_company_metrics(
{"slug": "race"}, "credits_used", 5,
)

await asyncio.gather(partial_msg, metric_update)

company = await client._get_company_from_cache({"slug": "race"})
assert company is not None
# Without per-id locks, one of these would be lost.
assert company.credit_balances == {"credit-1": 25.0}
assert company.metrics[0].value == 5

def test_lock_object_is_reused_across_calls(self, logger: logging.Logger) -> None:
client = DataStreamClient(DataStreamClientOptions(
api_key="test-key",
base_url="https://api.schematichq.com",
logger=logger,
))
a = client._get_company_lock("co_1")
b = client._get_company_lock("co_1")
c = client._get_company_lock("co_2")
assert a is b
assert a is not c


class _DelayedCache(MockCacheProvider):
"""Cache that sleeps inside get() so the asyncio scheduler interleaves
concurrent callers — needed to actually exercise read-modify-write races."""

def __init__(self, get_delay: float = 0.0) -> None:
super().__init__()
self._get_delay = get_delay

async def get(self, key: str) -> Optional[Any]:
value = self._store.get(key)
if self._get_delay:
await asyncio.sleep(self._get_delay)
return value


class TestDataStreamClientDeepCopy:
"""Spec test #12: Deep copy prevents mutation of cached entities."""

Expand Down
Loading
Loading