diff --git a/.gitignore b/.gitignore index ac5342dfb..966779d6f 100644 --- a/.gitignore +++ b/.gitignore @@ -208,6 +208,11 @@ cython_debug/ # Claude .claude/*.local.json +# Temporary and backup files +*.tmp +*.bak +*.backup + # Dashboard generated files agentlightning/dashboard/**/*.css agentlightning/dashboard/**/*.js diff --git a/agentlightning/store/__init__.py b/agentlightning/store/__init__.py index 6cc9be8a9..b94037342 100644 --- a/agentlightning/store/__init__.py +++ b/agentlightning/store/__init__.py @@ -2,6 +2,7 @@ from .base import LightningStore, LightningStoreCapabilities from .client_server import LightningStoreClient, LightningStoreServer +from .database import SqlLightningStore from .memory import InMemoryLightningStore from .threading import LightningStoreThreaded @@ -12,4 +13,5 @@ "LightningStoreServer", "InMemoryLightningStore", "LightningStoreThreaded", + "SqlLightningStore", ] diff --git a/agentlightning/store/base.py b/agentlightning/store/base.py index 76df9deac..a07f855a2 100644 --- a/agentlightning/store/base.py +++ b/agentlightning/store/base.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional, Sequence, TypedDict +from typing import Any, Dict, List, Literal, Optional, Sequence, Union, TypedDict from opentelemetry.sdk.trace import ReadableSpan @@ -267,7 +267,7 @@ async def add_otel_span( async def query_rollouts( self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None - ) -> List[Rollout]: + ) -> List[Union[Rollout, AttemptedRollout]]: """Retrieve rollouts filtered by status and/or explicit identifiers. Args: @@ -297,7 +297,7 @@ async def query_attempts(self, rollout_id: str) -> List[Attempt]: """ raise NotImplementedError() - async def get_rollout_by_id(self, rollout_id: str) -> Optional[Rollout]: + async def get_rollout_by_id(self, rollout_id: str) -> Optional[Union[Rollout, AttemptedRollout]]: """Fetch a rollout by identifier without mutating its state. Args: @@ -457,6 +457,8 @@ async def update_resources(self, resources_id: str, resources: NamedResources) - This API is typically used by algorithms that maintain mutable resources (e.g., model checkpoints) under a stable identifier. + If `resources_id` does not exist, implementations should add it as a new snapshot. + Args: resources_id: Identifier of the snapshot to replace. resources: Updated mapping of resource names to payloads. @@ -466,7 +468,6 @@ async def update_resources(self, resources_id: str, resources: NamedResources) - Raises: NotImplementedError: Subclasses must implement resource persistence. - ValueError: Implementations must raise when `resources_id` does not exist. """ raise NotImplementedError() diff --git a/agentlightning/store/database/__init__.py b/agentlightning/store/database/__init__.py new file mode 100644 index 000000000..60a61edfa --- /dev/null +++ b/agentlightning/store/database/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft. All rights reserved. + +from .sqlite import SqlLightningStore + +__all__ = [ + "SqlLightningStore", +] diff --git a/agentlightning/store/database/orm/__init__.py b/agentlightning/store/database/orm/__init__.py new file mode 100644 index 000000000..bb0c00b58 --- /dev/null +++ b/agentlightning/store/database/orm/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft. All rights reserved. + +from .attempt import AttemptInDB, SpanSeqIdInDB +from .base import ( + AttemptStatusUpdateMessage, + SqlAlchemyBase, +) +from .resources import ResourcesUpdateInDB +from .rollout import RolloutInDB +from .span import SpanInDB + +__all__ = [ + "SqlAlchemyBase", + "AttemptStatusUpdateMessage", + "RolloutInDB", + "AttemptInDB", + "ResourcesUpdateInDB", + "SpanSeqIdInDB", + "SpanInDB", +] diff --git a/agentlightning/store/database/orm/attempt.py b/agentlightning/store/database/orm/attempt.py new file mode 100644 index 000000000..a11a2cdc2 --- /dev/null +++ b/agentlightning/store/database/orm/attempt.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import hashlib +import logging +import time +import uuid +from dataclasses import InitVar +from typing import Any, Dict, List, Optional + +from sqlalchemy import JSON, Float, Integer, String, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import Mapped, mapped_column + +from agentlightning.types import Attempt + +from .base import AttemptStatusUpdateMessage, SqlAlchemyBase + +logger = logging.getLogger(__name__) + + +def _generate_attempt_id() -> str: + """We don't need that long because attempts are limited to rollouts.""" + short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:8] + return "at-" + short_id + + +class AttemptInDB(SqlAlchemyBase): + __tablename__ = "attempts" + + rollout_id: Mapped[str] = mapped_column(String, nullable=False) + attempt_id: Mapped[str] = mapped_column(String, primary_key=True, default_factory=_generate_attempt_id) + sequence_id: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + start_time: Mapped[float] = mapped_column(Float, default_factory=time.time, nullable=False) + end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None) + status: Mapped[str] = mapped_column(String, default="preparing", nullable=False) + worker_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None) + last_heartbeat_time: Mapped[Optional[float]] = mapped_column(Float, nullable=False, default_factory=time.time) + attempt_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=None) + + # addition columns for processing + max_duration: Mapped[Optional[float]] = mapped_column( + Float, nullable=True, default=None + ) # maximum duration allowed for this attempt in seconds + max_heartbeat_interval: Mapped[Optional[float]] = mapped_column( + Float, nullable=True, default=None + ) # maximum allowed heartbeat interval in seconds + + version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + __mapper_args__ = { + "version_id_col": version_id, + } + + def is_unresponsive(self, current_time: float) -> bool: + """Check if the attempt is unresponsive based on the last heartbeat time and max_heartbeat_interval.""" + if self.max_heartbeat_interval is None: + return False + if self.last_heartbeat_time is None: + return False + return (current_time - self.last_heartbeat_time) > self.max_heartbeat_interval + + def is_timed_out(self, current_time: float) -> bool: + """Check if the attempt has timed out based on the start time and max_duration.""" + if self.max_duration is None: + return False + return (current_time - self.start_time) > self.max_duration + + def as_attempt(self) -> Attempt: + return Attempt( + **self.model_dump( + exclude={"max_duration", "max_heartbeat_interval", "version_id"}, + mapper={"metadata": lambda obj: obj.attempt_metadata}, # type: ignore + ) + ) + + def _validate_status_message(self, msg: Dict[str, Any]) -> None: + """This function validates the status update message from caller. + Raises ValueError if the message is invalid. + """ + if "event" not in msg: + raise ValueError("Status update message must contain 'event' field.") + if "timestamp" not in msg: + msg["timestamp"] = time.time() + if msg["event"] not in [ + "user_update", # user update attempt status via dbstore.update_attempt() + "span_received", # new span received + "single_step_timeout", # single step timeout detected (from last span heartbeat) + "overall_timeout", # overall timeout detected + ]: + raise ValueError(f"Unsupported event type: {msg['event']}") + if msg["event"] == "user_update" and "new_status" not in msg: + raise ValueError("User update event must contain 'new_status' field.") + + def get_finished_statuses(self) -> List[str]: + """This function returns the list of statuses that are considered finished.""" + return [ + "succeeded", + "failed", + "timeout", + ] + + def update_status(self, msg: Dict[str, Any]) -> Optional[AttemptStatusUpdateMessage]: + """This function updates the status of the attempt based on the event. + Args: + msg: A dictionary containing the status update message. It must contain an "event" field, and optionally a "new_status" field. + More details about the message format can be found in the `_validate_status_message`() method. + current_time: The current time to use for updating timestamps. If None, uses time.time(). + Returns: + A dictionary containing the status update message: {"event": "attempt_status_updated", "old_status": old_status, "new_status": new_status}. + IF no meaningful status update is performed, returns None. + Raises: + ValueError: If the event is not recognized or the status transition is invalid. + NotImplementedError: If the event handling is not implemented for the current status. + RuntimeError: If the new status is not set after processing the event. + """ + self._validate_status_message(msg) + event = msg["event"] + current_time = msg.get("timestamp", time.time()) + old_status = self.status + new_status = msg.get("new_status", None) + + # Step 1: Determine the new status based on the event and current status + if event == "user_update": + if not new_status: + raise ValueError("new_status must be provided for user_update event.") + elif event == "span_received": + self.last_heartbeat_time = current_time + if old_status in ["preparing", "unresponsive", "running"]: + new_status = "running" + elif old_status in self.get_finished_statuses(): + logger.warning( + f"Span received after attempt is already in status {self.status}. No status update performed." + ) + return # no further status update needed + else: + raise NotImplementedError(f"Event {event} is not implemented for status {old_status}.") + elif event == "single_step_timeout": + if old_status in [ + "preparing", + "running", + ]: + new_status = "unresponsive" + else: + logger.warning( + f"Single step timeout detected but attempt is in status {self.status}. No status update performed." + ) + return # no further status update needed + elif event == "overall_timeout": + if old_status not in self.get_finished_statuses(): + new_status = "timeout" + else: + logger.warning( + f"Overall timeout detected but attempt is in status {self.status}. No status update performed." + ) + return # no further status update needed + else: + raise NotImplementedError(f"Event {event} is not implemented for status update.") + + # Step 2: Update the status + if not new_status: + raise RuntimeError( + f"new_status should not be {new_status} after processing event for {event} on status {old_status}." + ) + if new_status == old_status: + return # no status change + if new_status in self.get_finished_statuses(): + # when attempt is finished, set end_time + self.end_time = current_time + self.status = new_status + + # Step 3: Return the status update info for further processing + return AttemptStatusUpdateMessage( + attempt_id=self.attempt_id, + rollout_id=self.rollout_id, + timestamp=current_time, + old_status=old_status, + new_status=new_status, + ) + + @classmethod + async def get_latest_attempt_for_rollout( + cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str + ) -> Optional[Attempt]: + async with session_factory() as session: + async with session.begin(): + result = await session.scalars( + select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.desc()).limit(1) + ) + attempt_obj = result.one_or_none() + if attempt_obj is None: + return None + return attempt_obj.as_attempt() + + @classmethod + async def get_attempts_for_rollout( + cls: type[AttemptInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str + ) -> List[Attempt]: + async with session_factory() as session: + async with session.begin(): + result = await session.scalars( + select(cls).where(cls.rollout_id == rollout_id).order_by(cls.sequence_id.asc()) + ) + return [attempt.as_attempt() for attempt in result.all()] + + +class SpanSeqIdInDB(SqlAlchemyBase): + __tablename__ = "span_sequence" + + rollout_id: Mapped[str] = mapped_column(nullable=False, primary_key=True) + + # FIXME InMemoryLightningStore let all attempts under the same rollout share the same span sequence for sorting + # attempt_id: Mapped[str] = mapped_column(nullable=False) + attempt_id: InitVar[str] # not mapped column, just for type hinting + + current_sequence: Mapped[int] = mapped_column(default=1, nullable=False) + + # Versioning for optimistic concurrency control + version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + __mapper_args__ = { + "version_id_col": version_id, + # "primary_key": [rollout_id, attempt_id], + # "primary_key": [rollout_id], + } + + @classmethod + async def get_next_sequence_id( + cls: type[SpanSeqIdInDB], + session_factory: async_sessionmaker[AsyncSession], + rollout_id: str, + attempt_id: str, + external_seq_id: Optional[int] = None, + ) -> int: + """Get the next sequence ID with retries to handle race conditions. + IF external_seq_id is provided and is greater than current_sequence, set current_sequence to external_seq_id. + """ + async with session_factory() as session: + async with session.begin(): + seq_obj = await session.get(cls, rollout_id) + # seq_obj = await session.get(cls, [rollout_id, attempt_id]) + if seq_obj is None: + raise ValueError(f"Rollout {rollout_id} not found") + else: + current_seq = ( + external_seq_id + if external_seq_id is not None and external_seq_id > seq_obj.current_sequence + else seq_obj.current_sequence + ) + seq_obj.current_sequence = current_seq + 1 + await session.flush() + return current_seq diff --git a/agentlightning/store/database/orm/base.py b/agentlightning/store/database/orm/base.py new file mode 100644 index 000000000..b99253f7c --- /dev/null +++ b/agentlightning/store/database/orm/base.py @@ -0,0 +1,186 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import json +import time +from typing import Any, Callable, Dict, List, Optional + +from pydantic import BaseModel, Field, TypeAdapter, computed_field + +# from dataclasses import asdict +from sqlalchemy import JSON, TypeDecorator +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass + + +class SqlAlchemyBase(AsyncAttrs, MappedAsDataclass, DeclarativeBase): + pass + + def model_dump( + self, + exclude: set[str] | None = None, + mapper: Dict[str, Callable[["SqlAlchemyBase"], Any]] | None = None, + ) -> Dict[str, Any]: + """Dump the SQLAlchemy model to a dictionary. + Args: + exclude: set[str] + The set of field names to exclude. + mapper: Dict[str, Callable[[SqlAlchemyBase], Any]] + A mapping from field names to functions that take the model instance and return the value to be used for that field. + If the key is "*", the function should return a dictionary of additional fields to be added to the output. + Returns: + Dict[str, Any]: The dumped model as a dictionary. + """ + exclude = exclude or set() + mapper = mapper or {} + dic = {k: getattr(self, k) for k in self.__table__.columns.keys() if k not in exclude} + for k, func in mapper.items(): + if k == "*": + dic.update(func(self)) + else: + dic[k] = func(self) + return dic + + +class PydanticInDB(TypeDecorator[BaseModel]): + """Custom SQLAlchemy type to store pydantic.BaseModel as JSON in the database. + Attributes: + target_type: type[BaseModel], the type of the pydantic model to be stored. + """ + + impl = JSON + target_type: type[BaseModel] | None = None + + def process_bind_param(self, value: BaseModel | None, dialect: Any) -> Optional[str]: + if value is None: + return None + if self.target_type is not None: + return TypeAdapter(self.target_type).validate_python(value).model_dump_json() # type: ignore + return json.dumps(value) + + def process_result_value(self, value: Optional[str], dialect: Any) -> Optional[BaseModel]: + if value is None: + return None + if self.target_type is not None: + return TypeAdapter(self.target_type).validate_json(value) # type: ignore + dic = json.loads(value) + return dic # type: ignore + + +class PydanticListInDB(TypeDecorator[list[BaseModel]]): + """Custom SQLAlchemy type to store List[pydantic.BaseModel] as JSON in the database. + Attributes: + value_type: type[BaseModel], the type of the pydantic model to be stored in the list. + """ + + impl = JSON + value_type: type[BaseModel] | None = None + + def process_bind_param(self, value: List[BaseModel] | None, dialect: Any) -> Optional[str]: + if value is None: + return None + if self.value_type is not None: + lst = [TypeAdapter(self.value_type).validate_python(v).model_dump() for v in value] + return json.dumps(lst) + raise ValueError("target_type must be set for PydanticListInDB") + + def process_result_value(self, value: Optional[str], dialect: Any) -> Optional[List[BaseModel]]: + if value is None: + return None + if self.value_type is not None: + dic = json.loads(value) + return [TypeAdapter(self.value_type).validate_python(v) for v in dic] # type: ignore + raise ValueError("target_type must be set for PydanticListInDB") + + +class NamedDictBase(TypeDecorator[Dict[str, Any]]): + """Custom SQLAlchemy type to store Dict[str, pydantic.BaseModel] as JSON in the database. + Attributes: + target_alias: type[Dict[str, BaseModel]], the alias type of the dict. + value_type: type[BaseModel], the type of the values in the dict. + + For example, given NamedResources = Dict[str, ResourceUnion], + we can define NamedDictBase with target_alias=NamedResources and target_type=ResourceUnion. + """ + + impl = JSON + target_alias: type | None = None + value_type: type[BaseModel] | Any = None + + def process_bind_param(self, value: Dict[str, Any] | None, dialect: Any) -> Optional[str]: + if value is None: + return None + + # ignore target_alias for when dumping because Dict is not a pydantic model + if self.value_type is not None: + dic = { + k: TypeAdapter(self.value_type).validate_python(v).model_dump() if isinstance(v, BaseModel) else v + for k, v in value.items() + } + return json.dumps(dic) + dic = {k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in value.items()} + return json.dumps(dic) + + def process_result_value(self, value: Optional[str], dialect: Any) -> Optional[Dict[str, Any]]: + if value is None: + return None + if self.target_alias is not None: + return TypeAdapter(self.target_alias).validate_json(value) # type: ignore + if self.value_type is not None: + dic = json.loads(value) + return {k: TypeAdapter(self.value_type).validate_python(v) for k, v in dic.items()} # type: ignore + return json.loads(value) + + +class DatabaseRuntimeError(Exception): + """Raised when a runtime error occurs during database operations. + Particularly used when the execution of a query fails. + """ + + pass + + +class RaceConditionError(Exception): + """Raised when a race condition is detected during database operations.""" + + pass + + +class NoRolloutToDequeueError(Exception): + """Raised when there is no rollout available to dequeue.""" + + pass + + +class AttemptStatusUpdateMessage(BaseModel): + attempt_id: str + rollout_id: str + timestamp: float = Field(default_factory=time.time) + old_status: Optional[str] = None + new_status: str + + @computed_field + @property + def event(self) -> str: + return "attempt_status_update" + + @computed_field + @property + def is_failed(self) -> bool: + return self.new_status in ["failed", "timeout", "unresponsive"] + + @computed_field + @property + def is_succeeded(self) -> bool: + return self.new_status == "succeeded" + + @computed_field + @property + def is_finished(self) -> bool: + return self.is_failed or self.is_succeeded + + @computed_field + @property + def is_running(self) -> bool: + return self.new_status in ["running", "preparing"] diff --git a/agentlightning/store/database/orm/resources.py b/agentlightning/store/database/orm/resources.py new file mode 100644 index 000000000..1a045444b --- /dev/null +++ b/agentlightning/store/database/orm/resources.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import hashlib +import time +import uuid +from typing import Optional + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import Mapped, mapped_column + +from agentlightning.types import NamedResources, ResourcesUpdate + +from .base import NamedDictBase, SqlAlchemyBase + + +def _generate_resources_id() -> str: + short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:12] + return "rs-" + short_id + + +class NamedResourcesInDB(NamedDictBase): + """Custom SQLAlchemy type to store NamedResources as JSON in the database.""" + + target_alias = NamedResources + + +class ResourcesUpdateInDB(SqlAlchemyBase): + __tablename__ = "resources" + resources: Mapped[NamedResources] = mapped_column( + NamedResourcesInDB, nullable=False + ) # JSON serialized, convert to NamedResources when needed + resources_id: Mapped[str] = mapped_column(primary_key=True, default_factory=_generate_resources_id) + create_time: Mapped[float] = mapped_column(nullable=False, default_factory=time.time) + update_time: Mapped[float] = mapped_column(nullable=False, default_factory=time.time, onupdate=time.time) + version: Mapped[int] = mapped_column(nullable=False, default=1) + + __mapper_args__ = { + "version_id_col": version, + } + + @classmethod + async def get_resources_by_id( + cls, session_factory: async_sessionmaker[AsyncSession], resources_id: str + ) -> Optional[ResourcesUpdate]: + async with session_factory() as session: + async with session.begin(): + obj = await session.get(cls, resources_id) + if obj is None: + return None + return obj.as_resources_update() + + def as_resources_update(self) -> ResourcesUpdate: + return ResourcesUpdate(**self.model_dump()) diff --git a/agentlightning/store/database/orm/rollout.py b/agentlightning/store/database/orm/rollout.py new file mode 100644 index 000000000..45735d4c4 --- /dev/null +++ b/agentlightning/store/database/orm/rollout.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import hashlib +import logging +import time +import uuid +from typing import Any, Dict, List, Optional, cast + +from sqlalchemy import JSON, Float, Integer, String, and_, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import Mapped, mapped_column + +from agentlightning.types import AttemptedRollout, Rollout, RolloutConfig, RolloutStatus + +from ...base import is_finished, is_queuing +from .attempt import AttemptInDB +from .base import AttemptStatusUpdateMessage, PydanticInDB, SqlAlchemyBase + +logger = logging.getLogger(__name__) + + +def _generate_rollout_id() -> str: + short_id = hashlib.sha1(uuid.uuid4().bytes).hexdigest()[:12] + return "ro-" + short_id + + +class RolloutConfigInDB(PydanticInDB): + """Custom SQLAlchemy type to store RolloutConfig as JSON in the database.""" + + target_type = RolloutConfig + + +class RolloutInDB(SqlAlchemyBase): + __tablename__ = "rollouts" + + input: Mapped[Any] = mapped_column(JSON, nullable=False) + rollout_id: Mapped[str] = mapped_column(String, primary_key=True, default_factory=_generate_rollout_id) + start_time: Mapped[float] = mapped_column(Float, default_factory=time.time, nullable=False) + end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True, default=None) + mode: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None) + resources_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, default=None) + status: Mapped[RolloutStatus] = mapped_column(String, default="queuing", nullable=False) + config: Mapped[RolloutConfig] = mapped_column( + RolloutConfigInDB, nullable=False, default_factory=RolloutConfig + ) # JSON serialized, convert to RolloutConfig when needed + rollout_metadata: Mapped[Optional[Dict[str, Any]]] = mapped_column( + JSON, nullable=True, default=None + ) # JSON serialized, convert to Dict when needed + + # Attempt-related helper methods can be added here if needed + num_attempts: Mapped[int] = mapped_column( + Integer, default=0, nullable=False + ) # number of attempts made for this rollout + enqueue_time: Mapped[Optional[float]] = mapped_column( + Float, nullable=True, default_factory=time.time + ) # time when the rollout was enqueued (for FIFO scheduling) + latest_attempt_id: Mapped[Optional[str]] = mapped_column( + String, nullable=True, default=None + ) # the attempt_id of the latest attempt + + # use optimistic concurrency control + version_id: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + __mapper_args__ = { + "version_id_col": version_id, + } + + def __post_init__(self): + if self.status not in ["queuing", "running", "succeeded", "failed", "requeuing"]: + raise ValueError(f"Invalid rollout status: {self.status}") + + def as_rollout(self) -> Rollout: + return Rollout( + **self.model_dump( + exclude={"rollout_metadata", "num_attempts", "enqueue_time", "latest_attempt_id", "version_id"}, + mapper={ + "metadata": lambda obj: obj.rollout_metadata, # type: ignore + "config": lambda obj: obj.config if obj.config is not None else RolloutConfig(), # type: ignore + }, + ) + ) + + def _validate_status_message(self, msg: Dict[str, str]) -> None: + """Validate the status update message. + Raises: + ValueError: If the message is invalid. + """ + if "event" not in msg: + raise ValueError("Status update message must contain 'event' field.") + event = msg["event"] + if event not in [ + "attempt_status_update", # from attempt status update + "user_update", # from user-initiated update + ]: + raise ValueError(f"Invalid event type in status update message: {event}") + if event == "user_update": + if "new_status" not in msg: + raise ValueError("Status update message for event 'user_update' must contain 'new_status' field.") + if event == "attempt_status_update": + # leverage AttemptStatusUpdateMessage for validation + pass + + async def update_status(self, msg: Dict[str, Any] | AttemptStatusUpdateMessage) -> None: + """Update the rollout status based on the provided message. + Args: + msg (Dict[str, str]): The status update message. Refer to `_validate_status_message` for the expected format. + current_time (Optional[float]): The current time to set end_time or enqueue_time if needed. + """ + if isinstance(msg, dict): + self._validate_status_message(msg) + event = msg["event"] + current_time = msg.get("timestamp", time.time()) + else: + event = msg.event + current_time = msg.timestamp + + old_status = self.status + new_status = self.status # initialize new_status with old_status + + # Step 1: Determine the new status based on the event + if event == "user_update": + assert isinstance(msg, dict) + new_status = msg["new_status"] + elif event == "attempt_status_update": + msg = AttemptStatusUpdateMessage(**msg) if isinstance(msg, dict) else msg + if msg.attempt_id == self.latest_attempt_id: + new_status = msg.new_status # directly take the latest attempt status + if msg.is_succeeded: + new_status = "succeeded" + elif msg.is_failed: + # no other attempts running, decide whether to requeue or fail + config = self.config + if config.max_attempts > self.num_attempts and msg.new_status in config.retry_condition: + new_status = "requeuing" + else: + new_status = "failed" + # elif msg.is_running and old_status in ["failed", "requeuing"]: + # new_status = "running" + else: + # ignore attempts from old attempts + new_status = old_status + + # Step 2: Update the status if it has changed and handle follow-up actions + if new_status is None: + raise RuntimeError( + f"New status of `{old_status}` and `{self.latest_attempt_id}` could not be determined from the message {msg}." + ) + if new_status == old_status: + return + self.status = cast(RolloutStatus, new_status) + + if is_finished(self): # type: ignore + self.end_time = current_time + if is_queuing(self): # type: ignore + self.enqueue_time = current_time + # When requeuing, we do not reset latest_attempt_id or num_attempts, + # as they should persist across requeues. + + @classmethod + async def get_rollout_by_id( + cls: type[RolloutInDB], session_factory: async_sessionmaker[AsyncSession], rollout_id: str + ) -> Optional[Rollout | AttemptedRollout]: + """Query a specific rollout from the database.""" + async with session_factory() as session: + async with session.begin(): + rollout_obj = await session.get(cls, rollout_id) + if rollout_obj is None: + return None + if rollout_obj.latest_attempt_id is not None: + attempt_obj = await session.get(AttemptInDB, rollout_obj.latest_attempt_id) + if attempt_obj is not None: + return AttemptedRollout( + **rollout_obj.as_rollout().model_dump(), attempt=attempt_obj.as_attempt() + ) + return rollout_obj.as_rollout() + + @classmethod + async def query_rollouts( + cls: type[RolloutInDB], + session_factory: async_sessionmaker[AsyncSession], + *, + statuses: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + ) -> List[RolloutInDB]: + """ + Query rollouts from the database with optional filters. + """ + async with session_factory() as session: + async with session.begin(): + conditions: list[Any] = [] + if statuses is not None: + conditions.append(cls.status.in_(statuses)) + if ids is not None: + conditions.append(cls.rollout_id.in_(ids)) + query = select(cls) + if conditions: + query = query.where(and_(*conditions)) + result = await session.scalars(query) + rollout_objs = result.all() + return list(rollout_objs) diff --git a/agentlightning/store/database/orm/span.py b/agentlightning/store/database/orm/span.py new file mode 100644 index 000000000..dc13897cc --- /dev/null +++ b/agentlightning/store/database/orm/span.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from sqlalchemy import JSON, Float, Integer, String +from sqlalchemy.orm import Mapped, mapped_column + +logger = logging.getLogger(__name__) + +from agentlightning.types.tracer import ( + Attributes, + AttributeValue, + Event, + Link, + OtelResource, + Span, + SpanContext, + TraceStatus, +) + +from .base import NamedDictBase, PydanticInDB, PydanticListInDB, SqlAlchemyBase + + +class TraceStatusInDB(PydanticInDB): + target_type = TraceStatus + + +class AttributesInDB(NamedDictBase): + target_alias = None # type: ignore + value_type = AttributeValue + + +class EventListInDB(PydanticListInDB): + value_type = Event + + +class LinkListInDB(PydanticListInDB): + value_type = Link + + +class SpanContextInDB(PydanticInDB): + target_type = SpanContext + + +class OtelResourceInDB(PydanticInDB): + target_type = OtelResource + + +class SpanInDB(SqlAlchemyBase): + __tablename__ = "spans" + + rollout_id: Mapped[str] = mapped_column(String, nullable=False) # The rollout which this span belongs to. + attempt_id: Mapped[str] = mapped_column(String, nullable=False) # The attempt which this span belongs to. + sequence_id: Mapped[int] = mapped_column( + Integer, nullable=False + ) # The ID to make spans ordered within a single attempt. + + # Current ID (in hex, formatted via trace_api.format_*) + trace_id: Mapped[str] = mapped_column( + String, nullable=False + ) # one rollout can have traces coming from multiple places + + # FIXME: span_id may be not unique across different attempts/rollouts, use (rollout_id, attempt_id, sequence_id) as the primary key instead + span_id: Mapped[str] = mapped_column( + String, nullable=False + ) # The span ID of the span. This ID comes from the OpenTelemetry span ID generator. + parent_id: Mapped[Optional[str]] = mapped_column(String, nullable=True) # The parent span ID of the span. + + # Core ReadableSpan fields + name: Mapped[str] = mapped_column(String, nullable=False) + status: Mapped[TraceStatus] = mapped_column(TraceStatusInDB, nullable=False) + attributes: Mapped[Attributes] = mapped_column(AttributesInDB, nullable=False) + events: Mapped[List[Event]] = mapped_column(EventListInDB, nullable=False) + links: Mapped[List[Link]] = mapped_column(LinkListInDB, nullable=False) + + # Timestamps + start_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + end_time: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + + # Other parsable fields + context: Mapped[Optional[SpanContext]] = mapped_column(SpanContextInDB, nullable=True) + parent: Mapped[Optional[SpanContext]] = mapped_column(SpanContextInDB, nullable=True) + resource: Mapped[OtelResource] = mapped_column(OtelResourceInDB, nullable=False) + + # extra fields can be added here as needed + extra: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True, default=None) + + __mapper_args__ = { + "primary_key": [rollout_id, attempt_id, sequence_id], + } + + def as_span(self) -> Span: + return Span( + **self.model_dump( + exclude={"extra"}, + mapper={"*": lambda obj: obj.extra or {}}, # type: ignore + ) + ) diff --git a/agentlightning/store/database/retry_helper.py b/agentlightning/store/database/retry_helper.py new file mode 100644 index 000000000..600b6e6f5 --- /dev/null +++ b/agentlightning/store/database/retry_helper.py @@ -0,0 +1,316 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""This file contains a configurable async retry decorator based on exception type.""" + +from __future__ import annotations + +import functools +import importlib +import logging +import random +from dataclasses import asdict, dataclass +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Optional, Type, TypeVar + +from tenacity import AsyncRetrying, RetryCallState, retry_if_exception + +# ---------------------------------------------------------------------- +# Logging setup +# ---------------------------------------------------------------------- +logger = logging.getLogger("async_retry") +logging.basicConfig(level=logging.INFO) + +# ---------------------------------------------------------------------- +# Type alias for async callable +# ---------------------------------------------------------------------- +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +# ---------------------------------------------------------------------- +# Dataclass definition for retry configuration +# ---------------------------------------------------------------------- +@dataclass +class RetryStrategy: + """Configuration schema for retry behavior of a specific exception type. + The wait time before $n$-th retry is calculated as ($n$ starts from 1): + wait_time = wait_seconds * (backoff ** (n - 1)) * (1 + jitter * U(-1, 1)) + where U(-1, 1) is a uniform random variable between -1 and 1. + Attributes: + max_attempts: Maximum number of attempts before giving up. Default is 1 (no retry). None means infinite retries. + max_retry_delay: Optional maximum delay between retries in seconds. Default is None (no limit). + wait_seconds: Base wait time in seconds before the first retry. Default is 0.0. + max_wait_seconds: Maximum wait time in seconds between retries. Default is None (no limit). + backoff: Exponential backoff multiplier. Default is 1.0 (no backoff). + jitter: Fractional (relative) jitter to apply to wait time. Default is 0.0 (no jitter). + log: Whether to log each retry attempt. Default is False. + """ + + max_attempts: Optional[int] = 1 + max_retry_delay: Optional[float] = None + wait_seconds: float = 0.0 + max_wait_seconds: Optional[float] = None + backoff: float = 1.0 + jitter: float = 0.0 + log: bool = False + + def asdict(self) -> Dict[str, Any]: + return asdict(self) + + def __post_init__(self): + if self.max_attempts is not None and self.max_attempts < 1: + raise ValueError("max_attempts must be at least 1 or None for infinite retries") + if self.wait_seconds < 0.0: + raise ValueError("wait_seconds must be non-negative") + if self.backoff < 1.0: + raise ValueError("backoff must be at least 1.0") + if not (0.0 <= self.jitter <= 1.0): + raise ValueError("jitter must be between 0.0 and 1.0") + + def _get_wait_time(self, attempt_number: int) -> float: + """Calculate the wait time before the given attempt number.""" + base_wait = self.wait_seconds * (self.backoff ** (attempt_number - 1)) + if self.jitter > 0: + delta = base_wait * self.jitter + wait_time = random.uniform(base_wait - delta, base_wait + delta) + else: + wait_time = base_wait + wait_time = max(wait_time, 0.0) + if self.max_wait_seconds is not None: + wait_time = min(wait_time, self.max_wait_seconds) + return wait_time + + def wait_func(self, retry_state: RetryCallState) -> float: + """Tenacity wait function based on the given strategy.""" + return self._get_wait_time(retry_state.attempt_number) + + def stop_func(self, retry_state: RetryCallState) -> bool: + """Tenacity stop function based on the given strategy.""" + if self.max_attempts is not None: + if retry_state.attempt_number >= self.max_attempts: + return True + if self.max_retry_delay is not None: + time_since_start = retry_state.seconds_since_start + if time_since_start is None: + logger.warning("Cannot determine time since start for retry stop condition.") + return False + if time_since_start >= self.max_retry_delay: + return True + return False + + async def before_sleep(self, retry_state: RetryCallState): + """Tenacity before_sleep callback to log retry attempts.""" + if self.log: + exc = retry_state.outcome.exception() if retry_state.outcome else None + next_wait = self.wait_func(retry_state) + logger.warning( + f"[Retry] {exc.__class__.__name__}: attempt={retry_state.attempt_number}, " + f"next_wait={next_wait:.2f}s, message={exc}" + ) + + +# ---------------------------------------------------------------------- +# Exception Registry — shared, reusable, and extensible +# ---------------------------------------------------------------------- +class ExceptionRegistry: + """ + Global registry for mapping string keys to Exception classes. + Supports dynamic registration and fallback to importlib. + """ + + _registry: Dict[str, Type[BaseException]] = {} + + @classmethod + def register(cls, name: str, exc_type: Type[BaseException] | None = None) -> None: + """Register an exception type under a given name.""" + if name in cls._registry: + logger.warning(f"Overwriting existing exception registration for name '{name}'.") + if exc_type is None: + # Try to dynamically import the exception class + try: + module_name, class_name = name.rsplit(".", 1) + module = importlib.import_module(module_name) + exc_type = getattr(module, class_name) + if exc_type is None: + raise TypeError(f"{name} is not an Exception type.") + except (ImportError, AttributeError, ValueError, TypeError) as e: + raise ValueError(f"Cannot resolve exception type for name '{name}': {e}") + cls._registry[name] = exc_type + + @classmethod + def all_registered(cls) -> Dict[str, Type[BaseException]]: + """Return the current registry mapping.""" + return dict(cls._registry) + + @classmethod + def clear(cls): + """Clear all registered exception mappings.""" + cls._registry.clear() + + +# ---------------------------------------------------------------------- +# Async Retry Decorator +# ---------------------------------------------------------------------- +class AsyncTypeBasedRetry: + """ + A configurable async retry decorator based on exception type. + + - Takes configuration as a Dict[str, RetryStrategy]. + - Provides `from_json()` for quick loading. + - Uses a global ExceptionRegistry to resolve exception names. + """ + + def __init__(self, strategies: Dict[str, RetryStrategy], default_strategy: RetryStrategy | None = None): + self.exception_map = self._build_exception_map(strategies) + self.default_strategy = default_strategy or RetryStrategy() + + # ------------------------------------------------------------------ + # Build exception map + # ------------------------------------------------------------------ + def _build_exception_map(self, strategies: Dict[str, RetryStrategy]) -> Dict[Type[BaseException], RetryStrategy]: + mapping: Dict[Type[BaseException], RetryStrategy] = {} + all_registered = ExceptionRegistry.all_registered() + for name, strat in strategies.items(): + if name in all_registered: + exc_type = all_registered[name] + else: + raise ValueError(f"Exception type '{name}' is not registered in ExceptionRegistry.") + mapping[exc_type] = strat + return mapping + + # ------------------------------------------------------------------ + # Retry core logic + # ------------------------------------------------------------------ + def get_exception(self, retry_state: RetryCallState) -> Optional[BaseException]: + """Get the exception from the given retry state, if any.""" + return retry_state.outcome.exception() if retry_state.outcome else None + + def get_strategy(self, retry_state: RetryCallState) -> Optional[RetryStrategy]: + """Get the RetryStrategy for the exception in the given retry state. + IF no matching exception type is found, return the default strategy. + IF no exception is found, return None. + """ + exc = self.get_exception(retry_state) + if exc is None: + return None + for exc_type, strat in self.exception_map.items(): + if isinstance(exc, exc_type): + return strat + return self.default_strategy + + def should_retry(self, exc: BaseException) -> bool: + return any(isinstance(exc, t) for t in self.exception_map.keys()) + + def wait_func(self, retry_state: RetryCallState) -> float: + strat = self.get_strategy(retry_state) + if strat is None: + return 0.0 + return strat.wait_func(retry_state) + + def stop_func(self, retry_state: RetryCallState) -> bool: + strat = self.get_strategy(retry_state) + if strat is None: + return False + return strat.stop_func(retry_state) + + async def before_sleep(self, retry_state: RetryCallState): + strat = self.get_strategy(retry_state) + if strat is None: + return + await strat.before_sleep(retry_state) + + # ------------------------------------------------------------------ + # Decorator entry point + # ------------------------------------------------------------------ + def __call__(self, func: F) -> F: + @functools.wraps(func) + async def wrapper(*args, **kwargs): # type: ignore + async for attempt in AsyncRetrying( + retry=retry_if_exception(lambda e: self.should_retry(e)), + wait=self.wait_func, + stop=self.stop_func, + before_sleep=self.before_sleep, + reraise=True, + ): + with attempt: + return await func(*args, **kwargs) + + return wrapper # type: ignore + + +# ---------------------------------------------------------------------- +# A configurable async retrier for any code block +# ---------------------------------------------------------------------- + + +class AsyncRetryBlock: + """ + Async retry helper for a single exception type and strategy. + + Usage: + async with AsyncRetryBlock(strategy): + await some_async_function() + """ + + def __init__(self, strategy: RetryStrategy, **retry_kwargs): # type: ignore + self.strategy = strategy + self._retryer = AsyncRetrying( + wait=self._wait_func, + stop=self._stop_func, + before_sleep=self._before_sleep, + **retry_kwargs, # type: ignore + ) + + async def run(self, coro: Callable[..., Awaitable[Any]]) -> Any: + """Run the given coroutine with retries according to the strategy. + For example: + async def my_coro(): + ... + retry_block = AsyncRetryBlock(strategy) + result = await retry_block.run(my_coro) + """ + async for attempt in self._retryer: + with attempt: + return await coro() + + # ------------------------------------------------------------------ + # Core: async iterator interface + # ------------------------------------------------------------------ + def __aiter__(self) -> AsyncIterator[Any]: + """Return an async iterator that yields retry attempts. + Usage: + async for attempt in retry_block: + with attempt: + await some_async_function() + """ + return self._retryer.__aiter__() + + # ------------------------------------------------------------------ + # Context manager entry + # ------------------------------------------------------------------ + async def __aenter__(self): + self._aiter = self._retryer.__aiter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): # type: ignore + # Consume the retry iterator + try: + # If exception occurred, let the retryer handle it + async for attempt in self._aiter: + with attempt: + if exc_val: + raise exc_val + except Exception: + # Allow exception to propagate if retries exhausted + pass + return False + + # ------------------------------------------------------------------ + # Strategy function + # ------------------------------------------------------------------ + def _wait_func(self, retry_state: RetryCallState) -> float: + return self.strategy.wait_func(retry_state) + + def _stop_func(self, retry_state: RetryCallState) -> bool: + return self.strategy.stop_func(retry_state) + + async def _before_sleep(self, retry_state: RetryCallState): + await self.strategy.before_sleep(retry_state) diff --git a/agentlightning/store/database/sqlite.py b/agentlightning/store/database/sqlite.py new file mode 100644 index 000000000..65ee02b66 --- /dev/null +++ b/agentlightning/store/database/sqlite.py @@ -0,0 +1,685 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from datetime import datetime, timedelta +from typing import Any, Dict, List, Literal, Optional, Sequence, Union + +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.interval import IntervalTrigger +from opentelemetry.sdk.trace import ReadableSpan +from pydantic import BaseModel +from sqlalchemy import and_, or_, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm.exc import StaleDataError +from tenacity import RetryError + +from agentlightning.types import ( + Attempt, + AttemptedRollout, + AttemptStatus, + NamedResources, + ResourcesUpdate, + Rollout, + RolloutConfig, + RolloutStatus, + Span, + TaskInput, +) + +from ..base import UNSET, LightningStore, Unset, is_finished +from .orm import ( + AttemptInDB, + ResourcesUpdateInDB, + RolloutInDB, + SpanInDB, + SpanSeqIdInDB, + SqlAlchemyBase, +) +from .retry_helper import AsyncRetryBlock, AsyncTypeBasedRetry, ExceptionRegistry, RetryStrategy + +logger = logging.getLogger(__name__) + +# TODO add periodic cleanup of old rollouts/attempts/spans + +ExceptionRegistry.register("sqlalchemy.orm.exc.StaleDataError") +ExceptionRegistry.register("sqlalchemy.exc.OperationalError") + +db_retry = AsyncTypeBasedRetry( + { + "sqlalchemy.exc.OperationalError": RetryStrategy( + max_attempts=5, wait_seconds=1, backoff=1.5, jitter=0.3, log=True + ), + "sqlalchemy.orm.exc.StaleDataError": RetryStrategy( + max_attempts=100, wait_seconds=1e-3, backoff=1.0, jitter=0.1, log=True + ), + } +) + + +class _WaitForRolloutsCompleted(Exception): + """Internal exception to signal that not all rollouts have completed yet.""" + + pass + + +class BackgroundTaskConfig(BaseModel): + name: str # unique name for the task + method: str # method name to call, currently only supports methods of SqlLightningStore + interval: Dict[Literal["seconds", "minutes", "hours"], float] # interval for the task + is_async: bool = True # whether the task method is async, default to True + + +class SqlLightningStore(LightningStore): + """ + A LightningStore implementation that uses a database backend to store and manage rollouts and attempts. + The database backend is expected to support asynchronous operations. + The store uses SQLAlchemy ORM models to interact with the database + Args: + database_url (string): + The database URL for connecting to the database. + If None, will read from the 'DATABASE_URL' environment variable. + retry_for_waiting (RetryStrategy): + Retry strategy for polling when waiting for rollouts to complete. + If None, a default strategy will be used. + wait_for_nonexistent_rollout (Bool): + If True, when waiting for rollouts, will wait for all specified rollouts to complete, including non-existing ones. + If False, will ignore non-existing rollouts as completed. (Default: False) + background_tasks_cfg (list[Dict[str, Any]]): + The configuration for in-process periodic tasks, following the definition of `BackgroundTaskConfig`. + IF not provided (None as default), the dbstore will incorporate a default set of periodic tasks as follows: + [ + BackgroundTaskConfig(name="check_attempt_timeout", method="check_attempt_timeout", interval={"seconds": 10.0}), + ] + To disable all periodic tasks, provide an empty list `[]`. + Note: + Explicitly use async `start()` and `stop()` methods to manage the database connection lifecycle. + """ + + def __init__( + self, + database_url: Optional[str] = None, + *, + retry_for_waiting: Optional[dict[str, Any] | RetryStrategy] = None, + wait_for_nonexistent_rollout: bool = False, + background_tasks_cfg: list[Dict[str, Any]] | None = None, + ) -> None: + super().__init__() + if database_url is None: + database_url = os.getenv("DATABASE_URL", None) + if database_url is None: + raise ValueError( + "A database URL must be provided either via the 'database_url' parameter or the 'DATABASE_URL' environment variable." + ) + + self._engine = create_async_engine(database_url, echo=False) + self._async_session = async_sessionmaker(self._engine, expire_on_commit=False) + + self._latest_resources_id = None + + # special handling for retry strategy + retry_for_waiting = retry_for_waiting or RetryStrategy( + max_attempts=10, # set a limit for retries if timeout is specified, otherwise will change to None later + max_retry_delay=None, # set later + wait_seconds=10.0, # poll every 10 seconds + max_wait_seconds=60.0, # at most wait 60 seconds between retries + backoff=1.0, + jitter=0.0, + log=True, + ) + self.retry_for_waiting = ( + retry_for_waiting if isinstance(retry_for_waiting, RetryStrategy) else RetryStrategy(**retry_for_waiting) + ) + self.wait_for_nonexistent_rollout = wait_for_nonexistent_rollout + + # setup in-process periodic tasks + if background_tasks_cfg is None: + self.background_tasks_cfg = [ + BackgroundTaskConfig( + name="check_attempt_timeout", method="check_attempt_timeout", interval={"seconds": 10.0} + ), + ] + else: + self.background_tasks_cfg = [BackgroundTaskConfig(**cfg) for cfg in background_tasks_cfg] + self._background_scheduler = BackgroundScheduler() + + async def start(self): + async with self._engine.begin() as conn: + await conn.run_sync(SqlAlchemyBase.metadata.create_all) + for task_cfg in self.background_tasks_cfg: + self.add_background_task(task_cfg, to_scheduler_only=True) + self._background_scheduler.start() # type: ignore + + async def stop(self): + await self._engine.dispose() + self._background_scheduler.shutdown() # type: ignore + + def add_background_task( + self, task_cfg: Dict[str, Any] | BackgroundTaskConfig, to_scheduler_only: bool = False + ) -> None: + """Add a new periodic background task to the scheduler. + Args: + task_cfg (Dict[str, Any] | BackgroundTaskConfig): The configuration for the background task. + to_scheduler_only (bool): If True, only add the task to the scheduler without updating the configuration list. + Raises: + ValueError: If the task method is not defined in SqlLightningStore. + """ + config = task_cfg if isinstance(task_cfg, BackgroundTaskConfig) else BackgroundTaskConfig(**task_cfg) + if not to_scheduler_only: + # check existing tasks + for existing in self.background_tasks_cfg: + if existing.name == config.name: + logger.warning( + f"Background task {config.name} is already scheduled, will update its configuration." + ) + self.background_tasks_cfg.append(config) + delta_t = timedelta(**config.interval) + if not hasattr(self, config.method): + raise ValueError(f"Periodic task method {config.method} is not defined in SqlLightningStore.") + if config.is_async: + func = lambda: asyncio.run(getattr(self, config.method)()) + else: + func = lambda: getattr(self, config.method)() + + self._background_scheduler.add_job( # type: ignore + func=func, + trigger=IntervalTrigger(**config.interval), # type: ignore + name=f"SqlLightningStore.{config.name}", + replace_existing=True, + next_run_time=datetime.now() + delta_t, # schedule the first run after the interval + ) + + # ------------------------------------------------------ + # Public methods defined in LightningStore + # ------------------------------------------------------ + + @db_retry + async def start_rollout( + self, + input: TaskInput, + mode: Literal["train", "val", "test"] | None = None, + resources_id: str | None = None, + config: RolloutConfig | None = None, + metadata: Dict[str, Any] | None = None, + ) -> AttemptedRollout: + async with self._async_session() as session: + async with session.begin(): + rollout_obj = RolloutInDB( + input=input, + mode=mode, + resources_id=resources_id or self._latest_resources_id, + status="queuing", + config=config or RolloutConfig(), + rollout_metadata=metadata, + ) + session.add(rollout_obj) + attempted_rollout = await self._start_attempt_for_rollout(session, rollout_obj) + await session.flush() # ensure the object is written to the DB + return attempted_rollout + + @db_retry + async def enqueue_rollout( + self, + input: TaskInput, + mode: Literal["train", "val", "test"] | None = None, + resources_id: str | None = None, + config: RolloutConfig | None = None, + metadata: Dict[str, Any] | None = None, + ) -> Rollout: + async with self._async_session() as session: + async with session.begin(): + rollout_obj = RolloutInDB( + input=input, + mode=mode, + resources_id=resources_id or self._latest_resources_id, + status="queuing", + config=config or RolloutConfig(), + rollout_metadata=metadata, + ) + session.add(rollout_obj) + await session.flush() # ensure the object is written to the DB + return rollout_obj.as_rollout() + + @db_retry + async def dequeue_rollout(self) -> Optional[AttemptedRollout]: + return await self._fifo_dequeue_rollout() + + @db_retry + async def start_attempt(self, rollout_id: str) -> AttemptedRollout: + async with self._async_session() as session: + async with session.begin(): + rollout_obj = await session.get(RolloutInDB, rollout_id) + if rollout_obj is None: + raise ValueError(f"Rollout {rollout_id} not found") + attempted_rollout = await self._start_attempt_for_rollout(session, rollout_obj) + await session.flush() # ensure the object is written to the DB + return attempted_rollout + + @db_retry + async def add_span(self, span: Span) -> Span: + seq_id = await SpanSeqIdInDB.get_next_sequence_id(self._async_session, span.rollout_id, span.attempt_id) + return await self._add_span(span.model_dump(), seq_id=seq_id) + + @db_retry + async def add_otel_span( + self, + rollout_id: str, + attempt_id: str, + readable_span: ReadableSpan, + sequence_id: int | None = None, + ) -> Span: + sequence_id = await SpanSeqIdInDB.get_next_sequence_id(self._async_session, rollout_id, attempt_id, sequence_id) + span = Span.from_opentelemetry( + src=readable_span, + rollout_id=rollout_id, + attempt_id=attempt_id, + sequence_id=sequence_id, + ) + return await self._add_span(span.model_dump(), seq_id=sequence_id) + + @db_retry + async def query_rollouts( + self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None + ) -> List[Rollout]: + rollouts = await RolloutInDB.query_rollouts(self._async_session, statuses=status, ids=rollout_ids) # type: ignore + attempt_ids = [r.latest_attempt_id for r in rollouts if r.latest_attempt_id is not None] + async with self._async_session() as session: + async with session.begin(): + scalars = await session.scalars(select(AttemptInDB).where(AttemptInDB.attempt_id.in_(attempt_ids))) + attempts = scalars.all() + attempt_map = {a.attempt_id: a.as_attempt() for a in attempts} + return [ + ( + AttemptedRollout(**r.as_rollout().model_dump(), attempt=attempt_map[r.latest_attempt_id]) + if r.latest_attempt_id in attempt_map + else r.as_rollout() + ) + for r in rollouts + ] # type: ignore + + @db_retry + async def query_attempts(self, rollout_id: str) -> List[Attempt]: + return await AttemptInDB.get_attempts_for_rollout(self._async_session, rollout_id) # type: ignore + + @db_retry + async def get_rollout_by_id(self, rollout_id: str) -> Optional[Union[Rollout, AttemptedRollout]]: + return await RolloutInDB.get_rollout_by_id(self._async_session, rollout_id) + + @db_retry + async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]: + return await AttemptInDB.get_latest_attempt_for_rollout(self._async_session, rollout_id) + + @db_retry + async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]: + return await ResourcesUpdateInDB.get_resources_by_id(self._async_session, resources_id) + + @db_retry + async def get_latest_resources(self) -> Optional[ResourcesUpdate]: + if self._latest_resources_id is None: + return None + return await ResourcesUpdateInDB.get_resources_by_id(self._async_session, self._latest_resources_id) + + @db_retry + async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int: + return await SpanSeqIdInDB.get_next_sequence_id(self._async_session, rollout_id, attempt_id) + + async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]: + # implementation the timeout via tenacity retry mechanism, by a `with` context + strategy = RetryStrategy(**self.retry_for_waiting.asdict()) + if timeout is not None: + strategy.max_retry_delay = timeout + if strategy.max_attempts is not None: + strategy.wait_seconds = min(strategy.wait_seconds, timeout / (strategy.max_attempts + 1)) + else: + strategy.max_attempts = None # infinite retries + + non_completed_ids, non_existing_ids = set(rollout_ids), set(rollout_ids) + completed_rollouts: Dict[str, Rollout] = {} + if len(non_completed_ids) < len(rollout_ids): + logger.warning("Duplicate rollout_ids found in wait_for_rollouts input. Duplicates will be ignored.") + + try: + async for attempt in AsyncRetryBlock( + strategy, + reraise=True, + ): + with attempt: + async with self._async_session() as session: + async with session.begin(): + result = await session.scalars( + select(RolloutInDB).where(RolloutInDB.rollout_id.in_(non_completed_ids)) + ) + rollouts = [r.as_rollout() for r in result.all()] + for r in rollouts: + if r.rollout_id in non_existing_ids: + non_existing_ids.discard(r.rollout_id) # found existing rollout + if is_finished(r): + completed_rollouts[r.rollout_id] = r + non_completed_ids.discard(r.rollout_id) + # check termination conditions + if self.wait_for_nonexistent_rollout: + if len(non_completed_ids) == 0: + return [completed_rollouts[rid] for rid in rollout_ids if rid in completed_rollouts] + raise _WaitForRolloutsCompleted( + f"WaitForRolloutsCompleted: requested={len(rollout_ids)}, completed={len(completed_rollouts)}, non_existing={len(non_existing_ids)}" + ) + else: + if len(non_completed_ids) == len(non_existing_ids): + logger.warning(f"All remaining rollouts are non-existing: {non_existing_ids}.") + return [completed_rollouts[rid] for rid in rollout_ids if rid in completed_rollouts] + raise _WaitForRolloutsCompleted( + f"WaitForRolloutsCompleted: requested={len(rollout_ids)}, completed={len(completed_rollouts)}, non_existing={len(non_existing_ids)}" + ) + + except (RetryError, _WaitForRolloutsCompleted): + return [completed_rollouts[rid] for rid in rollout_ids if rid in completed_rollouts] + except Exception as e: + logger.error(f"Error while waiting for rollouts: {e}") + raise e + + # Ensure a return value in case no rollouts are completed + return [completed_rollouts[rid] for rid in rollout_ids if rid in completed_rollouts] + + @db_retry + async def query_spans(self, rollout_id: str, attempt_id: str | Literal["latest"] | None = None) -> List[Span]: + async with self._async_session() as session: + async with session.begin(): + conditions: List[Any] = [SpanInDB.rollout_id == rollout_id] + if attempt_id is not None: + if attempt_id == "latest": + rollout_obj = await session.get(RolloutInDB, rollout_id) + if rollout_obj is None: + logger.warning(f"Rollout {rollout_id} does not exist. Cannot query latest attempt spans.") + return [] + attempt_id = rollout_obj.latest_attempt_id + conditions.append(SpanInDB.attempt_id == attempt_id) + query = select(SpanInDB).where(and_(*conditions)).order_by(SpanInDB.sequence_id.asc()) + result = await session.scalars(query) + span_objs = result.all() + return [obj.as_span() for obj in span_objs] + + @db_retry + async def add_resources(self, resources: NamedResources) -> ResourcesUpdate: + async with self._async_session() as session: + async with session.begin(): + current_time = time.time() + resource_obj = ResourcesUpdateInDB( + resources=resources, + create_time=current_time, + update_time=current_time, + ) + session.add(resource_obj) + await session.flush() # ensure the object is written to the DB + self._latest_resources_id = resource_obj.resources_id + return resource_obj.as_resources_update() + + @db_retry + async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate: + async with self._async_session() as session: + async with session.begin(): + obj = await session.get(ResourcesUpdateInDB, resources_id) + if obj is None: + # raise ValueError(f"Failed to update resources {resources_id}. It may not exist.") + # FIXME InMemoryLightningStore will create the resources if not exist, but the base method require to raise error + # HACK here stick to the behavior of InMemoryLightningStore for compatibility + current_time = time.time() + obj = ResourcesUpdateInDB( + resources_id=resources_id, + resources=resources, + create_time=current_time, + update_time=current_time, + ) + session.add(obj) + else: + obj.resources = resources + await session.flush() + self._latest_resources_id = resources_id + return obj.as_resources_update() + + @db_retry + async def query_resources(self) -> List[ResourcesUpdate]: + async with self._async_session() as session: + async with session.begin(): + result = await session.scalars( + select(ResourcesUpdateInDB).order_by(ResourcesUpdateInDB.create_time.asc()) + ) + resource_objs = result.all() + return [obj.as_resources_update() for obj in resource_objs] + + @db_retry + async def update_rollout( + self, + rollout_id: str | None, + input: TaskInput | Unset = UNSET, + mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET, + resources_id: Optional[str] | Unset = UNSET, + status: RolloutStatus | Unset = UNSET, + config: RolloutConfig | Unset = UNSET, + metadata: Optional[Dict[str, Any]] | Unset = UNSET, + ) -> Rollout: + if rollout_id is None: + raise ValueError("rollout_id must be provided for updating a rollout.") + + async with self._async_session() as session: + async with session.begin(): + rollout_obj = await session.get(RolloutInDB, rollout_id) + if rollout_obj is None: + raise ValueError(f"Rollout {rollout_id} not found") + # udpate fields + if not isinstance(input, Unset): + rollout_obj.input = input + if not isinstance(mode, Unset): + rollout_obj.mode = mode + if not isinstance(resources_id, Unset): + rollout_obj.resources_id = resources_id + if not isinstance(status, Unset): + await rollout_obj.update_status(dict(event="user_update", new_status=status)) + if not isinstance(config, Unset): + rollout_obj.config = config + if not isinstance(metadata, Unset): + rollout_obj.rollout_metadata = metadata + await session.flush() # ensure the object is written to the DB + return rollout_obj.as_rollout() + + @db_retry + async def update_attempt( + self, + rollout_id: str, + attempt_id: str | Literal["latest"], + status: AttemptStatus | Unset = UNSET, + worker_id: str | Unset = UNSET, + last_heartbeat_time: float | Unset = UNSET, + metadata: Optional[Dict[str, Any]] | Unset = UNSET, + ) -> Attempt: + async with self._async_session() as session: + async with session.begin(): + rollout_obj = await session.get(RolloutInDB, rollout_id) + if rollout_obj is None: + raise ValueError(f"Rollout {rollout_id} not found") + if attempt_id == "latest": + if rollout_obj.latest_attempt_id is None: + raise ValueError(f"Rollout {rollout_id} has no attempts. Cannot update latest attempt.") + attempt_id = rollout_obj.latest_attempt_id + if attempt_id != rollout_obj.latest_attempt_id: + logger.warning( + f"Updating attempt {attempt_id} which is not the latest attempt for rollout {rollout_id}. Latest is {rollout_obj.latest_attempt_id}." + ) + attempt_obj = await session.get(AttemptInDB, attempt_id) + if attempt_obj is None: + raise ValueError(f"No attempts found") + if attempt_obj.rollout_id != rollout_id: + raise ValueError(f"Attempt {attempt_id} does not belong to rollout {rollout_id}.") + # update fields + if not isinstance(status, Unset): + msg = attempt_obj.update_status(dict(event="user_update", new_status=status)) + if msg is not None: + await rollout_obj.update_status(msg) + if not isinstance(worker_id, Unset): + attempt_obj.worker_id = worker_id + if not isinstance(last_heartbeat_time, Unset): + attempt_obj.last_heartbeat_time = last_heartbeat_time + if not isinstance(metadata, Unset): + attempt_obj.attempt_metadata = metadata + await session.flush() # ensure the object is written to the DB + return attempt_obj.as_attempt() + + # ------------------------------------------------------ + # periodic background tasks can be added here + # ------------------------------------------------------ + + async def check_attempt_timeout(self): + """Periodically check for attempts that have timed out and update their status accordingly.""" + # use update with where condition to find and update timed-out attempts + current_time = time.time() + + timed_out_results = await self._attempt_timeout_check(current_time) + + # TODO run the tasks with a wrapper with asyncio semaphore to limit concurrency and handle exceptions + tasks = [self._process_timed_out_attempt(attempt, current_time) for attempt in timed_out_results] + await asyncio.gather(*tasks) + + async def _process_timed_out_attempt(self, attempt_ref: AttemptInDB, current_time: float) -> None: + async with self._async_session() as session: + async with session.begin(): + # Step 1: Update attempt status + attempt_obj = await session.get( + AttemptInDB, attempt_ref.attempt_id + ) # refresh the object in the new session + if attempt_obj is None: + raise ValueError(f"Attempt {attempt_ref.attempt_id} not found during timeout processing") + if attempt_obj.version_id != attempt_ref.version_id: + # version mismatch, skip processing to avoid race conditions + raise StaleDataError(f"Attempt {attempt_ref.attempt_id} version mismatch during timeout processing") + msg = {} + if attempt_obj.is_timed_out(current_time): + msg = dict(event="overall_timeout", timestamp=current_time) + elif attempt_obj.is_unresponsive(current_time): + msg = dict(event="single_step_timeout", timestamp=current_time) + else: + raise ValueError(f"Attempt {attempt_ref.attempt_id} is not timed out during timeout processing") + msg2rollout = attempt_obj.update_status(msg) + if msg2rollout is None: + return # no further update needed + + # Step 2: Update rollouts + rollout_obj = await session.get(RolloutInDB, attempt_obj.rollout_id) + if rollout_obj is None: + raise ValueError(f"Rollout {attempt_obj.rollout_id} not found during timeout processing") + await rollout_obj.update_status(msg2rollout) + + # ------------------------------------------------------ + # internal helper methods can be added here + # ------------------------------------------------------ + + async def _add_span(self, span: Dict[str, Any], seq_id: Optional[int] = None) -> Span: + """Add a new span to the database.""" + if seq_id is not None: + span["sequence_id"] = seq_id + extra_dic: Dict[str, Any] = {} + for k in list(span.keys()): + if k not in SpanInDB.__table__.columns.keys(): + extra_dic[k] = span.pop(k) + span["extra"] = extra_dic if extra_dic else None + + async with self._async_session() as session: + async with session.begin(): + # create SpanInDB object + span_obj = SpanInDB(**span) + session.add(span_obj) + # update attempt's last_heartbeat_time and status + attempt_obj = await session.get(AttemptInDB, span["attempt_id"]) + if attempt_obj is None: + raise ValueError(f"Attempt {span['attempt_id']} not found") + # ensure the attempt and rollout are in running status + msg = attempt_obj.update_status(dict(event="span_received")) + if msg is not None: + rollout_obj = await session.get(RolloutInDB, attempt_obj.rollout_id) + if rollout_obj is None: + raise ValueError(f"Rollout {attempt_obj.rollout_id} not found") + await rollout_obj.update_status(msg) + await session.flush() # ensure the object is written to the DB + return span_obj.as_span() + + async def _fifo_dequeue_rollout(self) -> Optional[AttemptedRollout]: + """Dequeue the next rollout in FIFO order (the one with the earliest enqueue_time). + Returns the RolloutInDB object if found, else None. + Note: This method does not update the status of the rollout. The caller should handle that. + """ + async with self._async_session() as session: + async with session.begin(): + # use the update...returning to atomically select the next rollout and claim it by updating its status to 'preparing' + result = await session.scalars( + select(RolloutInDB) + .where(RolloutInDB.status.in_(["queuing", "requeuing"]), RolloutInDB.enqueue_time.isnot(None)) + .order_by(RolloutInDB.enqueue_time.asc()) + .limit(1) + ) + rollout_obj = result.one_or_none() + if rollout_obj is None: + return None # no rollout available + # update the status of the rollout to 'preparing' via Compare-and-Swap to avoid race + attempted_rollout = await self._start_attempt_for_rollout(session, rollout_obj) + await session.flush() # ensure the object is written to the DB + return attempted_rollout + + async def _start_attempt_for_rollout(self, session: AsyncSession, rollout_obj: RolloutInDB) -> AttemptedRollout: + """Create a new attempt for the given rollout and update the rollout's fields.""" + # create a new attempt for this rollout + rollout_config = rollout_obj.config + attempt_obj = AttemptInDB( + rollout_id=rollout_obj.rollout_id, + sequence_id=rollout_obj.num_attempts + 1, + status="preparing", + max_duration=rollout_config.timeout_seconds, + max_heartbeat_interval=rollout_config.unresponsive_seconds, + ) + session.add(attempt_obj) + # pre-update the rollout_obj fields for CAS + rollout_obj.status = attempt_obj.status # type: ignore pre-update the status in the object for CAS + rollout_obj.enqueue_time = None # pre-update the enqueue_time in the object for CAS + rollout_obj.num_attempts += 1 # pre-update the num_attempts in the object for CAS + rollout_obj.latest_attempt_id = attempt_obj.attempt_id # pre-update the latest_attempt_id in the object for CAS + + # create a sequence id tracker for each attempt + # FIXME currently InMemoryLightningStore let all attempts under the same rollout share the same span sequence for sorting + # create a sequence id tracker for this rollout, only if not exists + existing = await session.get(SpanSeqIdInDB, rollout_obj.rollout_id) + if existing is None: + seq_obj = SpanSeqIdInDB( + rollout_id=rollout_obj.rollout_id, + attempt_id=attempt_obj.attempt_id, + ) + session.add(seq_obj) + + return AttemptedRollout(**rollout_obj.as_rollout().model_dump(), attempt=attempt_obj.as_attempt()) + + async def _attempt_timeout_check(self, now: float) -> Sequence[AttemptInDB]: + """Scan the table for attempts that have timed out based on the given mode, and return them for further processing. + Returns: + list[AttemptInDB]: + A list of AttemptInDB objects that timed out. + """ + async with self._async_session() as session: + async with session.begin(): + scalars = await session.scalars( + select(AttemptInDB).where( + and_( + AttemptInDB.status.in_(["preparing", "running"]), + or_( + and_( + AttemptInDB.max_duration.isnot(None), + (now - AttemptInDB.start_time) > AttemptInDB.max_duration, + ), + and_( + AttemptInDB.max_heartbeat_interval.isnot(None), + (now - AttemptInDB.last_heartbeat_time) > AttemptInDB.max_heartbeat_interval, + ), + ), + ) + ) + ) + return scalars.all() diff --git a/agentlightning/store/memory.py b/agentlightning/store/memory.py index c3dfd3068..ead88c264 100644 --- a/agentlightning/store/memory.py +++ b/agentlightning/store/memory.py @@ -435,7 +435,7 @@ async def start_attempt(self, rollout_id: str) -> AttemptedRollout: @_healthcheck_wrapper async def query_rollouts( self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None - ) -> List[Rollout]: + ) -> List[Union[Rollout, AttemptedRollout]]: """Retrieves rollouts filtered by their status and rollout ids. If no status is provided, returns all rollouts. diff --git a/agentlightning/store/sqlite.py b/agentlightning/store/sqlite.py deleted file mode 100644 index 2fd777b2f..000000000 --- a/agentlightning/store/sqlite.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -# TODO: Implement this diff --git a/agentlightning/types/core.py b/agentlightning/types/core.py index 57cc316d7..c26d28cf6 100644 --- a/agentlightning/types/core.py +++ b/agentlightning/types/core.py @@ -117,6 +117,7 @@ class RolloutLegacy(BaseModel): ] """The status of an attempt.""" + RolloutMode = Literal["train", "val", "test"] """Possible rollout modes.""" diff --git a/pyproject.toml b/pyproject.toml index 1e94c2dd4..31807759b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,9 @@ dependencies = [ "pydantic>=2.11", "openai", "rich", + "sqlalchemy[asyncio]", + "aiosqlite", + "tenacity", "portpicker", "gunicorn", "uvicorn_worker", diff --git a/tests/store/conftest.py b/tests/store/conftest.py index 697419268..457204f09 100644 --- a/tests/store/conftest.py +++ b/tests/store/conftest.py @@ -1,14 +1,18 @@ # Copyright (c) Microsoft. All rights reserved. +import os import time +import typing +import uuid from unittest.mock import Mock import pytest +import pytest_asyncio from opentelemetry.sdk.trace import ReadableSpan from pytest import FixtureRequest +from agentlightning.store import InMemoryLightningStore, SqlLightningStore from agentlightning.store.base import LightningStore -from agentlightning.store.memory import InMemoryLightningStore __all__ = [ "inmemory_store", @@ -22,15 +26,35 @@ def inmemory_store() -> InMemoryLightningStore: return InMemoryLightningStore() -@pytest.fixture -def sql_store(): +@pytest_asyncio.fixture +async def sql_store() -> typing.AsyncGenerator[SqlLightningStore, None]: """Placeholder fixture for SQL store implementation. Returns None until SQL store is ready.""" - return None + """Helper generator to create a SqlLightningStore using a SQLite file for testing.""" + tmp_path = ".pytest_cache" + # Ensure the directory exists and create a random file in it + os.makedirs(tmp_path, exist_ok=True) + db_path = os.path.join(tmp_path, f"test_db_{uuid.uuid4().hex}.sqlite3") + database_url = f"sqlite+aiosqlite:///{db_path}" + store = SqlLightningStore(database_url=database_url) + store.retry_for_waiting.wait_seconds = 0.2 # Set polling interval to 0.2s for test + + # Config db_store with a short time interval for healthcheck + store.add_background_task( + {"name": "test_healthcheck", "method": "check_attempt_timeout", "interval": {"seconds": 0.1}} + ) + + await store.start() + try: + yield store + finally: + await store.stop() + if os.path.exists(db_path): + os.remove(db_path) # Uncomment this when sql store is ready -# @pytest.fixture(params=["inmemory_store", "sql_store"]) -@pytest.fixture(params=["inmemory_store"]) +@pytest.fixture(params=["inmemory_store", "sql_store"]) +# @pytest.fixture(params=["inmemory_store"]) def store_fixture(request: FixtureRequest) -> LightningStore: """Parameterized fixture that provides different store implementations for testing. Currently supports InMemoryLightningStore, with SQL store support planned. diff --git a/tests/store/test_implementation.py b/tests/store/test_implementation.py index 9af8fb3ce..7e374ad49 100644 --- a/tests/store/test_implementation.py +++ b/tests/store/test_implementation.py @@ -897,10 +897,17 @@ async def test_span_triggers_status_transition(store_fixture: LightningStore, mo # Get the attempt attempts = await store_fixture.query_attempts(rollout.rollout_id) attempt_id = attempts[0].attempt_id + assert attempts[0].status == "preparing" # Add first span await store_fixture.add_otel_span(rollout.rollout_id, attempt_id, mock_readable_span) + # Attempt status should be changed + attempt_v2 = await store_fixture.get_latest_attempt(rollout.rollout_id) + assert attempt_v2 is not None + assert attempt_v2.attempt_id == attempt_id + assert attempt_v2.status == "running" + # Status should transition to running rollouts = await store_fixture.query_rollouts(status=["running"]) assert len(rollouts) == 1 @@ -1838,7 +1845,7 @@ async def test_healthcheck_timeout_behavior(store_fixture: LightningStore, mock_ assert len(running_rollouts) == 1 # Wait for timeout to occur - await asyncio.sleep(0.15) # Wait longer than timeout_seconds + await asyncio.sleep(0.3) # Wait longer than timeout_seconds # Trigger healthcheck by calling any decorated method # Verify the attempt was marked as timeout and rollout was requeued @@ -1876,7 +1883,7 @@ async def test_healthcheck_unresponsive_behavior(store_fixture: LightningStore, assert running_attempts[0].last_heartbeat_time is not None # Wait for unresponsive timeout - await asyncio.sleep(0.15) # Wait longer than unresponsive_seconds + await asyncio.sleep(0.3) # Wait longer than unresponsive_seconds # Verify attempt was marked as unresponsive attempts_after = await store_fixture.query_attempts(rollout.rollout_id) diff --git a/uv.lock b/uv.lock index 7d38c349f..62d5be07d 100644 --- a/uv.lock +++ b/uv.lock @@ -128,6 +128,7 @@ dependencies = [ { name = "agentops", version = "0.4.18", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra == 'group-14-agentlightning-core-legacy') or (sys_platform == 'linux' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (sys_platform == 'linux' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable')" }, { name = "agentops", version = "0.4.21", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'linux' and extra == 'group-14-agentlightning-core-stable') or (sys_platform == 'linux' and extra == 'group-14-agentlightning-tinker') or (sys_platform == 'linux' and extra == 'group-14-agentlightning-torch-gpu-stable') or (sys_platform == 'linux' and extra != 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-cpu' and extra != 'group-14-agentlightning-torch-legacy') or (sys_platform == 'linux' and extra != 'group-14-agentlightning-core-legacy' and extra != 'group-14-agentlightning-torch-gpu-legacy' and extra != 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "aiohttp", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, + { name = "aiosqlite", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "fastapi", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "flask", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "graphviz", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, @@ -146,6 +147,8 @@ dependencies = [ { name = "pydantic", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "rich", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "setproctitle", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, + { name = "sqlalchemy", extra = ["asyncio"], marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, + { name = "tenacity", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "uvicorn", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, { name = "uvicorn-worker", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, ] @@ -346,6 +349,7 @@ trl = [ requires-dist = [ { name = "agentops", specifier = ">=0.4.13" }, { name = "aiohttp" }, + { name = "aiosqlite" }, { name = "fastapi" }, { name = "flask" }, { name = "graphviz" }, @@ -362,6 +366,8 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.11" }, { name = "rich" }, { name = "setproctitle" }, + { name = "sqlalchemy", extras = ["asyncio"] }, + { name = "tenacity" }, { name = "uvicorn" }, { name = "uvicorn-worker" }, { name = "verl", marker = "extra == 'verl'", specifier = ">=0.5.0" }, @@ -818,6 +824,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "airportsdata" version = "20250909" @@ -10299,6 +10317,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/5e/6a29fa884d9fb7ddadf6b69490a9d45fded3b38541713010dad16b77d015/sqlalchemy-2.0.44-py3-none-any.whl", hash = "sha256:19de7ca1246fbef9f9d1bff8f1ab25641569df226364a0e40457dc5457c54b05", size = 1928718, upload-time = "2025-10-10T15:29:45.32Z" }, ] +[package.optional-dependencies] +asyncio = [ + { name = "greenlet", marker = "sys_platform == 'linux' or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-core-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-tinker') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-core-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-core-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-tinker' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-cu128') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-legacy') or (extra == 'group-14-agentlightning-torch-cpu' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-gpu-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-gpu-legacy' and extra == 'group-14-agentlightning-trl') or (extra == 'group-14-agentlightning-torch-gpu-stable' and extra == 'group-14-agentlightning-torch-legacy') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-torch-stable') or (extra == 'group-14-agentlightning-torch-legacy' and extra == 'group-14-agentlightning-trl')" }, +] + [[package]] name = "sqlparse" version = "0.5.3"