From a3c5fa052d1128b6fb9da5ed98da0b35d78cd747 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Tue, 29 Jul 2025 17:07:10 +0000 Subject: [PATCH 01/29] base implementation sql client --- pyproject.toml | 6 +- src/crawlee/storage_clients/__init__.py | 2 + src/crawlee/storage_clients/_sql/__init__.py | 6 + .../storage_clients/_sql/_dataset_client.py | 289 +++++++++ .../storage_clients/_sql/_db_models.py | 103 ++++ .../_sql/_key_value_store_client.py | 323 ++++++++++ .../_sql/_request_queue_client.py | 554 ++++++++++++++++++ .../storage_clients/_sql/_storage_client.py | 197 +++++++ src/crawlee/storage_clients/_sql/py.typed | 0 src/crawlee/storage_clients/models.py | 20 +- .../_sql/test_sql_dataset_client.py | 165 ++++++ .../_sql/test_sql_kvs_client.py | 211 +++++++ .../_sql/test_sql_rq_client.py | 173 ++++++ tests/unit/storages/test_dataset.py | 6 +- tests/unit/storages/test_key_value_store.py | 6 +- tests/unit/storages/test_request_queue.py | 7 +- uv.lock | 74 ++- 17 files changed, 2122 insertions(+), 20 deletions(-) create mode 100644 src/crawlee/storage_clients/_sql/__init__.py create mode 100644 src/crawlee/storage_clients/_sql/_dataset_client.py create mode 100644 src/crawlee/storage_clients/_sql/_db_models.py create mode 100644 src/crawlee/storage_clients/_sql/_key_value_store_client.py create mode 100644 src/crawlee/storage_clients/_sql/_request_queue_client.py create mode 100644 src/crawlee/storage_clients/_sql/_storage_client.py create mode 100644 src/crawlee/storage_clients/_sql/py.typed create mode 100644 tests/unit/storage_clients/_sql/test_sql_dataset_client.py create mode 100644 tests/unit/storage_clients/_sql/test_sql_kvs_client.py create mode 100644 tests/unit/storage_clients/_sql/test_sql_rq_client.py diff --git a/pyproject.toml b/pyproject.toml index 5dc8c280ad..6f1b112c80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ ] [project.optional-dependencies] -all = ["crawlee[adaptive-crawler,beautifulsoup,cli,curl-impersonate,parsel,playwright,otel]"] +all = ["crawlee[adaptive-crawler,beautifulsoup,cli,curl-impersonate,parsel,playwright,otel,sql]"] adaptive-crawler = [ "jaro-winkler>=2.0.3", "playwright>=1.27.0", @@ -73,6 +73,10 @@ otel = [ "opentelemetry-semantic-conventions>=0.54", "wrapt>=1.17.0", ] +sql = [ + "sqlalchemy[asyncio]>=2.0.42,<3.0.0", + "aiosqlite>=0.21.0", +] [project.scripts] crawlee = "crawlee._cli:cli" diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index ce8c713ca9..6c2d58591d 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,9 +1,11 @@ from ._base import StorageClient from ._file_system import FileSystemStorageClient from ._memory import MemoryStorageClient +from ._sql import SQLStorageClient __all__ = [ 'FileSystemStorageClient', 'MemoryStorageClient', + 'SQLStorageClient', 'StorageClient', ] diff --git a/src/crawlee/storage_clients/_sql/__init__.py b/src/crawlee/storage_clients/_sql/__init__.py new file mode 100644 index 0000000000..32fb3b6880 --- /dev/null +++ b/src/crawlee/storage_clients/_sql/__init__.py @@ -0,0 +1,6 @@ +from ._dataset_client import SQLDatasetClient +from ._key_value_store_client import SQLKeyValueStoreClient +from ._request_queue_client import SQLRequestQueueClient +from ._storage_client import SQLStorageClient + +__all__ = ['SQLDatasetClient', 'SQLKeyValueStoreClient', 'SQLRequestQueueClient', 'SQLStorageClient'] diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py new file mode 100644 index 0000000000..964a5306fc --- /dev/null +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from datetime import datetime, timezone +from logging import getLogger +from typing import TYPE_CHECKING, Any + +from sqlalchemy import delete, select +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee.storage_clients._base import DatasetClient +from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata + +from ._db_models import DatasetItemDB, DatasetMetadataDB + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from ._storage_client import SQLStorageClient + +logger = getLogger(__name__) + + +class SQLDatasetClient(DatasetClient): + """SQL implementation of the dataset client. + + This client persists dataset items to a SQL database with proper transaction handling and + concurrent access safety. Items are stored in a normalized table structure with automatic + ordering preservation and efficient querying capabilities. + + The SQL implementation provides ACID compliance, supports complex queries, and allows + multiple processes to safely access the same dataset concurrently through database-level + locking mechanisms. + """ + + def __init__( + self, + *, + orm_metadata: DatasetMetadataDB, + storage_client: SQLStorageClient, + ) -> None: + """Initialize a new instance. + + Preferably use the `SqlDatasetClient.open` class method to create a new instance. + """ + self._orm_metadata = orm_metadata + self._storage_client = storage_client + + @override + async def get_metadata(self) -> DatasetMetadata: + return DatasetMetadata.model_validate(self._orm_metadata) + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + storage_client: SQLStorageClient, + ) -> SQLDatasetClient: + """Open or create a SQL dataset client. + + Args: + id: The ID of the dataset to open. If provided, searches for existing dataset by ID. + name: The name of the dataset to open. If not provided, uses the default dataset. + storage_client: The SQL storage client instance. + + Returns: + An instance for the opened or created storage client. + + Raises: + ValueError: If a dataset with the specified ID is not found. + """ + async with storage_client.create_session() as session: + if id: + orm_metadata = await session.get(DatasetMetadataDB, id) + if not orm_metadata: + raise ValueError(f'Dataset with ID "{id}" not found.') + + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + await client._update_metadata(update_accessed_at=True) + + else: + orm_metadata = await session.get(DatasetMetadataDB, name) + if orm_metadata: + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + await client._update_metadata(update_accessed_at=True) + + else: + now = datetime.now(timezone.utc) + metadata = DatasetMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, + ) + orm_metadata = DatasetMetadataDB(**metadata.model_dump()) + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + session.add(orm_metadata) + + await session.commit() + + return client + + @override + async def drop(self) -> None: + async with self._storage_client.create_session() as session: + dataset_db = await session.get(DatasetItemDB, self._orm_metadata.id) + if dataset_db: + await session.delete(dataset_db) + await session.commit() + + @override + async def purge(self) -> None: + async with self._storage_client.create_session() as session: + stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) + await session.execute(stmt) + + self._orm_metadata.item_count = 0 + await self._update_metadata(update_accessed_at=True, update_modified_at=True) + await session.commit() + + @override + async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: + if not isinstance(data, list): + data = [data] + + db_items: list[DatasetItemDB] = [] + + for item in data: + json_item = json.dumps(item, default=str, ensure_ascii=False) + db_items.append( + DatasetItemDB( + dataset_id=self._orm_metadata.id, + data=json_item, + created_at=datetime.now(timezone.utc), + ) + ) + + async with self._storage_client.create_session() as session: + session.add_all(db_items) + self._orm_metadata.item_count += len(data) + await self._update_metadata(update_modified_at=True) + + await session.commit() + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + # Check for unsupported arguments and log a warning if found. + unsupported_args: dict[str, Any] = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + 'flatten': flatten, + 'view': view, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of get_data are not supported by the ' + f'{self.__class__.__name__} client.' + ) + + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) + + if skip_empty: + stmt = stmt.where(DatasetItemDB.data != '"{}"') + + stmt = stmt.order_by(DatasetItemDB.created_at.desc()) if desc else stmt.order_by(DatasetItemDB.created_at.asc()) + + stmt = stmt.offset(offset).limit(limit) + + async with self._storage_client.create_session() as session: + result = await session.execute(stmt) + db_items = result.scalars().all() + + await self._update_metadata(update_accessed_at=True) + + await session.commit() + + items = [json.loads(db_item.data) for db_item in db_items] + return DatasetItemsListPage( + items=items, + count=len(items), + desc=desc, + limit=limit or 0, + offset=offset or 0, + total=self._orm_metadata.item_count, + ) + + @override + async def iterate_items( + self, + *, + offset: int = 0, + limit: int | None = None, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + ) -> AsyncIterator[dict[str, Any]]: + # Check for unsupported arguments and log a warning if found. + unsupported_args: dict[str, Any] = { + 'clean': clean, + 'fields': fields, + 'omit': omit, + 'unwind': unwind, + 'skip_hidden': skip_hidden, + } + unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} + + if unsupported: + logger.warning( + f'The arguments {list(unsupported.keys())} of iterate are not supported ' + f'by the {self.__class__.__name__} client.' + ) + + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) + + if skip_empty: + stmt = stmt.where(DatasetItemDB.data != '"{}"') + + stmt = stmt.order_by(DatasetItemDB.created_at.desc()) if desc else stmt.order_by(DatasetItemDB.created_at.asc()) + + stmt = stmt.offset(offset).limit(limit) + + async with self._storage_client.create_session() as session: + result = await session.execute(stmt) + db_items = result.scalars().all() + + await self._update_metadata(update_accessed_at=True) + + await session.commit() + + items = [json.loads(db_item.data) for db_item in db_items] + for item in items: + yield item + + async def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the KVS metadata in the database. + + Args: + session: The SQLAlchemy AsyncSession to use for the update. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._orm_metadata.accessed_at = now + if update_modified_at: + self._orm_metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py new file mode 100644 index 0000000000..32818ab358 --- /dev/null +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from datetime import datetime # noqa: TC003 +from typing import Any + +from sqlalchemy import ( + JSON, + Boolean, + DateTime, + ForeignKey, + Integer, + LargeBinary, + String, +) +from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship + +Base = declarative_base() + + +class StorageMetadataDB: + """Base database model for storage metadata.""" + + id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) + name: Mapped[str | None] = mapped_column(String(100), nullable=True) + accessed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + modified_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + +class DatasetMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-type,misc] + __tablename__ = 'dataset_metadata' + + item_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + items: Mapped[list[DatasetItemDB]] = relationship(back_populates='dataset', cascade='all, delete-orphan') + + +class RequestQueueMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-type,misc] + __tablename__ = 'request_queue_metadata' + + had_multiple_clients: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + handled_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + pending_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + stats: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False, default={}) + total_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + requests: Mapped[list[RequestDB]] = relationship(back_populates='queue', cascade='all, delete-orphan') + + +class KeyValueStoreMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-type,misc] + __tablename__ = 'kvs_metadata' + + records: Mapped[list[KeyValueStoreRecordDB]] = relationship(back_populates='kvs', cascade='all, delete-orphan') + + +class KeyValueStoreRecordDB(Base): # type: ignore[valid-type,misc] + """Database model for key-value store records.""" + + __tablename__ = 'kvs_record' + + kvs_id: Mapped[str] = mapped_column(String(255), ForeignKey('kvs_metadata.id'), primary_key=True, index=True) + + key: Mapped[str] = mapped_column(String(255), primary_key=True) + value: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + + content_type: Mapped[str] = mapped_column(String(100), nullable=False) + size: Mapped[int | None] = mapped_column(Integer, nullable=False, default=0) + + kvs: Mapped[KeyValueStoreMetadataDB] = relationship(back_populates='records') + + +class DatasetItemDB(Base): # type: ignore[valid-type,misc] + """Database model for dataset items.""" + + __tablename__ = 'dataset_item' + + order_id: Mapped[int] = mapped_column(Integer, primary_key=True) + dataset_id: Mapped[str] = mapped_column(String(20), ForeignKey('dataset_metadata.id'), index=True) + data: Mapped[str] = mapped_column(JSON, nullable=False) + + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + dataset: Mapped[DatasetMetadataDB] = relationship(back_populates='items') + + +class RequestDB(Base): # type: ignore[valid-type,misc] + """Database model for requests in the request queue.""" + + __tablename__ = 'request' + + request_id: Mapped[str] = mapped_column(String(20), primary_key=True) + queue_id: Mapped[str] = mapped_column( + String(20), ForeignKey('request_queue_metadata.id'), index=True, primary_key=True + ) + + data: Mapped[str] = mapped_column(JSON, nullable=False) + unique_key: Mapped[str] = mapped_column(String(512), nullable=False) + + sequence_number: Mapped[int] = mapped_column(Integer, nullable=False, index=True) + + is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py new file mode 100644 index 0000000000..e4826eaa1e --- /dev/null +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import json +from datetime import datetime, timezone +from logging import getLogger +from typing import TYPE_CHECKING, Any + +from sqlalchemy import delete, select +from typing_extensions import override + +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.file import infer_mime_type +from crawlee.storage_clients._base import KeyValueStoreClient +from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata + +from ._db_models import KeyValueStoreMetadataDB, KeyValueStoreRecordDB + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from ._storage_client import SQLStorageClient + + +logger = getLogger(__name__) + + +class SQLKeyValueStoreClient(KeyValueStoreClient): + """SQL implementation of the key-value store client. + + This client persists data to a SQL database, making it suitable for scenarios where data needs to + survive process restarts. Keys are mapped to rows in a database table. + + Binary data is stored as-is, while JSON and text data are stored in human-readable format. + The implementation automatically handles serialization based on the content type and + maintains metadata about each record. + + This implementation is ideal for long-running crawlers where persistence is important and + for development environments where you want to easily inspect the stored data between runs. + + Binary data is stored as-is, while JSON and text data are stored in human-readable format. + The implementation automatically handles serialization based on the content type and + maintains metadata about each record. + + This implementation is ideal for long-running crawlers where persistence is important and + for development environments where you want to easily inspect the stored data between runs. + """ + + def __init__( + self, + *, + storage_client: SQLStorageClient, + orm_metadata: KeyValueStoreMetadataDB, + ) -> None: + """Initialize a new instance. + + Preferably use the `SQLKeyValueStoreClient.open` class method to create a new instance. + """ + self._orm_metadata = orm_metadata + + self._storage_client = storage_client + """The storage client used to access the SQL database.""" + + @override + async def get_metadata(self) -> KeyValueStoreMetadata: + return KeyValueStoreMetadata.model_validate(self._orm_metadata) + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + storage_client: SQLStorageClient, + ) -> SQLKeyValueStoreClient: + """Open or create a SQL key-value store client. + + This method attempts to open an existing key-value store from the SQL database. If a KVS with the specified + ID or name exists, it loads the metadata from the database. If no existing store is found, a new one + is created. + + Args: + id: The ID of the key-value store to open. If provided, searches for existing store by ID. + name: The name of the key-value store to open. If not provided, uses the default store. + storage_client: The SQL storage client used to access the database. + + Returns: + An instance for the opened or created storage client. + + Raises: + ValueError: If a store with the specified ID is not found, or if metadata is invalid. + """ + async with storage_client.create_session() as session: + if id: + orm_metadata = await session.get(KeyValueStoreMetadataDB, id) + if not orm_metadata: + raise ValueError(f'Key-value store with ID "{id}" not found.') + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + client._update_metadata(update_accessed_at=True) + + else: + orm_metadata = await session.get(KeyValueStoreMetadataDB, name) + if orm_metadata: + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + client._update_metadata(update_accessed_at=True) + else: + now = datetime.now(timezone.utc) + metadata = KeyValueStoreMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + ) + orm_metadata = KeyValueStoreMetadataDB(**metadata.model_dump()) + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + session.add(orm_metadata) + + await session.commit() + + return client + + @override + async def drop(self) -> None: + async with self._storage_client.create_session() as session: + kvs_db = await session.get(KeyValueStoreMetadataDB, self._orm_metadata.id) + if kvs_db: + await session.delete(kvs_db) + await session.commit() + + @override + async def purge(self) -> None: + async with self._storage_client.create_session() as session: + stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._orm_metadata.id) + await session.execute(stmt) + + self._update_metadata(update_accessed_at=True, update_modified_at=True) + await session.commit() + + @override + async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + # Special handling for None values + if value is None: + content_type = 'application/x-none' # Special content type to identify None values + value_bytes = b'' + else: + content_type = content_type or infer_mime_type(value) + + # Serialize the value to bytes. + if 'application/json' in content_type: + value_bytes = json.dumps(value, default=str, ensure_ascii=False).encode('utf-8') + elif isinstance(value, str): + value_bytes = value.encode('utf-8') + elif isinstance(value, (bytes, bytearray)): + value_bytes = value + else: + # Fallback: attempt to convert to string and encode. + value_bytes = str(value).encode('utf-8') + + size = len(value_bytes) + record_db = KeyValueStoreRecordDB( + kvs_id=self._orm_metadata.id, + key=key, + value=value_bytes, + content_type=content_type, + size=size, + ) + + async with self._storage_client.create_session() as session: + existing_record = await session.get(KeyValueStoreRecordDB, (self._orm_metadata.id, key)) + if existing_record: + # Update existing record + existing_record.value = value_bytes + existing_record.content_type = content_type + existing_record.size = size + else: + session.add(record_db) + self._update_metadata(update_accessed_at=True, update_modified_at=True) + await session.commit() + + @override + async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + # Update the metadata to record access + async with self._storage_client.create_session() as session: + self._update_metadata(update_accessed_at=True) + + stmt = select(KeyValueStoreRecordDB).where( + KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key + ) + result = await session.execute(stmt) + record_db = result.scalar_one_or_none() + + await session.commit() + + if not record_db: + return None + + # Deserialize the value based on content type + value_bytes = record_db.value + + # Handle None values + if record_db.content_type == 'application/x-none': + value = None + # Handle JSON values + elif 'application/json' in record_db.content_type: + try: + value = json.loads(value_bytes.decode('utf-8')) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning(f'Failed to decode JSON value for key "{key}"') + return None + # Handle text values + elif record_db.content_type.startswith('text/'): + try: + value = value_bytes.decode('utf-8') + except UnicodeDecodeError: + logger.warning(f'Failed to decode text value for key "{key}"') + return None + # Handle binary values + else: + value = value_bytes + + return KeyValueStoreRecord( + key=record_db.key, + value=value, + content_type=record_db.content_type, + size=record_db.size, + ) + + @override + async def delete_value(self, *, key: str) -> None: + async with self._storage_client.create_session() as session: + # Delete the record if it exists + stmt = delete(KeyValueStoreRecordDB).where( + KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key + ) + result = await session.execute(stmt) + + # Update metadata if we actually deleted something + if result.rowcount > 0: + self._update_metadata(update_accessed_at=True, update_modified_at=True) + + await session.commit() + + @override + async def iterate_keys( + self, + *, + exclusive_start_key: str | None = None, + limit: int | None = None, + ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + async with self._storage_client.create_session() as session: + # Build query for record metadata + stmt = ( + select(KeyValueStoreRecordDB.key, KeyValueStoreRecordDB.content_type, KeyValueStoreRecordDB.size) + .where(KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id) + .order_by(KeyValueStoreRecordDB.key) + ) + + # Apply exclusive_start_key filter + if exclusive_start_key is not None: + stmt = stmt.where(KeyValueStoreRecordDB.key > exclusive_start_key) + + # Apply limit + if limit is not None: + stmt = stmt.limit(limit) + + result = await session.execute(stmt) + + self._update_metadata(update_accessed_at=True) + await session.commit() + + for row in result: + yield KeyValueStoreRecordMetadata( + key=row.key, + content_type=row.content_type, + size=row.size, + ) + + @override + async def record_exists(self, *, key: str) -> bool: + async with self._storage_client.create_session() as session: + # Check if record exists + stmt = select(KeyValueStoreRecordDB.key).where( + KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key + ) + result = await session.execute(stmt) + + self._update_metadata(update_accessed_at=True) + await session.commit() + + return result.scalar_one_or_none() is not None + + @override + async def get_public_url(self, *, key: str) -> str: + raise NotImplementedError('Public URLs are not supported for memory key-value stores.') + + def _update_metadata( + self, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the KVS metadata in the database. + + Args: + session: The SQLAlchemy AsyncSession to use for the update. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._orm_metadata.accessed_at = now + if update_modified_at: + self._orm_metadata.modified_at = now diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py new file mode 100644 index 0000000000..3bf25da152 --- /dev/null +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import asyncio +from collections import deque +from datetime import datetime, timezone +from logging import getLogger +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from sqlalchemy import delete, func, select, update +from sqlalchemy.exc import SQLAlchemyError +from typing_extensions import override + +from crawlee import Request +from crawlee._utils.crypto import crypto_random_object_id +from crawlee._utils.recoverable_state import RecoverableState +from crawlee.storage_clients._base import RequestQueueClient +from crawlee.storage_clients.models import ( + AddRequestsResponse, + ProcessedRequest, + RequestQueueMetadata, + UnprocessedRequest, +) + +from ._db_models import RequestDB, RequestQueueMetadataDB + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ._storage_client import SQLStorageClient + + +logger = getLogger(__name__) + + +class RequestQueueState(BaseModel): + """Simplified state model for SQL implementation.""" + + sequence_counter: int = 1 + """Counter for regular request ordering (positive).""" + + forefront_sequence_counter: int = -1 + """Counter for forefront request ordering (negative).""" + + in_progress_requests: set[str] = set() + """Set of request IDs currently being processed.""" + + +class SQLRequestQueueClient(RequestQueueClient): + """SQL implementation of the request queue client. + + This client persists requests to a SQL database with proper transaction handling and + concurrent access safety. Requests are stored in a normalized table structure with + sequence-based ordering and efficient querying capabilities. + + The implementation uses negative sequence numbers for forefront (high-priority) requests + and positive sequence numbers for regular requests, allowing for efficient single-query + ordering. A cache mechanism reduces database queries for better performance. + """ + + _MAX_REQUESTS_IN_CACHE = 100_000 + """Maximum number of requests to keep in cache for faster access.""" + + def __init__( + self, + *, + orm_metadata: RequestQueueMetadataDB, + storage_client: SQLStorageClient, + ) -> None: + """Initialize a new instance. + + Preferably use the `SQLRequestQueueClient.open` class method to create a new instance. + """ + self._orm_metadata = orm_metadata + + self._request_cache: deque[Request] = deque() + """Cache for requests: ordered by sequence number.""" + + self._request_cache_needs_refresh = True + """Flag indicating whether the cache needs to be refreshed from database.""" + + self._is_empty_cache: bool | None = None + """Cache for is_empty result: None means unknown, True/False is cached state.""" + + self._state = RecoverableState[RequestQueueState]( + default_state=RequestQueueState(), + persist_state_key='request_queue_state', + persistence_enabled=True, + persist_state_kvs_name=f'__RQ_STATE_{self._orm_metadata.id}', + logger=logger, + ) + """Recoverable state to maintain request ordering and in-progress status.""" + + self._storage_client = storage_client + """The storage client used to access the SQL database.""" + + self._lock = asyncio.Lock() + + @override + async def get_metadata(self) -> RequestQueueMetadata: + return RequestQueueMetadata.model_validate(self._orm_metadata) + + @classmethod + async def open( + cls, + *, + id: str | None, + name: str | None, + storage_client: SQLStorageClient, + ) -> SQLRequestQueueClient: + """Open or create a SQL request queue client. + + Args: + id: The ID of the request queue to open. If provided, searches for existing queue by ID. + name: The name of the request queue to open. If not provided, uses the default queue. + storage_client: The SQL storage client used to access the database. + + Returns: + An instance for the opened or created storage client. + + Raises: + ValueError: If a queue with the specified ID is not found. + """ + async with storage_client.create_session() as session: + if id: + orm_metadata = await session.get(RequestQueueMetadataDB, id) + if not orm_metadata: + raise ValueError(f'Request queue with ID "{id}" not found.') + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + await client._update_metadata(update_accessed_at=True) + else: + # Try to find by name + orm_metadata = await session.get(RequestQueueMetadataDB, name) + + if orm_metadata: + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + await client._update_metadata(update_accessed_at=True) + else: + now = datetime.now(timezone.utc) + metadata = RequestQueueMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + stats={}, + total_request_count=0, + ) + orm_metadata = RequestQueueMetadataDB(**metadata.model_dump()) + client = cls( + orm_metadata=orm_metadata, + storage_client=storage_client, + ) + + session.add(orm_metadata) + + await session.commit() + + await client._state.initialize() + + return client + + @override + async def drop(self) -> None: + async with self._storage_client.create_session() as session: + # Delete the request queue metadata (cascade will delete requests) + rq_db = await session.get(RequestQueueMetadataDB, self._orm_metadata.id) + if rq_db: + await session.delete(rq_db) + + # Clear recoverable state + await self._state.reset() + await self._state.teardown() + self._request_cache.clear() + self._request_cache_needs_refresh = True + self._is_empty_cache = None + + await session.commit() + + @override + async def purge(self) -> None: + async with self._storage_client.create_session() as session: + # Delete all requests for this queue + stmt = delete(RequestDB).where(RequestDB.queue_id == self._orm_metadata.id) + await session.execute(stmt) + + # Update metadata + self._orm_metadata.pending_request_count = 0 + self._orm_metadata.handled_request_count = 0 + + await self._update_metadata(update_modified_at=True, update_accessed_at=True) + + self._is_empty_cache = None + await session.commit() + + # Clear recoverable state + self._request_cache.clear() + self._request_cache_needs_refresh = True + await self._state.reset() + + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + async with self._storage_client.create_session() as session, self._lock: + self._is_empty_cache = None + processed_requests = [] + unprocessed_requests = [] + state = self._state.current_value + + # Get existing requests by unique keys + unique_keys = {req.unique_key for req in requests} + stmt = select(RequestDB).where( + RequestDB.queue_id == self._orm_metadata.id, RequestDB.unique_key.in_(unique_keys) + ) + result = await session.execute(stmt) + existing_requests = {req.unique_key: req for req in result.scalars()} + result = await session.execute(stmt) + + batch_processed = set() + + # Process each request + for request in requests: + if request.unique_key in batch_processed: + continue + + existing_req_db = existing_requests.get(request.unique_key) + + if existing_req_db is None: + # New request + if forefront: + sequence_number = state.forefront_sequence_counter + state.forefront_sequence_counter -= 1 + else: + sequence_number = state.sequence_counter + state.sequence_counter += 1 + + request_db = RequestDB( + request_id=request.id, + queue_id=self._orm_metadata.id, + data=request.model_dump_json(), + unique_key=request.unique_key, + sequence_number=sequence_number, + is_handled=False, + ) + session.add(request_db) + + self._orm_metadata.total_request_count += 1 + self._orm_metadata.pending_request_count += 1 + + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=False, + was_already_handled=False, + ) + ) + + elif existing_req_db.is_handled: + # Already handled + processed_requests.append( + ProcessedRequest( + id=existing_req_db.request_id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + ) + + else: + # Exists but not handled - might update priority + if forefront and existing_req_db.sequence_number > 0: + existing_req_db.sequence_number = state.forefront_sequence_counter + state.forefront_sequence_counter -= 1 + self._request_cache_needs_refresh = True + + processed_requests.append( + ProcessedRequest( + id=existing_req_db.request_id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + ) + + batch_processed.add(request.unique_key) + + await self._update_metadata(update_modified_at=True, update_accessed_at=True) + + if forefront: + self._request_cache_needs_refresh = True + + try: + await session.commit() + except SQLAlchemyError as e: + logger.warning(f'Failed to commit session: {e}') + await session.rollback() + input() + processed_requests.clear() + unprocessed_requests.extend( + [ + UnprocessedRequest( + unique_key=request.unique_key, + url=request.url, + method=request.method, + ) + for request in requests + if request.unique_key not in existing_requests + ] + ) + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=unprocessed_requests, + ) + + @override + async def get_request(self, request_id: str) -> Request | None: + async with self._storage_client.create_session() as session: + stmt = select(RequestDB).where( + RequestDB.queue_id == self._orm_metadata.id, RequestDB.request_id == request_id + ) + result = await session.execute(stmt) + request_db = result.scalar_one_or_none() + + if request_db is None: + logger.warning(f'Request with ID "{request_id}" not found in the queue.') + return None + + request = Request.model_validate_json(request_db.data) + + state = self._state.current_value + state.in_progress_requests.add(request.id) + + await self._update_metadata(update_accessed_at=True) + await session.commit() + + return request + + @override + async def fetch_next_request(self) -> Request | None: + # Refresh cache if needed + if self._request_cache_needs_refresh or not self._request_cache: + await self._refresh_cache() + + next_request = None + state = self._state.current_value + + # Get from cache + while self._request_cache and next_request is None: + candidate = self._request_cache.popleft() + + # Only check local state + if candidate.id not in state.in_progress_requests: + next_request = candidate + state.in_progress_requests.add(next_request.id) + + if not self._request_cache: + self._is_empty_cache = None + + return next_request + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + self._is_empty_cache = None + state = self._state.current_value + + if request.id not in state.in_progress_requests: + logger.warning(f'Marking request {request.id} as handled that is not in progress.') + return None + + # Update request in DB + stmt = ( + update(RequestDB) + .where(RequestDB.queue_id == self._orm_metadata.id, RequestDB.request_id == request.id) + .values(is_handled=True) + ) + + async with self._storage_client.create_session() as session: + result = await session.execute(stmt) + + if result.rowcount == 0: + logger.warning(f'Request {request.id} not found in database.') + return None + + # Update state + state.in_progress_requests.discard(request.id) + + # Update metadata + self._orm_metadata.handled_request_count += 1 + self._orm_metadata.pending_request_count -= 1 + + await self._update_metadata(update_modified_at=True, update_accessed_at=True) + + await session.commit() + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=True, + ) + + @override + async def reclaim_request( + self, + request: Request, + *, + forefront: bool = False, + ) -> ProcessedRequest | None: + self._is_empty_cache = None + state = self._state.current_value + + if request.id not in state.in_progress_requests: + logger.info(f'Reclaiming request {request.id} that is not in progress.') + return None + + # Update sequence number if changing priority + if forefront: + new_sequence = state.forefront_sequence_counter + state.forefront_sequence_counter -= 1 + else: + new_sequence = state.sequence_counter + state.sequence_counter += 1 + + stmt = ( + update(RequestDB) + .where(RequestDB.queue_id == self._orm_metadata.id, RequestDB.request_id == request.id) + .values(sequence_number=new_sequence) + ) + + async with self._storage_client.create_session() as session: + result = await session.execute(stmt) + + if result.rowcount == 0: + logger.warning(f'Request {request.id} not found in database.') + return None + + # Remove from in-progress + state.in_progress_requests.discard(request.id) + + # Invalidate cache or add to cache + if forefront: + self._request_cache_needs_refresh = True + elif len(self._request_cache) < self._MAX_REQUESTS_IN_CACHE: + # For regular requests, we can add to the end if there's space + self._request_cache.append(request) + + await self._update_metadata(update_modified_at=True, update_accessed_at=True) + + await session.commit() + + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) + + @override + async def is_empty(self) -> bool: + if self._is_empty_cache is not None: + return self._is_empty_cache + + state = self._state.current_value + + # If there are in-progress requests, not empty + if len(state.in_progress_requests) > 0: + self._is_empty_cache = False + return False + + # Check database for unhandled requests + async with self._storage_client.create_session() as session: + stmt = ( + select(func.count()) + .select_from(RequestDB) + .where( + RequestDB.queue_id == self._orm_metadata.id, + RequestDB.is_handled == False, # noqa: E712 + ) + ) + result = await session.execute(stmt) + unhandled_count = result.scalar() + self._is_empty_cache = unhandled_count == 0 + return self._is_empty_cache + + async def _refresh_cache(self) -> None: + """Refresh the request cache from database.""" + self._request_cache.clear() + state = self._state.current_value + + async with self._storage_client.create_session() as session: + # Simple query - get unhandled requests not in progress + stmt = ( + select(RequestDB) + .where( + RequestDB.queue_id == self._orm_metadata.id, + RequestDB.is_handled == False, # noqa: E712 + ) + .order_by(RequestDB.sequence_number.asc()) + .limit(self._MAX_REQUESTS_IN_CACHE) + ) + + if state.in_progress_requests: + stmt = stmt.where(RequestDB.request_id.notin_(state.in_progress_requests)) + + result = await session.execute(stmt) + request_dbs = result.scalars().all() + + # Add to cache in order + for request_db in request_dbs: + request = Request.model_validate_json(request_db.data) + self._request_cache.append(request) + + self._request_cache_needs_refresh = False + + async def _update_metadata( + self, + *, + update_had_multiple_clients: bool = False, + update_accessed_at: bool = False, + update_modified_at: bool = False, + ) -> None: + """Update the request queue metadata in the database. + + Args: + session: The SQLAlchemy session to use for database operations. + update_had_multiple_clients: If True, set had_multiple_clients to True. + update_accessed_at: If True, update the `accessed_at` timestamp to the current time. + update_modified_at: If True, update the `modified_at` timestamp to the current time. + """ + now = datetime.now(timezone.utc) + + if update_accessed_at: + self._orm_metadata.accessed_at = now + + if update_modified_at: + self._orm_metadata.modified_at = now + + if update_had_multiple_clients: + self._orm_metadata.had_multiple_clients = True diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py new file mode 100644 index 0000000000..0d12a7408f --- /dev/null +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.sql import text +from typing_extensions import override + +from crawlee._utils.docs import docs_group +from crawlee.configuration import Configuration +from crawlee.storage_clients._base import StorageClient + +from ._dataset_client import SQLDatasetClient +from ._db_models import Base +from ._key_value_store_client import SQLKeyValueStoreClient +from ._request_queue_client import SQLRequestQueueClient + +if TYPE_CHECKING: + from types import TracebackType + + from crawlee.storage_clients._base import ( + DatasetClient, + KeyValueStoreClient, + RequestQueueClient, + ) + + +@docs_group('Classes') +class SQLStorageClient(StorageClient): + """SQL implementation of the storage client. + + This storage client provides access to datasets, key-value stores, and request queues that persist data + to a SQL database using SQLAlchemy 2+ with Pydantic dataclasses for type safety. Data is stored in + normalized relational tables, providing ACID compliance, concurrent access safety, and the ability to + query data using SQL. + + The SQL implementation supports various database backends including PostgreSQL, MySQL, SQLite, and others + supported by SQLAlchemy. It provides durability, consistency, and supports concurrent access from multiple + processes through database-level locking mechanisms. + + This implementation is ideal for production environments where data persistence, consistency, and + concurrent access are critical requirements. + """ + + _DB_NAME = 'crawlee.db' + """Default database name if not specified in connection string.""" + + def __init__( + self, + *, + connection_string: str | None = None, + engine: AsyncEngine | None = None, + ) -> None: + """Initialize the SQL storage client. + + Args: + connection_string: Database connection string (e.g., "sqlite+aiosqlite:///crawlee.db"). + If not provided, defaults to SQLite database in the storage directory. + engine: Pre-configured AsyncEngine instance. If provided, connection_string is ignored. + """ + if engine is not None and connection_string is not None: + raise ValueError('Either connection_string or engine must be provided, not both.') + + self._connection_string = connection_string + self._engine = engine + self._initialized = False + + def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: + """Get or create the database engine based on configuration.""" + if self._engine is not None: + return self._engine + + if self._connection_string is not None: + connection_string = self._connection_string + else: + # Create SQLite database in the storage directory + storage_dir = Path(configuration.storage_dir) + if not storage_dir.exists(): + storage_dir.mkdir(parents=True, exist_ok=True) + + db_path = storage_dir / self._DB_NAME + + connection_string = f'sqlite+aiosqlite:///{db_path}' + + self._engine = create_async_engine( + connection_string, future=True, pool_size=5, max_overflow=10, pool_timeout=30, pool_recycle=600, echo=False + ) + return self._engine + + async def initialize(self, configuration: Configuration) -> None: + """Initialize the database schema. + + This method creates all necessary tables if they don't exist. + Should be called before using the storage client. + """ + if not self._initialized: + engine = self._get_or_create_engine(configuration) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + if 'sqlite' in str(engine.url): + await conn.execute(text('PRAGMA journal_mode=WAL')) + await conn.execute(text('PRAGMA synchronous=NORMAL')) + await conn.execute(text('PRAGMA cache_size=10000')) + await conn.execute(text('PRAGMA temp_store=MEMORY')) + await conn.execute(text('PRAGMA mmap_size=268435456')) + await conn.execute(text('PRAGMA foreign_keys=ON')) + await conn.execute(text('PRAGMA busy_timeout=30000')) + self._initialized = True + + async def close(self) -> None: + """Close the database connection pool.""" + if self._engine is not None: + await self._engine.dispose() + self._engine = None + + def create_session(self) -> AsyncSession: + """Create a new database session. + + Returns: + A new AsyncSession instance. + """ + session = async_sessionmaker(self._engine, expire_on_commit=False, autoflush=False) + return session() + + @override + async def create_dataset_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> DatasetClient: + configuration = configuration or Configuration.get_global_configuration() + await self.initialize(configuration) + + client = await SQLDatasetClient.open( + id=id, + name=name, + storage_client=self, + ) + + await self._purge_if_needed(client, configuration) + return client + + @override + async def create_kvs_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> KeyValueStoreClient: + configuration = configuration or Configuration.get_global_configuration() + await self.initialize(configuration) + + client = await SQLKeyValueStoreClient.open( + id=id, + name=name, + storage_client=self, + ) + + await self._purge_if_needed(client, configuration) + return client + + @override + async def create_rq_client( + self, + *, + id: str | None = None, + name: str | None = None, + configuration: Configuration | None = None, + ) -> RequestQueueClient: + configuration = configuration or Configuration.get_global_configuration() + await self.initialize(configuration) + + client = await SQLRequestQueueClient.open( + id=id, + name=name, + storage_client=self, + ) + + await self._purge_if_needed(client, configuration) + return client + + async def __aenter__(self) -> SQLStorageClient: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + """Async context manager exit.""" + await self.close() diff --git a/src/crawlee/storage_clients/_sql/py.typed b/src/crawlee/storage_clients/_sql/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/crawlee/storage_clients/models.py b/src/crawlee/storage_clients/models.py index 3cb5b67b7a..c5748a4b78 100644 --- a/src/crawlee/storage_clients/models.py +++ b/src/crawlee/storage_clients/models.py @@ -20,7 +20,7 @@ class StorageMetadata(BaseModel): It contains common fields shared across all specific storage types. """ - model_config = ConfigDict(populate_by_name=True, extra='allow') + model_config = ConfigDict(populate_by_name=True, extra='allow', from_attributes=True) id: Annotated[str, Field(alias='id')] """The unique identifier of the storage.""" @@ -42,7 +42,7 @@ class StorageMetadata(BaseModel): class DatasetMetadata(StorageMetadata): """Model for a dataset metadata.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) item_count: Annotated[int, Field(alias='itemCount')] """The number of items in the dataset.""" @@ -52,14 +52,14 @@ class DatasetMetadata(StorageMetadata): class KeyValueStoreMetadata(StorageMetadata): """Model for a key-value store metadata.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) @docs_group('Data structures') class RequestQueueMetadata(StorageMetadata): """Model for a request queue metadata.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) had_multiple_clients: Annotated[bool, Field(alias='hadMultipleClients')] """Indicates whether the queue has been accessed by multiple clients (consumers).""" @@ -81,7 +81,7 @@ class RequestQueueMetadata(StorageMetadata): class KeyValueStoreRecordMetadata(BaseModel): """Model for a key-value store record metadata.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) key: Annotated[str, Field(alias='key')] """The key of the record. @@ -103,7 +103,7 @@ class KeyValueStoreRecordMetadata(BaseModel): class KeyValueStoreRecord(KeyValueStoreRecordMetadata, Generic[KvsValueType]): """Model for a key-value store record.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) value: Annotated[KvsValueType, Field(alias='value')] """The value of the record.""" @@ -113,7 +113,7 @@ class KeyValueStoreRecord(KeyValueStoreRecordMetadata, Generic[KvsValueType]): class DatasetItemsListPage(BaseModel): """Model for a single page of dataset items returned from a collection list method.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) count: Annotated[int, Field(default=0)] """The number of objects returned on this page.""" @@ -138,7 +138,7 @@ class DatasetItemsListPage(BaseModel): class ProcessedRequest(BaseModel): """Represents a processed request.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) id: Annotated[str, Field(alias='requestId')] unique_key: Annotated[str, Field(alias='uniqueKey')] @@ -150,7 +150,7 @@ class ProcessedRequest(BaseModel): class UnprocessedRequest(BaseModel): """Represents an unprocessed request.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) unique_key: Annotated[str, Field(alias='uniqueKey')] url: Annotated[str, BeforeValidator(validate_http_url), Field()] @@ -166,7 +166,7 @@ class AddRequestsResponse(BaseModel): encountered issues during processing. """ - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) processed_requests: Annotated[list[ProcessedRequest], Field(alias='processedRequests')] """Successfully processed requests, including information about whether they were diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py new file mode 100644 index 0000000000..c5f31f144e --- /dev/null +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from crawlee.storage_clients._file_system import FileSystemDatasetClient + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def dataset_client(configuration: Configuration) -> AsyncGenerator[FileSystemDatasetClient, None]: + """A fixture for a file system dataset client.""" + client = await FileSystemStorageClient().create_dataset_client( + name='test_dataset', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_file_and_directory_creation(configuration: Configuration) -> None: + """Test that file system dataset creates proper files and directories.""" + client = await FileSystemStorageClient().create_dataset_client( + name='new_dataset', + configuration=configuration, + ) + + # Verify files were created + assert client.path_to_dataset.exists() + assert client.path_to_metadata.exists() + + # Verify metadata file structure + with client.path_to_metadata.open() as f: + metadata = json.load(f) + client_metadata = await client.get_metadata() + assert metadata['id'] == client_metadata.id + assert metadata['name'] == 'new_dataset' + assert metadata['item_count'] == 0 + + await client.drop() + + +async def test_file_persistence_and_content_verification(dataset_client: FileSystemDatasetClient) -> None: + """Test that data is properly persisted to files with correct content.""" + item = {'key': 'value', 'number': 42} + await dataset_client.push_data(item) + + # Verify files are created on disk + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 2 # 1 data file + 1 metadata file + + # Verify actual file content + data_files = [item for item in all_files if item.name != METADATA_FILENAME] + assert len(data_files) == 1 + + with Path(data_files[0]).open() as f: + saved_item = json.load(f) + assert saved_item == item + + # Test multiple items file creation + items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] + await dataset_client.push_data(items) + + all_files = list(dataset_client.path_to_dataset.glob('*.json')) + assert len(all_files) == 5 # 4 data files + 1 metadata file + + data_files = [f for f in all_files if f.name != METADATA_FILENAME] + assert len(data_files) == 4 # Original item + 3 new items + + +async def test_drop_removes_files_from_disk(dataset_client: FileSystemDatasetClient) -> None: + """Test that dropping a dataset removes the entire dataset directory from disk.""" + await dataset_client.push_data({'test': 'data'}) + + assert dataset_client.path_to_dataset.exists() + + # Drop the dataset + await dataset_client.drop() + + assert not dataset_client.path_to_dataset.exists() + + +async def test_metadata_file_updates(dataset_client: FileSystemDatasetClient) -> None: + """Test that metadata file is updated correctly after operations.""" + # Record initial timestamps + metadata = await dataset_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates accessed_at + await dataset_client.get_data() + + # Verify timestamps + metadata = await dataset_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_get = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform an operation that updates modified_at + await dataset_client.push_data({'new': 'item'}) + + # Verify timestamps again + metadata = await dataset_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_get + + # Verify metadata file is updated on disk + with dataset_client.path_to_metadata.open() as f: + metadata_json = json.load(f) + assert metadata_json['item_count'] == 1 + + +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that data persists correctly when reopening the same dataset.""" + storage_client = FileSystemStorageClient() + + # Create dataset and add data + original_client = await storage_client.create_dataset_client( + name='persistence-test', + configuration=configuration, + ) + + test_data = {'test_item': 'test_value', 'id': 123} + await original_client.push_data(test_data) + + dataset_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify data persists + reopened_client = await storage_client.create_dataset_client( + id=dataset_id, + configuration=configuration, + ) + + data = await reopened_client.get_data() + assert len(data.items) == 1 + assert data.items[0] == test_data + + await reopened_client.drop() diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py new file mode 100644 index 0000000000..c5bfa96c47 --- /dev/null +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +import pytest + +from crawlee._consts import METADATA_FILENAME +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def kvs_client(configuration: Configuration) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: + """A fixture for a file system key-value store client.""" + client = await FileSystemStorageClient().create_kvs_client( + name='test_kvs', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_file_and_directory_creation(configuration: Configuration) -> None: + """Test that file system KVS creates proper files and directories.""" + client = await FileSystemStorageClient().create_kvs_client( + name='new_kvs', + configuration=configuration, + ) + + # Verify files were created + assert client.path_to_kvs.exists() + assert client.path_to_metadata.exists() + + # Verify metadata file structure + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == (await client.get_metadata()).id + assert metadata['name'] == 'new_kvs' + + await client.drop() + + +async def test_value_file_creation_and_content(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that values are properly persisted to files with correct content and metadata.""" + test_key = 'test-key' + test_value = 'Hello, world!' + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if the files were created + key_path = kvs_client.path_to_kvs / test_key + key_metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert key_metadata_path.exists() + + # Check file content + content = key_path.read_text(encoding='utf-8') + assert content == test_value + + # Check record metadata file + with key_metadata_path.open() as f: + metadata = json.load(f) + assert metadata['key'] == test_key + assert metadata['content_type'] == 'text/plain; charset=utf-8' + assert metadata['size'] == len(test_value.encode('utf-8')) + + +async def test_binary_data_persistence(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that binary data is stored correctly without corruption.""" + test_key = 'test-binary' + test_value = b'\x00\x01\x02\x03\x04' + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify binary file exists + key_path = kvs_client.path_to_kvs / test_key + assert key_path.exists() + + # Verify binary content is preserved + content = key_path.read_bytes() + assert content == test_value + + # Verify retrieval works correctly + record = await kvs_client.get_value(key=test_key) + assert record is not None + assert record.value == test_value + assert record.content_type == 'application/octet-stream' + + +async def test_json_serialization_to_file(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that JSON objects are properly serialized to files.""" + test_key = 'test-json' + test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} + await kvs_client.set_value(key=test_key, value=test_value) + + # Check if file content is valid JSON + key_path = kvs_client.path_to_kvs / test_key + with key_path.open() as f: + file_content = json.load(f) + assert file_content == test_value + + +async def test_file_deletion_on_value_delete(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that deleting a value removes its files from disk.""" + test_key = 'test-delete' + test_value = 'Delete me' + + # Set a value + await kvs_client.set_value(key=test_key, value=test_value) + + # Verify files exist + key_path = kvs_client.path_to_kvs / test_key + metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' + assert key_path.exists() + assert metadata_path.exists() + + # Delete the value + await kvs_client.delete_value(key=test_key) + + # Verify files were deleted + assert not key_path.exists() + assert not metadata_path.exists() + + +async def test_drop_removes_directory(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that drop removes the entire store directory from disk.""" + await kvs_client.set_value(key='test', value='test-value') + + assert kvs_client.path_to_kvs.exists() + + # Drop the store + await kvs_client.drop() + + assert not kvs_client.path_to_kvs.exists() + + +async def test_metadata_file_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: + """Test that read/write operations properly update metadata file timestamps.""" + # Record initial timestamps + metadata = await kvs_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await kvs_client.get_value(key='nonexistent') + + # Verify accessed timestamp was updated + metadata = await kvs_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await kvs_client.set_value(key='test', value='test-value') + + # Verify modified timestamp was updated + metadata = await kvs_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read + + +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that data persists correctly when reopening the same KVS.""" + storage_client = FileSystemStorageClient() + + # Create KVS and add data + original_client = await storage_client.create_kvs_client( + name='persistence-test', + configuration=configuration, + ) + + test_key = 'persistent-key' + test_value = 'persistent-value' + await original_client.set_value(key=test_key, value=test_value) + + kvs_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify data persists + reopened_client = await storage_client.create_kvs_client( + id=kvs_id, + configuration=configuration, + ) + + record = await reopened_client.get_value(key=test_key) + assert record is not None + assert record.value == test_value + + await reopened_client.drop() diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py new file mode 100644 index 0000000000..0be182fcd8 --- /dev/null +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING + +import pytest + +from crawlee import Request +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + from crawlee.storage_clients._file_system import FileSystemRequestQueueClient + + +@pytest.fixture +def configuration(tmp_path: Path) -> Configuration: + return Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + +@pytest.fixture +async def rq_client(configuration: Configuration) -> AsyncGenerator[FileSystemRequestQueueClient, None]: + """A fixture for a file system request queue client.""" + client = await FileSystemStorageClient().create_rq_client( + name='test_request_queue', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_file_and_directory_creation(configuration: Configuration) -> None: + """Test that file system RQ creates proper files and directories.""" + client = await FileSystemStorageClient().create_rq_client( + name='new_request_queue', + configuration=configuration, + ) + + # Verify files were created + assert client.path_to_rq.exists() + assert client.path_to_metadata.exists() + + # Verify metadata file structure + with client.path_to_metadata.open() as f: + metadata = json.load(f) + assert metadata['id'] == (await client.get_metadata()).id + assert metadata['name'] == 'new_request_queue' + + await client.drop() + + +async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) -> None: + """Test that requests are properly persisted to files.""" + requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + Request.from_url('https://example.com/3'), + ] + + await rq_client.add_batch_of_requests(requests) + + # Verify request files are created + request_files = list(rq_client.path_to_rq.glob('*.json')) + # Should have 3 request files + 1 metadata file + assert len(request_files) == 4 + assert rq_client.path_to_metadata in request_files + + # Verify actual request file content + data_files = [f for f in request_files if f != rq_client.path_to_metadata] + assert len(data_files) == 3 + + for req_file in data_files: + with req_file.open() as f: + request_data = json.load(f) + assert 'url' in request_data + assert request_data['url'].startswith('https://example.com/') + + +async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None: + """Test that drop removes the entire RQ directory from disk.""" + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + rq_path = rq_client.path_to_rq + assert rq_path.exists() + + # Drop the request queue + await rq_client.drop() + + assert not rq_path.exists() + + +async def test_metadata_file_updates(rq_client: FileSystemRequestQueueClient) -> None: + """Test that metadata file is updated correctly after operations.""" + # Record initial timestamps + metadata = await rq_client.get_metadata() + initial_created = metadata.created_at + initial_accessed = metadata.accessed_at + initial_modified = metadata.modified_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a read operation + await rq_client.is_empty() + + # Verify accessed timestamp was updated + metadata = await rq_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified + + accessed_after_read = metadata.accessed_at + + # Wait a moment to ensure timestamps can change + await asyncio.sleep(0.01) + + # Perform a write operation + await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + + # Verify modified timestamp was updated + metadata = await rq_client.get_metadata() + assert metadata.created_at == initial_created + assert metadata.modified_at > initial_modified + assert metadata.accessed_at > accessed_after_read + + # Verify metadata file is updated on disk + with rq_client.path_to_metadata.open() as f: + metadata_json = json.load(f) + assert metadata_json['total_request_count'] == 1 + + +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that requests persist correctly when reopening the same RQ.""" + storage_client = FileSystemStorageClient() + + # Create RQ and add requests + original_client = await storage_client.create_rq_client( + name='persistence-test', + configuration=configuration, + ) + + test_requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + await original_client.add_batch_of_requests(test_requests) + + rq_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify requests persist + reopened_client = await storage_client.create_rq_client( + id=rq_id, + configuration=configuration, + ) + + metadata = await reopened_client.get_metadata() + assert metadata.total_request_count == 2 + + # Fetch requests to verify they're still there + request1 = await reopened_client.fetch_next_request() + request2 = await reopened_client.fetch_next_request() + + assert request1 is not None + assert request2 is not None + assert {request1.url, request2.url} == {'https://example.com/1', 'https://example.com/2'} + + await reopened_client.drop() diff --git a/tests/unit/storages/test_dataset.py b/tests/unit/storages/test_dataset.py index b4f75bc6b4..315f7538c3 100644 --- a/tests/unit/storages/test_dataset.py +++ b/tests/unit/storages/test_dataset.py @@ -8,7 +8,7 @@ import pytest from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SQLStorageClient from crawlee.storages import Dataset, KeyValueStore if TYPE_CHECKING: @@ -19,11 +19,13 @@ from crawlee.storage_clients import StorageClient -@pytest.fixture(params=['memory', 'file_system']) +@pytest.fixture(params=['memory', 'file_system', 'sql']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" if request.param == 'memory': return MemoryStorageClient() + if request.param == 'sql': + return SQLStorageClient() return FileSystemStorageClient() diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 25bbcb4fc0..c885562375 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -9,7 +9,7 @@ import pytest from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SQLStorageClient from crawlee.storages import KeyValueStore if TYPE_CHECKING: @@ -19,11 +19,13 @@ from crawlee.storage_clients import StorageClient -@pytest.fixture(params=['memory', 'file_system']) +@pytest.fixture(params=['memory', 'file_system', 'sql']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" if request.param == 'memory': return MemoryStorageClient() + if request.param == 'sql': + return SQLStorageClient() return FileSystemStorageClient() diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 8df759a27f..1512d4b78f 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -10,7 +10,7 @@ from crawlee import Request, service_locator from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, StorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SQLStorageClient, StorageClient from crawlee.storages import RequestQueue if TYPE_CHECKING: @@ -20,12 +20,13 @@ from crawlee.storage_clients import StorageClient -@pytest.fixture(params=['memory', 'file_system']) +@pytest.fixture(params=['sql']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" if request.param == 'memory': return MemoryStorageClient() - + if request.param == 'sql': + return SQLStorageClient() return FileSystemStorageClient() diff --git a/uv.lock b/uv.lock index ea78cf175e..595102aa16 100644 --- a/uv.lock +++ b/uv.lock @@ -7,6 +7,18 @@ resolution-markers = [ "python_full_version < '3.11'", ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -581,6 +593,7 @@ adaptive-crawler = [ { name = "scikit-learn" }, ] all = [ + { name = "aiosqlite" }, { name = "beautifulsoup4", extra = ["lxml"] }, { name = "cookiecutter" }, { name = "curl-cffi" }, @@ -598,6 +611,7 @@ all = [ { name = "playwright" }, { name = "rich" }, { name = "scikit-learn" }, + { name = "sqlalchemy", extra = ["asyncio"] }, { name = "typer" }, { name = "wrapt" }, ] @@ -633,6 +647,10 @@ parsel = [ playwright = [ { name = "playwright" }, ] +sql = [ + { name = "aiosqlite" }, + { name = "sqlalchemy", extra = ["asyncio"] }, +] [package.dev-dependencies] dev = [ @@ -661,13 +679,14 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", marker = "extra == 'sql'", specifier = ">=0.21.0" }, { name = "apify-fingerprint-datapoints", specifier = ">=0.0.2" }, { name = "beautifulsoup4", extras = ["lxml"], marker = "extra == 'beautifulsoup'", specifier = ">=4.12.0" }, { name = "browserforge", specifier = ">=1.2.3" }, { name = "cachetools", specifier = ">=5.5.0" }, { name = "colorama", specifier = ">=0.4.0" }, { name = "cookiecutter", marker = "extra == 'cli'", specifier = ">=2.6.0" }, - { name = "crawlee", extras = ["adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "parsel", "playwright", "otel"], marker = "extra == 'all'" }, + { name = "crawlee", extras = ["adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "parsel", "playwright", "otel", "sql"], marker = "extra == 'all'" }, { name = "curl-cffi", marker = "extra == 'curl-impersonate'", specifier = ">=0.9.0" }, { name = "html5lib", marker = "extra == 'beautifulsoup'", specifier = ">=1.0" }, { name = "httpx", extras = ["brotli", "http2", "zstd"], specifier = ">=0.27.0" }, @@ -694,13 +713,14 @@ requires-dist = [ { name = "scikit-learn", marker = "extra == 'adaptive-crawler'", specifier = ">=1.6.0" }, { name = "sortedcollections", specifier = ">=2.1.0" }, { name = "sortedcontainers", specifier = ">=2.4.0" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql'", specifier = ">=2.0.42,<3.0.0" }, { name = "tldextract", specifier = ">=5.1.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.0" }, { name = "typing-extensions", specifier = ">=4.1.0" }, { name = "wrapt", marker = "extra == 'otel'", specifier = ">=1.17.0" }, { name = "yarl", specifier = ">=1.18.0" }, ] -provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "impit", "parsel", "playwright", "otel"] +provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "impit", "parsel", "playwright", "otel", "sql"] [package.metadata.requires-dev] dev = [ @@ -2805,6 +2825,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/9c/0e6afc12c269578be5c0c1c9f4b49a8d32770a080260c333ac04cc1c832d/soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4", size = 36677, upload-time = "2025-04-20T18:50:07.196Z" }, ] +[[package]] +name = "sqlalchemy" +version = "2.0.42" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "greenlet", marker = "(python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'WIN32') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'ppc64le') or (python_full_version < '3.14' and platform_machine == 'win32') or (python_full_version < '3.14' and platform_machine == 'x86_64')" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/03/a0af991e3a43174d6b83fca4fb399745abceddd1171bdabae48ce877ff47/sqlalchemy-2.0.42.tar.gz", hash = "sha256:160bedd8a5c28765bd5be4dec2d881e109e33b34922e50a3b881a7681773ac5f", size = 9749972, upload-time = "2025-07-29T12:48:09.323Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/12/33ff43214c2c6cc87499b402fe419869d2980a08101c991daae31345e901/sqlalchemy-2.0.42-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:172b244753e034d91a826f80a9a70f4cbac690641207f2217f8404c261473efe", size = 2130469, upload-time = "2025-07-29T13:25:15.215Z" }, + { url = "https://files.pythonhosted.org/packages/63/c4/4d2f2c21ddde9a2c7f7b258b202d6af0bac9fc5abfca5de367461c86d766/sqlalchemy-2.0.42-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be28f88abd74af8519a4542185ee80ca914933ca65cdfa99504d82af0e4210df", size = 2120393, upload-time = "2025-07-29T13:25:16.367Z" }, + { url = "https://files.pythonhosted.org/packages/a8/0d/5ff2f2dfbac10e4a9ade1942f8985ffc4bd8f157926b1f8aed553dfe3b88/sqlalchemy-2.0.42-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98b344859d282fde388047f1710860bb23f4098f705491e06b8ab52a48aafea9", size = 3206173, upload-time = "2025-07-29T13:29:00.623Z" }, + { url = "https://files.pythonhosted.org/packages/1f/59/71493fe74bd76a773ae8fa0c50bfc2ccac1cbf7cfa4f9843ad92897e6dcf/sqlalchemy-2.0.42-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97978d223b11f1d161390a96f28c49a13ce48fdd2fed7683167c39bdb1b8aa09", size = 3206910, upload-time = "2025-07-29T13:24:50.58Z" }, + { url = "https://files.pythonhosted.org/packages/a9/51/01b1d85bbb492a36b25df54a070a0f887052e9b190dff71263a09f48576b/sqlalchemy-2.0.42-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e35b9b000c59fcac2867ab3a79fc368a6caca8706741beab3b799d47005b3407", size = 3145479, upload-time = "2025-07-29T13:29:02.3Z" }, + { url = "https://files.pythonhosted.org/packages/fa/78/10834f010e2a3df689f6d1888ea6ea0074ff10184e6a550b8ed7f9189a89/sqlalchemy-2.0.42-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bc7347ad7a7b1c78b94177f2d57263113bb950e62c59b96ed839b131ea4234e1", size = 3169605, upload-time = "2025-07-29T13:24:52.135Z" }, + { url = "https://files.pythonhosted.org/packages/0c/75/e6fdd66d237582c8488dd1dfa90899f6502822fbd866363ab70e8ac4a2ce/sqlalchemy-2.0.42-cp310-cp310-win32.whl", hash = "sha256:739e58879b20a179156b63aa21f05ccacfd3e28e08e9c2b630ff55cd7177c4f1", size = 2098759, upload-time = "2025-07-29T13:23:55.809Z" }, + { url = "https://files.pythonhosted.org/packages/a5/a8/366db192641c2c2d1ea8977e7c77b65a0d16a7858907bb76ea68b9dd37af/sqlalchemy-2.0.42-cp310-cp310-win_amd64.whl", hash = "sha256:1aef304ada61b81f1955196f584b9e72b798ed525a7c0b46e09e98397393297b", size = 2122423, upload-time = "2025-07-29T13:23:56.968Z" }, + { url = "https://files.pythonhosted.org/packages/ea/3c/7bfd65f3c2046e2fb4475b21fa0b9d7995f8c08bfa0948df7a4d2d0de869/sqlalchemy-2.0.42-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c34100c0b7ea31fbc113c124bcf93a53094f8951c7bf39c45f39d327bad6d1e7", size = 2133779, upload-time = "2025-07-29T13:25:18.446Z" }, + { url = "https://files.pythonhosted.org/packages/66/17/19be542fe9dd64a766090e90e789e86bdaa608affda6b3c1e118a25a2509/sqlalchemy-2.0.42-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad59dbe4d1252448c19d171dfba14c74e7950b46dc49d015722a4a06bfdab2b0", size = 2123843, upload-time = "2025-07-29T13:25:19.749Z" }, + { url = "https://files.pythonhosted.org/packages/14/fc/83e45fc25f0acf1c26962ebff45b4c77e5570abb7c1a425a54b00bcfa9c7/sqlalchemy-2.0.42-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9187498c2149919753a7fd51766ea9c8eecdec7da47c1b955fa8090bc642eaa", size = 3294824, upload-time = "2025-07-29T13:29:03.879Z" }, + { url = "https://files.pythonhosted.org/packages/b9/81/421efc09837104cd1a267d68b470e5b7b6792c2963b8096ca1e060ba0975/sqlalchemy-2.0.42-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f092cf83ebcafba23a247f5e03f99f5436e3ef026d01c8213b5eca48ad6efa9", size = 3294662, upload-time = "2025-07-29T13:24:53.715Z" }, + { url = "https://files.pythonhosted.org/packages/2f/ba/55406e09d32ed5e5f9e8aaec5ef70c4f20b4ae25b9fa9784f4afaa28e7c3/sqlalchemy-2.0.42-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fc6afee7e66fdba4f5a68610b487c1f754fccdc53894a9567785932dbb6a265e", size = 3229413, upload-time = "2025-07-29T13:29:05.638Z" }, + { url = "https://files.pythonhosted.org/packages/d4/c4/df596777fce27bde2d1a4a2f5a7ddea997c0c6d4b5246aafba966b421cc0/sqlalchemy-2.0.42-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:260ca1d2e5910f1f1ad3fe0113f8fab28657cee2542cb48c2f342ed90046e8ec", size = 3255563, upload-time = "2025-07-29T13:24:55.17Z" }, + { url = "https://files.pythonhosted.org/packages/16/ed/b9c4a939b314400f43f972c9eb0091da59d8466ef9c51d0fd5b449edc495/sqlalchemy-2.0.42-cp311-cp311-win32.whl", hash = "sha256:2eb539fd83185a85e5fcd6b19214e1c734ab0351d81505b0f987705ba0a1e231", size = 2098513, upload-time = "2025-07-29T13:23:58.946Z" }, + { url = "https://files.pythonhosted.org/packages/91/72/55b0c34e39feb81991aa3c974d85074c356239ac1170dfb81a474b4c23b3/sqlalchemy-2.0.42-cp311-cp311-win_amd64.whl", hash = "sha256:9193fa484bf00dcc1804aecbb4f528f1123c04bad6a08d7710c909750fa76aeb", size = 2123380, upload-time = "2025-07-29T13:24:00.155Z" }, + { url = "https://files.pythonhosted.org/packages/61/66/ac31a9821fc70a7376321fb2c70fdd7eadbc06dadf66ee216a22a41d6058/sqlalchemy-2.0.42-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:09637a0872689d3eb71c41e249c6f422e3e18bbd05b4cd258193cfc7a9a50da2", size = 2132203, upload-time = "2025-07-29T13:29:19.291Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ba/fd943172e017f955d7a8b3a94695265b7114efe4854feaa01f057e8f5293/sqlalchemy-2.0.42-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3cb3ec67cc08bea54e06b569398ae21623534a7b1b23c258883a7c696ae10df", size = 2120373, upload-time = "2025-07-29T13:29:21.049Z" }, + { url = "https://files.pythonhosted.org/packages/ea/a2/b5f7d233d063ffadf7e9fff3898b42657ba154a5bec95a96f44cba7f818b/sqlalchemy-2.0.42-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87e6a5ef6f9d8daeb2ce5918bf5fddecc11cae6a7d7a671fcc4616c47635e01", size = 3317685, upload-time = "2025-07-29T13:26:40.837Z" }, + { url = "https://files.pythonhosted.org/packages/86/00/fcd8daab13a9119d41f3e485a101c29f5d2085bda459154ba354c616bf4e/sqlalchemy-2.0.42-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b718011a9d66c0d2f78e1997755cd965f3414563b31867475e9bc6efdc2281d", size = 3326967, upload-time = "2025-07-29T13:22:31.009Z" }, + { url = "https://files.pythonhosted.org/packages/a3/85/e622a273d648d39d6771157961956991a6d760e323e273d15e9704c30ccc/sqlalchemy-2.0.42-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:16d9b544873fe6486dddbb859501a07d89f77c61d29060bb87d0faf7519b6a4d", size = 3255331, upload-time = "2025-07-29T13:26:42.579Z" }, + { url = "https://files.pythonhosted.org/packages/3a/a0/2c2338b592c7b0a61feffd005378c084b4c01fabaf1ed5f655ab7bd446f0/sqlalchemy-2.0.42-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:21bfdf57abf72fa89b97dd74d3187caa3172a78c125f2144764a73970810c4ee", size = 3291791, upload-time = "2025-07-29T13:22:32.454Z" }, + { url = "https://files.pythonhosted.org/packages/41/19/b8a2907972a78285fdce4c880ecaab3c5067eb726882ca6347f7a4bf64f6/sqlalchemy-2.0.42-cp312-cp312-win32.whl", hash = "sha256:78b46555b730a24901ceb4cb901c6b45c9407f8875209ed3c5d6bcd0390a6ed1", size = 2096180, upload-time = "2025-07-29T13:16:08.952Z" }, + { url = "https://files.pythonhosted.org/packages/48/1f/67a78f3dfd08a2ed1c7be820fe7775944f5126080b5027cc859084f8e223/sqlalchemy-2.0.42-cp312-cp312-win_amd64.whl", hash = "sha256:4c94447a016f36c4da80072e6c6964713b0af3c8019e9c4daadf21f61b81ab53", size = 2123533, upload-time = "2025-07-29T13:16:11.705Z" }, + { url = "https://files.pythonhosted.org/packages/e9/7e/25d8c28b86730c9fb0e09156f601d7a96d1c634043bf8ba36513eb78887b/sqlalchemy-2.0.42-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:941804f55c7d507334da38133268e3f6e5b0340d584ba0f277dd884197f4ae8c", size = 2127905, upload-time = "2025-07-29T13:29:22.249Z" }, + { url = "https://files.pythonhosted.org/packages/e5/a1/9d8c93434d1d983880d976400fcb7895a79576bd94dca61c3b7b90b1ed0d/sqlalchemy-2.0.42-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:95d3d06a968a760ce2aa6a5889fefcbdd53ca935735e0768e1db046ec08cbf01", size = 2115726, upload-time = "2025-07-29T13:29:23.496Z" }, + { url = "https://files.pythonhosted.org/packages/a2/cc/d33646fcc24c87cc4e30a03556b611a4e7bcfa69a4c935bffb923e3c89f4/sqlalchemy-2.0.42-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cf10396a8a700a0f38ccd220d940be529c8f64435c5d5b29375acab9267a6c9", size = 3246007, upload-time = "2025-07-29T13:26:44.166Z" }, + { url = "https://files.pythonhosted.org/packages/67/08/4e6c533d4c7f5e7c4cbb6fe8a2c4e813202a40f05700d4009a44ec6e236d/sqlalchemy-2.0.42-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9cae6c2b05326d7c2c7c0519f323f90e0fb9e8afa783c6a05bb9ee92a90d0f04", size = 3250919, upload-time = "2025-07-29T13:22:33.74Z" }, + { url = "https://files.pythonhosted.org/packages/5c/82/f680e9a636d217aece1b9a8030d18ad2b59b5e216e0c94e03ad86b344af3/sqlalchemy-2.0.42-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f50f7b20677b23cfb35b6afcd8372b2feb348a38e3033f6447ee0704540be894", size = 3180546, upload-time = "2025-07-29T13:26:45.648Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a2/8c8f6325f153894afa3775584c429cc936353fb1db26eddb60a549d0ff4b/sqlalchemy-2.0.42-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9d88a1c0d66d24e229e3938e1ef16ebdbd2bf4ced93af6eff55225f7465cf350", size = 3216683, upload-time = "2025-07-29T13:22:34.977Z" }, + { url = "https://files.pythonhosted.org/packages/39/44/3a451d7fa4482a8ffdf364e803ddc2cfcafc1c4635fb366f169ecc2c3b11/sqlalchemy-2.0.42-cp313-cp313-win32.whl", hash = "sha256:45c842c94c9ad546c72225a0c0d1ae8ef3f7c212484be3d429715a062970e87f", size = 2093990, upload-time = "2025-07-29T13:16:13.036Z" }, + { url = "https://files.pythonhosted.org/packages/4b/9e/9bce34f67aea0251c8ac104f7bdb2229d58fb2e86a4ad8807999c4bee34b/sqlalchemy-2.0.42-cp313-cp313-win_amd64.whl", hash = "sha256:eb9905f7f1e49fd57a7ed6269bc567fcbbdac9feadff20ad6bd7707266a91577", size = 2120473, upload-time = "2025-07-29T13:16:14.502Z" }, + { url = "https://files.pythonhosted.org/packages/ee/55/ba2546ab09a6adebc521bf3974440dc1d8c06ed342cceb30ed62a8858835/sqlalchemy-2.0.42-py3-none-any.whl", hash = "sha256:defcdff7e661f0043daa381832af65d616e060ddb54d3fe4476f51df7eaa1835", size = 1922072, upload-time = "2025-07-29T13:09:17.061Z" }, +] + +[package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] + [[package]] name = "text-unidecode" version = "1.3" From b056505ea2b425c70c13d7afb54c6c3ea824fe25 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Tue, 29 Jul 2025 23:25:31 +0000 Subject: [PATCH 02/29] add dataset tests --- .../storage_clients/_sql/_dataset_client.py | 27 ++- .../_sql/_key_value_store_client.py | 6 + .../storage_clients/_sql/_storage_client.py | 23 +- .../_sql/test_sql_dataset_client.py | 210 ++++++++++++------ 4 files changed, 174 insertions(+), 92 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 964a5306fc..746f8e38d6 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator + from sqlalchemy.ext.asyncio import AsyncSession + from ._storage_client import SQLStorageClient logger = getLogger(__name__) @@ -47,6 +49,10 @@ def __init__( self._orm_metadata = orm_metadata self._storage_client = storage_client + def create_session(self) -> AsyncSession: + """Create a new SQLAlchemy session for this key-value store.""" + return self._storage_client.create_session() + @override async def get_metadata(self) -> DatasetMetadata: return DatasetMetadata.model_validate(self._orm_metadata) @@ -85,7 +91,9 @@ async def open( await client._update_metadata(update_accessed_at=True) else: - orm_metadata = await session.get(DatasetMetadataDB, name) + stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() if orm_metadata: client = cls( orm_metadata=orm_metadata, @@ -116,15 +124,15 @@ async def open( @override async def drop(self) -> None: - async with self._storage_client.create_session() as session: - dataset_db = await session.get(DatasetItemDB, self._orm_metadata.id) + async with self.create_session() as session: + dataset_db = await session.get(DatasetMetadataDB, self._orm_metadata.id) if dataset_db: await session.delete(dataset_db) await session.commit() @override async def purge(self) -> None: - async with self._storage_client.create_session() as session: + async with self.create_session() as session: stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) await session.execute(stmt) @@ -149,10 +157,13 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: ) ) - async with self._storage_client.create_session() as session: + async with self.create_session() as session: session.add_all(db_items) self._orm_metadata.item_count += len(data) - await self._update_metadata(update_modified_at=True) + await self._update_metadata( + update_accessed_at=True, + update_modified_at=True, + ) await session.commit() @@ -199,7 +210,7 @@ async def get_data( stmt = stmt.offset(offset).limit(limit) - async with self._storage_client.create_session() as session: + async with self.create_session() as session: result = await session.execute(stmt) db_items = result.scalars().all() @@ -256,7 +267,7 @@ async def iterate_items( stmt = stmt.offset(offset).limit(limit) - async with self._storage_client.create_session() as session: + async with self.create_session() as session: result = await session.execute(stmt) db_items = result.scalars().all() diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index e4826eaa1e..bbacc19dbe 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -18,6 +18,8 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator + from sqlalchemy.ext.asyncio import AsyncSession + from ._storage_client import SQLStorageClient @@ -60,6 +62,10 @@ def __init__( self._storage_client = storage_client """The storage client used to access the SQL database.""" + async def create_session(self) -> AsyncSession: + """Create a new SQLAlchemy session for this key-value store.""" + return self._storage_client.create_session() + @override async def get_metadata(self) -> KeyValueStoreMetadata: return KeyValueStoreMetadata.model_validate(self._orm_metadata) diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index 1bb29a290f..f7c201e82f 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -19,12 +19,6 @@ if TYPE_CHECKING: from types import TracebackType - from crawlee.storage_clients._base import ( - DatasetClient, - KeyValueStoreClient, - RequestQueueClient, - ) - @docs_group('Storage clients') class SQLStorageClient(StorageClient): @@ -66,6 +60,15 @@ def __init__( self._engine = engine self._initialized = False + self._default_flag = self._engine is None and self._connection_string is None + + @property + def engine(self) -> AsyncEngine: + """Get the SQLAlchemy AsyncEngine instance.""" + if self._engine is None: + raise ValueError('Engine is not initialized. Call initialize() before accessing the engine.') + return self._engine + def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: """Get or create the database engine based on configuration.""" if self._engine is not None: @@ -98,7 +101,7 @@ async def initialize(self, configuration: Configuration) -> None: engine = self._get_or_create_engine(configuration) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - if 'sqlite' in str(engine.url): + if self._default_flag: await conn.execute(text('PRAGMA journal_mode=WAL')) await conn.execute(text('PRAGMA synchronous=NORMAL')) await conn.execute(text('PRAGMA cache_size=10000')) @@ -130,7 +133,7 @@ async def create_dataset_client( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, - ) -> DatasetClient: + ) -> SQLDatasetClient: configuration = configuration or Configuration.get_global_configuration() await self.initialize(configuration) @@ -150,7 +153,7 @@ async def create_kvs_client( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, - ) -> KeyValueStoreClient: + ) -> SQLKeyValueStoreClient: configuration = configuration or Configuration.get_global_configuration() await self.initialize(configuration) @@ -170,7 +173,7 @@ async def create_rq_client( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, - ) -> RequestQueueClient: + ) -> SQLRequestQueueClient: configuration = configuration or Configuration.get_global_configuration() await self.initialize(configuration) diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index c5f31f144e..609df4cfc2 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -2,19 +2,24 @@ import asyncio import json -from pathlib import Path from typing import TYPE_CHECKING import pytest +from sqlalchemy import inspect, select +from sqlalchemy.ext.asyncio import create_async_engine -from crawlee._consts import METADATA_FILENAME from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients import SQLStorageClient +from crawlee.storage_clients._sql._db_models import DatasetItemDB, DatasetMetadataDB +from crawlee.storage_clients.models import DatasetMetadata if TYPE_CHECKING: from collections.abc import AsyncGenerator + from pathlib import Path - from crawlee.storage_clients._file_system import FileSystemDatasetClient + from sqlalchemy import Connection + + from crawlee.storage_clients._sql import SQLDatasetClient @pytest.fixture @@ -24,80 +29,139 @@ def configuration(tmp_path: Path) -> Configuration: ) -@pytest.fixture -async def dataset_client(configuration: Configuration) -> AsyncGenerator[FileSystemDatasetClient, None]: - """A fixture for a file system dataset client.""" - client = await FileSystemStorageClient().create_dataset_client( - name='test_dataset', - configuration=configuration, - ) - yield client - await client.drop() - +# Helper function that allows you to use inspect with an asynchronous engine +def get_tables(sync_conn: Connection) -> list[str]: + inspector = inspect(sync_conn) + return inspector.get_table_names() -async def test_file_and_directory_creation(configuration: Configuration) -> None: - """Test that file system dataset creates proper files and directories.""" - client = await FileSystemStorageClient().create_dataset_client( - name='new_dataset', - configuration=configuration, - ) - # Verify files were created - assert client.path_to_dataset.exists() - assert client.path_to_metadata.exists() +@pytest.fixture +async def dataset_client(configuration: Configuration) -> AsyncGenerator[SQLDatasetClient, None]: + """A fixture for a SQL dataset client.""" + async with SQLStorageClient() as storage_client: + client = await storage_client.create_dataset_client( + name='test_dataset', + configuration=configuration, + ) + yield client + await client.drop() + + +async def test_create_tables_with_connection_string(configuration: Configuration, tmp_path: Path) -> None: + """Test that SQL dataset client creates tables with a connection string.""" + storage_dir = tmp_path / 'test_table.db' + + async with SQLStorageClient(connection_string=f'sqlite+aiosqlite:///{storage_dir}') as storage_client: + await storage_client.create_dataset_client( + name='new_dataset', + configuration=configuration, + ) + + async with storage_client.engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'dataset_item' in tables + assert 'dataset_metadata' in tables + + +async def test_create_tables_with_engine(configuration: Configuration, tmp_path: Path) -> None: + """Test that SQL dataset client creates tables with a pre-configured engine.""" + storage_dir = tmp_path / 'test_table.db' + + engine = create_async_engine(f'sqlite+aiosqlite:///{storage_dir}', future=True, echo=False) + + async with SQLStorageClient(engine=engine) as storage_client: + await storage_client.create_dataset_client( + name='new_dataset', + configuration=configuration, + ) + + async with engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'dataset_item' in tables + assert 'dataset_metadata' in tables + + +async def test_tables_and_metadata_record(configuration: Configuration) -> None: + """Test that SQL dataset creates proper tables and metadata records.""" + async with SQLStorageClient() as storage_client: + client = await storage_client.create_dataset_client( + name='new_dataset', + configuration=configuration, + ) - # Verify metadata file structure - with client.path_to_metadata.open() as f: - metadata = json.load(f) client_metadata = await client.get_metadata() - assert metadata['id'] == client_metadata.id - assert metadata['name'] == 'new_dataset' - assert metadata['item_count'] == 0 - await client.drop() + async with storage_client.engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'dataset_item' in tables + assert 'dataset_metadata' in tables + async with client.create_session() as session: + stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == 'new_dataset') + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + metadata = DatasetMetadata.model_validate(orm_metadata) + assert metadata.id == client_metadata.id + assert metadata.name == 'new_dataset' + assert metadata.item_count == 0 -async def test_file_persistence_and_content_verification(dataset_client: FileSystemDatasetClient) -> None: - """Test that data is properly persisted to files with correct content.""" + +async def test_record_and_content_verification(dataset_client: SQLDatasetClient) -> None: + """Test that dataset client can push data and verify its content.""" item = {'key': 'value', 'number': 42} await dataset_client.push_data(item) - # Verify files are created on disk - all_files = list(dataset_client.path_to_dataset.glob('*.json')) - assert len(all_files) == 2 # 1 data file + 1 metadata file - - # Verify actual file content - data_files = [item for item in all_files if item.name != METADATA_FILENAME] - assert len(data_files) == 1 - - with Path(data_files[0]).open() as f: - saved_item = json.load(f) + # Verify metadata record + metadata = await dataset_client.get_metadata() + assert metadata.item_count == 1 + assert metadata.created_at is not None + assert metadata.modified_at is not None + assert metadata.accessed_at is not None + + async with dataset_client.create_session() as session: + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == metadata.id) + result = await session.execute(stmt) + records = result.scalars().all() + assert len(records) == 1 + saved_item = json.loads(records[0].data) assert saved_item == item # Test multiple items file creation items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] await dataset_client.push_data(items) - all_files = list(dataset_client.path_to_dataset.glob('*.json')) - assert len(all_files) == 5 # 4 data files + 1 metadata file + async with dataset_client.create_session() as session: + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == metadata.id) + result = await session.execute(stmt) + records = result.scalars().all() + assert len(records) == 4 - data_files = [f for f in all_files if f.name != METADATA_FILENAME] - assert len(data_files) == 4 # Original item + 3 new items - -async def test_drop_removes_files_from_disk(dataset_client: FileSystemDatasetClient) -> None: - """Test that dropping a dataset removes the entire dataset directory from disk.""" +async def test_drop_removes_records(dataset_client: SQLDatasetClient) -> None: + """Test that dropping a dataset removes all records from the database.""" await dataset_client.push_data({'test': 'data'}) - assert dataset_client.path_to_dataset.exists() + client_metadata = await dataset_client.get_metadata() + + async with dataset_client.create_session() as session: + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == client_metadata.id) + result = await session.execute(stmt) + records = result.scalars().all() + assert len(records) == 1 # Drop the dataset await dataset_client.drop() - assert not dataset_client.path_to_dataset.exists() + async with dataset_client.create_session() as session: + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == client_metadata.id) + result = await session.execute(stmt) + records = result.scalars().all() + assert len(records) == 0 + metadata = await session.get(DatasetMetadataDB, client_metadata.id) + assert metadata is None -async def test_metadata_file_updates(dataset_client: FileSystemDatasetClient) -> None: +async def test_metadata_recaord_updates(dataset_client: SQLDatasetClient) -> None: """Test that metadata file is updated correctly after operations.""" # Record initial timestamps metadata = await dataset_client.get_metadata() @@ -132,34 +196,32 @@ async def test_metadata_file_updates(dataset_client: FileSystemDatasetClient) -> assert metadata.accessed_at > accessed_after_get # Verify metadata file is updated on disk - with dataset_client.path_to_metadata.open() as f: - metadata_json = json.load(f) - assert metadata_json['item_count'] == 1 + async with dataset_client.create_session() as session: + orm_metadata = await session.get(DatasetMetadataDB, metadata.id) + record_metadata = DatasetMetadata.model_validate(orm_metadata) + record_metadata.item_count = 1 async def test_data_persistence_across_reopens(configuration: Configuration) -> None: """Test that data persists correctly when reopening the same dataset.""" - storage_client = FileSystemStorageClient() + async with SQLStorageClient() as storage_client: + original_client = await storage_client.create_dataset_client( + name='persistence-test', + configuration=configuration, + ) - # Create dataset and add data - original_client = await storage_client.create_dataset_client( - name='persistence-test', - configuration=configuration, - ) - - test_data = {'test_item': 'test_value', 'id': 123} - await original_client.push_data(test_data) + test_data = {'test_item': 'test_value', 'id': 123} + await original_client.push_data(test_data) - dataset_id = (await original_client.get_metadata()).id + dataset_id = (await original_client.get_metadata()).id - # Reopen by ID and verify data persists - reopened_client = await storage_client.create_dataset_client( - id=dataset_id, - configuration=configuration, - ) + reopened_client = await storage_client.create_dataset_client( + id=dataset_id, + configuration=configuration, + ) - data = await reopened_client.get_data() - assert len(data.items) == 1 - assert data.items[0] == test_data + data = await reopened_client.get_data() + assert len(data.items) == 1 + assert data.items[0] == test_data - await reopened_client.drop() + await reopened_client.drop() From ae3bc3da65b638e41659609708a329e4e96fb1c5 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 30 Jul 2025 00:34:54 +0000 Subject: [PATCH 03/29] add kvs tests --- .../storage_clients/_sql/_dataset_client.py | 7 +- .../_sql/_key_value_store_client.py | 11 +- .../storage_clients/_sql/_storage_client.py | 4 +- .../_sql/test_sql_dataset_client.py | 12 +- .../_sql/test_sql_kvs_client.py | 286 +++++++++++------- 5 files changed, 203 insertions(+), 117 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 746f8e38d6..a3c126ef39 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -138,6 +138,7 @@ async def purge(self) -> None: self._orm_metadata.item_count = 0 await self._update_metadata(update_accessed_at=True, update_modified_at=True) + await session.merge(self._orm_metadata) await session.commit() @override @@ -164,7 +165,7 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: update_accessed_at=True, update_modified_at=True, ) - + await session.merge(self._orm_metadata) await session.commit() @override @@ -215,7 +216,7 @@ async def get_data( db_items = result.scalars().all() await self._update_metadata(update_accessed_at=True) - + await session.merge(self._orm_metadata) await session.commit() items = [json.loads(db_item.data) for db_item in db_items] @@ -272,7 +273,7 @@ async def iterate_items( db_items = result.scalars().all() await self._update_metadata(update_accessed_at=True) - + await session.merge(self._orm_metadata) await session.commit() items = [json.loads(db_item.data) for db_item in db_items] diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index bbacc19dbe..30d57ca444 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -62,7 +62,7 @@ def __init__( self._storage_client = storage_client """The storage client used to access the SQL database.""" - async def create_session(self) -> AsyncSession: + def create_session(self) -> AsyncSession: """Create a new SQLAlchemy session for this key-value store.""" return self._storage_client.create_session() @@ -190,20 +190,22 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No else: session.add(record_db) self._update_metadata(update_accessed_at=True, update_modified_at=True) + await session.merge(self._orm_metadata) await session.commit() @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: # Update the metadata to record access async with self._storage_client.create_session() as session: - self._update_metadata(update_accessed_at=True) - stmt = select(KeyValueStoreRecordDB).where( KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key ) result = await session.execute(stmt) record_db = result.scalar_one_or_none() + self._update_metadata(update_accessed_at=True) + + await session.merge(self._orm_metadata) await session.commit() if not record_db: @@ -252,6 +254,7 @@ async def delete_value(self, *, key: str) -> None: # Update metadata if we actually deleted something if result.rowcount > 0: self._update_metadata(update_accessed_at=True, update_modified_at=True) + await session.merge(self._orm_metadata) await session.commit() @@ -281,6 +284,7 @@ async def iterate_keys( result = await session.execute(stmt) self._update_metadata(update_accessed_at=True) + await session.merge(self._orm_metadata) await session.commit() for row in result: @@ -300,6 +304,7 @@ async def record_exists(self, *, key: str) -> bool: result = await session.execute(stmt) self._update_metadata(update_accessed_at=True) + await session.merge(self._orm_metadata) await session.commit() return result.scalar_one_or_none() is not None diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index f7c201e82f..9e29f7b481 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -60,6 +60,7 @@ def __init__( self._engine = engine self._initialized = False + # Default flag to indicate if the default database should be created self._default_flag = self._engine is None and self._connection_string is None @property @@ -101,10 +102,11 @@ async def initialize(self, configuration: Configuration) -> None: engine = self._get_or_create_engine(configuration) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + # Set SQLite pragmas for performance and consistency if self._default_flag: await conn.execute(text('PRAGMA journal_mode=WAL')) await conn.execute(text('PRAGMA synchronous=NORMAL')) - await conn.execute(text('PRAGMA cache_size=10000')) + await conn.execute(text('PRAGMA cache_size=100000')) await conn.execute(text('PRAGMA temp_store=MEMORY')) await conn.execute(text('PRAGMA mmap_size=268435456')) await conn.execute(text('PRAGMA foreign_keys=ON')) diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index 609df4cfc2..ced147f3f8 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -2,6 +2,7 @@ import asyncio import json +from datetime import timezone from typing import TYPE_CHECKING import pytest @@ -105,6 +106,8 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: assert metadata.name == 'new_dataset' assert metadata.item_count == 0 + await client.drop() + async def test_record_and_content_verification(dataset_client: SQLDatasetClient) -> None: """Test that dataset client can push data and verify its content.""" @@ -177,9 +180,9 @@ async def test_metadata_recaord_updates(dataset_client: SQLDatasetClient) -> Non # Verify timestamps metadata = await dataset_client.get_metadata() - assert metadata.created_at == initial_created - assert metadata.accessed_at > initial_accessed - assert metadata.modified_at == initial_modified + assert metadata.created_at.replace(tzinfo=timezone.utc) == initial_created + assert metadata.accessed_at.replace(tzinfo=timezone.utc) > initial_accessed + assert metadata.modified_at.replace(tzinfo=timezone.utc) == initial_modified accessed_after_get = metadata.accessed_at @@ -200,6 +203,9 @@ async def test_metadata_recaord_updates(dataset_client: SQLDatasetClient) -> Non orm_metadata = await session.get(DatasetMetadataDB, metadata.id) record_metadata = DatasetMetadata.model_validate(orm_metadata) record_metadata.item_count = 1 + assert record_metadata.created_at.replace(tzinfo=timezone.utc) == initial_created + assert record_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at + assert record_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at async def test_data_persistence_across_reopens(configuration: Configuration) -> None: diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py index c5bfa96c47..852132cdff 100644 --- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -2,19 +2,25 @@ import asyncio import json +from datetime import timezone from typing import TYPE_CHECKING import pytest +from sqlalchemy import inspect, select +from sqlalchemy.ext.asyncio import create_async_engine -from crawlee._consts import METADATA_FILENAME from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients import SQLStorageClient +from crawlee.storage_clients._sql._db_models import KeyValueStoreMetadataDB, KeyValueStoreRecordDB +from crawlee.storage_clients.models import KeyValueStoreMetadata if TYPE_CHECKING: from collections.abc import AsyncGenerator from pathlib import Path - from crawlee.storage_clients._file_system import FileSystemKeyValueStoreClient + from sqlalchemy import Connection + + from crawlee.storage_clients._sql import SQLKeyValueStoreClient @pytest.fixture @@ -25,130 +31,189 @@ def configuration(tmp_path: Path) -> Configuration: @pytest.fixture -async def kvs_client(configuration: Configuration) -> AsyncGenerator[FileSystemKeyValueStoreClient, None]: - """A fixture for a file system key-value store client.""" - client = await FileSystemStorageClient().create_kvs_client( - name='test_kvs', - configuration=configuration, - ) - yield client - await client.drop() +async def kvs_client(configuration: Configuration) -> AsyncGenerator[SQLKeyValueStoreClient, None]: + """A fixture for a SQL key-value store client.""" + async with SQLStorageClient() as storage_client: + client = await storage_client.create_kvs_client( + name='test_kvs', + configuration=configuration, + ) + yield client + await client.drop() -async def test_file_and_directory_creation(configuration: Configuration) -> None: - """Test that file system KVS creates proper files and directories.""" - client = await FileSystemStorageClient().create_kvs_client( - name='new_kvs', - configuration=configuration, - ) +# Helper function that allows you to use inspect with an asynchronous engine +def get_tables(sync_conn: Connection) -> list[str]: + inspector = inspect(sync_conn) + return inspector.get_table_names() - # Verify files were created - assert client.path_to_kvs.exists() - assert client.path_to_metadata.exists() - # Verify metadata file structure - with client.path_to_metadata.open() as f: - metadata = json.load(f) - assert metadata['id'] == (await client.get_metadata()).id - assert metadata['name'] == 'new_kvs' +async def test_create_tables_with_connection_string(configuration: Configuration, tmp_path: Path) -> None: + """Test that SQL dataset client creates tables with a connection string.""" + storage_dir = tmp_path / 'test_table.db' - await client.drop() + async with SQLStorageClient(connection_string=f'sqlite+aiosqlite:///{storage_dir}') as storage_client: + await storage_client.create_kvs_client( + name='new_kvs', + configuration=configuration, + ) + async with storage_client.engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'kvs_metadata' in tables + assert 'kvs_record' in tables -async def test_value_file_creation_and_content(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that values are properly persisted to files with correct content and metadata.""" - test_key = 'test-key' - test_value = 'Hello, world!' - await kvs_client.set_value(key=test_key, value=test_value) - # Check if the files were created - key_path = kvs_client.path_to_kvs / test_key - key_metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' - assert key_path.exists() - assert key_metadata_path.exists() +async def test_create_tables_with_engine(configuration: Configuration, tmp_path: Path) -> None: + """Test that SQL dataset client creates tables with a pre-configured engine.""" + storage_dir = tmp_path / 'test_table.db' + + engine = create_async_engine(f'sqlite+aiosqlite:///{storage_dir}', future=True, echo=False) - # Check file content - content = key_path.read_text(encoding='utf-8') - assert content == test_value + async with SQLStorageClient(engine=engine) as storage_client: + await storage_client.create_kvs_client( + name='new_kvs', + configuration=configuration, + ) - # Check record metadata file - with key_metadata_path.open() as f: - metadata = json.load(f) - assert metadata['key'] == test_key - assert metadata['content_type'] == 'text/plain; charset=utf-8' - assert metadata['size'] == len(test_value.encode('utf-8')) + async with engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'kvs_metadata' in tables + assert 'kvs_record' in tables -async def test_binary_data_persistence(kvs_client: FileSystemKeyValueStoreClient) -> None: +async def test_tables_and_metadata_record(configuration: Configuration) -> None: + """Test that SQL dataset creates proper tables and metadata records.""" + async with SQLStorageClient() as storage_client: + client = await storage_client.create_kvs_client( + name='new_kvs', + configuration=configuration, + ) + + client_metadata = await client.get_metadata() + + async with storage_client.engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'kvs_metadata' in tables + assert 'kvs_record' in tables + + async with client.create_session() as session: + stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == 'new_kvs') + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + metadata = KeyValueStoreMetadata.model_validate(orm_metadata) + assert metadata.id == client_metadata.id + assert metadata.name == 'new_kvs' + + await client.drop() + + +async def test_value_record_creation(kvs_client: SQLKeyValueStoreClient) -> None: + """Test that key-value store client can create a record.""" + test_key = 'test-key' + test_value = 'Hello, world!' + await kvs_client.set_value(key=test_key, value=test_value) + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is not None + assert record.key == test_key + assert record.content_type == 'text/plain; charset=utf-8' + assert record.size == len(test_value.encode('utf-8')) + assert record.value == test_value.encode('utf-8') + + +async def test_binary_data_persistence(kvs_client: SQLKeyValueStoreClient) -> None: """Test that binary data is stored correctly without corruption.""" test_key = 'test-binary' test_value = b'\x00\x01\x02\x03\x04' await kvs_client.set_value(key=test_key, value=test_value) - # Verify binary file exists - key_path = kvs_client.path_to_kvs / test_key - assert key_path.exists() + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is not None + assert record.key == test_key + assert record.content_type == 'application/octet-stream' + assert record.size == len(test_value) + assert record.value == test_value - # Verify binary content is preserved - content = key_path.read_bytes() - assert content == test_value + verify_record = await kvs_client.get_value(key=test_key) + assert verify_record is not None + assert verify_record.value == test_value + assert verify_record.content_type == 'application/octet-stream' - # Verify retrieval works correctly - record = await kvs_client.get_value(key=test_key) - assert record is not None - assert record.value == test_value - assert record.content_type == 'application/octet-stream' - -async def test_json_serialization_to_file(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that JSON objects are properly serialized to files.""" +async def test_json_serialization_to_record(kvs_client: SQLKeyValueStoreClient) -> None: + """Test that JSON objects are properly serialized to records.""" test_key = 'test-json' test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} await kvs_client.set_value(key=test_key, value=test_value) - # Check if file content is valid JSON - key_path = kvs_client.path_to_kvs / test_key - with key_path.open() as f: - file_content = json.load(f) - assert file_content == test_value + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is not None + assert record.key == test_key + assert json.loads(record.value.decode('utf-8')) == test_value -async def test_file_deletion_on_value_delete(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that deleting a value removes its files from disk.""" +async def test_record_deletion_on_value_delete(kvs_client: SQLKeyValueStoreClient) -> None: + """Test that deleting a value removes its record from the database.""" test_key = 'test-delete' test_value = 'Delete me' # Set a value await kvs_client.set_value(key=test_key, value=test_value) - # Verify files exist - key_path = kvs_client.path_to_kvs / test_key - metadata_path = kvs_client.path_to_kvs / f'{test_key}.{METADATA_FILENAME}' - assert key_path.exists() - assert metadata_path.exists() + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is not None + assert record.key == test_key + assert record.value == test_value.encode('utf-8') # Delete the value await kvs_client.delete_value(key=test_key) - # Verify files were deleted - assert not key_path.exists() - assert not metadata_path.exists() + # Verify record was deleted + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is None -async def test_drop_removes_directory(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that drop removes the entire store directory from disk.""" +async def test_drop_removes_records(kvs_client: SQLKeyValueStoreClient) -> None: + """Test that drop removes all records from the database.""" await kvs_client.set_value(key='test', value='test-value') - assert kvs_client.path_to_kvs.exists() + client_metadata = await kvs_client.get_metadata() + + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == 'test') + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is not None # Drop the store await kvs_client.drop() - assert not kvs_client.path_to_kvs.exists() + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == 'test') + result = await session.execute(stmt) + record = result.scalar_one_or_none() + assert record is None + metadata = await session.get(KeyValueStoreMetadataDB, client_metadata.id) + assert metadata is None -async def test_metadata_file_updates(kvs_client: FileSystemKeyValueStoreClient) -> None: - """Test that read/write operations properly update metadata file timestamps.""" +async def test_metadata_record_updates(kvs_client: SQLKeyValueStoreClient) -> None: + """Test that read/write operations properly update metadata record timestamps.""" # Record initial timestamps metadata = await kvs_client.get_metadata() initial_created = metadata.created_at @@ -181,31 +246,38 @@ async def test_metadata_file_updates(kvs_client: FileSystemKeyValueStoreClient) assert metadata.modified_at > initial_modified assert metadata.accessed_at > accessed_after_read + async with kvs_client.create_session() as session: + stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == metadata.id) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + assert orm_metadata is not None + assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == metadata.created_at + assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at + assert orm_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at -async def test_data_persistence_across_reopens(configuration: Configuration) -> None: - """Test that data persists correctly when reopening the same KVS.""" - storage_client = FileSystemStorageClient() - - # Create KVS and add data - original_client = await storage_client.create_kvs_client( - name='persistence-test', - configuration=configuration, - ) - test_key = 'persistent-key' - test_value = 'persistent-value' - await original_client.set_value(key=test_key, value=test_value) - - kvs_id = (await original_client.get_metadata()).id - - # Reopen by ID and verify data persists - reopened_client = await storage_client.create_kvs_client( - id=kvs_id, - configuration=configuration, - ) - - record = await reopened_client.get_value(key=test_key) - assert record is not None - assert record.value == test_value - - await reopened_client.drop() +async def test_data_persistence_across_reopens(configuration: Configuration) -> None: + """Test that data persists correctly when reopening the same dataset.""" + async with SQLStorageClient() as storage_client: + original_client = await storage_client.create_kvs_client( + name='persistence-test', + configuration=configuration, + ) + + test_key = 'persistent-key' + test_value = 'persistent-value' + await original_client.set_value(key=test_key, value=test_value) + + kvs_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify data persists + reopened_client = await storage_client.create_kvs_client( + id=kvs_id, + configuration=configuration, + ) + + record = await reopened_client.get_value(key=test_key) + assert record is not None + assert record.value == test_value + + await reopened_client.drop() From 49f2643a10f09c260ee32a2aab82557b4bf794fb Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 30 Jul 2025 16:28:40 +0000 Subject: [PATCH 04/29] add rq tests --- .../storage_clients/_sql/_db_models.py | 2 - .../_sql/_request_queue_client.py | 22 +- .../_sql/test_sql_dataset_client.py | 19 +- .../_sql/test_sql_kvs_client.py | 4 +- .../_sql/test_sql_rq_client.py | 234 +++++++++++------- 5 files changed, 172 insertions(+), 109 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 32818ab358..5a48a79c66 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -1,7 +1,6 @@ from __future__ import annotations from datetime import datetime # noqa: TC003 -from typing import Any from sqlalchemy import ( JSON, @@ -41,7 +40,6 @@ class RequestQueueMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-typ had_multiple_clients: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) handled_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) pending_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - stats: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False, default={}) total_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) requests: Mapped[list[RequestDB]] = relationship(back_populates='queue', cascade='all, delete-orphan') diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 3bf25da152..67b21b8a62 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -27,6 +27,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from sqlalchemy.ext.asyncio import AsyncSession + from ._storage_client import SQLStorageClient @@ -96,6 +98,10 @@ def __init__( self._lock = asyncio.Lock() + def create_session(self) -> AsyncSession: + """Create a new SQLAlchemy session for this key-value store.""" + return self._storage_client.create_session() + @override async def get_metadata(self) -> RequestQueueMetadata: return RequestQueueMetadata.model_validate(self._orm_metadata) @@ -152,7 +158,6 @@ async def open( had_multiple_clients=False, handled_request_count=0, pending_request_count=0, - stats={}, total_request_count=0, ) orm_metadata = RequestQueueMetadataDB(**metadata.model_dump()) @@ -200,6 +205,8 @@ async def purge(self) -> None: await self._update_metadata(update_modified_at=True, update_accessed_at=True) self._is_empty_cache = None + + await session.merge(self._orm_metadata) await session.commit() # Clear recoverable state @@ -304,11 +311,10 @@ async def add_batch_of_requests( self._request_cache_needs_refresh = True try: + await session.merge(self._orm_metadata) await session.commit() except SQLAlchemyError as e: logger.warning(f'Failed to commit session: {e}') - await session.rollback() - input() processed_requests.clear() unprocessed_requests.extend( [ @@ -346,6 +352,7 @@ async def get_request(self, request_id: str) -> Request | None: state.in_progress_requests.add(request.id) await self._update_metadata(update_accessed_at=True) + await session.merge(self._orm_metadata) await session.commit() return request @@ -404,7 +411,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | self._orm_metadata.pending_request_count -= 1 await self._update_metadata(update_modified_at=True, update_accessed_at=True) - + await session.merge(self._orm_metadata) await session.commit() return ProcessedRequest( @@ -460,7 +467,7 @@ async def reclaim_request( self._request_cache.append(request) await self._update_metadata(update_modified_at=True, update_accessed_at=True) - + await session.merge(self._orm_metadata) await session.commit() return ProcessedRequest( @@ -495,6 +502,11 @@ async def is_empty(self) -> bool: result = await session.execute(stmt) unhandled_count = result.scalar() self._is_empty_cache = unhandled_count == 0 + + await self._update_metadata(update_accessed_at=True) + await session.merge(self._orm_metadata) + await session.commit() + return self._is_empty_cache async def _refresh_cache(self) -> None: diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index ced147f3f8..0ff0c13bdf 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -12,7 +12,6 @@ from crawlee.configuration import Configuration from crawlee.storage_clients import SQLStorageClient from crawlee.storage_clients._sql._db_models import DatasetItemDB, DatasetMetadataDB -from crawlee.storage_clients.models import DatasetMetadata if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -101,10 +100,10 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == 'new_dataset') result = await session.execute(stmt) orm_metadata = result.scalar_one_or_none() - metadata = DatasetMetadata.model_validate(orm_metadata) - assert metadata.id == client_metadata.id - assert metadata.name == 'new_dataset' - assert metadata.item_count == 0 + assert orm_metadata is not None + assert orm_metadata.id == client_metadata.id + assert orm_metadata.name == 'new_dataset' + assert orm_metadata.item_count == 0 await client.drop() @@ -201,11 +200,11 @@ async def test_metadata_recaord_updates(dataset_client: SQLDatasetClient) -> Non # Verify metadata file is updated on disk async with dataset_client.create_session() as session: orm_metadata = await session.get(DatasetMetadataDB, metadata.id) - record_metadata = DatasetMetadata.model_validate(orm_metadata) - record_metadata.item_count = 1 - assert record_metadata.created_at.replace(tzinfo=timezone.utc) == initial_created - assert record_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at - assert record_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at + assert orm_metadata is not None + orm_metadata.item_count = 1 + assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == initial_created + assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at + assert orm_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at async def test_data_persistence_across_reopens(configuration: Configuration) -> None: diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py index 852132cdff..9dcc388f86 100644 --- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -247,9 +247,7 @@ async def test_metadata_record_updates(kvs_client: SQLKeyValueStoreClient) -> No assert metadata.accessed_at > accessed_after_read async with kvs_client.create_session() as session: - stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == metadata.id) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() + orm_metadata = await session.get(KeyValueStoreMetadataDB, metadata.id) assert orm_metadata is not None assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == metadata.created_at assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py index 0be182fcd8..2607d29c29 100644 --- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -2,19 +2,26 @@ import asyncio import json +from datetime import timezone from typing import TYPE_CHECKING import pytest +from sqlalchemy import inspect, select +from sqlalchemy.ext.asyncio import create_async_engine from crawlee import Request from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients import SQLStorageClient +from crawlee.storage_clients._sql._db_models import RequestDB, RequestQueueMetadataDB +from crawlee.storage_clients.models import RequestQueueMetadata if TYPE_CHECKING: from collections.abc import AsyncGenerator from pathlib import Path - from crawlee.storage_clients._file_system import FileSystemRequestQueueClient + from sqlalchemy import Connection + + from crawlee.storage_clients._sql import SQLRequestQueueClient @pytest.fixture @@ -25,38 +32,84 @@ def configuration(tmp_path: Path) -> Configuration: @pytest.fixture -async def rq_client(configuration: Configuration) -> AsyncGenerator[FileSystemRequestQueueClient, None]: - """A fixture for a file system request queue client.""" - client = await FileSystemStorageClient().create_rq_client( - name='test_request_queue', - configuration=configuration, - ) - yield client - await client.drop() +async def rq_client(configuration: Configuration) -> AsyncGenerator[SQLRequestQueueClient, None]: + """A fixture for a SQL request queue client.""" + async with SQLStorageClient() as storage_client: + client = await storage_client.create_rq_client( + name='test_request_queue', + configuration=configuration, + ) + yield client + await client.drop() -async def test_file_and_directory_creation(configuration: Configuration) -> None: - """Test that file system RQ creates proper files and directories.""" - client = await FileSystemStorageClient().create_rq_client( - name='new_request_queue', - configuration=configuration, - ) +# Helper function that allows you to use inspect with an asynchronous engine +def get_tables(sync_conn: Connection) -> list[str]: + inspector = inspect(sync_conn) + return inspector.get_table_names() + + +async def test_create_tables_with_connection_string(configuration: Configuration, tmp_path: Path) -> None: + """Test that SQL dataset client creates tables with a connection string.""" + storage_dir = tmp_path / 'test_table.db' + + async with SQLStorageClient(connection_string=f'sqlite+aiosqlite:///{storage_dir}') as storage_client: + await storage_client.create_rq_client( + name='test_request_queue', + configuration=configuration, + ) + + async with storage_client.engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'request_queue_metadata' in tables + assert 'request' in tables + - # Verify files were created - assert client.path_to_rq.exists() - assert client.path_to_metadata.exists() +async def test_create_tables_with_engine(configuration: Configuration, tmp_path: Path) -> None: + """Test that SQL dataset client creates tables with a pre-configured engine.""" + storage_dir = tmp_path / 'test_table.db' - # Verify metadata file structure - with client.path_to_metadata.open() as f: - metadata = json.load(f) - assert metadata['id'] == (await client.get_metadata()).id - assert metadata['name'] == 'new_request_queue' + engine = create_async_engine(f'sqlite+aiosqlite:///{storage_dir}', future=True, echo=False) - await client.drop() + async with SQLStorageClient(engine=engine) as storage_client: + await storage_client.create_rq_client( + name='test_request_queue', + configuration=configuration, + ) + async with engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'request_queue_metadata' in tables + assert 'request' in tables -async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) -> None: - """Test that requests are properly persisted to files.""" + +async def test_tables_and_metadata_record(configuration: Configuration) -> None: + """Test that SQL dataset creates proper tables and metadata records.""" + async with SQLStorageClient() as storage_client: + client = await storage_client.create_rq_client( + name='test_request_queue', + configuration=configuration, + ) + + client_metadata = await client.get_metadata() + + async with storage_client.engine.begin() as conn: + tables = await conn.run_sync(get_tables) + assert 'request_queue_metadata' in tables + assert 'request' in tables + + async with client.create_session() as session: + stmt = select(RequestQueueMetadataDB).where(RequestQueueMetadataDB.name == 'test_request_queue') + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + metadata = RequestQueueMetadata.model_validate(orm_metadata) + assert metadata.id == client_metadata.id + assert metadata.name == 'test_request_queue' + + await client.drop() + + +async def test_request_records_persistence(rq_client: SQLRequestQueueClient) -> None: requests = [ Request.from_url('https://example.com/1'), Request.from_url('https://example.com/2'), @@ -65,38 +118,41 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) await rq_client.add_batch_of_requests(requests) - # Verify request files are created - request_files = list(rq_client.path_to_rq.glob('*.json')) - # Should have 3 request files + 1 metadata file - assert len(request_files) == 4 - assert rq_client.path_to_metadata in request_files - - # Verify actual request file content - data_files = [f for f in request_files if f != rq_client.path_to_metadata] - assert len(data_files) == 3 + metadata_client = await rq_client.get_metadata() - for req_file in data_files: - with req_file.open() as f: - request_data = json.load(f) - assert 'url' in request_data - assert request_data['url'].startswith('https://example.com/') + async with rq_client.create_session() as session: + stmt = select(RequestDB).where(RequestDB.queue_id == metadata_client.id) + result = await session.execute(stmt) + db_requests = result.scalars().all() + assert len(db_requests) == 3 + for db_request in db_requests: + request = json.loads(db_request.data) + assert request['url'] in ['https://example.com/1', 'https://example.com/2', 'https://example.com/3'] -async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None: - """Test that drop removes the entire RQ directory from disk.""" +async def test_drop_removes_records(rq_client: SQLRequestQueueClient) -> None: + """Test that drop removes all records from the database.""" await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) + metadata = await rq_client.get_metadata() + async with rq_client.create_session() as session: + stmt = select(RequestDB).where(RequestDB.queue_id == metadata.id) + result = await session.execute(stmt) + records = result.scalars().all() + assert len(records) == 1 - rq_path = rq_client.path_to_rq - assert rq_path.exists() - - # Drop the request queue await rq_client.drop() - assert not rq_path.exists() + async with rq_client.create_session() as session: + stmt = select(RequestDB).where(RequestDB.queue_id == metadata.id) + result = await session.execute(stmt) + records = result.scalars().all() + assert len(records) == 0 + db_metadata = await session.get(RequestQueueMetadataDB, metadata.id) + assert db_metadata is None -async def test_metadata_file_updates(rq_client: FileSystemRequestQueueClient) -> None: - """Test that metadata file is updated correctly after operations.""" +async def test_metadata_record_updates(rq_client: SQLRequestQueueClient) -> None: + """Test that metadata record updates correctly after operations.""" # Record initial timestamps metadata = await rq_client.get_metadata() initial_created = metadata.created_at @@ -129,45 +185,45 @@ async def test_metadata_file_updates(rq_client: FileSystemRequestQueueClient) -> assert metadata.modified_at > initial_modified assert metadata.accessed_at > accessed_after_read - # Verify metadata file is updated on disk - with rq_client.path_to_metadata.open() as f: - metadata_json = json.load(f) - assert metadata_json['total_request_count'] == 1 + async with rq_client.create_session() as session: + orm_metadata = await session.get(RequestQueueMetadataDB, metadata.id) + assert orm_metadata is not None + assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == metadata.created_at + assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at + assert orm_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at async def test_data_persistence_across_reopens(configuration: Configuration) -> None: - """Test that requests persist correctly when reopening the same RQ.""" - storage_client = FileSystemStorageClient() - - # Create RQ and add requests - original_client = await storage_client.create_rq_client( - name='persistence-test', - configuration=configuration, - ) - - test_requests = [ - Request.from_url('https://example.com/1'), - Request.from_url('https://example.com/2'), - ] - await original_client.add_batch_of_requests(test_requests) - - rq_id = (await original_client.get_metadata()).id - - # Reopen by ID and verify requests persist - reopened_client = await storage_client.create_rq_client( - id=rq_id, - configuration=configuration, - ) - - metadata = await reopened_client.get_metadata() - assert metadata.total_request_count == 2 - - # Fetch requests to verify they're still there - request1 = await reopened_client.fetch_next_request() - request2 = await reopened_client.fetch_next_request() - - assert request1 is not None - assert request2 is not None - assert {request1.url, request2.url} == {'https://example.com/1', 'https://example.com/2'} - - await reopened_client.drop() + """Test that data persists correctly when reopening the same dataset.""" + async with SQLStorageClient() as storage_client: + original_client = await storage_client.create_rq_client( + name='persistence-test', + configuration=configuration, + ) + + test_requests = [ + Request.from_url('https://example.com/1'), + Request.from_url('https://example.com/2'), + ] + await original_client.add_batch_of_requests(test_requests) + + rq_id = (await original_client.get_metadata()).id + + # Reopen by ID and verify data persists + reopened_client = await storage_client.create_rq_client( + id=rq_id, + configuration=configuration, + ) + + metadata = await reopened_client.get_metadata() + assert metadata.total_request_count == 2 + + # Fetch requests to verify they're still there + request1 = await reopened_client.fetch_next_request() + request2 = await reopened_client.fetch_next_request() + + assert request1 is not None + assert request2 is not None + assert {request1.url, request2.url} == {'https://example.com/1', 'https://example.com/2'} + + await reopened_client.drop() From 35a27fcd4208deebfb3be21eda1ee75580c3144d Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 30 Jul 2025 17:15:52 +0000 Subject: [PATCH 05/29] fix docs in tests --- .../storage_clients/_sql/test_sql_dataset_client.py | 9 +++++---- .../unit/storage_clients/_sql/test_sql_kvs_client.py | 11 ++++++----- tests/unit/storage_clients/_sql/test_sql_rq_client.py | 10 ++++++---- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index 0ff0c13bdf..4b195fe0a3 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -24,6 +24,7 @@ @pytest.fixture def configuration(tmp_path: Path) -> Configuration: + """Temporary configuration for tests.""" return Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) @@ -128,7 +129,7 @@ async def test_record_and_content_verification(dataset_client: SQLDatasetClient) saved_item = json.loads(records[0].data) assert saved_item == item - # Test multiple items file creation + # Test pushing multiple items and verify total count items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] await dataset_client.push_data(items) @@ -163,8 +164,8 @@ async def test_drop_removes_records(dataset_client: SQLDatasetClient) -> None: assert metadata is None -async def test_metadata_recaord_updates(dataset_client: SQLDatasetClient) -> None: - """Test that metadata file is updated correctly after operations.""" +async def test_metadata_record_updates(dataset_client: SQLDatasetClient) -> None: + """Test that metadata record is updated correctly after operations.""" # Record initial timestamps metadata = await dataset_client.get_metadata() initial_created = metadata.created_at @@ -197,7 +198,7 @@ async def test_metadata_recaord_updates(dataset_client: SQLDatasetClient) -> Non assert metadata.modified_at > initial_modified assert metadata.accessed_at > accessed_after_get - # Verify metadata file is updated on disk + # Verify metadata record is updated in db async with dataset_client.create_session() as session: orm_metadata = await session.get(DatasetMetadataDB, metadata.id) assert orm_metadata is not None diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py index 9dcc388f86..378322cc56 100644 --- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -25,6 +25,7 @@ @pytest.fixture def configuration(tmp_path: Path) -> Configuration: + """Temporary configuration for tests.""" return Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) @@ -49,7 +50,7 @@ def get_tables(sync_conn: Connection) -> list[str]: async def test_create_tables_with_connection_string(configuration: Configuration, tmp_path: Path) -> None: - """Test that SQL dataset client creates tables with a connection string.""" + """Test that SQL key-value store client creates tables with a connection string.""" storage_dir = tmp_path / 'test_table.db' async with SQLStorageClient(connection_string=f'sqlite+aiosqlite:///{storage_dir}') as storage_client: @@ -65,7 +66,7 @@ async def test_create_tables_with_connection_string(configuration: Configuration async def test_create_tables_with_engine(configuration: Configuration, tmp_path: Path) -> None: - """Test that SQL dataset client creates tables with a pre-configured engine.""" + """Test that SQL key-value store client creates tables with a pre-configured engine.""" storage_dir = tmp_path / 'test_table.db' engine = create_async_engine(f'sqlite+aiosqlite:///{storage_dir}', future=True, echo=False) @@ -83,7 +84,7 @@ async def test_create_tables_with_engine(configuration: Configuration, tmp_path: async def test_tables_and_metadata_record(configuration: Configuration) -> None: - """Test that SQL dataset creates proper tables and metadata records.""" + """Test that SQL key-value store creates proper tables and metadata records.""" async with SQLStorageClient() as storage_client: client = await storage_client.create_kvs_client( name='new_kvs', @@ -109,7 +110,7 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: async def test_value_record_creation(kvs_client: SQLKeyValueStoreClient) -> None: - """Test that key-value store client can create a record.""" + """Test that SQL key-value store client can create a record.""" test_key = 'test-key' test_value = 'Hello, world!' await kvs_client.set_value(key=test_key, value=test_value) @@ -255,7 +256,7 @@ async def test_metadata_record_updates(kvs_client: SQLKeyValueStoreClient) -> No async def test_data_persistence_across_reopens(configuration: Configuration) -> None: - """Test that data persists correctly when reopening the same dataset.""" + """Test that data persists correctly when reopening the same key-value store.""" async with SQLStorageClient() as storage_client: original_client = await storage_client.create_kvs_client( name='persistence-test', diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py index 2607d29c29..c90c91235d 100644 --- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -26,6 +26,7 @@ @pytest.fixture def configuration(tmp_path: Path) -> Configuration: + """Temporary configuration for tests.""" return Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) @@ -50,7 +51,7 @@ def get_tables(sync_conn: Connection) -> list[str]: async def test_create_tables_with_connection_string(configuration: Configuration, tmp_path: Path) -> None: - """Test that SQL dataset client creates tables with a connection string.""" + """Test that SQL request queue client creates tables with a connection string.""" storage_dir = tmp_path / 'test_table.db' async with SQLStorageClient(connection_string=f'sqlite+aiosqlite:///{storage_dir}') as storage_client: @@ -66,7 +67,7 @@ async def test_create_tables_with_connection_string(configuration: Configuration async def test_create_tables_with_engine(configuration: Configuration, tmp_path: Path) -> None: - """Test that SQL dataset client creates tables with a pre-configured engine.""" + """Test that SQL request queue client creates tables with a pre-configured engine.""" storage_dir = tmp_path / 'test_table.db' engine = create_async_engine(f'sqlite+aiosqlite:///{storage_dir}', future=True, echo=False) @@ -84,7 +85,7 @@ async def test_create_tables_with_engine(configuration: Configuration, tmp_path: async def test_tables_and_metadata_record(configuration: Configuration) -> None: - """Test that SQL dataset creates proper tables and metadata records.""" + """Test that SQL request queue creates proper tables and metadata records.""" async with SQLStorageClient() as storage_client: client = await storage_client.create_rq_client( name='test_request_queue', @@ -110,6 +111,7 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: async def test_request_records_persistence(rq_client: SQLRequestQueueClient) -> None: + """Test that all added requests are persisted and can be retrieved from the database.""" requests = [ Request.from_url('https://example.com/1'), Request.from_url('https://example.com/2'), @@ -194,7 +196,7 @@ async def test_metadata_record_updates(rq_client: SQLRequestQueueClient) -> None async def test_data_persistence_across_reopens(configuration: Configuration) -> None: - """Test that data persists correctly when reopening the same dataset.""" + """Test that data persists correctly when reopening the same request queue.""" async with SQLStorageClient() as storage_client: original_client = await storage_client.create_rq_client( name='persistence-test', From 52e1ad2d83ff4f61fced223e091c241ee284cec8 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 30 Jul 2025 17:20:00 +0000 Subject: [PATCH 06/29] wrap `SQLStorageClient` in _try_import --- src/crawlee/storage_clients/__init__.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/crawlee/storage_clients/__init__.py b/src/crawlee/storage_clients/__init__.py index 6c2d58591d..484960eb4b 100644 --- a/src/crawlee/storage_clients/__init__.py +++ b/src/crawlee/storage_clients/__init__.py @@ -1,7 +1,17 @@ +from crawlee._utils.try_import import install_import_hook as _install_import_hook +from crawlee._utils.try_import import try_import as _try_import + +# These imports have only mandatory dependencies, so they are imported directly. from ._base import StorageClient from ._file_system import FileSystemStorageClient from ._memory import MemoryStorageClient -from ._sql import SQLStorageClient + +_install_import_hook(__name__) + +# The following imports are wrapped in try_import to handle optional dependencies, +# ensuring the module can still function even if these dependencies are missing. +with _try_import(__name__, 'SQLStorageClient'): + from ._sql import SQLStorageClient __all__ = [ 'FileSystemStorageClient', From df41c4572aeac6abba7638a99188106dfa457e22 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 30 Jul 2025 18:48:39 +0000 Subject: [PATCH 07/29] update db models --- .../storage_clients/_sql/_dataset_client.py | 5 +- .../storage_clients/_sql/_db_models.py | 52 +++++++++++-------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index a3c126ef39..c08217c209 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -154,7 +154,6 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: DatasetItemDB( dataset_id=self._orm_metadata.id, data=json_item, - created_at=datetime.now(timezone.utc), ) ) @@ -207,7 +206,7 @@ async def get_data( if skip_empty: stmt = stmt.where(DatasetItemDB.data != '"{}"') - stmt = stmt.order_by(DatasetItemDB.created_at.desc()) if desc else stmt.order_by(DatasetItemDB.created_at.asc()) + stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) stmt = stmt.offset(offset).limit(limit) @@ -264,7 +263,7 @@ async def iterate_items( if skip_empty: stmt = stmt.where(DatasetItemDB.data != '"{}"') - stmt = stmt.order_by(DatasetItemDB.created_at.desc()) if desc else stmt.order_by(DatasetItemDB.created_at.asc()) + stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) stmt = stmt.offset(offset).limit(limit) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 5a48a79c66..d1b4f45ae2 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -11,30 +11,34 @@ LargeBinary, String, ) -from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -Base = declarative_base() + +class Base(DeclarativeBase): + """Base class for all database models for correct type annotations.""" class StorageMetadataDB: """Base database model for storage metadata.""" id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) - name: Mapped[str | None] = mapped_column(String(100), nullable=True) + name: Mapped[str | None] = mapped_column(String(100), nullable=True, index=True) accessed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) modified_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) -class DatasetMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-type,misc] +class DatasetMetadataDB(StorageMetadataDB, Base): __tablename__ = 'dataset_metadata' item_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - items: Mapped[list[DatasetItemDB]] = relationship(back_populates='dataset', cascade='all, delete-orphan') + items: Mapped[list[DatasetItemDB]] = relationship( + back_populates='dataset', cascade='all, delete-orphan', lazy='select' + ) -class RequestQueueMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-type,misc] +class RequestQueueMetadataDB(StorageMetadataDB, Base): __tablename__ = 'request_queue_metadata' had_multiple_clients: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) @@ -42,60 +46,62 @@ class RequestQueueMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-typ pending_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) total_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - requests: Mapped[list[RequestDB]] = relationship(back_populates='queue', cascade='all, delete-orphan') + requests: Mapped[list[RequestDB]] = relationship( + back_populates='queue', cascade='all, delete-orphan', lazy='select' + ) -class KeyValueStoreMetadataDB(StorageMetadataDB, Base): # type: ignore[valid-type,misc] +class KeyValueStoreMetadataDB(StorageMetadataDB, Base): __tablename__ = 'kvs_metadata' - records: Mapped[list[KeyValueStoreRecordDB]] = relationship(back_populates='kvs', cascade='all, delete-orphan') + records: Mapped[list[KeyValueStoreRecordDB]] = relationship( + back_populates='kvs', cascade='all, delete-orphan', lazy='select' + ) -class KeyValueStoreRecordDB(Base): # type: ignore[valid-type,misc] +class KeyValueStoreRecordDB(Base): """Database model for key-value store records.""" __tablename__ = 'kvs_record' - kvs_id: Mapped[str] = mapped_column(String(255), ForeignKey('kvs_metadata.id'), primary_key=True, index=True) - + kvs_id: Mapped[str] = mapped_column( + String(255), ForeignKey('kvs_metadata.id', ondelete='CASCADE'), primary_key=True, index=True + ) key: Mapped[str] = mapped_column(String(255), primary_key=True) value: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) - content_type: Mapped[str] = mapped_column(String(100), nullable=False) size: Mapped[int | None] = mapped_column(Integer, nullable=False, default=0) kvs: Mapped[KeyValueStoreMetadataDB] = relationship(back_populates='records') -class DatasetItemDB(Base): # type: ignore[valid-type,misc] +class DatasetItemDB(Base): """Database model for dataset items.""" __tablename__ = 'dataset_item' order_id: Mapped[int] = mapped_column(Integer, primary_key=True) - dataset_id: Mapped[str] = mapped_column(String(20), ForeignKey('dataset_metadata.id'), index=True) + dataset_id: Mapped[str] = mapped_column( + String(20), ForeignKey('dataset_metadata.id', ondelete='CASCADE'), index=True + ) data: Mapped[str] = mapped_column(JSON, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) - dataset: Mapped[DatasetMetadataDB] = relationship(back_populates='items') -class RequestDB(Base): # type: ignore[valid-type,misc] +class RequestDB(Base): """Database model for requests in the request queue.""" __tablename__ = 'request' request_id: Mapped[str] = mapped_column(String(20), primary_key=True) queue_id: Mapped[str] = mapped_column( - String(20), ForeignKey('request_queue_metadata.id'), index=True, primary_key=True + String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), index=True, primary_key=True ) data: Mapped[str] = mapped_column(JSON, nullable=False) - unique_key: Mapped[str] = mapped_column(String(512), nullable=False) - + unique_key: Mapped[str] = mapped_column(String(512), nullable=False, index=True) sequence_number: Mapped[int] = mapped_column(Integer, nullable=False, index=True) - - is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True) queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') From 342c65adb0bbad6053594341fcbe939e029f35d7 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 30 Jul 2025 21:25:12 +0000 Subject: [PATCH 08/29] dataset optimization --- .../storage_clients/_sql/_dataset_client.py | 147 ++++++++++-------- .../storage_clients/_sql/_db_models.py | 24 ++- .../_sql/test_sql_dataset_client.py | 18 +-- 3 files changed, 109 insertions(+), 80 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index c08217c209..3e4fb7e0bb 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -1,11 +1,12 @@ from __future__ import annotations import json +from contextlib import asynccontextmanager from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from typing_extensions import override from crawlee._utils.crypto import crypto_random_object_id @@ -39,23 +40,34 @@ class SQLDatasetClient(DatasetClient): def __init__( self, *, - orm_metadata: DatasetMetadataDB, + metadata: DatasetMetadata, storage_client: SQLStorageClient, ) -> None: """Initialize a new instance. Preferably use the `SqlDatasetClient.open` class method to create a new instance. """ - self._orm_metadata = orm_metadata + self._metadata = metadata self._storage_client = storage_client - def create_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for this key-value store.""" + def get_session(self) -> AsyncSession: + """Create a new SQLAlchemy session for this dataset.""" return self._storage_client.create_session() + @asynccontextmanager + async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: + """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" + async with self.get_session() as session: + try: + yield session + await session.commit() + except Exception as e: + logger.warning(f'Error occurred during session transaction: {e}') + await session.rollback() + @override async def get_metadata(self) -> DatasetMetadata: - return DatasetMetadata.model_validate(self._orm_metadata) + return self._metadata @classmethod async def open( @@ -79,44 +91,38 @@ async def open( ValueError: If a dataset with the specified ID is not found. """ async with storage_client.create_session() as session: + orm_metadata: DatasetMetadataDB | None = None if id: orm_metadata = await session.get(DatasetMetadataDB, id) if not orm_metadata: raise ValueError(f'Dataset with ID "{id}" not found.') + else: + stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + if orm_metadata: client = cls( - orm_metadata=orm_metadata, + metadata=DatasetMetadata.model_validate(orm_metadata), storage_client=storage_client, ) - await client._update_metadata(update_accessed_at=True) - + await client._update_metadata(session, update_accessed_at=True) else: - stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - if orm_metadata: - client = cls( - orm_metadata=orm_metadata, - storage_client=storage_client, - ) - await client._update_metadata(update_accessed_at=True) - - else: - now = datetime.now(timezone.utc) - metadata = DatasetMetadata( - id=crypto_random_object_id(), - name=name, - created_at=now, - accessed_at=now, - modified_at=now, - item_count=0, - ) - orm_metadata = DatasetMetadataDB(**metadata.model_dump()) - client = cls( - orm_metadata=orm_metadata, - storage_client=storage_client, - ) - session.add(orm_metadata) + now = datetime.now(timezone.utc) + metadata = DatasetMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + item_count=0, + ) + + client = cls( + metadata=metadata, + storage_client=storage_client, + ) + session.add(DatasetMetadataDB(**metadata.model_dump())) await session.commit() @@ -124,22 +130,17 @@ async def open( @override async def drop(self) -> None: - async with self.create_session() as session: - dataset_db = await session.get(DatasetMetadataDB, self._orm_metadata.id) - if dataset_db: - await session.delete(dataset_db) - await session.commit() + stmt = delete(DatasetMetadataDB).where(DatasetMetadataDB.id == self._metadata.id) + async with self.get_autocommit_session() as autocommit: + await autocommit.execute(stmt) @override async def purge(self) -> None: - async with self.create_session() as session: - stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) - await session.execute(stmt) + stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._metadata.id) + async with self.get_autocommit_session() as autocommit: + await autocommit.execute(stmt) - self._orm_metadata.item_count = 0 - await self._update_metadata(update_accessed_at=True, update_modified_at=True) - await session.merge(self._orm_metadata) - await session.commit() + await self._update_metadata(autocommit, new_item_count=0, update_accessed_at=True, update_modified_at=True) @override async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: @@ -152,20 +153,16 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: json_item = json.dumps(item, default=str, ensure_ascii=False) db_items.append( DatasetItemDB( - dataset_id=self._orm_metadata.id, + dataset_id=self._metadata.id, data=json_item, ) ) - async with self.create_session() as session: - session.add_all(db_items) - self._orm_metadata.item_count += len(data) + async with self.get_autocommit_session() as autocommit: + autocommit.add_all(db_items) await self._update_metadata( - update_accessed_at=True, - update_modified_at=True, + autocommit, update_accessed_at=True, update_modified_at=True, delta_item_count=len(data) ) - await session.merge(self._orm_metadata) - await session.commit() @override async def get_data( @@ -201,7 +198,7 @@ async def get_data( f'{self.__class__.__name__} client.' ) - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._metadata.id) if skip_empty: stmt = stmt.where(DatasetItemDB.data != '"{}"') @@ -210,12 +207,11 @@ async def get_data( stmt = stmt.offset(offset).limit(limit) - async with self.create_session() as session: + async with self.get_session() as session: result = await session.execute(stmt) db_items = result.scalars().all() - await self._update_metadata(update_accessed_at=True) - await session.merge(self._orm_metadata) + await self._update_metadata(session, update_accessed_at=True) await session.commit() items = [json.loads(db_item.data) for db_item in db_items] @@ -225,7 +221,7 @@ async def get_data( desc=desc, limit=limit or 0, offset=offset or 0, - total=self._orm_metadata.item_count, + total=self._metadata.item_count, ) @override @@ -258,7 +254,7 @@ async def iterate_items( f'by the {self.__class__.__name__} client.' ) - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._orm_metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._metadata.id) if skip_empty: stmt = stmt.where(DatasetItemDB.data != '"{}"') @@ -267,12 +263,11 @@ async def iterate_items( stmt = stmt.offset(offset).limit(limit) - async with self.create_session() as session: + async with self.get_session() as session: result = await session.execute(stmt) db_items = result.scalars().all() - await self._update_metadata(update_accessed_at=True) - await session.merge(self._orm_metadata) + await self._update_metadata(session, update_accessed_at=True) await session.commit() items = [json.loads(db_item.data) for db_item in db_items] @@ -281,20 +276,40 @@ async def iterate_items( async def _update_metadata( self, + session: AsyncSession, *, + new_item_count: int | None = None, update_accessed_at: bool = False, update_modified_at: bool = False, + delta_item_count: int | None = None, ) -> None: """Update the KVS metadata in the database. Args: session: The SQLAlchemy AsyncSession to use for the update. + new_item_count: If provided, update the item count to this value. update_accessed_at: If True, update the `accessed_at` timestamp to the current time. update_modified_at: If True, update the `modified_at` timestamp to the current time. + delta_item_count: If provided, increment the item count by this value. """ now = datetime.now(timezone.utc) + values_to_set: dict[str, Any] = {} if update_accessed_at: - self._orm_metadata.accessed_at = now + self._metadata.accessed_at = now + values_to_set['accessed_at'] = now if update_modified_at: - self._orm_metadata.modified_at = now + self._metadata.modified_at = now + values_to_set['modified_at'] = now + + if new_item_count is not None: + self._metadata.item_count = new_item_count + values_to_set['item_count'] = new_item_count + + if delta_item_count: + self._metadata.item_count += delta_item_count + values_to_set['item_count'] = DatasetMetadataDB.item_count + self._metadata.item_count + + if values_to_set: + stmt = update(DatasetMetadataDB).where(DatasetMetadataDB.id == self._metadata.id).values(**values_to_set) + await session.execute(stmt) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index d1b4f45ae2..791d627025 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -1,17 +1,31 @@ from __future__ import annotations -from datetime import datetime # noqa: TC003 +from datetime import datetime, timezone +from typing import TYPE_CHECKING from sqlalchemy import ( JSON, Boolean, - DateTime, ForeignKey, Integer, LargeBinary, String, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.types import DateTime, TypeDecorator + +if TYPE_CHECKING: + from sqlalchemy.engine import Dialect + + +class AwareDateTime(TypeDecorator): + impl = DateTime(timezone=True) + cache_ok = True + + def process_result_value(self, value: datetime | None, _dialect: Dialect) -> datetime | None: + if value is not None and value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value class Base(DeclarativeBase): @@ -23,9 +37,9 @@ class StorageMetadataDB: id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) name: Mapped[str | None] = mapped_column(String(100), nullable=True, index=True) - accessed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) - modified_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + accessed_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) + created_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) + modified_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) class DatasetMetadataDB(StorageMetadataDB, Base): diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index 4b195fe0a3..b8c403b8b1 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -97,7 +97,7 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: assert 'dataset_item' in tables assert 'dataset_metadata' in tables - async with client.create_session() as session: + async with client.get_session() as session: stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == 'new_dataset') result = await session.execute(stmt) orm_metadata = result.scalar_one_or_none() @@ -121,7 +121,7 @@ async def test_record_and_content_verification(dataset_client: SQLDatasetClient) assert metadata.modified_at is not None assert metadata.accessed_at is not None - async with dataset_client.create_session() as session: + async with dataset_client.get_session() as session: stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() @@ -133,7 +133,7 @@ async def test_record_and_content_verification(dataset_client: SQLDatasetClient) items = [{'id': 1, 'name': 'Item 1'}, {'id': 2, 'name': 'Item 2'}, {'id': 3, 'name': 'Item 3'}] await dataset_client.push_data(items) - async with dataset_client.create_session() as session: + async with dataset_client.get_session() as session: stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() @@ -146,7 +146,7 @@ async def test_drop_removes_records(dataset_client: SQLDatasetClient) -> None: client_metadata = await dataset_client.get_metadata() - async with dataset_client.create_session() as session: + async with dataset_client.get_session() as session: stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == client_metadata.id) result = await session.execute(stmt) records = result.scalars().all() @@ -155,7 +155,7 @@ async def test_drop_removes_records(dataset_client: SQLDatasetClient) -> None: # Drop the dataset await dataset_client.drop() - async with dataset_client.create_session() as session: + async with dataset_client.get_session() as session: stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == client_metadata.id) result = await session.execute(stmt) records = result.scalars().all() @@ -199,13 +199,13 @@ async def test_metadata_record_updates(dataset_client: SQLDatasetClient) -> None assert metadata.accessed_at > accessed_after_get # Verify metadata record is updated in db - async with dataset_client.create_session() as session: + async with dataset_client.get_session() as session: orm_metadata = await session.get(DatasetMetadataDB, metadata.id) assert orm_metadata is not None orm_metadata.item_count = 1 - assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == initial_created - assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at - assert orm_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at + assert orm_metadata.created_at == initial_created + assert orm_metadata.accessed_at == metadata.accessed_at + assert orm_metadata.modified_at == metadata.modified_at async def test_data_persistence_across_reopens(configuration: Configuration) -> None: From 61a26660250f6b2cdefa836cbe279e997a583a11 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Thu, 31 Jul 2025 01:08:16 +0000 Subject: [PATCH 09/29] kvs optimization --- .../storage_clients/_sql/_dataset_client.py | 13 +- .../storage_clients/_sql/_db_models.py | 2 +- .../_sql/_key_value_store_client.py | 217 ++++++++++-------- .../_sql/test_sql_kvs_client.py | 25 +- 4 files changed, 141 insertions(+), 116 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 3e4fb7e0bb..648828195e 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -50,6 +50,10 @@ def __init__( self._metadata = metadata self._storage_client = storage_client + @override + async def get_metadata(self) -> DatasetMetadata: + return self._metadata + def get_session(self) -> AsyncSession: """Create a new SQLAlchemy session for this dataset.""" return self._storage_client.create_session() @@ -65,10 +69,6 @@ async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: logger.warning(f'Error occurred during session transaction: {e}') await session.rollback() - @override - async def get_metadata(self) -> DatasetMetadata: - return self._metadata - @classmethod async def open( cls, @@ -124,6 +124,7 @@ async def open( ) session.add(DatasetMetadataDB(**metadata.model_dump())) + # Commit the insert or update metadata to the database await session.commit() return client @@ -212,6 +213,8 @@ async def get_data( db_items = result.scalars().all() await self._update_metadata(session, update_accessed_at=True) + + # Commit updates to the metadata await session.commit() items = [json.loads(db_item.data) for db_item in db_items] @@ -268,6 +271,8 @@ async def iterate_items( db_items = result.scalars().all() await self._update_metadata(session, update_accessed_at=True) + + # Commit updates to the metadata await session.commit() items = [json.loads(db_item.data) for db_item in db_items] diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 791d627025..a12633fcdc 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -36,7 +36,7 @@ class StorageMetadataDB: """Base database model for storage metadata.""" id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) - name: Mapped[str | None] = mapped_column(String(100), nullable=True, index=True) + name: Mapped[str | None] = mapped_column(String(100), nullable=True, index=True, unique=True) accessed_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) created_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) modified_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index 30d57ca444..d29e24427a 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -1,11 +1,12 @@ from __future__ import annotations import json +from contextlib import asynccontextmanager from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from typing_extensions import override from crawlee._utils.crypto import crypto_random_object_id @@ -51,24 +52,36 @@ def __init__( self, *, storage_client: SQLStorageClient, - orm_metadata: KeyValueStoreMetadataDB, + metadata: KeyValueStoreMetadata, ) -> None: """Initialize a new instance. Preferably use the `SQLKeyValueStoreClient.open` class method to create a new instance. """ - self._orm_metadata = orm_metadata + self._metadata = metadata self._storage_client = storage_client """The storage client used to access the SQL database.""" - def create_session(self) -> AsyncSession: + @override + async def get_metadata(self) -> KeyValueStoreMetadata: + return self._metadata + + def get_session(self) -> AsyncSession: """Create a new SQLAlchemy session for this key-value store.""" return self._storage_client.create_session() - @override - async def get_metadata(self) -> KeyValueStoreMetadata: - return KeyValueStoreMetadata.model_validate(self._orm_metadata) + @asynccontextmanager + async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: + """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" + async with self.get_session() as session: + try: + yield session + await session.commit() + except Exception as e: + logger.warning(f'Error occurred during session transaction: {e}') + # Rollback the session in case of an error + await session.rollback() @classmethod async def open( @@ -96,60 +109,55 @@ async def open( ValueError: If a store with the specified ID is not found, or if metadata is invalid. """ async with storage_client.create_session() as session: + orm_metadata: KeyValueStoreMetadataDB | None = None if id: orm_metadata = await session.get(KeyValueStoreMetadataDB, id) if not orm_metadata: raise ValueError(f'Key-value store with ID "{id}" not found.') + else: + stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + if orm_metadata: client = cls( - orm_metadata=orm_metadata, + metadata=KeyValueStoreMetadata.model_validate(orm_metadata), storage_client=storage_client, ) - client._update_metadata(update_accessed_at=True) - + await client._update_metadata(session, update_accessed_at=True) else: - orm_metadata = await session.get(KeyValueStoreMetadataDB, name) - if orm_metadata: - client = cls( - orm_metadata=orm_metadata, - storage_client=storage_client, - ) - client._update_metadata(update_accessed_at=True) - else: - now = datetime.now(timezone.utc) - metadata = KeyValueStoreMetadata( - id=crypto_random_object_id(), - name=name, - created_at=now, - accessed_at=now, - modified_at=now, - ) - orm_metadata = KeyValueStoreMetadataDB(**metadata.model_dump()) - client = cls( - orm_metadata=orm_metadata, - storage_client=storage_client, - ) - session.add(orm_metadata) + now = datetime.now(timezone.utc) + metadata = KeyValueStoreMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + ) + client = cls( + metadata=metadata, + storage_client=storage_client, + ) + orm_metadata = KeyValueStoreMetadataDB(**metadata.model_dump()) + session.add(orm_metadata) + # Commit the insert or update metadata to the database await session.commit() return client @override async def drop(self) -> None: - async with self._storage_client.create_session() as session: - kvs_db = await session.get(KeyValueStoreMetadataDB, self._orm_metadata.id) - if kvs_db: - await session.delete(kvs_db) - await session.commit() + stmt = delete(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._metadata.id) + async with self.get_autocommit_session() as autosession: + await autosession.execute(stmt) @override async def purge(self) -> None: - async with self._storage_client.create_session() as session: - stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._orm_metadata.id) - await session.execute(stmt) + stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._metadata.id) + async with self.get_autocommit_session() as autosession: + await autosession.execute(stmt) - self._update_metadata(update_accessed_at=True, update_modified_at=True) - await session.commit() + await self._update_metadata(autosession, update_accessed_at=True, update_modified_at=True) @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: @@ -173,39 +181,41 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No size = len(value_bytes) record_db = KeyValueStoreRecordDB( - kvs_id=self._orm_metadata.id, + kvs_id=self._metadata.id, key=key, value=value_bytes, content_type=content_type, size=size, ) - async with self._storage_client.create_session() as session: - existing_record = await session.get(KeyValueStoreRecordDB, (self._orm_metadata.id, key)) - if existing_record: - # Update existing record - existing_record.value = value_bytes - existing_record.content_type = content_type - existing_record.size = size - else: - session.add(record_db) - self._update_metadata(update_accessed_at=True, update_modified_at=True) - await session.merge(self._orm_metadata) - await session.commit() + stmt = ( + update(KeyValueStoreRecordDB) + .where(KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key) + .values(value=value_bytes, content_type=content_type, size=size) + ) + + # A race condition is possible if several clients work with one kvs. + # Unfortunately, there is no implementation of atomic Upsert that is independent of specific dialects. + # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements + async with self.get_autocommit_session() as autocommit: + result = await autocommit.execute(stmt) + if result.rowcount == 0: + autocommit.add(record_db) + + await self._update_metadata(autocommit, update_accessed_at=True, update_modified_at=True) @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: - # Update the metadata to record access - async with self._storage_client.create_session() as session: - stmt = select(KeyValueStoreRecordDB).where( - KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key - ) + stmt = select(KeyValueStoreRecordDB).where( + KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key + ) + async with self.get_session() as session: result = await session.execute(stmt) record_db = result.scalar_one_or_none() - self._update_metadata(update_accessed_at=True) + await self._update_metadata(session, update_accessed_at=True) - await session.merge(self._orm_metadata) + # Commit updates to the metadata await session.commit() if not record_db: @@ -244,19 +254,18 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @override async def delete_value(self, *, key: str) -> None: - async with self._storage_client.create_session() as session: + stmt = delete(KeyValueStoreRecordDB).where( + KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key + ) + async with self.get_autocommit_session() as autocommit: # Delete the record if it exists - stmt = delete(KeyValueStoreRecordDB).where( - KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key - ) - result = await session.execute(stmt) + result = await autocommit.execute(stmt) # Update metadata if we actually deleted something if result.rowcount > 0: - self._update_metadata(update_accessed_at=True, update_modified_at=True) - await session.merge(self._orm_metadata) + await self._update_metadata(autocommit, update_accessed_at=True, update_modified_at=True) - await session.commit() + await autocommit.commit() @override async def iterate_keys( @@ -265,46 +274,46 @@ async def iterate_keys( exclusive_start_key: str | None = None, limit: int | None = None, ) -> AsyncIterator[KeyValueStoreRecordMetadata]: - async with self._storage_client.create_session() as session: - # Build query for record metadata - stmt = ( - select(KeyValueStoreRecordDB.key, KeyValueStoreRecordDB.content_type, KeyValueStoreRecordDB.size) - .where(KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id) - .order_by(KeyValueStoreRecordDB.key) - ) + # Build query for record metadata + stmt = ( + select(KeyValueStoreRecordDB.key, KeyValueStoreRecordDB.content_type, KeyValueStoreRecordDB.size) + .where(KeyValueStoreRecordDB.kvs_id == self._metadata.id) + .order_by(KeyValueStoreRecordDB.key) + ) - # Apply exclusive_start_key filter - if exclusive_start_key is not None: - stmt = stmt.where(KeyValueStoreRecordDB.key > exclusive_start_key) + # Apply exclusive_start_key filter + if exclusive_start_key is not None: + stmt = stmt.where(KeyValueStoreRecordDB.key > exclusive_start_key) - # Apply limit - if limit is not None: - stmt = stmt.limit(limit) + # Apply limit + if limit is not None: + stmt = stmt.limit(limit) - result = await session.execute(stmt) + async with self.get_session() as session: + result = await session.stream(stmt.execution_options(stream_results=True)) - self._update_metadata(update_accessed_at=True) - await session.merge(self._orm_metadata) - await session.commit() - - for row in result: + async for row in result: yield KeyValueStoreRecordMetadata( key=row.key, content_type=row.content_type, size=row.size, ) + await self._update_metadata(session, update_accessed_at=True) + + # Commit updates to the metadata + await session.commit() + @override async def record_exists(self, *, key: str) -> bool: - async with self._storage_client.create_session() as session: + stmt = select(KeyValueStoreRecordDB.key).where( + KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key + ) + async with self.get_session() as session: # Check if record exists - stmt = select(KeyValueStoreRecordDB.key).where( - KeyValueStoreRecordDB.kvs_id == self._orm_metadata.id, KeyValueStoreRecordDB.key == key - ) result = await session.execute(stmt) - self._update_metadata(update_accessed_at=True) - await session.merge(self._orm_metadata) + await self._update_metadata(session, update_accessed_at=True) await session.commit() return result.scalar_one_or_none() is not None @@ -313,8 +322,9 @@ async def record_exists(self, *, key: str) -> bool: async def get_public_url(self, *, key: str) -> str: raise NotImplementedError('Public URLs are not supported for memory key-value stores.') - def _update_metadata( + async def _update_metadata( self, + session: AsyncSession, *, update_accessed_at: bool = False, update_modified_at: bool = False, @@ -327,8 +337,19 @@ def _update_metadata( update_modified_at: If True, update the `modified_at` timestamp to the current time. """ now = datetime.now(timezone.utc) + values_to_set: dict[str, Any] = {} if update_accessed_at: - self._orm_metadata.accessed_at = now + self._metadata.accessed_at = now + values_to_set['accessed_at'] = now if update_modified_at: - self._orm_metadata.modified_at = now + self._metadata.modified_at = now + values_to_set['modified_at'] = now + + if values_to_set: + stmt = ( + update(KeyValueStoreMetadataDB) + .where(KeyValueStoreMetadataDB.id == self._metadata.id) + .values(**values_to_set) + ) + await session.execute(stmt) diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py index 378322cc56..8fc4342362 100644 --- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -2,7 +2,6 @@ import asyncio import json -from datetime import timezone from typing import TYPE_CHECKING import pytest @@ -98,7 +97,7 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: assert 'kvs_metadata' in tables assert 'kvs_record' in tables - async with client.create_session() as session: + async with client.get_session() as session: stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == 'new_kvs') result = await session.execute(stmt) orm_metadata = result.scalar_one_or_none() @@ -114,7 +113,7 @@ async def test_value_record_creation(kvs_client: SQLKeyValueStoreClient) -> None test_key = 'test-key' test_value = 'Hello, world!' await kvs_client.set_value(key=test_key, value=test_value) - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -131,7 +130,7 @@ async def test_binary_data_persistence(kvs_client: SQLKeyValueStoreClient) -> No test_value = b'\x00\x01\x02\x03\x04' await kvs_client.set_value(key=test_key, value=test_value) - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -153,7 +152,7 @@ async def test_json_serialization_to_record(kvs_client: SQLKeyValueStoreClient) test_value = {'name': 'John', 'age': 30, 'items': [1, 2, 3]} await kvs_client.set_value(key=test_key, value=test_value) - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -170,7 +169,7 @@ async def test_record_deletion_on_value_delete(kvs_client: SQLKeyValueStoreClien # Set a value await kvs_client.set_value(key=test_key, value=test_value) - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -182,7 +181,7 @@ async def test_record_deletion_on_value_delete(kvs_client: SQLKeyValueStoreClien await kvs_client.delete_value(key=test_key) # Verify record was deleted - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == test_key) result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -195,7 +194,7 @@ async def test_drop_removes_records(kvs_client: SQLKeyValueStoreClient) -> None: client_metadata = await kvs_client.get_metadata() - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == 'test') result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -204,7 +203,7 @@ async def test_drop_removes_records(kvs_client: SQLKeyValueStoreClient) -> None: # Drop the store await kvs_client.drop() - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: stmt = select(KeyValueStoreRecordDB).where(KeyValueStoreRecordDB.key == 'test') result = await session.execute(stmt) record = result.scalar_one_or_none() @@ -247,12 +246,12 @@ async def test_metadata_record_updates(kvs_client: SQLKeyValueStoreClient) -> No assert metadata.modified_at > initial_modified assert metadata.accessed_at > accessed_after_read - async with kvs_client.create_session() as session: + async with kvs_client.get_session() as session: orm_metadata = await session.get(KeyValueStoreMetadataDB, metadata.id) assert orm_metadata is not None - assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == metadata.created_at - assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at - assert orm_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at + assert orm_metadata.created_at == metadata.created_at + assert orm_metadata.accessed_at == metadata.accessed_at + assert orm_metadata.modified_at == metadata.modified_at async def test_data_persistence_across_reopens(configuration: Configuration) -> None: From 7055f7d36a0baabaa506f4b1ccd271b1a2287e9b Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Fri, 1 Aug 2025 03:18:16 +0000 Subject: [PATCH 10/29] optimization --- .../storage_clients/_sql/_dataset_client.py | 73 ++- .../storage_clients/_sql/_db_models.py | 38 +- .../_sql/_key_value_store_client.py | 76 +-- .../_sql/_request_queue_client.py | 442 +++++++++++------- .../storage_clients/_sql/_storage_client.py | 20 +- .../_sql/test_sql_rq_client.py | 10 +- 6 files changed, 408 insertions(+), 251 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 648828195e..dc829c2ab1 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -6,7 +6,8 @@ from logging import getLogger from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, select, update +from sqlalchemy import delete, select, text, update +from sqlalchemy.exc import SQLAlchemyError from typing_extensions import override from crawlee._utils.crypto import crypto_random_object_id @@ -37,22 +38,30 @@ class SQLDatasetClient(DatasetClient): locking mechanisms. """ + _DEFAULT_NAME_DB = 'default' + """Default dataset name used when no name is provided.""" + def __init__( self, *, - metadata: DatasetMetadata, + id: str, storage_client: SQLStorageClient, ) -> None: """Initialize a new instance. Preferably use the `SqlDatasetClient.open` class method to create a new instance. """ - self._metadata = metadata + self._id = id self._storage_client = storage_client @override async def get_metadata(self) -> DatasetMetadata: - return self._metadata + async with self.get_session() as session: + orm_metadata: DatasetMetadataDB | None = await session.get(DatasetMetadataDB, self._id) + if not orm_metadata: + raise ValueError(f'Dataset with ID "{self._id}" not found.') + + return DatasetMetadata.model_validate(orm_metadata) def get_session(self) -> AsyncSession: """Create a new SQLAlchemy session for this dataset.""" @@ -65,8 +74,9 @@ async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: try: yield session await session.commit() - except Exception as e: + except SQLAlchemyError as e: logger.warning(f'Error occurred during session transaction: {e}') + # Rollback the session in case of an error await session.rollback() @classmethod @@ -97,13 +107,14 @@ async def open( if not orm_metadata: raise ValueError(f'Dataset with ID "{id}" not found.') else: - stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == name) + search_name = name or cls._DEFAULT_NAME_DB + stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == search_name) result = await session.execute(stmt) orm_metadata = result.scalar_one_or_none() if orm_metadata: client = cls( - metadata=DatasetMetadata.model_validate(orm_metadata), + id=orm_metadata.id, storage_client=storage_client, ) await client._update_metadata(session, update_accessed_at=True) @@ -119,25 +130,43 @@ async def open( ) client = cls( - metadata=metadata, + id=metadata.id, storage_client=storage_client, ) session.add(DatasetMetadataDB(**metadata.model_dump())) - # Commit the insert or update metadata to the database - await session.commit() + try: + # Commit the insert or update metadata to the database + await session.commit() + except SQLAlchemyError: + # Attempt to open simultaneously by different clients. + # The commit that created the record has already been executed, make rollback and get by name. + await session.rollback() + search_name = name or cls._DEFAULT_NAME_DB + stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == search_name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + if not orm_metadata: + raise ValueError(f'Dataset with Name "{search_name}" not found.') from None + client = cls( + id=orm_metadata.id, + storage_client=storage_client, + ) return client @override async def drop(self) -> None: - stmt = delete(DatasetMetadataDB).where(DatasetMetadataDB.id == self._metadata.id) + stmt = delete(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id) async with self.get_autocommit_session() as autocommit: + if self._storage_client.get_default_flag(): + # foreign_keys=ON is set at the connection level. Required for cascade deletion. + await autocommit.execute(text('PRAGMA foreign_keys=ON')) await autocommit.execute(stmt) @override async def purge(self) -> None: - stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._metadata.id) + stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) async with self.get_autocommit_session() as autocommit: await autocommit.execute(stmt) @@ -154,7 +183,7 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: json_item = json.dumps(item, default=str, ensure_ascii=False) db_items.append( DatasetItemDB( - dataset_id=self._metadata.id, + dataset_id=self._id, data=json_item, ) ) @@ -199,7 +228,7 @@ async def get_data( f'{self.__class__.__name__} client.' ) - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) if skip_empty: stmt = stmt.where(DatasetItemDB.data != '"{}"') @@ -218,13 +247,14 @@ async def get_data( await session.commit() items = [json.loads(db_item.data) for db_item in db_items] + metadata = await self.get_metadata() return DatasetItemsListPage( items=items, count=len(items), desc=desc, limit=limit or 0, offset=offset or 0, - total=self._metadata.item_count, + total=metadata.item_count, ) @override @@ -257,7 +287,7 @@ async def iterate_items( f'by the {self.__class__.__name__} client.' ) - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) if skip_empty: stmt = stmt.where(DatasetItemDB.data != '"{}"') @@ -301,20 +331,15 @@ async def _update_metadata( values_to_set: dict[str, Any] = {} if update_accessed_at: - self._metadata.accessed_at = now values_to_set['accessed_at'] = now if update_modified_at: - self._metadata.modified_at = now values_to_set['modified_at'] = now if new_item_count is not None: - self._metadata.item_count = new_item_count values_to_set['item_count'] = new_item_count - - if delta_item_count: - self._metadata.item_count += delta_item_count - values_to_set['item_count'] = DatasetMetadataDB.item_count + self._metadata.item_count + elif delta_item_count: + values_to_set['item_count'] = DatasetMetadataDB.item_count + delta_item_count if values_to_set: - stmt = update(DatasetMetadataDB).where(DatasetMetadataDB.id == self._metadata.id).values(**values_to_set) + stmt = update(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id).values(**values_to_set) await session.execute(stmt) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index a12633fcdc..8f6d631a8e 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -3,14 +3,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING -from sqlalchemy import ( - JSON, - Boolean, - ForeignKey, - Integer, - LargeBinary, - String, -) +from sqlalchemy import JSON, Boolean, ForeignKey, Index, Integer, LargeBinary, String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.types import DateTime, TypeDecorator @@ -18,6 +11,17 @@ from sqlalchemy.engine import Dialect +class NameDefaultNone(TypeDecorator): + impl = String(100) + cache_ok = True + + def process_bind_param(self, value: str | None, _dialect: Dialect) -> str | None: + return 'default' if value is None else value + + def process_result_value(self, value: str | None, _dialect: Dialect) -> str | None: + return None if value == 'default' else value + + class AwareDateTime(TypeDecorator): impl = DateTime(timezone=True) cache_ok = True @@ -36,7 +40,7 @@ class StorageMetadataDB: """Base database model for storage metadata.""" id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) - name: Mapped[str | None] = mapped_column(String(100), nullable=True, index=True, unique=True) + name: Mapped[str | None] = mapped_column(NameDefaultNone, nullable=False, index=True, unique=True) accessed_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) created_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) modified_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) @@ -96,7 +100,9 @@ class DatasetItemDB(Base): order_id: Mapped[int] = mapped_column(Integer, primary_key=True) dataset_id: Mapped[str] = mapped_column( - String(20), ForeignKey('dataset_metadata.id', ondelete='CASCADE'), index=True + String(20), + ForeignKey('dataset_metadata.id', ondelete='CASCADE'), + index=True, ) data: Mapped[str] = mapped_column(JSON, nullable=False) @@ -107,15 +113,19 @@ class RequestDB(Base): """Database model for requests in the request queue.""" __tablename__ = 'request' + __table_args__ = ( + Index('idx_queue_handled_seq', 'queue_id', 'is_handled', 'sequence_number'), + Index('idx_queue_unique_key', 'queue_id', 'unique_key'), + ) request_id: Mapped[str] = mapped_column(String(20), primary_key=True) queue_id: Mapped[str] = mapped_column( - String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), index=True, primary_key=True + String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), primary_key=True ) data: Mapped[str] = mapped_column(JSON, nullable=False) - unique_key: Mapped[str] = mapped_column(String(512), nullable=False, index=True) - sequence_number: Mapped[int] = mapped_column(Integer, nullable=False, index=True) - is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True) + unique_key: Mapped[str] = mapped_column(String(512), nullable=False) + sequence_number: Mapped[int] = mapped_column(Integer, nullable=False) + is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index d29e24427a..97e0f9fa96 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -6,7 +6,8 @@ from logging import getLogger from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, select, update +from sqlalchemy import delete, select, text, update +from sqlalchemy.exc import SQLAlchemyError from typing_extensions import override from crawlee._utils.crypto import crypto_random_object_id @@ -48,24 +49,32 @@ class SQLKeyValueStoreClient(KeyValueStoreClient): for development environments where you want to easily inspect the stored data between runs. """ + _DEFAULT_NAME_DB = 'default' + """Default dataset name used when no name is provided.""" + def __init__( self, *, storage_client: SQLStorageClient, - metadata: KeyValueStoreMetadata, + id: str, ) -> None: """Initialize a new instance. Preferably use the `SQLKeyValueStoreClient.open` class method to create a new instance. """ - self._metadata = metadata + self._id = id self._storage_client = storage_client """The storage client used to access the SQL database.""" @override async def get_metadata(self) -> KeyValueStoreMetadata: - return self._metadata + async with self.get_session() as session: + orm_metadata: KeyValueStoreMetadataDB | None = await session.get(KeyValueStoreMetadataDB, self._id) + if not orm_metadata: + raise ValueError(f'Key-value store with ID "{self._id}" not found.') + + return KeyValueStoreMetadata.model_validate(orm_metadata) def get_session(self) -> AsyncSession: """Create a new SQLAlchemy session for this key-value store.""" @@ -78,7 +87,7 @@ async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: try: yield session await session.commit() - except Exception as e: + except SQLAlchemyError as e: logger.warning(f'Error occurred during session transaction: {e}') # Rollback the session in case of an error await session.rollback() @@ -115,12 +124,13 @@ async def open( if not orm_metadata: raise ValueError(f'Key-value store with ID "{id}" not found.') else: - stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == name) + search_name = name or cls._DEFAULT_NAME_DB + stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == search_name) result = await session.execute(stmt) orm_metadata = result.scalar_one_or_none() if orm_metadata: client = cls( - metadata=KeyValueStoreMetadata.model_validate(orm_metadata), + id=orm_metadata.id, storage_client=storage_client, ) await client._update_metadata(session, update_accessed_at=True) @@ -134,26 +144,42 @@ async def open( modified_at=now, ) client = cls( - metadata=metadata, + id=metadata.id, storage_client=storage_client, ) - orm_metadata = KeyValueStoreMetadataDB(**metadata.model_dump()) - session.add(orm_metadata) - - # Commit the insert or update metadata to the database - await session.commit() + session.add(KeyValueStoreMetadataDB(**metadata.model_dump())) + try: + # Commit the insert or update metadata to the database + await session.commit() + except SQLAlchemyError: + # Attempt to open simultaneously by different clients. + # The commit that created the record has already been executed, make rollback and get by name. + await session.rollback() + search_name = name or cls._DEFAULT_NAME_DB + stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == search_name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + if not orm_metadata: + raise ValueError(f'Key-value store with Name "{search_name}" not found.') from None + client = cls( + id=orm_metadata.id, + storage_client=storage_client, + ) return client @override async def drop(self) -> None: - stmt = delete(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._metadata.id) + stmt = delete(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id) async with self.get_autocommit_session() as autosession: + if self._storage_client.get_default_flag(): + # foreign_keys=ON is set at the connection level. Required for cascade deletion. + await autosession.execute(text('PRAGMA foreign_keys=ON')) await autosession.execute(stmt) @override async def purge(self) -> None: - stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._metadata.id) + stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._id) async with self.get_autocommit_session() as autosession: await autosession.execute(stmt) @@ -181,7 +207,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No size = len(value_bytes) record_db = KeyValueStoreRecordDB( - kvs_id=self._metadata.id, + kvs_id=self._id, key=key, value=value_bytes, content_type=content_type, @@ -190,7 +216,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No stmt = ( update(KeyValueStoreRecordDB) - .where(KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key) + .where(KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key) .values(value=value_bytes, content_type=content_type, size=size) ) @@ -207,7 +233,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: stmt = select(KeyValueStoreRecordDB).where( - KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key + KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key ) async with self.get_session() as session: result = await session.execute(stmt) @@ -255,7 +281,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @override async def delete_value(self, *, key: str) -> None: stmt = delete(KeyValueStoreRecordDB).where( - KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key + KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key ) async with self.get_autocommit_session() as autocommit: # Delete the record if it exists @@ -277,7 +303,7 @@ async def iterate_keys( # Build query for record metadata stmt = ( select(KeyValueStoreRecordDB.key, KeyValueStoreRecordDB.content_type, KeyValueStoreRecordDB.size) - .where(KeyValueStoreRecordDB.kvs_id == self._metadata.id) + .where(KeyValueStoreRecordDB.kvs_id == self._id) .order_by(KeyValueStoreRecordDB.key) ) @@ -307,7 +333,7 @@ async def iterate_keys( @override async def record_exists(self, *, key: str) -> bool: stmt = select(KeyValueStoreRecordDB.key).where( - KeyValueStoreRecordDB.kvs_id == self._metadata.id, KeyValueStoreRecordDB.key == key + KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key ) async with self.get_session() as session: # Check if record exists @@ -340,16 +366,10 @@ async def _update_metadata( values_to_set: dict[str, Any] = {} if update_accessed_at: - self._metadata.accessed_at = now values_to_set['accessed_at'] = now if update_modified_at: - self._metadata.modified_at = now values_to_set['modified_at'] = now if values_to_set: - stmt = ( - update(KeyValueStoreMetadataDB) - .where(KeyValueStoreMetadataDB.id == self._metadata.id) - .values(**values_to_set) - ) + stmt = update(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id).values(**values_to_set) await session.execute(stmt) diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 67b21b8a62..8139cecd2a 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -2,13 +2,15 @@ import asyncio from collections import deque +from contextlib import asynccontextmanager from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pydantic import BaseModel -from sqlalchemy import delete, func, select, update +from sqlalchemy import delete, select, text, update from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import load_only from typing_extensions import override from crawlee import Request @@ -25,7 +27,7 @@ from ._db_models import RequestDB, RequestQueueMetadataDB if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import AsyncIterator, Sequence from sqlalchemy.ext.asyncio import AsyncSession @@ -44,9 +46,6 @@ class RequestQueueState(BaseModel): forefront_sequence_counter: int = -1 """Counter for forefront request ordering (negative).""" - in_progress_requests: set[str] = set() - """Set of request IDs currently being processed.""" - class SQLRequestQueueClient(RequestQueueClient): """SQL implementation of the request queue client. @@ -60,24 +59,30 @@ class SQLRequestQueueClient(RequestQueueClient): ordering. A cache mechanism reduces database queries for better performance. """ + _DEFAULT_NAME_DB = 'default' + """Default dataset name used when no name is provided.""" + _MAX_REQUESTS_IN_CACHE = 100_000 """Maximum number of requests to keep in cache for faster access.""" def __init__( self, *, - orm_metadata: RequestQueueMetadataDB, + id: str, storage_client: SQLStorageClient, ) -> None: """Initialize a new instance. Preferably use the `SQLRequestQueueClient.open` class method to create a new instance. """ - self._orm_metadata = orm_metadata + self._id = id self._request_cache: deque[Request] = deque() """Cache for requests: ordered by sequence number.""" + self.in_progress_requests: set[str] = set() + """Set of request IDs currently being processed.""" + self._request_cache_needs_refresh = True """Flag indicating whether the cache needs to be refreshed from database.""" @@ -88,7 +93,7 @@ def __init__( default_state=RequestQueueState(), persist_state_key='request_queue_state', persistence_enabled=True, - persist_state_kvs_name=f'__RQ_STATE_{self._orm_metadata.id}', + persist_state_kvs_name=f'__RQ_STATE_{self._id}', logger=logger, ) """Recoverable state to maintain request ordering and in-progress status.""" @@ -98,13 +103,30 @@ def __init__( self._lock = asyncio.Lock() - def create_session(self) -> AsyncSession: + @override + async def get_metadata(self) -> RequestQueueMetadata: + async with self.get_session() as session: + orm_metadata: RequestQueueMetadataDB | None = await session.get(RequestQueueMetadataDB, self._id) + if not orm_metadata: + raise ValueError(f'Request queue with ID "{self._id}" not found.') + + return RequestQueueMetadata.model_validate(orm_metadata) + + def get_session(self) -> AsyncSession: """Create a new SQLAlchemy session for this key-value store.""" return self._storage_client.create_session() - @override - async def get_metadata(self) -> RequestQueueMetadata: - return RequestQueueMetadata.model_validate(self._orm_metadata) + @asynccontextmanager + async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: + """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" + async with self.get_session() as session: + try: + yield session + await session.commit() + except SQLAlchemyError as e: + logger.warning(f'Error occurred during session transaction: {e}') + # Rollback the session in case of an error + await session.rollback() @classmethod async def open( @@ -128,47 +150,60 @@ async def open( ValueError: If a queue with the specified ID is not found. """ async with storage_client.create_session() as session: + orm_metadata: RequestQueueMetadataDB | None = None if id: orm_metadata = await session.get(RequestQueueMetadataDB, id) if not orm_metadata: raise ValueError(f'Request queue with ID "{id}" not found.') + else: + # Try to find by name + search_name = name or cls._DEFAULT_NAME_DB + stmt = select(RequestQueueMetadataDB).where(RequestQueueMetadataDB.name == search_name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + if orm_metadata: client = cls( - orm_metadata=orm_metadata, + id=orm_metadata.id, storage_client=storage_client, ) - await client._update_metadata(update_accessed_at=True) + await client._update_metadata(session, update_accessed_at=True) else: - # Try to find by name - orm_metadata = await session.get(RequestQueueMetadataDB, name) - - if orm_metadata: - client = cls( - orm_metadata=orm_metadata, - storage_client=storage_client, - ) - await client._update_metadata(update_accessed_at=True) - else: - now = datetime.now(timezone.utc) - metadata = RequestQueueMetadata( - id=crypto_random_object_id(), - name=name, - created_at=now, - accessed_at=now, - modified_at=now, - had_multiple_clients=False, - handled_request_count=0, - pending_request_count=0, - total_request_count=0, - ) - orm_metadata = RequestQueueMetadataDB(**metadata.model_dump()) - client = cls( - orm_metadata=orm_metadata, - storage_client=storage_client, - ) + now = datetime.now(timezone.utc) + metadata = RequestQueueMetadata( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + had_multiple_clients=False, + handled_request_count=0, + pending_request_count=0, + total_request_count=0, + ) - session.add(orm_metadata) + client = cls( + id=metadata.id, + storage_client=storage_client, + ) + session.add(RequestQueueMetadataDB(**metadata.model_dump())) - await session.commit() + try: + # Commit the insert or update metadata to the database + await session.commit() + except SQLAlchemyError: + # Attempt to open simultaneously by different clients. + # The commit that created the record has already been executed, make rollback and get by name. + await session.rollback() + search_name = name or cls._DEFAULT_NAME_DB + stmt = select(RequestQueueMetadataDB).where(RequestQueueMetadataDB.name == search_name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() + if not orm_metadata: + raise ValueError(f'Request queue with Name "{search_name}" not found.') from None + client = cls( + id=orm_metadata.id, + storage_client=storage_client, + ) await client._state.initialize() @@ -176,38 +211,37 @@ async def open( @override async def drop(self) -> None: - async with self._storage_client.create_session() as session: + stmt = delete(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id) + async with self.get_autocommit_session() as autocommit: + if self._storage_client.get_default_flag(): + # foreign_keys=ON is set at the connection level. Required for cascade deletion. + await autocommit.execute(text('PRAGMA foreign_keys=ON')) # Delete the request queue metadata (cascade will delete requests) - rq_db = await session.get(RequestQueueMetadataDB, self._orm_metadata.id) - if rq_db: - await session.delete(rq_db) - - # Clear recoverable state - await self._state.reset() - await self._state.teardown() - self._request_cache.clear() - self._request_cache_needs_refresh = True - self._is_empty_cache = None + await autocommit.execute(stmt) - await session.commit() + # Clear recoverable state + await self._state.reset() + await self._state.teardown() + self._request_cache.clear() + self._request_cache_needs_refresh = True + self._is_empty_cache = None @override async def purge(self) -> None: - async with self._storage_client.create_session() as session: + stmt = delete(RequestDB).where(RequestDB.queue_id == self._id) + async with self.get_autocommit_session() as autocommit: # Delete all requests for this queue - stmt = delete(RequestDB).where(RequestDB.queue_id == self._orm_metadata.id) - await session.execute(stmt) - - # Update metadata - self._orm_metadata.pending_request_count = 0 - self._orm_metadata.handled_request_count = 0 - - await self._update_metadata(update_modified_at=True, update_accessed_at=True) - - self._is_empty_cache = None + await autocommit.execute(stmt) + + await self._update_metadata( + autocommit, + new_pending_request_count=0, + new_handled_request_count=0, + update_modified_at=True, + update_accessed_at=True, + ) - await session.merge(self._orm_metadata) - await session.commit() + self._is_empty_cache = None # Clear recoverable state self._request_cache.clear() @@ -221,32 +255,52 @@ async def add_batch_of_requests( *, forefront: bool = False, ) -> AddRequestsResponse: - async with self._storage_client.create_session() as session, self._lock: - self._is_empty_cache = None - processed_requests = [] - unprocessed_requests = [] - state = self._state.current_value - - # Get existing requests by unique keys - unique_keys = {req.unique_key for req in requests} - stmt = select(RequestDB).where( - RequestDB.queue_id == self._orm_metadata.id, RequestDB.unique_key.in_(unique_keys) + if not requests: + return AddRequestsResponse(processed_requests=[], unprocessed_requests=[]) + + self._is_empty_cache = None + processed_requests = [] + unprocessed_requests = [] + + delta_total_request_count = 0 + delta_pending_request_count = 0 + + # Deduplicate requests by unique_key upfront + unique_requests = {} + for req in requests: + if req.unique_key not in unique_requests: + unique_requests[req.unique_key] = req + + unique_keys = list(unique_requests.keys()) + + # Get existing requests by unique keys + stmt = ( + select(RequestDB) + .where(RequestDB.queue_id == self._id, RequestDB.unique_key.in_(unique_keys)) + .options( + load_only( + RequestDB.request_id, + RequestDB.unique_key, + RequestDB.is_handled, + RequestDB.sequence_number, + ) ) + ) + + state = self._state.current_value + + async with self.get_session() as session, self._lock: result = await session.execute(stmt) existing_requests = {req.unique_key: req for req in result.scalars()} - result = await session.execute(stmt) - batch_processed = set() + new_request_objects = [] # Process each request - for request in requests: - if request.unique_key in batch_processed: - continue - - existing_req_db = existing_requests.get(request.unique_key) + for unique_key, request in unique_requests.items(): + existing_req_db = existing_requests.get(unique_key) + # New request if existing_req_db is None: - # New request if forefront: sequence_number = state.forefront_sequence_counter state.forefront_sequence_counter -= 1 @@ -254,18 +308,19 @@ async def add_batch_of_requests( sequence_number = state.sequence_counter state.sequence_counter += 1 - request_db = RequestDB( - request_id=request.id, - queue_id=self._orm_metadata.id, - data=request.model_dump_json(), - unique_key=request.unique_key, - sequence_number=sequence_number, - is_handled=False, + new_request_objects.append( + RequestDB( + request_id=request.id, + queue_id=self._id, + data=request.model_dump_json(), + unique_key=request.unique_key, + sequence_number=sequence_number, + is_handled=False, + ) ) - session.add(request_db) - self._orm_metadata.total_request_count += 1 - self._orm_metadata.pending_request_count += 1 + delta_total_request_count += 1 + delta_pending_request_count += 1 processed_requests.append( ProcessedRequest( @@ -292,7 +347,6 @@ async def add_batch_of_requests( if forefront and existing_req_db.sequence_number > 0: existing_req_db.sequence_number = state.forefront_sequence_counter state.forefront_sequence_counter -= 1 - self._request_cache_needs_refresh = True processed_requests.append( ProcessedRequest( @@ -303,15 +357,18 @@ async def add_batch_of_requests( ) ) - batch_processed.add(request.unique_key) + if new_request_objects: + session.add_all(new_request_objects) - await self._update_metadata(update_modified_at=True, update_accessed_at=True) - - if forefront: - self._request_cache_needs_refresh = True + await self._update_metadata( + session, + delta_total_request_count=delta_total_request_count, + delta_pending_request_count=delta_pending_request_count, + update_modified_at=True, + update_accessed_at=True, + ) try: - await session.merge(self._orm_metadata) await session.commit() except SQLAlchemyError as e: logger.warning(f'Failed to commit session: {e}') @@ -328,17 +385,18 @@ async def add_batch_of_requests( ] ) - return AddRequestsResponse( - processed_requests=processed_requests, - unprocessed_requests=unprocessed_requests, - ) + if forefront: + self._request_cache_needs_refresh = True + + return AddRequestsResponse( + processed_requests=processed_requests, + unprocessed_requests=unprocessed_requests, + ) @override async def get_request(self, request_id: str) -> Request | None: - async with self._storage_client.create_session() as session: - stmt = select(RequestDB).where( - RequestDB.queue_id == self._orm_metadata.id, RequestDB.request_id == request_id - ) + stmt = select(RequestDB).where(RequestDB.queue_id == self._id, RequestDB.request_id == request_id) + async with self.get_session() as session: result = await session.execute(stmt) request_db = result.scalar_one_or_none() @@ -346,34 +404,34 @@ async def get_request(self, request_id: str) -> Request | None: logger.warning(f'Request with ID "{request_id}" not found in the queue.') return None - request = Request.model_validate_json(request_db.data) + await self._update_metadata(session, update_accessed_at=True) - state = self._state.current_value - state.in_progress_requests.add(request.id) - - await self._update_metadata(update_accessed_at=True) - await session.merge(self._orm_metadata) + # Commit updates to the metadata await session.commit() - return request + request = Request.model_validate_json(request_db.data) + + self.in_progress_requests.add(request.id) + + return request @override async def fetch_next_request(self) -> Request | None: # Refresh cache if needed - if self._request_cache_needs_refresh or not self._request_cache: - await self._refresh_cache() + async with self._lock: + if self._request_cache_needs_refresh or not self._request_cache: + await self._refresh_cache() next_request = None - state = self._state.current_value # Get from cache while self._request_cache and next_request is None: candidate = self._request_cache.popleft() # Only check local state - if candidate.id not in state.in_progress_requests: + if candidate.id not in self.in_progress_requests: next_request = candidate - state.in_progress_requests.add(next_request.id) + self.in_progress_requests.add(next_request.id) if not self._request_cache: self._is_empty_cache = None @@ -383,36 +441,40 @@ async def fetch_next_request(self) -> Request | None: @override async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: self._is_empty_cache = None - state = self._state.current_value - if request.id not in state.in_progress_requests: + if request.id not in self.in_progress_requests: logger.warning(f'Marking request {request.id} as handled that is not in progress.') return None # Update request in DB stmt = ( update(RequestDB) - .where(RequestDB.queue_id == self._orm_metadata.id, RequestDB.request_id == request.id) + .where(RequestDB.queue_id == self._id, RequestDB.request_id == request.id) .values(is_handled=True) ) - async with self._storage_client.create_session() as session: + async with self.get_session() as session: result = await session.execute(stmt) if result.rowcount == 0: logger.warning(f'Request {request.id} not found in database.') return None - # Update state - state.in_progress_requests.discard(request.id) + await self._update_metadata( + session, + delta_handled_request_count=1, + delta_pending_request_count=-1, + update_modified_at=True, + update_accessed_at=True, + ) - # Update metadata - self._orm_metadata.handled_request_count += 1 - self._orm_metadata.pending_request_count -= 1 + try: + await session.commit() + except SQLAlchemyError: + await session.rollback() + return None - await self._update_metadata(update_modified_at=True, update_accessed_at=True) - await session.merge(self._orm_metadata) - await session.commit() + self.in_progress_requests.discard(request.id) return ProcessedRequest( id=request.id, @@ -431,7 +493,7 @@ async def reclaim_request( self._is_empty_cache = None state = self._state.current_value - if request.id not in state.in_progress_requests: + if request.id not in self.in_progress_requests: logger.info(f'Reclaiming request {request.id} that is not in progress.') return None @@ -445,19 +507,19 @@ async def reclaim_request( stmt = ( update(RequestDB) - .where(RequestDB.queue_id == self._orm_metadata.id, RequestDB.request_id == request.id) + .where(RequestDB.queue_id == self._id, RequestDB.request_id == request.id) .values(sequence_number=new_sequence) ) - async with self._storage_client.create_session() as session: - result = await session.execute(stmt) + async with self.get_autocommit_session() as autocommit: + result = await autocommit.execute(stmt) if result.rowcount == 0: logger.warning(f'Request {request.id} not found in database.') return None # Remove from in-progress - state.in_progress_requests.discard(request.id) + self.in_progress_requests.discard(request.id) # Invalidate cache or add to cache if forefront: @@ -466,69 +528,56 @@ async def reclaim_request( # For regular requests, we can add to the end if there's space self._request_cache.append(request) - await self._update_metadata(update_modified_at=True, update_accessed_at=True) - await session.merge(self._orm_metadata) - await session.commit() + await self._update_metadata(autocommit, update_modified_at=True, update_accessed_at=True) - return ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ) + return ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=False, + ) @override async def is_empty(self) -> bool: if self._is_empty_cache is not None: return self._is_empty_cache - state = self._state.current_value - # If there are in-progress requests, not empty - if len(state.in_progress_requests) > 0: + if len(self.in_progress_requests) > 0: self._is_empty_cache = False return False # Check database for unhandled requests - async with self._storage_client.create_session() as session: - stmt = ( - select(func.count()) - .select_from(RequestDB) - .where( - RequestDB.queue_id == self._orm_metadata.id, - RequestDB.is_handled == False, # noqa: E712 - ) - ) - result = await session.execute(stmt) - unhandled_count = result.scalar() - self._is_empty_cache = unhandled_count == 0 + async with self.get_session() as session: + metadata_orm = await session.get(RequestQueueMetadataDB, self._id) + if not metadata_orm: + raise ValueError(f'Request queue with ID "{self._id}" not found.') - await self._update_metadata(update_accessed_at=True) - await session.merge(self._orm_metadata) + self._is_empty_cache = metadata_orm.pending_request_count == 0 + await self._update_metadata(session, update_accessed_at=True) + + # Commit updates to the metadata await session.commit() - return self._is_empty_cache + return self._is_empty_cache async def _refresh_cache(self) -> None: """Refresh the request cache from database.""" self._request_cache.clear() - state = self._state.current_value - async with self._storage_client.create_session() as session: + async with self.get_session() as session: # Simple query - get unhandled requests not in progress stmt = ( select(RequestDB) .where( - RequestDB.queue_id == self._orm_metadata.id, + RequestDB.queue_id == self._id, RequestDB.is_handled == False, # noqa: E712 + RequestDB.request_id.notin_(self.in_progress_requests), ) .order_by(RequestDB.sequence_number.asc()) .limit(self._MAX_REQUESTS_IN_CACHE) ) - if state.in_progress_requests: - stmt = stmt.where(RequestDB.request_id.notin_(state.in_progress_requests)) - result = await session.execute(stmt) request_dbs = result.scalars().all() @@ -541,7 +590,14 @@ async def _refresh_cache(self) -> None: async def _update_metadata( self, + session: AsyncSession, *, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, + delta_handled_request_count: int | None = None, + delta_pending_request_count: int | None = None, + delta_total_request_count: int | None = None, update_had_multiple_clients: bool = False, update_accessed_at: bool = False, update_modified_at: bool = False, @@ -550,17 +606,49 @@ async def _update_metadata( Args: session: The SQLAlchemy session to use for database operations. + new_handled_request_count: If provided, update the handled_request_count to this value. + new_pending_request_count: If provided, update the pending_request_count to this value. + new_total_request_count: If provided, update the total_request_count to this value. + delta_handled_request_count: If provided, add this value to the handled_request_count. + delta_pending_request_count: If provided, add this value to the pending_request_count. + delta_total_request_count: If provided, add this value to the total_request_count. update_had_multiple_clients: If True, set had_multiple_clients to True. update_accessed_at: If True, update the `accessed_at` timestamp to the current time. update_modified_at: If True, update the `modified_at` timestamp to the current time. """ now = datetime.now(timezone.utc) + values_to_set: dict[str, Any] = {} if update_accessed_at: - self._orm_metadata.accessed_at = now + values_to_set['accessed_at'] = now if update_modified_at: - self._orm_metadata.modified_at = now + values_to_set['modified_at'] = now if update_had_multiple_clients: - self._orm_metadata.had_multiple_clients = True + values_to_set['had_multiple_clients'] = True + + if new_handled_request_count is not None: + values_to_set['handled_request_count'] = new_handled_request_count + elif delta_handled_request_count is not None: + values_to_set['handled_request_count'] = ( + RequestQueueMetadataDB.handled_request_count + delta_handled_request_count + ) + + if new_pending_request_count is not None: + values_to_set['pending_request_count'] = new_pending_request_count + elif delta_pending_request_count is not None: + values_to_set['pending_request_count'] = ( + RequestQueueMetadataDB.pending_request_count + delta_pending_request_count + ) + + if new_total_request_count is not None: + values_to_set['total_request_count'] = new_total_request_count + elif delta_total_request_count is not None: + values_to_set['total_request_count'] = ( + RequestQueueMetadataDB.total_request_count + delta_total_request_count + ) + + if values_to_set: + stmt = update(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id).values(**values_to_set) + await session.execute(stmt) diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index 9e29f7b481..98bd5b2c8e 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.sql import text from typing_extensions import override @@ -19,6 +19,8 @@ if TYPE_CHECKING: from types import TracebackType + from sqlalchemy.ext.asyncio import AsyncSession + @docs_group('Storage clients') class SQLStorageClient(StorageClient): @@ -70,6 +72,10 @@ def engine(self) -> AsyncEngine: raise ValueError('Engine is not initialized. Call initialize() before accessing the engine.') return self._engine + def get_default_flag(self) -> bool: + """Check if the default database should be created.""" + return self._default_flag + def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: """Get or create the database engine based on configuration.""" if self._engine is not None: @@ -88,7 +94,15 @@ def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: connection_string = f'sqlite+aiosqlite:///{db_path}' self._engine = create_async_engine( - connection_string, future=True, pool_size=5, max_overflow=10, pool_timeout=30, pool_recycle=600, echo=False + connection_string, + future=True, + pool_size=5, + max_overflow=10, + pool_timeout=30, + pool_recycle=600, + pool_pre_ping=True, + echo=False, + connect_args={'timeout': 30}, ) return self._engine @@ -101,7 +115,6 @@ async def initialize(self, configuration: Configuration) -> None: if not self._initialized: engine = self._get_or_create_engine(configuration) async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) # Set SQLite pragmas for performance and consistency if self._default_flag: await conn.execute(text('PRAGMA journal_mode=WAL')) @@ -111,6 +124,7 @@ async def initialize(self, configuration: Configuration) -> None: await conn.execute(text('PRAGMA mmap_size=268435456')) await conn.execute(text('PRAGMA foreign_keys=ON')) await conn.execute(text('PRAGMA busy_timeout=30000')) + await conn.run_sync(Base.metadata.create_all) self._initialized = True async def close(self) -> None: diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py index c90c91235d..99d31e2090 100644 --- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -99,7 +99,7 @@ async def test_tables_and_metadata_record(configuration: Configuration) -> None: assert 'request_queue_metadata' in tables assert 'request' in tables - async with client.create_session() as session: + async with client.get_session() as session: stmt = select(RequestQueueMetadataDB).where(RequestQueueMetadataDB.name == 'test_request_queue') result = await session.execute(stmt) orm_metadata = result.scalar_one_or_none() @@ -122,7 +122,7 @@ async def test_request_records_persistence(rq_client: SQLRequestQueueClient) -> metadata_client = await rq_client.get_metadata() - async with rq_client.create_session() as session: + async with rq_client.get_session() as session: stmt = select(RequestDB).where(RequestDB.queue_id == metadata_client.id) result = await session.execute(stmt) db_requests = result.scalars().all() @@ -136,7 +136,7 @@ async def test_drop_removes_records(rq_client: SQLRequestQueueClient) -> None: """Test that drop removes all records from the database.""" await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) metadata = await rq_client.get_metadata() - async with rq_client.create_session() as session: + async with rq_client.get_session() as session: stmt = select(RequestDB).where(RequestDB.queue_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() @@ -144,7 +144,7 @@ async def test_drop_removes_records(rq_client: SQLRequestQueueClient) -> None: await rq_client.drop() - async with rq_client.create_session() as session: + async with rq_client.get_session() as session: stmt = select(RequestDB).where(RequestDB.queue_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() @@ -187,7 +187,7 @@ async def test_metadata_record_updates(rq_client: SQLRequestQueueClient) -> None assert metadata.modified_at > initial_modified assert metadata.accessed_at > accessed_after_read - async with rq_client.create_session() as session: + async with rq_client.get_session() as session: orm_metadata = await session.get(RequestQueueMetadataDB, metadata.id) assert orm_metadata is not None assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == metadata.created_at From 1884f7d34bf62cf70afd361f2717c956b1273997 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Fri, 1 Aug 2025 17:00:26 +0000 Subject: [PATCH 11/29] reduce the refresh rate of `accessed_at` --- .../_file_system/_storage_client.py | 6 +-- .../storage_clients/_sql/_dataset_client.py | 30 ++++++++--- .../_sql/_key_value_store_client.py | 54 ++++++++++++++----- .../_sql/_request_queue_client.py | 29 +++++++--- .../storage_clients/_sql/_storage_client.py | 11 ++++ .../_sql/test_sql_dataset_client.py | 10 ++-- .../_sql/test_sql_kvs_client.py | 3 +- .../_sql/test_sql_rq_client.py | 10 ++-- 8 files changed, 109 insertions(+), 44 deletions(-) diff --git a/src/crawlee/storage_clients/_file_system/_storage_client.py b/src/crawlee/storage_clients/_file_system/_storage_client.py index 86903ea4e7..2d3563125d 100644 --- a/src/crawlee/storage_clients/_file_system/_storage_client.py +++ b/src/crawlee/storage_clients/_file_system/_storage_client.py @@ -31,11 +31,7 @@ class FileSystemStorageClient(StorageClient): @override async def create_dataset_client( - self, - *, - id: str | None = None, - name: str | None = None, - configuration: Configuration | None = None, + self, *, id: str | None = None, name: str | None = None, configuration: Configuration | None = None ) -> FileSystemDatasetClient: configuration = configuration or Configuration.get_global_configuration() client = await FileSystemDatasetClient.open(id=id, name=name, configuration=configuration) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index dc829c2ab1..319ea78bd9 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -54,6 +54,10 @@ def __init__( self._id = id self._storage_client = storage_client + self._last_accessed_at: datetime | None = None + self._last_modified_at: datetime | None = None + self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() + @override async def get_metadata(self) -> DatasetMetadata: async with self.get_session() as session: @@ -241,10 +245,11 @@ async def get_data( result = await session.execute(stmt) db_items = result.scalars().all() - await self._update_metadata(session, update_accessed_at=True) + updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata - await session.commit() + if updated: + await session.commit() items = [json.loads(db_item.data) for db_item in db_items] metadata = await self.get_metadata() @@ -300,10 +305,11 @@ async def iterate_items( result = await session.execute(stmt) db_items = result.scalars().all() - await self._update_metadata(session, update_accessed_at=True) + updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata - await session.commit() + if updated: + await session.commit() items = [json.loads(db_item.data) for db_item in db_items] for item in items: @@ -317,7 +323,7 @@ async def _update_metadata( update_accessed_at: bool = False, update_modified_at: bool = False, delta_item_count: int | None = None, - ) -> None: + ) -> bool: """Update the KVS metadata in the database. Args: @@ -330,10 +336,17 @@ async def _update_metadata( now = datetime.now(timezone.utc) values_to_set: dict[str, Any] = {} - if update_accessed_at: + if update_accessed_at and ( + self._last_accessed_at is None or (now - self._last_accessed_at) > self._accessed_modified_update_interval + ): values_to_set['accessed_at'] = now - if update_modified_at: + self._last_accessed_at = now + + if update_modified_at and ( + self._last_modified_at is None or (now - self._last_modified_at) > self._accessed_modified_update_interval + ): values_to_set['modified_at'] = now + self._last_modified_at = now if new_item_count is not None: values_to_set['item_count'] = new_item_count @@ -343,3 +356,6 @@ async def _update_metadata( if values_to_set: stmt = update(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id).values(**values_to_set) await session.execute(stmt) + return True + + return False diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index 97e0f9fa96..4ce0c43623 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import delete, select, text, update -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.exc import IntegrityError, SQLAlchemyError from typing_extensions import override from crawlee._utils.crypto import crypto_random_object_id @@ -67,6 +67,12 @@ def __init__( self._storage_client = storage_client """The storage client used to access the SQL database.""" + self._last_accessed_at: datetime | None = None + self._last_modified_at: datetime | None = None + + self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() + """Interval for updating metadata in the database.""" + @override async def get_metadata(self) -> KeyValueStoreMetadata: async with self.get_session() as session: @@ -223,12 +229,17 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No # A race condition is possible if several clients work with one kvs. # Unfortunately, there is no implementation of atomic Upsert that is independent of specific dialects. # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements - async with self.get_autocommit_session() as autocommit: - result = await autocommit.execute(stmt) + async with self.get_session() as session: + result = await session.execute(stmt) if result.rowcount == 0: - autocommit.add(record_db) + session.add(record_db) - await self._update_metadata(autocommit, update_accessed_at=True, update_modified_at=True) + await self._update_metadata(session, update_accessed_at=True, update_modified_at=True) + try: + await session.commit() + except IntegrityError: + # Race condition when attempting to INSERT the same key. Ignore duplicates. + await session.rollback() @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @@ -239,10 +250,11 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: result = await session.execute(stmt) record_db = result.scalar_one_or_none() - await self._update_metadata(session, update_accessed_at=True) + updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata - await session.commit() + if updated: + await session.commit() if not record_db: return None @@ -325,10 +337,11 @@ async def iterate_keys( size=row.size, ) - await self._update_metadata(session, update_accessed_at=True) + updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata - await session.commit() + if updated: + await session.commit() @override async def record_exists(self, *, key: str) -> bool: @@ -339,8 +352,11 @@ async def record_exists(self, *, key: str) -> bool: # Check if record exists result = await session.execute(stmt) - await self._update_metadata(session, update_accessed_at=True) - await session.commit() + updated = await self._update_metadata(session, update_accessed_at=True) + + # Commit updates to the metadata + if updated: + await session.commit() return result.scalar_one_or_none() is not None @@ -354,7 +370,7 @@ async def _update_metadata( *, update_accessed_at: bool = False, update_modified_at: bool = False, - ) -> None: + ) -> bool: """Update the KVS metadata in the database. Args: @@ -365,11 +381,21 @@ async def _update_metadata( now = datetime.now(timezone.utc) values_to_set: dict[str, Any] = {} - if update_accessed_at: + if update_accessed_at and ( + self._last_accessed_at is None or (now - self._last_accessed_at) > self._accessed_modified_update_interval + ): values_to_set['accessed_at'] = now - if update_modified_at: + self._last_accessed_at = now + + if update_modified_at and ( + self._last_modified_at is None or (now - self._last_modified_at) > self._accessed_modified_update_interval + ): values_to_set['modified_at'] = now + self._last_modified_at = now if values_to_set: stmt = update(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id).values(**values_to_set) await session.execute(stmt) + return True + + return False diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 8139cecd2a..15c64f65c3 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -89,6 +89,10 @@ def __init__( self._is_empty_cache: bool | None = None """Cache for is_empty result: None means unknown, True/False is cached state.""" + self._last_accessed_at: datetime | None = None + self._last_modified_at: datetime | None = None + self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() + self._state = RecoverableState[RequestQueueState]( default_state=RequestQueueState(), persist_state_key='request_queue_state', @@ -404,10 +408,11 @@ async def get_request(self, request_id: str) -> Request | None: logger.warning(f'Request with ID "{request_id}" not found in the queue.') return None - await self._update_metadata(session, update_accessed_at=True) + updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata - await session.commit() + if updated: + await session.commit() request = Request.model_validate_json(request_db.data) @@ -554,10 +559,11 @@ async def is_empty(self) -> bool: raise ValueError(f'Request queue with ID "{self._id}" not found.') self._is_empty_cache = metadata_orm.pending_request_count == 0 - await self._update_metadata(session, update_accessed_at=True) + updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata - await session.commit() + if updated: + await session.commit() return self._is_empty_cache @@ -601,7 +607,7 @@ async def _update_metadata( update_had_multiple_clients: bool = False, update_accessed_at: bool = False, update_modified_at: bool = False, - ) -> None: + ) -> bool: """Update the request queue metadata in the database. Args: @@ -619,11 +625,17 @@ async def _update_metadata( now = datetime.now(timezone.utc) values_to_set: dict[str, Any] = {} - if update_accessed_at: + if update_accessed_at and ( + self._last_accessed_at is None or (now - self._last_accessed_at) > self._accessed_modified_update_interval + ): values_to_set['accessed_at'] = now + self._last_accessed_at = now - if update_modified_at: + if update_modified_at and ( + self._last_modified_at is None or (now - self._last_modified_at) > self._accessed_modified_update_interval + ): values_to_set['modified_at'] = now + self._last_modified_at = now if update_had_multiple_clients: values_to_set['had_multiple_clients'] = True @@ -652,3 +664,6 @@ async def _update_metadata( if values_to_set: stmt = update(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id).values(**values_to_set) await session.execute(stmt) + return True + + return False diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index 98bd5b2c8e..f6f92a2de3 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -47,6 +48,7 @@ def __init__( *, connection_string: str | None = None, engine: AsyncEngine | None = None, + accessed_modified_update_interval: timedelta = timedelta(seconds=1), ) -> None: """Initialize the SQL storage client. @@ -54,6 +56,9 @@ def __init__( connection_string: Database connection string (e.g., "sqlite+aiosqlite:///crawlee.db"). If not provided, defaults to SQLite database in the storage directory. engine: Pre-configured AsyncEngine instance. If provided, connection_string is ignored. + accessed_modified_update_interval: Minimum interval between updates of accessed_at and modified_at + timestamps in metadata tables. Used to reduce frequency of timestamp updates during frequent + read/write operations. Default is 1 second. """ if engine is not None and connection_string is not None: raise ValueError('Either connection_string or engine must be provided, not both.') @@ -62,6 +67,8 @@ def __init__( self._engine = engine self._initialized = False + self._accessed_modified_update_interval = accessed_modified_update_interval + # Default flag to indicate if the default database should be created self._default_flag = self._engine is None and self._connection_string is None @@ -76,6 +83,10 @@ def get_default_flag(self) -> bool: """Check if the default database should be created.""" return self._default_flag + def get_accessed_modified_update_interval(self) -> timedelta: + """Get the interval for accessed and modified updates.""" + return self._accessed_modified_update_interval + def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: """Get or create the database engine based on configuration.""" if self._engine is not None: diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index b8c403b8b1..88f38cbfc5 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -2,7 +2,7 @@ import asyncio import json -from datetime import timezone +from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -39,7 +39,7 @@ def get_tables(sync_conn: Connection) -> list[str]: @pytest.fixture async def dataset_client(configuration: Configuration) -> AsyncGenerator[SQLDatasetClient, None]: """A fixture for a SQL dataset client.""" - async with SQLStorageClient() as storage_client: + async with SQLStorageClient(accessed_modified_update_interval=timedelta(seconds=0)) as storage_client: client = await storage_client.create_dataset_client( name='test_dataset', configuration=configuration, @@ -180,9 +180,9 @@ async def test_metadata_record_updates(dataset_client: SQLDatasetClient) -> None # Verify timestamps metadata = await dataset_client.get_metadata() - assert metadata.created_at.replace(tzinfo=timezone.utc) == initial_created - assert metadata.accessed_at.replace(tzinfo=timezone.utc) > initial_accessed - assert metadata.modified_at.replace(tzinfo=timezone.utc) == initial_modified + assert metadata.created_at == initial_created + assert metadata.accessed_at > initial_accessed + assert metadata.modified_at == initial_modified accessed_after_get = metadata.accessed_at diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py index 8fc4342362..8f4a5a3ddc 100644 --- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -2,6 +2,7 @@ import asyncio import json +from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -33,7 +34,7 @@ def configuration(tmp_path: Path) -> Configuration: @pytest.fixture async def kvs_client(configuration: Configuration) -> AsyncGenerator[SQLKeyValueStoreClient, None]: """A fixture for a SQL key-value store client.""" - async with SQLStorageClient() as storage_client: + async with SQLStorageClient(accessed_modified_update_interval=timedelta(seconds=0)) as storage_client: client = await storage_client.create_kvs_client( name='test_kvs', configuration=configuration, diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py index 99d31e2090..7e3350719c 100644 --- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -2,7 +2,7 @@ import asyncio import json -from datetime import timezone +from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -35,7 +35,7 @@ def configuration(tmp_path: Path) -> Configuration: @pytest.fixture async def rq_client(configuration: Configuration) -> AsyncGenerator[SQLRequestQueueClient, None]: """A fixture for a SQL request queue client.""" - async with SQLStorageClient() as storage_client: + async with SQLStorageClient(accessed_modified_update_interval=timedelta(seconds=0)) as storage_client: client = await storage_client.create_rq_client( name='test_request_queue', configuration=configuration, @@ -190,9 +190,9 @@ async def test_metadata_record_updates(rq_client: SQLRequestQueueClient) -> None async with rq_client.get_session() as session: orm_metadata = await session.get(RequestQueueMetadataDB, metadata.id) assert orm_metadata is not None - assert orm_metadata.created_at.replace(tzinfo=timezone.utc) == metadata.created_at - assert orm_metadata.accessed_at.replace(tzinfo=timezone.utc) == metadata.accessed_at - assert orm_metadata.modified_at.replace(tzinfo=timezone.utc) == metadata.modified_at + assert orm_metadata.created_at == metadata.created_at + assert orm_metadata.accessed_at == metadata.accessed_at + assert orm_metadata.modified_at == metadata.modified_at async def test_data_persistence_across_reopens(configuration: Configuration) -> None: From a10e3cf3f415fb68cdeae1317812ec5d3a403475 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Fri, 1 Aug 2025 20:54:55 +0000 Subject: [PATCH 12/29] up docs --- .../storage_clients/_sql/_dataset_client.py | 48 +++++++++--- .../storage_clients/_sql/_db_models.py | 76 ++++++++++++++++++- .../_sql/_key_value_store_client.py | 49 +++++++----- .../_sql/_request_queue_client.py | 39 +++++++--- .../storage_clients/_sql/_storage_client.py | 64 +++++++++++----- 5 files changed, 217 insertions(+), 59 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 319ea78bd9..3c0bb28f21 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -29,13 +29,16 @@ class SQLDatasetClient(DatasetClient): """SQL implementation of the dataset client. - This client persists dataset items to a SQL database with proper transaction handling and - concurrent access safety. Items are stored in a normalized table structure with automatic - ordering preservation and efficient querying capabilities. + This client persists dataset items to a SQL database using two tables for storage + and retrieval. Items are stored as JSON with automatic ordering preservation. - The SQL implementation provides ACID compliance, supports complex queries, and allows - multiple processes to safely access the same dataset concurrently through database-level - locking mechanisms. + The dataset data is stored in SQL database tables following the pattern: + - `dataset_metadata` table: Contains dataset metadata (id, name, timestamps, item_count) + - `dataset_item` table: Contains individual items with JSON data and auto-increment ordering + + Items are serialized to JSON with `default=str` to handle non-serializable types like datetime + objects. The `order_id` auto-increment primary key ensures insertion order is preserved. + All operations are wrapped in database transactions with CASCADE deletion support. """ _DEFAULT_NAME_DB = 'default' @@ -54,12 +57,15 @@ def __init__( self._id = id self._storage_client = storage_client + # Time tracking to reduce database writes during frequent operation self._last_accessed_at: datetime | None = None self._last_modified_at: datetime | None = None self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() @override async def get_metadata(self) -> DatasetMetadata: + """Get dataset metadata from the database.""" + # The database is a single place of truth async with self.get_session() as session: orm_metadata: DatasetMetadataDB | None = await session.get(DatasetMetadataDB, self._id) if not orm_metadata: @@ -91,7 +97,7 @@ async def open( name: str | None, storage_client: SQLStorageClient, ) -> SQLDatasetClient: - """Open or create a SQL dataset client. + """Open an existing dataset or create a new one. Args: id: The ID of the dataset to open. If provided, searches for existing dataset by ID. @@ -161,6 +167,10 @@ async def open( @override async def drop(self) -> None: + """Delete this dataset and all its items from the database. + + This operation is irreversible. Uses CASCADE deletion to remove all related items. + """ stmt = delete(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id) async with self.get_autocommit_session() as autocommit: if self._storage_client.get_default_flag(): @@ -170,6 +180,10 @@ async def drop(self) -> None: @override async def purge(self) -> None: + """Remove all items from this dataset while keeping the dataset structure. + + Resets item_count to 0 and deletes all records from dataset_item table. + """ stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) async with self.get_autocommit_session() as autocommit: await autocommit.execute(stmt) @@ -178,12 +192,14 @@ async def purge(self) -> None: @override async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: + """Add new items to the dataset.""" if not isinstance(data, list): data = [data] db_items: list[DatasetItemDB] = [] for item in data: + # Serialize with default=str to handle non-serializable types like datetime json_item = json.dumps(item, default=str, ensure_ascii=False) db_items.append( DatasetItemDB( @@ -235,8 +251,10 @@ async def get_data( stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) if skip_empty: + # Skip items that are empty JSON objects stmt = stmt.where(DatasetItemDB.data != '"{}"') + # Apply ordering by insertion order (order_id) stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) stmt = stmt.offset(offset).limit(limit) @@ -251,6 +269,7 @@ async def get_data( if updated: await session.commit() + # Deserialize JSON items items = [json.loads(db_item.data) for db_item in db_items] metadata = await self.get_metadata() return DatasetItemsListPage( @@ -276,6 +295,7 @@ async def iterate_items( skip_empty: bool = False, skip_hidden: bool = False, ) -> AsyncIterator[dict[str, Any]]: + """Iterate over dataset items with optional filtering and ordering.""" # Check for unsupported arguments and log a warning if found. unsupported_args: dict[str, Any] = { 'clean': clean, @@ -295,10 +315,12 @@ async def iterate_items( stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) if skip_empty: + # Skip items that are empty JSON objects stmt = stmt.where(DatasetItemDB.data != '"{}"') stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) + # Apply ordering by insertion order (order_id) stmt = stmt.offset(offset).limit(limit) async with self.get_session() as session: @@ -311,6 +333,7 @@ async def iterate_items( if updated: await session.commit() + # Deserialize and yield items items = [json.loads(db_item.data) for db_item in db_items] for item in items: yield item @@ -324,14 +347,14 @@ async def _update_metadata( update_modified_at: bool = False, delta_item_count: int | None = None, ) -> bool: - """Update the KVS metadata in the database. + """Update the dataset metadata in the database. Args: session: The SQLAlchemy AsyncSession to use for the update. - new_item_count: If provided, update the item count to this value. - update_accessed_at: If True, update the `accessed_at` timestamp to the current time. - update_modified_at: If True, update the `modified_at` timestamp to the current time. - delta_item_count: If provided, increment the item count by this value. + new_item_count: If provided, set item count to this value. + update_accessed_at: If True, update the accessed_at timestamp. + update_modified_at: If True, update the modified_at timestamp. + delta_item_count: If provided, add this value to the current item count. """ now = datetime.now(timezone.utc) values_to_set: dict[str, Any] = {} @@ -351,6 +374,7 @@ async def _update_metadata( if new_item_count is not None: values_to_set['item_count'] = new_item_count elif delta_item_count: + # Use database-level for atomic updates values_to_set['item_count'] = DatasetMetadataDB.item_count + delta_item_count if values_to_set: diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 8f6d631a8e..248065054a 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -11,22 +11,37 @@ from sqlalchemy.engine import Dialect +# This is necessary because unique constraints don't apply to NULL values in SQL. class NameDefaultNone(TypeDecorator): + """Custom SQLAlchemy type for handling default name values. + + Converts None values to 'default' on storage and back to None on retrieval. + """ + impl = String(100) cache_ok = True def process_bind_param(self, value: str | None, _dialect: Dialect) -> str | None: + """Convert Python value to database value.""" return 'default' if value is None else value def process_result_value(self, value: str | None, _dialect: Dialect) -> str | None: + """Convert database value to Python value.""" return None if value == 'default' else value class AwareDateTime(TypeDecorator): + """Custom SQLAlchemy type for timezone-aware datetime handling. + + Ensures all datetime values are timezone-aware by adding UTC timezone to + naive datetime values from databases that don't store timezone information. + """ + impl = DateTime(timezone=True) cache_ok = True def process_result_value(self, value: datetime | None, _dialect: Dialect) -> datetime | None: + """Add UTC timezone to naive datetime values.""" if value is not None and value.tzinfo is None: return value.replace(tzinfo=timezone.utc) return value @@ -40,92 +55,147 @@ class StorageMetadataDB: """Base database model for storage metadata.""" id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) + """Unique identifier.""" + name: Mapped[str | None] = mapped_column(NameDefaultNone, nullable=False, index=True, unique=True) + """Human-readable name. None becomes 'default' in database to enforce uniqueness.""" + accessed_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) + """Last access datetime for usage tracking.""" + created_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) + """Creation datetime.""" + modified_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) + """Last modification datetime.""" class DatasetMetadataDB(StorageMetadataDB, Base): + """Metadata table for datasets.""" + __tablename__ = 'dataset_metadata' item_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + """Number of items in the dataset.""" + # Relationship to dataset items with cascade deletion items: Mapped[list[DatasetItemDB]] = relationship( back_populates='dataset', cascade='all, delete-orphan', lazy='select' ) class RequestQueueMetadataDB(StorageMetadataDB, Base): + """Metadata table for request queues.""" + __tablename__ = 'request_queue_metadata' had_multiple_clients: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + """Flag indicating if multiple clients have accessed this queue.""" + handled_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + """Number of requests processed.""" + pending_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + """Number of requests waiting to be processed.""" + total_request_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + """Total number of requests ever added to this queue.""" + # Relationship to queue requests with cascade deletion requests: Mapped[list[RequestDB]] = relationship( back_populates='queue', cascade='all, delete-orphan', lazy='select' ) class KeyValueStoreMetadataDB(StorageMetadataDB, Base): + """Metadata table for key-value stores.""" + __tablename__ = 'kvs_metadata' + # Relationship to store records with cascade deletion records: Mapped[list[KeyValueStoreRecordDB]] = relationship( back_populates='kvs', cascade='all, delete-orphan', lazy='select' ) class KeyValueStoreRecordDB(Base): - """Database model for key-value store records.""" + """Records table for key-value stores.""" __tablename__ = 'kvs_record' kvs_id: Mapped[str] = mapped_column( String(255), ForeignKey('kvs_metadata.id', ondelete='CASCADE'), primary_key=True, index=True ) + """Foreign key to metadata key-value store record.""" + key: Mapped[str] = mapped_column(String(255), primary_key=True) + """The key part of the key-value pair.""" + value: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + """Value stored as binary data to support any content type.""" + content_type: Mapped[str] = mapped_column(String(100), nullable=False) + """MIME type for proper value deserialization.""" + size: Mapped[int | None] = mapped_column(Integer, nullable=False, default=0) + """Size of stored value in bytes.""" + # Relationship back to parent store kvs: Mapped[KeyValueStoreMetadataDB] = relationship(back_populates='records') class DatasetItemDB(Base): - """Database model for dataset items.""" + """Items table for datasets.""" __tablename__ = 'dataset_item' order_id: Mapped[int] = mapped_column(Integer, primary_key=True) + """Auto-increment primary key preserving insertion order.""" + dataset_id: Mapped[str] = mapped_column( String(20), ForeignKey('dataset_metadata.id', ondelete='CASCADE'), index=True, ) + """Foreign key to metadata dataset record.""" + data: Mapped[str] = mapped_column(JSON, nullable=False) + """JSON-serialized item data.""" + # Relationship back to parent dataset dataset: Mapped[DatasetMetadataDB] = relationship(back_populates='items') class RequestDB(Base): - """Database model for requests in the request queue.""" + """Requests table for request queues.""" __tablename__ = 'request' __table_args__ = ( + # Index for efficient SELECT to cache Index('idx_queue_handled_seq', 'queue_id', 'is_handled', 'sequence_number'), + # Deduplication index Index('idx_queue_unique_key', 'queue_id', 'unique_key'), ) request_id: Mapped[str] = mapped_column(String(20), primary_key=True) + """Unique identifier for the request.""" + queue_id: Mapped[str] = mapped_column( String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), primary_key=True ) + """Foreign key to metadata request queue record.""" data: Mapped[str] = mapped_column(JSON, nullable=False) + """JSON-serialized Request object.""" + unique_key: Mapped[str] = mapped_column(String(512), nullable=False) + """Request unique key for deduplication within queue.""" + sequence_number: Mapped[int] = mapped_column(Integer, nullable=False) + """Ordering sequence: negative for forefront, positive for regular.""" + is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + """Processing status flag.""" + # Relationship back to parent queue queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index 4ce0c43623..d983af2899 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -31,22 +31,23 @@ class SQLKeyValueStoreClient(KeyValueStoreClient): """SQL implementation of the key-value store client. - This client persists data to a SQL database, making it suitable for scenarios where data needs to - survive process restarts. Keys are mapped to rows in a database table. - - Binary data is stored as-is, while JSON and text data are stored in human-readable format. - The implementation automatically handles serialization based on the content type and - maintains metadata about each record. - - This implementation is ideal for long-running crawlers where persistence is important and - for development environments where you want to easily inspect the stored data between runs. - - Binary data is stored as-is, while JSON and text data are stored in human-readable format. - The implementation automatically handles serialization based on the content type and - maintains metadata about each record. - - This implementation is ideal for long-running crawlers where persistence is important and - for development environments where you want to easily inspect the stored data between runs. + This client persists key-value data to a SQL database with transaction support and + concurrent access safety. Keys are mapped to rows in database tables with proper indexing + for efficient retrieval. + + The key-value store data is stored in SQL database tables following the pattern: + - `kvs_metadata` table: Contains store metadata (id, name, timestamps) + - `kvs_record` table: Contains individual key-value pairs with binary value storage, content type, and size + information + + Values are serialized based on their type: JSON objects are stored as formatted JSON, + text values as UTF-8 encoded strings, and binary data as-is in the `LargeBinary` column. + The implementation automatically handles content type detection and maintains metadata + about each record including size and MIME type information. + + All database operations are wrapped in transactions with proper error handling and rollback + mechanisms. The client supports atomic upsert operations and handles race conditions when + multiple clients access the same store using composite primary keys (kvs_id, key). """ _DEFAULT_NAME_DB = 'default' @@ -67,14 +68,15 @@ def __init__( self._storage_client = storage_client """The storage client used to access the SQL database.""" + # Time tracking to reduce database writes during frequent operation self._last_accessed_at: datetime | None = None self._last_modified_at: datetime | None = None - self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() - """Interval for updating metadata in the database.""" @override async def get_metadata(self) -> KeyValueStoreMetadata: + """Get the metadata for this key-value store.""" + # The database is a single place of truth async with self.get_session() as session: orm_metadata: KeyValueStoreMetadataDB | None = await session.get(KeyValueStoreMetadataDB, self._id) if not orm_metadata: @@ -176,6 +178,10 @@ async def open( @override async def drop(self) -> None: + """Delete this key-value store and all its records from the database. + + This operation is irreversible. Uses CASCADE deletion to remove all related records. + """ stmt = delete(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id) async with self.get_autocommit_session() as autosession: if self._storage_client.get_default_flag(): @@ -185,6 +191,7 @@ async def drop(self) -> None: @override async def purge(self) -> None: + """Remove all items from this key-value store while keeping the key-value store structure.""" stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._id) async with self.get_autocommit_session() as autosession: await autosession.execute(stmt) @@ -193,6 +200,7 @@ async def purge(self) -> None: @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: + """Set a value in the key-value store.""" # Special handling for None values if value is None: content_type = 'application/x-none' # Special content type to identify None values @@ -243,6 +251,8 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: + """Get a value from the key-value store.""" + # Query the record by key stmt = select(KeyValueStoreRecordDB).where( KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key ) @@ -292,6 +302,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @override async def delete_value(self, *, key: str) -> None: + """Delete a value from the key-value store.""" stmt = delete(KeyValueStoreRecordDB).where( KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key ) @@ -312,6 +323,7 @@ async def iterate_keys( exclusive_start_key: str | None = None, limit: int | None = None, ) -> AsyncIterator[KeyValueStoreRecordMetadata]: + """Iterate over the existing keys in the key-value store.""" # Build query for record metadata stmt = ( select(KeyValueStoreRecordDB.key, KeyValueStoreRecordDB.content_type, KeyValueStoreRecordDB.size) @@ -345,6 +357,7 @@ async def iterate_keys( @override async def record_exists(self, *, key: str) -> bool: + """Check if a record with the given key exists in the key-value store.""" stmt = select(KeyValueStoreRecordDB.key).where( KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key ) diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 15c64f65c3..34df1ed9d6 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -50,13 +50,22 @@ class RequestQueueState(BaseModel): class SQLRequestQueueClient(RequestQueueClient): """SQL implementation of the request queue client. - This client persists requests to a SQL database with proper transaction handling and - concurrent access safety. Requests are stored in a normalized table structure with - sequence-based ordering and efficient querying capabilities. + This client persists requests to a SQL database with transaction handling and + concurrent access safety. Requests are stored with sequence-based ordering and + efficient querying capabilities. The implementation uses negative sequence numbers for forefront (high-priority) requests and positive sequence numbers for regular requests, allowing for efficient single-query - ordering. A cache mechanism reduces database queries for better performance. + ordering. A cache mechanism reduces database queries. + + The request queue data is stored in SQL database tables following the pattern: + - `request_queue_metadata` table: Contains queue metadata (id, name, timestamps, request counts, multi-client flag) + - `request` table: Contains individual requests with JSON data, unique keys for deduplication, sequence numbers for + ordering, and processing status flags + + Requests are serialized to JSON for storage and maintain proper ordering through sequence + numbers. The implementation provides concurrent access safety through transaction + handling, locking mechanisms, and optimized database indexes for efficient querying. """ _DEFAULT_NAME_DB = 'default' @@ -109,6 +118,8 @@ def __init__( @override async def get_metadata(self) -> RequestQueueMetadata: + """Get the metadata for this request queue.""" + # The database is a single place of truth async with self.get_session() as session: orm_metadata: RequestQueueMetadataDB | None = await session.get(RequestQueueMetadataDB, self._id) if not orm_metadata: @@ -117,7 +128,7 @@ async def get_metadata(self) -> RequestQueueMetadata: return RequestQueueMetadata.model_validate(orm_metadata) def get_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for this key-value store.""" + """Create a new SQLAlchemy session for thi s request queue.""" return self._storage_client.create_session() @asynccontextmanager @@ -140,15 +151,19 @@ async def open( name: str | None, storage_client: SQLStorageClient, ) -> SQLRequestQueueClient: - """Open or create a SQL request queue client. + """Open an existing request queue or create a new one. + + This method first tries to find an existing queue by ID or name. + If found, it returns a client for that queue. If not found, it creates + a new queue with the specified parameters. Args: - id: The ID of the request queue to open. If provided, searches for existing queue by ID. - name: The name of the request queue to open. If not provided, uses the default queue. + id: The ID of the request queue to open. Takes precedence over name. + name: The name of the request queue to open. Uses 'default' if None. storage_client: The SQL storage client used to access the database. Returns: - An instance for the opened or created storage client. + An instance for the opened or created request queue. Raises: ValueError: If a queue with the specified ID is not found. @@ -215,6 +230,10 @@ async def open( @override async def drop(self) -> None: + """Delete this request queue and all its records from the database. + + This operation is irreversible. Uses CASCADE deletion to remove all related records. + """ stmt = delete(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id) async with self.get_autocommit_session() as autocommit: if self._storage_client.get_default_flag(): @@ -232,6 +251,7 @@ async def drop(self) -> None: @override async def purge(self) -> None: + """Purge all requests from this request queue.""" stmt = delete(RequestDB).where(RequestDB.queue_id == self._id) async with self.get_autocommit_session() as autocommit: # Delete all requests for this queue @@ -262,6 +282,7 @@ async def add_batch_of_requests( if not requests: return AddRequestsResponse(processed_requests=[], unprocessed_requests=[]) + # Clear empty cache since we're adding requests self._is_empty_cache = None processed_requests = [] unprocessed_requests = [] diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index f6f92a2de3..168d41635f 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -28,16 +28,14 @@ class SQLStorageClient(StorageClient): """SQL implementation of the storage client. This storage client provides access to datasets, key-value stores, and request queues that persist data - to a SQL database using SQLAlchemy 2+ with Pydantic dataclasses for type safety. Data is stored in - normalized relational tables, providing ACID compliance, concurrent access safety, and the ability to - query data using SQL. + to a SQL database using SQLAlchemy 2+. Each storage type uses two tables: one for metadata and one for + records/items. - The SQL implementation supports various database backends including PostgreSQL, MySQL, SQLite, and others - supported by SQLAlchemy. It provides durability, consistency, and supports concurrent access from multiple - processes through database-level locking mechanisms. + The client accepts either a database connection string or a pre-configured AsyncEngine. If neither is + provided, it creates a default SQLite database 'crawlee.db' in the storage directory. - This implementation is ideal for production environments where data persistence, consistency, and - concurrent access are critical requirements. + Database schema is automatically created during initialization. SQLite databases receive performance + optimizations including WAL mode and increased cache size. """ _DB_NAME = 'crawlee.db' @@ -67,9 +65,10 @@ def __init__( self._engine = engine self._initialized = False + # Minimum interval to reduce database load from frequent concurrent metadata updates self._accessed_modified_update_interval = accessed_modified_update_interval - # Default flag to indicate if the default database should be created + # Flag needed to apply optimizations only for default database self._default_flag = self._engine is None and self._connection_string is None @property @@ -80,7 +79,7 @@ def engine(self) -> AsyncEngine: return self._engine def get_default_flag(self) -> bool: - """Check if the default database should be created.""" + """Check if the default database is being used.""" return self._default_flag def get_accessed_modified_update_interval(self) -> timedelta: @@ -102,6 +101,7 @@ def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: db_path = storage_dir / self._DB_NAME + # Create connection string with path to default database connection_string = f'sqlite+aiosqlite:///{db_path}' self._engine = create_async_engine( @@ -128,13 +128,13 @@ async def initialize(self, configuration: Configuration) -> None: async with engine.begin() as conn: # Set SQLite pragmas for performance and consistency if self._default_flag: - await conn.execute(text('PRAGMA journal_mode=WAL')) - await conn.execute(text('PRAGMA synchronous=NORMAL')) - await conn.execute(text('PRAGMA cache_size=100000')) - await conn.execute(text('PRAGMA temp_store=MEMORY')) - await conn.execute(text('PRAGMA mmap_size=268435456')) - await conn.execute(text('PRAGMA foreign_keys=ON')) - await conn.execute(text('PRAGMA busy_timeout=30000')) + await conn.execute(text('PRAGMA journal_mode=WAL')) # Better concurrency + await conn.execute(text('PRAGMA synchronous=NORMAL')) # Balanced safety/speed + await conn.execute(text('PRAGMA cache_size=100000')) # 100MB cache + await conn.execute(text('PRAGMA temp_store=MEMORY')) # Memory temp storage + await conn.execute(text('PRAGMA mmap_size=268435456')) # 256MB memory mapping + await conn.execute(text('PRAGMA foreign_keys=ON')) # Enforce constraints + await conn.execute(text('PRAGMA busy_timeout=30000')) # 30s busy timeout await conn.run_sync(Base.metadata.create_all) self._initialized = True @@ -161,6 +161,16 @@ async def create_dataset_client( name: str | None = None, configuration: Configuration | None = None, ) -> SQLDatasetClient: + """Create or open a SQL dataset client. + + Args: + id: Specific dataset ID to open. If provided, name is ignored. + name: Dataset name to open or create. Uses 'default' if not specified. + configuration: Configuration object. Uses global config if not provided. + + Returns: + Configured dataset client ready for use. + """ configuration = configuration or Configuration.get_global_configuration() await self.initialize(configuration) @@ -181,6 +191,16 @@ async def create_kvs_client( name: str | None = None, configuration: Configuration | None = None, ) -> SQLKeyValueStoreClient: + """Create or open a SQL key-value store client. + + Args: + id: Specific store ID to open. If provided, name is ignored. + name: Store name to open or create. Uses 'default' if not specified. + configuration: Configuration object. Uses global config if not provided. + + Returns: + Configured key-value store client ready for use. + """ configuration = configuration or Configuration.get_global_configuration() await self.initialize(configuration) @@ -201,6 +221,16 @@ async def create_rq_client( name: str | None = None, configuration: Configuration | None = None, ) -> SQLRequestQueueClient: + """Create or open a SQL request queue client. + + Args: + id: Specific queue ID to open. If provided, name is ignored. + name: Queue name to open or create. Uses 'default' if not specified. + configuration: Configuration object. Uses global config if not provided. + + Returns: + Configured request queue client ready for use. + """ configuration = configuration or Configuration.get_global_configuration() await self.initialize(configuration) From f7ebbe520b40b776dc33fa66f374f162fd0214c0 Mon Sep 17 00:00:00 2001 From: Max Bohomolov <34358312+Mantisus@users.noreply.github.com> Date: Sat, 2 Aug 2025 00:25:50 +0300 Subject: [PATCH 13/29] Update src/crawlee/storage_clients/_sql/_request_queue_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/crawlee/storage_clients/_sql/_request_queue_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 34df1ed9d6..0e053c1a62 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -128,7 +128,7 @@ async def get_metadata(self) -> RequestQueueMetadata: return RequestQueueMetadata.model_validate(orm_metadata) def get_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for thi s request queue.""" + """Create a new SQLAlchemy session for this request queue.""" return self._storage_client.create_session() @asynccontextmanager From 83ca6d34cc60894e57998b22de2e29fdbd98882f Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Fri, 1 Aug 2025 21:28:25 +0000 Subject: [PATCH 14/29] fix tests --- tests/unit/storages/test_request_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 1512d4b78f..fb4b2500be 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -20,7 +20,7 @@ from crawlee.storage_clients import StorageClient -@pytest.fixture(params=['sql']) +@pytest.fixture(params=['memory', 'file_system', 'sql']) def storage_client(request: pytest.FixtureRequest) -> StorageClient: """Parameterized fixture to test with different storage clients.""" if request.param == 'memory': From 8086ab25aeac51b2aa0b4f4284c58c336848eefc Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Tue, 19 Aug 2025 10:46:47 +0000 Subject: [PATCH 15/29] same updates --- .../storage_clients/_sql/_client_mixin.py | 90 +++++++ .../storage_clients/_sql/_dataset_client.py | 125 +++++----- .../storage_clients/_sql/_db_models.py | 29 ++- .../_sql/_key_value_store_client.py | 86 ++++--- .../_sql/_request_queue_client.py | 232 ++++++++---------- .../storage_clients/_sql/_storage_client.py | 8 +- uv.lock | 4 +- 7 files changed, 331 insertions(+), 243 deletions(-) create mode 100644 src/crawlee/storage_clients/_sql/_client_mixin.py diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py new file mode 100644 index 0000000000..d7025db820 --- /dev/null +++ b/src/crawlee/storage_clients/_sql/_client_mixin.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from logging import getLogger +from typing import TYPE_CHECKING, Any + +from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as lite_insert +from sqlalchemy.exc import SQLAlchemyError + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from sqlalchemy import Insert + from sqlalchemy.ext.asyncio import AsyncSession + + from ._storage_client import SQLStorageClient + + +logger = getLogger(__name__) + + +class SQLClientMixin: + """Mixin class for SQL clients.""" + + _storage_client: SQLStorageClient + + def get_session(self) -> AsyncSession: + """Create a new SQLAlchemy session for this request queue.""" + return self._storage_client.create_session() + + @asynccontextmanager + async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: + """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" + async with self.get_session() as session: + try: + yield session + await session.commit() + except SQLAlchemyError as e: + logger.warning(f'Error occurred during session transaction: {e}') + # Rollback the session in case of an error + await session.rollback() + + def build_insert_stmt_with_ignore(self, table_model: Any, insert_values: dict | list[dict]) -> Insert: + """Build an insert statement with ignore for the SQL dialect.""" + if isinstance(insert_values, dict): + insert_values = [insert_values] + + dialect = self._storage_client.get_dialect_name() + + if dialect == 'postgresql': + return pg_insert(table_model).values(insert_values).on_conflict_do_nothing() + + if dialect == 'mysql': + return mysql_insert(table_model).values(insert_values).on_duplicate_key_update() + + if dialect == 'sqlite': + return lite_insert(table_model).values(insert_values).on_conflict_do_nothing() + + raise NotImplementedError(f'Insert with ignore not supported for dialect: {dialect}') + + def build_upsert_stmt( + self, + table_model: Any, + insert_values: dict | list[dict], + update_columns: list[str], + conflict_cols: list[str] | None = None, + ) -> Insert: + if isinstance(insert_values, dict): + insert_values = [insert_values] + + dialect = self._storage_client.get_dialect_name() + + if dialect == 'postgresql': + pg_stmt = pg_insert(table_model).values(insert_values) + set_ = {col: getattr(pg_stmt.excluded, col) for col in update_columns} + return pg_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=set_) + + if dialect == 'sqlite': + lite_stmt = lite_insert(table_model).values(insert_values) + set_ = {col: getattr(lite_stmt.excluded, col) for col in update_columns} + return lite_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=set_) + + if dialect == 'mysql': + mysql_stmt = mysql_insert(table_model).values(insert_values) + set_ = {col: mysql_stmt.inserted[col] for col in update_columns} + return mysql_stmt.on_duplicate_key_update(**set_) + + raise NotImplementedError(f'Upsert not supported for dialect: {dialect}') diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 3c0bb28f21..3fe43edfae 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -1,12 +1,11 @@ from __future__ import annotations import json -from contextlib import asynccontextmanager from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, select, text, update +from sqlalchemy import Select, delete, insert, select, text, update from sqlalchemy.exc import SQLAlchemyError from typing_extensions import override @@ -14,11 +13,13 @@ from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata +from ._client_mixin import SQLClientMixin from ._db_models import DatasetItemDB, DatasetMetadataDB if TYPE_CHECKING: from collections.abc import AsyncIterator + from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession from ._storage_client import SQLStorageClient @@ -26,7 +27,7 @@ logger = getLogger(__name__) -class SQLDatasetClient(DatasetClient): +class SQLDatasetClient(DatasetClient, SQLClientMixin): """SQL implementation of the dataset client. This client persists dataset items to a SQL database using two tables for storage @@ -73,22 +74,6 @@ async def get_metadata(self) -> DatasetMetadata: return DatasetMetadata.model_validate(orm_metadata) - def get_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for this dataset.""" - return self._storage_client.create_session() - - @asynccontextmanager - async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: - """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" - async with self.get_session() as session: - try: - yield session - await session.commit() - except SQLAlchemyError as e: - logger.warning(f'Error occurred during session transaction: {e}') - # Rollback the session in case of an error - await session.rollback() - @classmethod async def open( cls, @@ -173,7 +158,7 @@ async def drop(self) -> None: """ stmt = delete(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id) async with self.get_autocommit_session() as autocommit: - if self._storage_client.get_default_flag(): + if self._storage_client.get_dialect_name() == 'sqlite': # foreign_keys=ON is set at the connection level. Required for cascade deletion. await autocommit.execute(text('PRAGMA foreign_keys=ON')) await autocommit.execute(stmt) @@ -196,26 +181,28 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: if not isinstance(data, list): data = [data] - db_items: list[DatasetItemDB] = [] + db_items: list[dict[str, Any]] = [] for item in data: # Serialize with default=str to handle non-serializable types like datetime json_item = json.dumps(item, default=str, ensure_ascii=False) db_items.append( - DatasetItemDB( - dataset_id=self._id, - data=json_item, - ) + { + 'dataset_id': self._id, + 'data': json_item, + } ) + stmt = insert(DatasetItemDB).values(db_items) + async with self.get_autocommit_session() as autocommit: - autocommit.add_all(db_items) + await autocommit.execute(stmt) + await self._update_metadata( autocommit, update_accessed_at=True, update_modified_at=True, delta_item_count=len(data) ) - @override - async def get_data( + def _prepare_get_stmt( self, *, offset: int = 0, @@ -229,7 +216,7 @@ async def get_data( skip_hidden: bool = False, flatten: list[str] | None = None, view: str | None = None, - ) -> DatasetItemsListPage: + ) -> Select: # Check for unsupported arguments and log a warning if found. unsupported_args: dict[str, Any] = { 'clean': clean, @@ -257,7 +244,37 @@ async def get_data( # Apply ordering by insertion order (order_id) stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) - stmt = stmt.offset(offset).limit(limit) + return stmt.offset(offset).limit(limit) + + @override + async def get_data( + self, + *, + offset: int = 0, + limit: int | None = 999_999_999_999, + clean: bool = False, + desc: bool = False, + fields: list[str] | None = None, + omit: list[str] | None = None, + unwind: str | None = None, + skip_empty: bool = False, + skip_hidden: bool = False, + flatten: list[str] | None = None, + view: str | None = None, + ) -> DatasetItemsListPage: + stmt = self._prepare_get_stmt( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + flatten=flatten, + view=view, + ) async with self.get_session() as session: result = await session.execute(stmt) @@ -296,36 +313,23 @@ async def iterate_items( skip_hidden: bool = False, ) -> AsyncIterator[dict[str, Any]]: """Iterate over dataset items with optional filtering and ordering.""" - # Check for unsupported arguments and log a warning if found. - unsupported_args: dict[str, Any] = { - 'clean': clean, - 'fields': fields, - 'omit': omit, - 'unwind': unwind, - 'skip_hidden': skip_hidden, - } - unsupported = {k: v for k, v in unsupported_args.items() if v not in (False, None)} - - if unsupported: - logger.warning( - f'The arguments {list(unsupported.keys())} of iterate are not supported ' - f'by the {self.__class__.__name__} client.' - ) - - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) - - if skip_empty: - # Skip items that are empty JSON objects - stmt = stmt.where(DatasetItemDB.data != '"{}"') - - stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) - - # Apply ordering by insertion order (order_id) - stmt = stmt.offset(offset).limit(limit) + stmt = self._prepare_get_stmt( + offset=offset, + limit=limit, + clean=clean, + desc=desc, + fields=fields, + omit=omit, + unwind=unwind, + skip_empty=skip_empty, + skip_hidden=skip_hidden, + ) async with self.get_session() as session: - result = await session.execute(stmt) - db_items = result.scalars().all() + db_items = await session.stream_scalars(stmt) + + async for db_item in db_items: + yield json.loads(db_item.data) updated = await self._update_metadata(session, update_accessed_at=True) @@ -333,11 +337,6 @@ async def iterate_items( if updated: await session.commit() - # Deserialize and yield items - items = [json.loads(db_item.data) for db_item in db_items] - for item in items: - yield item - async def _update_metadata( self, session: AsyncSession, diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 248065054a..b41eecc3a0 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -6,6 +6,7 @@ from sqlalchemy import JSON, Boolean, ForeignKey, Index, Integer, LargeBinary, String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.types import DateTime, TypeDecorator +from typing_extensions import override if TYPE_CHECKING: from sqlalchemy.engine import Dialect @@ -21,10 +22,12 @@ class NameDefaultNone(TypeDecorator): impl = String(100) cache_ok = True - def process_bind_param(self, value: str | None, _dialect: Dialect) -> str | None: + @override + def process_bind_param(self, value: str | None, _dialect: Dialect) -> str: """Convert Python value to database value.""" return 'default' if value is None else value + @override def process_result_value(self, value: str | None, _dialect: Dialect) -> str | None: """Convert database value to Python value.""" return None if value == 'default' else value @@ -105,6 +108,10 @@ class RequestQueueMetadataDB(StorageMetadataDB, Base): requests: Mapped[list[RequestDB]] = relationship( back_populates='queue', cascade='all, delete-orphan', lazy='select' ) + # Relationship to queue state + state: Mapped[RequestQueueStateDB] = relationship( + back_populates='queue', cascade='all, delete-orphan', lazy='select' + ) class KeyValueStoreMetadataDB(StorageMetadataDB, Base): @@ -199,3 +206,23 @@ class RequestDB(Base): # Relationship back to parent queue queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') + + +class RequestQueueStateDB(Base): + """State table for request queues.""" + + __tablename__ = 'request_queue_state' + + queue_id: Mapped[str] = mapped_column( + String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), primary_key=True + ) + """Foreign key to metadata request queue record.""" + + sequence_counter: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + """Counter for regular request ordering (positive).""" + + forefront_sequence_counter: Mapped[int] = mapped_column(Integer, nullable=False, default=-1) + """Counter for forefront request ordering (negative).""" + + # Relationship back to parent queue + queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='state') diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index d983af2899..9cb432b0a4 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -1,12 +1,11 @@ from __future__ import annotations import json -from contextlib import asynccontextmanager from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any -from sqlalchemy import delete, select, text, update +from sqlalchemy import delete, insert, select, text, update from sqlalchemy.exc import IntegrityError, SQLAlchemyError from typing_extensions import override @@ -15,6 +14,7 @@ from crawlee.storage_clients._base import KeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata +from ._client_mixin import SQLClientMixin from ._db_models import KeyValueStoreMetadataDB, KeyValueStoreRecordDB if TYPE_CHECKING: @@ -28,7 +28,7 @@ logger = getLogger(__name__) -class SQLKeyValueStoreClient(KeyValueStoreClient): +class SQLKeyValueStoreClient(KeyValueStoreClient, SQLClientMixin): """SQL implementation of the key-value store client. This client persists key-value data to a SQL database with transaction support and @@ -84,22 +84,6 @@ async def get_metadata(self) -> KeyValueStoreMetadata: return KeyValueStoreMetadata.model_validate(orm_metadata) - def get_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for this key-value store.""" - return self._storage_client.create_session() - - @asynccontextmanager - async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: - """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" - async with self.get_session() as session: - try: - yield session - await session.commit() - except SQLAlchemyError as e: - logger.warning(f'Error occurred during session transaction: {e}') - # Rollback the session in case of an error - await session.rollback() - @classmethod async def open( cls, @@ -183,20 +167,20 @@ async def drop(self) -> None: This operation is irreversible. Uses CASCADE deletion to remove all related records. """ stmt = delete(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id) - async with self.get_autocommit_session() as autosession: - if self._storage_client.get_default_flag(): + async with self.get_autocommit_session() as autocommit: + if self._storage_client.get_dialect_name() == 'sqlite': # foreign_keys=ON is set at the connection level. Required for cascade deletion. - await autosession.execute(text('PRAGMA foreign_keys=ON')) - await autosession.execute(stmt) + await autocommit.execute(text('PRAGMA foreign_keys=ON')) + await autocommit.execute(stmt) @override async def purge(self) -> None: """Remove all items from this key-value store while keeping the key-value store structure.""" stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._id) - async with self.get_autocommit_session() as autosession: - await autosession.execute(stmt) + async with self.get_autocommit_session() as autocommit: + await autocommit.execute(stmt) - await self._update_metadata(autosession, update_accessed_at=True, update_modified_at=True) + await self._update_metadata(autocommit, update_accessed_at=True, update_modified_at=True) @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: @@ -220,29 +204,41 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No value_bytes = str(value).encode('utf-8') size = len(value_bytes) - record_db = KeyValueStoreRecordDB( - kvs_id=self._id, - key=key, - value=value_bytes, - content_type=content_type, - size=size, - ) - - stmt = ( - update(KeyValueStoreRecordDB) - .where(KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key) - .values(value=value_bytes, content_type=content_type, size=size) - ) + insert_values = { + 'kvs_id': self._id, + 'key': key, + 'value': value_bytes, + 'content_type': content_type, + 'size': size, + } + try: + # Trying to build a statement for Upsert + upsert_stmt = self.build_upsert_stmt( + KeyValueStoreRecordDB, + insert_values=insert_values, + update_columns=['value', 'content_type', 'size'], + conflict_cols=['kvs_id', 'key'], + ) + except NotImplementedError: + # If it is not possible to build an upsert for the current dialect, build an update + insert. + upsert_stmt = None + update_stmt = ( + update(KeyValueStoreRecordDB) + .where(KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key) + .values(value=value_bytes, content_type=content_type, size=size) + ) + insert_stmt = insert(KeyValueStoreRecordDB).values(**insert_values) - # A race condition is possible if several clients work with one kvs. - # Unfortunately, there is no implementation of atomic Upsert that is independent of specific dialects. - # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-upsert-statements async with self.get_session() as session: - result = await session.execute(stmt) - if result.rowcount == 0: - session.add(record_db) + if upsert_stmt is not None: + result = await session.execute(upsert_stmt) + else: + result = await session.execute(update_stmt) + if result.rowcount == 0: + await session.execute(insert_stmt) await self._update_metadata(session, update_accessed_at=True, update_modified_at=True) + try: await session.commit() except IntegrityError: diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 0e053c1a62..2462776f17 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -2,12 +2,10 @@ import asyncio from collections import deque -from contextlib import asynccontextmanager from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any -from pydantic import BaseModel from sqlalchemy import delete, select, text, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import load_only @@ -15,7 +13,6 @@ from crawlee import Request from crawlee._utils.crypto import crypto_random_object_id -from crawlee._utils.recoverable_state import RecoverableState from crawlee.storage_clients._base import RequestQueueClient from crawlee.storage_clients.models import ( AddRequestsResponse, @@ -24,10 +21,11 @@ UnprocessedRequest, ) -from ._db_models import RequestDB, RequestQueueMetadataDB +from ._client_mixin import SQLClientMixin +from ._db_models import RequestDB, RequestQueueMetadataDB, RequestQueueStateDB if TYPE_CHECKING: - from collections.abc import AsyncIterator, Sequence + from collections.abc import Sequence from sqlalchemy.ext.asyncio import AsyncSession @@ -37,17 +35,7 @@ logger = getLogger(__name__) -class RequestQueueState(BaseModel): - """Simplified state model for SQL implementation.""" - - sequence_counter: int = 1 - """Counter for regular request ordering (positive).""" - - forefront_sequence_counter: int = -1 - """Counter for forefront request ordering (negative).""" - - -class SQLRequestQueueClient(RequestQueueClient): +class SQLRequestQueueClient(RequestQueueClient, SQLClientMixin): """SQL implementation of the request queue client. This client persists requests to a SQL database with transaction handling and @@ -71,7 +59,7 @@ class SQLRequestQueueClient(RequestQueueClient): _DEFAULT_NAME_DB = 'default' """Default dataset name used when no name is provided.""" - _MAX_REQUESTS_IN_CACHE = 100_000 + _MAX_REQUESTS_IN_CACHE = 1000 """Maximum number of requests to keep in cache for faster access.""" def __init__( @@ -102,20 +90,20 @@ def __init__( self._last_modified_at: datetime | None = None self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() - self._state = RecoverableState[RequestQueueState]( - default_state=RequestQueueState(), - persist_state_key='request_queue_state', - persistence_enabled=True, - persist_state_kvs_name=f'__RQ_STATE_{self._id}', - logger=logger, - ) - """Recoverable state to maintain request ordering and in-progress status.""" - self._storage_client = storage_client """The storage client used to access the SQL database.""" self._lock = asyncio.Lock() + async def _get_state(self, session: AsyncSession) -> RequestQueueStateDB: + """Get the current state of the request queue.""" + orm_state: RequestQueueStateDB | None = await session.get(RequestQueueStateDB, self._id) + if not orm_state: + orm_state = RequestQueueStateDB(queue_id=self._id) + session.add(orm_state) + await session.flush() + return orm_state + @override async def get_metadata(self) -> RequestQueueMetadata: """Get the metadata for this request queue.""" @@ -127,22 +115,6 @@ async def get_metadata(self) -> RequestQueueMetadata: return RequestQueueMetadata.model_validate(orm_metadata) - def get_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for this request queue.""" - return self._storage_client.create_session() - - @asynccontextmanager - async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: - """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" - async with self.get_session() as session: - try: - yield session - await session.commit() - except SQLAlchemyError as e: - logger.warning(f'Error occurred during session transaction: {e}') - # Rollback the session in case of an error - await session.rollback() - @classmethod async def open( cls, @@ -224,8 +196,6 @@ async def open( storage_client=storage_client, ) - await client._state.initialize() - return client @override @@ -236,15 +206,12 @@ async def drop(self) -> None: """ stmt = delete(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id) async with self.get_autocommit_session() as autocommit: - if self._storage_client.get_default_flag(): + if self._storage_client.get_dialect_name() == 'sqlite': # foreign_keys=ON is set at the connection level. Required for cascade deletion. await autocommit.execute(text('PRAGMA foreign_keys=ON')) # Delete the request queue metadata (cascade will delete requests) await autocommit.execute(stmt) - # Clear recoverable state - await self._state.reset() - await self._state.teardown() self._request_cache.clear() self._request_cache_needs_refresh = True self._is_empty_cache = None @@ -270,10 +237,8 @@ async def purge(self) -> None: # Clear recoverable state self._request_cache.clear() self._request_cache_needs_refresh = True - await self._state.reset() - @override - async def add_batch_of_requests( + async def _add_batch_of_requests_optimization( self, requests: Sequence[Request], *, @@ -296,94 +261,89 @@ async def add_batch_of_requests( if req.unique_key not in unique_requests: unique_requests[req.unique_key] = req - unique_keys = list(unique_requests.keys()) - # Get existing requests by unique keys stmt = ( select(RequestDB) - .where(RequestDB.queue_id == self._id, RequestDB.unique_key.in_(unique_keys)) + .where(RequestDB.queue_id == self._id, RequestDB.unique_key.in_(set(unique_requests.keys()))) .options( load_only( RequestDB.request_id, RequestDB.unique_key, RequestDB.is_handled, - RequestDB.sequence_number, ) ) ) - state = self._state.current_value - - async with self.get_session() as session, self._lock: + async with self.get_session() as session: result = await session.execute(stmt) existing_requests = {req.unique_key: req for req in result.scalars()} - - new_request_objects = [] - - # Process each request + state = await self._get_state(session) + insert_values: list[dict] = [] for unique_key, request in unique_requests.items(): existing_req_db = existing_requests.get(unique_key) - - # New request - if existing_req_db is None: + if existing_req_db is None or not existing_req_db.is_handled: + value = { + 'request_id': request.id, + 'queue_id': self._id, + 'data': request.model_dump_json(), + 'unique_key': request.unique_key, + 'is_handled': False, + } if forefront: - sequence_number = state.forefront_sequence_counter + value['sequence_number'] = state.forefront_sequence_counter state.forefront_sequence_counter -= 1 else: - sequence_number = state.sequence_counter + value['sequence_number'] = state.sequence_counter state.sequence_counter += 1 - new_request_objects.append( - RequestDB( - request_id=request.id, - queue_id=self._id, - data=request.model_dump_json(), - unique_key=request.unique_key, - sequence_number=sequence_number, - is_handled=False, + insert_values.append(value) + + if existing_req_db is None: + delta_total_request_count += 1 + delta_pending_request_count += 1 + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=False, + was_already_handled=False, + ) ) - ) - - delta_total_request_count += 1 - delta_pending_request_count += 1 - - processed_requests.append( - ProcessedRequest( - id=request.id, - unique_key=request.unique_key, - was_already_present=False, - was_already_handled=False, + else: + processed_requests.append( + ProcessedRequest( + id=request.id, + unique_key=request.unique_key, + was_already_present=True, + was_already_handled=existing_req_db.is_handled, + ) ) - ) - elif existing_req_db.is_handled: - # Already handled + else: + # Already handled request, skip adding processed_requests.append( ProcessedRequest( id=existing_req_db.request_id, - unique_key=request.unique_key, + unique_key=unique_key, was_already_present=True, was_already_handled=True, ) ) - else: - # Exists but not handled - might update priority - if forefront and existing_req_db.sequence_number > 0: - existing_req_db.sequence_number = state.forefront_sequence_counter - state.forefront_sequence_counter -= 1 - - processed_requests.append( - ProcessedRequest( - id=existing_req_db.request_id, - unique_key=request.unique_key, - was_already_present=True, - was_already_handled=False, - ) + if insert_values: + if forefront: + # If the request already exists in the database, we update the sequence_number by shifting request + # to the left. + upsert_stmt = self.build_upsert_stmt( + RequestDB, + insert_values, + update_columns=['sequence_number'], ) - - if new_request_objects: - session.add_all(new_request_objects) + await session.execute(upsert_stmt) + else: + # If the request already exists in the database, we ignore this request when inserting. + insert_stmt_with_ignore = self.build_insert_stmt_with_ignore(RequestDB, insert_values) + await session.execute(insert_stmt_with_ignore) await self._update_metadata( session, @@ -396,6 +356,7 @@ async def add_batch_of_requests( try: await session.commit() except SQLAlchemyError as e: + await session.rollback() logger.warning(f'Failed to commit session: {e}') processed_requests.clear() unprocessed_requests.extend( @@ -418,6 +379,18 @@ async def add_batch_of_requests( unprocessed_requests=unprocessed_requests, ) + @override + async def add_batch_of_requests( + self, + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + if self._storage_client.get_dialect_name() in {'sqlite', 'postgresql', 'mysql'}: + return await self._add_batch_of_requests_optimization(requests, forefront=forefront) + + raise NotImplementedError('Batch addition is not supported for this database dialect.') + @override async def get_request(self, request_id: str) -> Request | None: stmt = select(RequestDB).where(RequestDB.queue_id == self._id, RequestDB.request_id == request_id) @@ -517,44 +490,45 @@ async def reclaim_request( forefront: bool = False, ) -> ProcessedRequest | None: self._is_empty_cache = None - state = self._state.current_value if request.id not in self.in_progress_requests: logger.info(f'Reclaiming request {request.id} that is not in progress.') return None - # Update sequence number if changing priority - if forefront: - new_sequence = state.forefront_sequence_counter - state.forefront_sequence_counter -= 1 - else: - new_sequence = state.sequence_counter - state.sequence_counter += 1 + async with self.get_autocommit_session() as autocommit: + state = await self._get_state(autocommit) - stmt = ( - update(RequestDB) - .where(RequestDB.queue_id == self._id, RequestDB.request_id == request.id) - .values(sequence_number=new_sequence) - ) + # Update sequence number if changing priority + if forefront: + new_sequence = state.forefront_sequence_counter + state.forefront_sequence_counter -= 1 + else: + new_sequence = state.sequence_counter + state.sequence_counter += 1 + + stmt = ( + update(RequestDB) + .where(RequestDB.queue_id == self._id, RequestDB.request_id == request.id) + .values(sequence_number=new_sequence) + ) - async with self.get_autocommit_session() as autocommit: result = await autocommit.execute(stmt) if result.rowcount == 0: logger.warning(f'Request {request.id} not found in database.') return None - # Remove from in-progress - self.in_progress_requests.discard(request.id) + await self._update_metadata(autocommit, update_modified_at=True, update_accessed_at=True) - # Invalidate cache or add to cache - if forefront: - self._request_cache_needs_refresh = True - elif len(self._request_cache) < self._MAX_REQUESTS_IN_CACHE: - # For regular requests, we can add to the end if there's space - self._request_cache.append(request) + # Remove from in-progress + self.in_progress_requests.discard(request.id) - await self._update_metadata(autocommit, update_modified_at=True, update_accessed_at=True) + # Invalidate cache or add to cache + if forefront: + self._request_cache_needs_refresh = True + elif len(self._request_cache) < self._MAX_REQUESTS_IN_CACHE: + # For regular requests, we can add to the end if there's space + self._request_cache.append(request) return ProcessedRequest( id=request.id, diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index 168d41635f..1bb8df58ad 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -70,6 +70,7 @@ def __init__( # Flag needed to apply optimizations only for default database self._default_flag = self._engine is None and self._connection_string is None + self._dialect_name: str | None = None @property def engine(self) -> AsyncEngine: @@ -78,9 +79,9 @@ def engine(self) -> AsyncEngine: raise ValueError('Engine is not initialized. Call initialize() before accessing the engine.') return self._engine - def get_default_flag(self) -> bool: - """Check if the default database is being used.""" - return self._default_flag + def get_dialect_name(self) -> str | None: + """Get the database dialect name.""" + return self._dialect_name def get_accessed_modified_update_interval(self) -> timedelta: """Get the interval for accessed and modified updates.""" @@ -127,6 +128,7 @@ async def initialize(self, configuration: Configuration) -> None: engine = self._get_or_create_engine(configuration) async with engine.begin() as conn: # Set SQLite pragmas for performance and consistency + self._dialect_name = engine.dialect.name if self._default_flag: await conn.execute(text('PRAGMA journal_mode=WAL')) # Better concurrency await conn.execute(text('PRAGMA synchronous=NORMAL')) # Balanced safety/speed diff --git a/uv.lock b/uv.lock index 06d260ffd8..82fd9fee77 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -586,7 +586,7 @@ toml = [ [[package]] name = "crawlee" -version = "0.6.12" +version = "0.6.13" source = { editable = "." } dependencies = [ { name = "cachetools" }, From 6401b654416ca028f8d8f8546ee5e091d6b96340 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Wed, 20 Aug 2025 17:12:03 +0000 Subject: [PATCH 16/29] up pyproject --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 727cb56ebf..974589a2f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ otel = [ "wrapt>=1.17.0", ] sql = [ - "sqlalchemy[asyncio]>=2.0.42,<3.0.0", + "sqlalchemy[asyncio]~=2.0.0,<3.0.0", "aiosqlite>=0.21.0", ] diff --git a/uv.lock b/uv.lock index a875eabf2c..b8d1657093 100644 --- a/uv.lock +++ b/uv.lock @@ -740,7 +740,7 @@ requires-dist = [ { name = "pyee", specifier = ">=9.0.0" }, { name = "rich", marker = "extra == 'cli'", specifier = ">=13.9.0" }, { name = "scikit-learn", marker = "extra == 'adaptive-crawler'", specifier = ">=1.6.0" }, - { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql'", specifier = ">=2.0.42,<3.0.0" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql'", specifier = "~=2.0.0,<3.0.0" }, { name = "tldextract", specifier = ">=5.1.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.0" }, { name = "typing-extensions", specifier = ">=4.1.0" }, From df927d10e66657eaa8c367598c45efe2e7517dfc Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Thu, 21 Aug 2025 20:03:15 +0000 Subject: [PATCH 17/29] refactor --- .../storage_clients/_sql/_client_mixin.py | 268 ++++++++++++++- .../storage_clients/_sql/_dataset_client.py | 155 ++------- .../storage_clients/_sql/_db_models.py | 25 +- .../_sql/_key_value_store_client.py | 180 +++------- .../_sql/_request_queue_client.py | 321 +++++++++--------- .../_sql/test_sql_dataset_client.py | 8 +- .../_sql/test_sql_rq_client.py | 6 +- 7 files changed, 515 insertions(+), 448 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py index d7025db820..eba510a36c 100644 --- a/src/crawlee/storage_clients/_sql/_client_mixin.py +++ b/src/crawlee/storage_clients/_sql/_client_mixin.py @@ -1,20 +1,36 @@ from __future__ import annotations from contextlib import asynccontextmanager +from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar +from sqlalchemy import delete, select, text, update from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as lite_insert from sqlalchemy.exc import SQLAlchemyError +from crawlee._utils.crypto import crypto_random_object_id + if TYPE_CHECKING: from collections.abc import AsyncIterator from sqlalchemy import Insert from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import DeclarativeBase + from typing_extensions import Self + + from crawlee.storage_clients.models import DatasetMetadata, KeyValueStoreMetadata, RequestQueueMetadata + from ._db_models import ( + DatasetItemDB, + DatasetMetadataDB, + KeyValueStoreMetadataDB, + KeyValueStoreRecordDB, + RequestDB, + RequestQueueMetadataDB, + ) from ._storage_client import SQLStorageClient @@ -22,14 +38,46 @@ class SQLClientMixin: - """Mixin class for SQL clients.""" + """Mixin class for SQL clients. + + This mixin provides common SQL operations and basic methods for SQL storage clients. + """ + + _DEFAULT_NAME: ClassVar[str] + """Default name when none provided.""" + + _METADATA_TABLE: ClassVar[type[DatasetMetadataDB | KeyValueStoreMetadataDB | RequestQueueMetadataDB]] + """SQLAlchemy model for metadata.""" + + _ITEM_TABLE: ClassVar[type[DatasetItemDB | KeyValueStoreRecordDB | RequestDB]] + """SQLAlchemy model for items.""" - _storage_client: SQLStorageClient + _CLIENT_TYPE: ClassVar[str] + """Human-readable client type for error messages.""" + + def __init__(self, *, id: str, storage_client: SQLStorageClient) -> None: + self._id = id + self._storage_client = storage_client + + # Time tracking to reduce database writes during frequent operation + self._accessed_at_allow_update_after: datetime | None = None + self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() def get_session(self) -> AsyncSession: - """Create a new SQLAlchemy session for this request queue.""" + """Create a new SQLAlchemy session for this storage.""" return self._storage_client.create_session() + async def _get_metadata( + self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata] + ) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata: + """Retrieve client metadata.""" + async with self.get_session() as session: + orm_metadata = await session.get(self._METADATA_TABLE, self._id) + if not orm_metadata: + raise ValueError(f'Dataset with ID "{self._id}" not found.') + + return metadata_model.model_validate(orm_metadata) + @asynccontextmanager async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: """Create a new SQLAlchemy autocommit session to insert, delete, or modify data.""" @@ -39,11 +87,17 @@ async def get_autocommit_session(self) -> AsyncIterator[AsyncSession]: await session.commit() except SQLAlchemyError as e: logger.warning(f'Error occurred during session transaction: {e}') - # Rollback the session in case of an error await session.rollback() - def build_insert_stmt_with_ignore(self, table_model: Any, insert_values: dict | list[dict]) -> Insert: - """Build an insert statement with ignore for the SQL dialect.""" + def build_insert_stmt_with_ignore( + self, table_model: type[DeclarativeBase], insert_values: dict[str, Any] | list[dict[str, Any]] + ) -> Insert: + """Build an insert statement with ignore for the SQL dialect. + + Args: + table_model: SQLAlchemy table model + insert_values: Single dict or list of dicts to insert + """ if isinstance(insert_values, dict): insert_values = [insert_values] @@ -62,11 +116,20 @@ def build_insert_stmt_with_ignore(self, table_model: Any, insert_values: dict | def build_upsert_stmt( self, - table_model: Any, - insert_values: dict | list[dict], + table_model: type[DeclarativeBase], + insert_values: dict[str, Any] | list[dict[str, Any]], update_columns: list[str], conflict_cols: list[str] | None = None, ) -> Insert: + """Build an upsert statement for the SQL dialect. + + Args: + table_model: SQLAlchemy table model + insert_values: Single dict or list of dicts to upsert + update_columns: Column names to update on conflict + conflict_cols: Column names that define uniqueness (for PostgreSQL/SQLite) + + """ if isinstance(insert_values, dict): insert_values = [insert_values] @@ -88,3 +151,190 @@ def build_upsert_stmt( return mysql_stmt.on_duplicate_key_update(**set_) raise NotImplementedError(f'Upsert not supported for dialect: {dialect}') + + async def _purge(self, metadata_kwargs: dict[str, Any]) -> None: + """Drop all items in storage and update metadata. + + Args: + metadata_kwargs: Arguments to pass to _update_metadata + """ + stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.metadata_id == self._id) + async with self.get_autocommit_session() as autocommit: + await autocommit.execute(stmt) + await self._update_metadata(autocommit, **metadata_kwargs) + + async def _drop(self) -> None: + """Delete this storage and all its data. + + This operation is irreversible. Uses CASCADE deletion to remove all related items. + """ + stmt = delete(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id) + async with self.get_autocommit_session() as autocommit: + if self._storage_client.get_dialect_name() == 'sqlite': + # foreign_keys=ON is set at the connection level. Required for cascade deletion. + await autocommit.execute(text('PRAGMA foreign_keys=ON')) + await autocommit.execute(stmt) + + def _default_update_metadata( + self, *, update_accessed_at: bool = False, update_modified_at: bool = False + ) -> dict[str, Any]: + """Prepare common metadata updates with rate limiting. + + Args: + update_accessed_at: Whether to update accessed_at timestamp + update_modified_at: Whether to update modified_at timestamp + """ + now = datetime.now(timezone.utc) + values_to_set: dict[str, Any] = {} + + if update_accessed_at and ( + self._accessed_at_allow_update_after is None or (now >= self._accessed_at_allow_update_after) + ): + values_to_set['accessed_at'] = now + self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval + + if update_modified_at: + values_to_set['modified_at'] = now + + return values_to_set + + def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]: + """Prepare storage-specific metadata updates. + + Must be implemented by concrete classes. + + Args: + **kwargs: Storage-specific update parameters + """ + raise NotImplementedError('Method _specific_update_metadata must be implemented in the client class.') + + async def _update_metadata( + self, + session: AsyncSession, + *, + update_accessed_at: bool = False, + update_modified_at: bool = False, + **kwargs: Any, + ) -> bool: + """Update storage metadata combining common and specific fields. + + Args: + session: Active database session + update_accessed_at: Whether to update accessed_at timestamp + update_modified_at: Whether to update modified_at timestamp + **kwargs: Additional arguments for _specific_update_metadata + + Returns: + True if any updates were made, False otherwise + """ + values_to_set = self._default_update_metadata( + update_accessed_at=update_accessed_at, update_modified_at=update_modified_at + ) + + values_to_set.update(self._specific_update_metadata(**kwargs)) + + if values_to_set: + stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(**values_to_set) + await session.execute(stmt) + return True + + return False + + @classmethod + async def _open( + cls, + *, + id: str | None, + name: str | None, + storage_client: SQLStorageClient, + metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata], + session: AsyncSession, + extra_metadata_fields: dict[str, Any], + ) -> Self: + """Open existing storage or create new one. + + Internal method used by _safely_open. + + Args: + id: Storage ID to open (takes precedence over name) + name: Storage name to open + storage_client: SQL storage client instance + metadata_model: Pydantic model for metadata validation + session: Active database session + extra_metadata_fields: Storage-specific metadata fields + """ + orm_metadata: DatasetMetadataDB | KeyValueStoreMetadataDB | RequestQueueMetadataDB | None = None + if id: + orm_metadata = await session.get(cls._METADATA_TABLE, id) + if not orm_metadata: + raise ValueError(f'{cls._CLIENT_TYPE} with ID "{id}" not found.') + else: + search_name = name or cls._DEFAULT_NAME + stmt = select(cls._METADATA_TABLE).where(cls._METADATA_TABLE.name == search_name) + result = await session.execute(stmt) + orm_metadata = result.scalar_one_or_none() # type: ignore[assignment] + + if orm_metadata: + client = cls(id=orm_metadata.id, storage_client=storage_client) + await client._update_metadata(session, update_accessed_at=True) + else: + now = datetime.now(timezone.utc) + metadata = metadata_model( + id=crypto_random_object_id(), + name=name, + created_at=now, + accessed_at=now, + modified_at=now, + **extra_metadata_fields, + ) + client = cls(id=metadata.id, storage_client=storage_client) + session.add(cls._METADATA_TABLE(**metadata.model_dump())) + + return client + + @classmethod + async def _safely_open( + cls, + *, + id: str | None, + name: str | None, + storage_client: SQLStorageClient, + metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata], + extra_metadata_fields: dict[str, Any], + ) -> Self: + """Safely open storage with transaction handling. + + Args: + id: Storage ID to open (takes precedence over name) + name: Storage name to open + storage_client: SQL storage client instance + client_class: Concrete client class to instantiate + metadata_model: Pydantic model for metadata validation + extra_metadata_fields: Storage-specific metadata fields + """ + async with storage_client.create_session() as session: + try: + client = await cls._open( + id=id, + name=name, + storage_client=storage_client, + metadata_model=metadata_model, + session=session, + extra_metadata_fields=extra_metadata_fields, + ) + await session.commit() + except SQLAlchemyError: + await session.rollback() + + search_name = name or cls._DEFAULT_NAME + stmt = select(cls._METADATA_TABLE).where(cls._METADATA_TABLE.name == search_name) + result = await session.execute(stmt) + orm_metadata: DatasetMetadataDB | KeyValueStoreMetadataDB | RequestQueueMetadataDB | None + orm_metadata = result.scalar_one_or_none() # type: ignore[assignment] + + if not orm_metadata: + raise ValueError(f'{cls._CLIENT_TYPE} with Name "{search_name}" not found.') from None + + client = cls(id=orm_metadata.id, storage_client=storage_client) + + return client diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 9ac20531f1..171cdadbc9 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -1,15 +1,12 @@ from __future__ import annotations import json -from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import Select, delete, insert, select, text, update -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy import Select, insert, select from typing_extensions import override -from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata @@ -20,10 +17,10 @@ from collections.abc import AsyncIterator from sqlalchemy import Select - from sqlalchemy.ext.asyncio import AsyncSession from ._storage_client import SQLStorageClient + logger = getLogger(__name__) @@ -42,9 +39,18 @@ class SQLDatasetClient(DatasetClient, SQLClientMixin): All operations are wrapped in database transactions with CASCADE deletion support. """ - _DEFAULT_NAME_DB = 'default' + _DEFAULT_NAME = 'default' """Default dataset name used when no name is provided.""" + _METADATA_TABLE = DatasetMetadataDB + """SQLAlchemy model for dataset metadata.""" + + _ITEM_TABLE = DatasetItemDB + """SQLAlchemy model for dataset items.""" + + _CLIENT_TYPE = 'Dataset' + """Human-readable client type for error messages.""" + def __init__( self, *, @@ -55,24 +61,14 @@ def __init__( Preferably use the `SqlDatasetClient.open` class method to create a new instance. """ - self._id = id - self._storage_client = storage_client - - # Time tracking to reduce database writes during frequent operation - self._last_accessed_at: datetime | None = None - self._last_modified_at: datetime | None = None - self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() + super().__init__(id=id, storage_client=storage_client) @override async def get_metadata(self) -> DatasetMetadata: """Get dataset metadata from the database.""" # The database is a single place of truth - async with self.get_session() as session: - orm_metadata: DatasetMetadataDB | None = await session.get(DatasetMetadataDB, self._id) - if not orm_metadata: - raise ValueError(f'Dataset with ID "{self._id}" not found.') - - return DatasetMetadata.model_validate(orm_metadata) + metadata = await self._get_metadata(DatasetMetadata) + return cast('DatasetMetadata', metadata) @classmethod async def open( @@ -95,60 +91,13 @@ async def open( Raises: ValueError: If a dataset with the specified ID is not found. """ - async with storage_client.create_session() as session: - orm_metadata: DatasetMetadataDB | None = None - if id: - orm_metadata = await session.get(DatasetMetadataDB, id) - if not orm_metadata: - raise ValueError(f'Dataset with ID "{id}" not found.') - else: - search_name = name or cls._DEFAULT_NAME_DB - stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == search_name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - - if orm_metadata: - client = cls( - id=orm_metadata.id, - storage_client=storage_client, - ) - await client._update_metadata(session, update_accessed_at=True) - else: - now = datetime.now(timezone.utc) - metadata = DatasetMetadata( - id=crypto_random_object_id(), - name=name, - created_at=now, - accessed_at=now, - modified_at=now, - item_count=0, - ) - - client = cls( - id=metadata.id, - storage_client=storage_client, - ) - session.add(DatasetMetadataDB(**metadata.model_dump())) - - try: - # Commit the insert or update metadata to the database - await session.commit() - except SQLAlchemyError: - # Attempt to open simultaneously by different clients. - # The commit that created the record has already been executed, make rollback and get by name. - await session.rollback() - search_name = name or cls._DEFAULT_NAME_DB - stmt = select(DatasetMetadataDB).where(DatasetMetadataDB.name == search_name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - if not orm_metadata: - raise ValueError(f'Dataset with Name "{search_name}" not found.') from None - client = cls( - id=orm_metadata.id, - storage_client=storage_client, - ) - - return client + return await cls._safely_open( + id=id, + name=name, + storage_client=storage_client, + metadata_model=DatasetMetadata, + extra_metadata_fields={'itemCount': 0}, + ) @override async def drop(self) -> None: @@ -156,12 +105,7 @@ async def drop(self) -> None: This operation is irreversible. Uses CASCADE deletion to remove all related items. """ - stmt = delete(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id) - async with self.get_autocommit_session() as autocommit: - if self._storage_client.get_dialect_name() == 'sqlite': - # foreign_keys=ON is set at the connection level. Required for cascade deletion. - await autocommit.execute(text('PRAGMA foreign_keys=ON')) - await autocommit.execute(stmt) + await self._drop() @override async def purge(self) -> None: @@ -169,11 +113,7 @@ async def purge(self) -> None: Resets item_count to 0 and deletes all records from dataset_item table. """ - stmt = delete(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) - async with self.get_autocommit_session() as autocommit: - await autocommit.execute(stmt) - - await self._update_metadata(autocommit, new_item_count=0, update_accessed_at=True, update_modified_at=True) + await self._purge(metadata_kwargs={'new_item_count': 0, 'update_accessed_at': True, 'update_modified_at': True}) @override async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: @@ -188,12 +128,12 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: json_item = json.dumps(item, default=str, ensure_ascii=False) db_items.append( { - 'dataset_id': self._id, + 'metadata_id': self._id, 'data': json_item, } ) - stmt = insert(DatasetItemDB).values(db_items) + stmt = insert(self._ITEM_TABLE).values(db_items) async with self.get_autocommit_session() as autocommit: await autocommit.execute(stmt) @@ -235,14 +175,16 @@ def _prepare_get_stmt( f'{self.__class__.__name__} client.' ) - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == self._id) + stmt = select(self._ITEM_TABLE).where(self._ITEM_TABLE.metadata_id == self._id) if skip_empty: # Skip items that are empty JSON objects - stmt = stmt.where(DatasetItemDB.data != '"{}"') + stmt = stmt.where(self._ITEM_TABLE.data != '"{}"') # Apply ordering by insertion order (order_id) - stmt = stmt.order_by(DatasetItemDB.order_id.desc()) if desc else stmt.order_by(DatasetItemDB.order_id.asc()) + stmt = ( + stmt.order_by(self._ITEM_TABLE.order_id.desc()) if desc else stmt.order_by(self._ITEM_TABLE.order_id.asc()) + ) return stmt.offset(offset).limit(limit) @@ -337,48 +279,25 @@ async def iterate_items( if updated: await session.commit() - async def _update_metadata( + def _specific_update_metadata( self, - session: AsyncSession, - *, new_item_count: int | None = None, - update_accessed_at: bool = False, - update_modified_at: bool = False, delta_item_count: int | None = None, - ) -> bool: + **_kwargs: dict[str, Any], + ) -> dict[str, Any]: """Update the dataset metadata in the database. Args: session: The SQLAlchemy AsyncSession to use for the update. new_item_count: If provided, set item count to this value. - update_accessed_at: If True, update the accessed_at timestamp. - update_modified_at: If True, update the modified_at timestamp. delta_item_count: If provided, add this value to the current item count. """ - now = datetime.now(timezone.utc) values_to_set: dict[str, Any] = {} - if update_accessed_at and ( - self._last_accessed_at is None or (now - self._last_accessed_at) > self._accessed_modified_update_interval - ): - values_to_set['accessed_at'] = now - self._last_accessed_at = now - - if update_modified_at and ( - self._last_modified_at is None or (now - self._last_modified_at) > self._accessed_modified_update_interval - ): - values_to_set['modified_at'] = now - self._last_modified_at = now - if new_item_count is not None: values_to_set['item_count'] = new_item_count elif delta_item_count: # Use database-level for atomic updates - values_to_set['item_count'] = DatasetMetadataDB.item_count + delta_item_count - - if values_to_set: - stmt = update(DatasetMetadataDB).where(DatasetMetadataDB.id == self._id).values(**values_to_set) - await session.execute(stmt) - return True + values_to_set['item_count'] = self._METADATA_TABLE.item_count + delta_item_count - return False + return values_to_set diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index b41eecc3a0..e5aee1b27a 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING -from sqlalchemy import JSON, Boolean, ForeignKey, Index, Integer, LargeBinary, String +from sqlalchemy import JSON, BigInteger, Boolean, ForeignKey, Index, Integer, LargeBinary, String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.types import DateTime, TypeDecorator from typing_extensions import override @@ -130,7 +130,7 @@ class KeyValueStoreRecordDB(Base): __tablename__ = 'kvs_record' - kvs_id: Mapped[str] = mapped_column( + metadata_id: Mapped[str] = mapped_column( String(255), ForeignKey('kvs_metadata.id', ondelete='CASCADE'), primary_key=True, index=True ) """Foreign key to metadata key-value store record.""" @@ -159,7 +159,7 @@ class DatasetItemDB(Base): order_id: Mapped[int] = mapped_column(Integer, primary_key=True) """Auto-increment primary key preserving insertion order.""" - dataset_id: Mapped[str] = mapped_column( + metadata_id: Mapped[str] = mapped_column( String(20), ForeignKey('dataset_metadata.id', ondelete='CASCADE'), index=True, @@ -179,15 +179,13 @@ class RequestDB(Base): __tablename__ = 'request' __table_args__ = ( # Index for efficient SELECT to cache - Index('idx_queue_handled_seq', 'queue_id', 'is_handled', 'sequence_number'), - # Deduplication index - Index('idx_queue_unique_key', 'queue_id', 'unique_key'), + Index('idx_queue_handled_seq', 'metadata_id', 'is_handled', 'sequence_number'), ) - request_id: Mapped[str] = mapped_column(String(20), primary_key=True) - """Unique identifier for the request.""" + request_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + """Unique identifier for the request representing the unique_key.""" - queue_id: Mapped[str] = mapped_column( + metadata_id: Mapped[str] = mapped_column( String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), primary_key=True ) """Foreign key to metadata request queue record.""" @@ -195,16 +193,13 @@ class RequestDB(Base): data: Mapped[str] = mapped_column(JSON, nullable=False) """JSON-serialized Request object.""" - unique_key: Mapped[str] = mapped_column(String(512), nullable=False) - """Request unique key for deduplication within queue.""" - sequence_number: Mapped[int] = mapped_column(Integer, nullable=False) """Ordering sequence: negative for forefront, positive for regular.""" is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) """Processing status flag.""" - # Relationship back to parent queue + # Relationship back to metadata table queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') @@ -213,7 +208,7 @@ class RequestQueueStateDB(Base): __tablename__ = 'request_queue_state' - queue_id: Mapped[str] = mapped_column( + metadata_id: Mapped[str] = mapped_column( String(20), ForeignKey('request_queue_metadata.id', ondelete='CASCADE'), primary_key=True ) """Foreign key to metadata request queue record.""" @@ -224,5 +219,5 @@ class RequestQueueStateDB(Base): forefront_sequence_counter: Mapped[int] = mapped_column(Integer, nullable=False, default=-1) """Counter for forefront request ordering (negative).""" - # Relationship back to parent queue + # Relationship back to metadata table queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='state') diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index 9cb432b0a4..6b4cda2219 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -1,15 +1,13 @@ from __future__ import annotations import json -from datetime import datetime, timezone from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import delete, insert, select, text, update -from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy import delete, insert, select, update +from sqlalchemy.exc import IntegrityError from typing_extensions import override -from crawlee._utils.crypto import crypto_random_object_id from crawlee._utils.file import infer_mime_type from crawlee.storage_clients._base import KeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata @@ -20,8 +18,6 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from sqlalchemy.ext.asyncio import AsyncSession - from ._storage_client import SQLStorageClient @@ -47,12 +43,21 @@ class SQLKeyValueStoreClient(KeyValueStoreClient, SQLClientMixin): All database operations are wrapped in transactions with proper error handling and rollback mechanisms. The client supports atomic upsert operations and handles race conditions when - multiple clients access the same store using composite primary keys (kvs_id, key). + multiple clients access the same store using composite primary keys (metadata_id, key). """ - _DEFAULT_NAME_DB = 'default' + _DEFAULT_NAME = 'default' """Default dataset name used when no name is provided.""" + _METADATA_TABLE = KeyValueStoreMetadataDB + """SQLAlchemy model for key-value store metadata.""" + + _ITEM_TABLE = KeyValueStoreRecordDB + """SQLAlchemy model for key-value store items.""" + + _CLIENT_TYPE = 'Key-value store' + """Human-readable client type for error messages.""" + def __init__( self, *, @@ -63,26 +68,14 @@ def __init__( Preferably use the `SQLKeyValueStoreClient.open` class method to create a new instance. """ - self._id = id - - self._storage_client = storage_client - """The storage client used to access the SQL database.""" - - # Time tracking to reduce database writes during frequent operation - self._last_accessed_at: datetime | None = None - self._last_modified_at: datetime | None = None - self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() + super().__init__(id=id, storage_client=storage_client) @override async def get_metadata(self) -> KeyValueStoreMetadata: """Get the metadata for this key-value store.""" # The database is a single place of truth - async with self.get_session() as session: - orm_metadata: KeyValueStoreMetadataDB | None = await session.get(KeyValueStoreMetadataDB, self._id) - if not orm_metadata: - raise ValueError(f'Key-value store with ID "{self._id}" not found.') - - return KeyValueStoreMetadata.model_validate(orm_metadata) + metadata = await self._get_metadata(KeyValueStoreMetadata) + return cast('KeyValueStoreMetadata', metadata) @classmethod async def open( @@ -109,56 +102,13 @@ async def open( Raises: ValueError: If a store with the specified ID is not found, or if metadata is invalid. """ - async with storage_client.create_session() as session: - orm_metadata: KeyValueStoreMetadataDB | None = None - if id: - orm_metadata = await session.get(KeyValueStoreMetadataDB, id) - if not orm_metadata: - raise ValueError(f'Key-value store with ID "{id}" not found.') - else: - search_name = name or cls._DEFAULT_NAME_DB - stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == search_name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - if orm_metadata: - client = cls( - id=orm_metadata.id, - storage_client=storage_client, - ) - await client._update_metadata(session, update_accessed_at=True) - else: - now = datetime.now(timezone.utc) - metadata = KeyValueStoreMetadata( - id=crypto_random_object_id(), - name=name, - created_at=now, - accessed_at=now, - modified_at=now, - ) - client = cls( - id=metadata.id, - storage_client=storage_client, - ) - session.add(KeyValueStoreMetadataDB(**metadata.model_dump())) - - try: - # Commit the insert or update metadata to the database - await session.commit() - except SQLAlchemyError: - # Attempt to open simultaneously by different clients. - # The commit that created the record has already been executed, make rollback and get by name. - await session.rollback() - search_name = name or cls._DEFAULT_NAME_DB - stmt = select(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.name == search_name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - if not orm_metadata: - raise ValueError(f'Key-value store with Name "{search_name}" not found.') from None - client = cls( - id=orm_metadata.id, - storage_client=storage_client, - ) - return client + return await cls._safely_open( + id=id, + name=name, + storage_client=storage_client, + metadata_model=KeyValueStoreMetadata, + extra_metadata_fields={}, + ) @override async def drop(self) -> None: @@ -166,21 +116,12 @@ async def drop(self) -> None: This operation is irreversible. Uses CASCADE deletion to remove all related records. """ - stmt = delete(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id) - async with self.get_autocommit_session() as autocommit: - if self._storage_client.get_dialect_name() == 'sqlite': - # foreign_keys=ON is set at the connection level. Required for cascade deletion. - await autocommit.execute(text('PRAGMA foreign_keys=ON')) - await autocommit.execute(stmt) + await self._drop() @override async def purge(self) -> None: """Remove all items from this key-value store while keeping the key-value store structure.""" - stmt = delete(KeyValueStoreRecordDB).filter_by(kvs_id=self._id) - async with self.get_autocommit_session() as autocommit: - await autocommit.execute(stmt) - - await self._update_metadata(autocommit, update_accessed_at=True, update_modified_at=True) + await self._purge(metadata_kwargs={'update_accessed_at': True, 'update_modified_at': True}) @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: @@ -205,7 +146,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No size = len(value_bytes) insert_values = { - 'kvs_id': self._id, + 'metadata_id': self._id, 'key': key, 'value': value_bytes, 'content_type': content_type, @@ -214,20 +155,20 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No try: # Trying to build a statement for Upsert upsert_stmt = self.build_upsert_stmt( - KeyValueStoreRecordDB, + self._ITEM_TABLE, insert_values=insert_values, update_columns=['value', 'content_type', 'size'], - conflict_cols=['kvs_id', 'key'], + conflict_cols=['metadata_id', 'key'], ) except NotImplementedError: # If it is not possible to build an upsert for the current dialect, build an update + insert. upsert_stmt = None update_stmt = ( - update(KeyValueStoreRecordDB) - .where(KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key) + update(self._ITEM_TABLE) + .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.key == key) .values(value=value_bytes, content_type=content_type, size=size) ) - insert_stmt = insert(KeyValueStoreRecordDB).values(**insert_values) + insert_stmt = insert(self._ITEM_TABLE).values(**insert_values) async with self.get_session() as session: if upsert_stmt is not None: @@ -249,9 +190,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: """Get a value from the key-value store.""" # Query the record by key - stmt = select(KeyValueStoreRecordDB).where( - KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key - ) + stmt = select(self._ITEM_TABLE).where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.key == key) async with self.get_session() as session: result = await session.execute(stmt) record_db = result.scalar_one_or_none() @@ -299,9 +238,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @override async def delete_value(self, *, key: str) -> None: """Delete a value from the key-value store.""" - stmt = delete(KeyValueStoreRecordDB).where( - KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key - ) + stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.key == key) async with self.get_autocommit_session() as autocommit: # Delete the record if it exists result = await autocommit.execute(stmt) @@ -322,14 +259,14 @@ async def iterate_keys( """Iterate over the existing keys in the key-value store.""" # Build query for record metadata stmt = ( - select(KeyValueStoreRecordDB.key, KeyValueStoreRecordDB.content_type, KeyValueStoreRecordDB.size) - .where(KeyValueStoreRecordDB.kvs_id == self._id) - .order_by(KeyValueStoreRecordDB.key) + select(self._ITEM_TABLE.key, self._ITEM_TABLE.content_type, self._ITEM_TABLE.size) + .where(self._ITEM_TABLE.metadata_id == self._id) + .order_by(self._ITEM_TABLE.key) ) # Apply exclusive_start_key filter if exclusive_start_key is not None: - stmt = stmt.where(KeyValueStoreRecordDB.key > exclusive_start_key) + stmt = stmt.where(self._ITEM_TABLE.key > exclusive_start_key) # Apply limit if limit is not None: @@ -354,9 +291,7 @@ async def iterate_keys( @override async def record_exists(self, *, key: str) -> bool: """Check if a record with the given key exists in the key-value store.""" - stmt = select(KeyValueStoreRecordDB.key).where( - KeyValueStoreRecordDB.kvs_id == self._id, KeyValueStoreRecordDB.key == key - ) + stmt = select(self._ITEM_TABLE.key).where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.key == key) async with self.get_session() as session: # Check if record exists result = await session.execute(stmt) @@ -373,38 +308,5 @@ async def record_exists(self, *, key: str) -> bool: async def get_public_url(self, *, key: str) -> str: raise NotImplementedError('Public URLs are not supported for memory key-value stores.') - async def _update_metadata( - self, - session: AsyncSession, - *, - update_accessed_at: bool = False, - update_modified_at: bool = False, - ) -> bool: - """Update the KVS metadata in the database. - - Args: - session: The SQLAlchemy AsyncSession to use for the update. - update_accessed_at: If True, update the `accessed_at` timestamp to the current time. - update_modified_at: If True, update the `modified_at` timestamp to the current time. - """ - now = datetime.now(timezone.utc) - values_to_set: dict[str, Any] = {} - - if update_accessed_at and ( - self._last_accessed_at is None or (now - self._last_accessed_at) > self._accessed_modified_update_interval - ): - values_to_set['accessed_at'] = now - self._last_accessed_at = now - - if update_modified_at and ( - self._last_modified_at is None or (now - self._last_modified_at) > self._accessed_modified_update_interval - ): - values_to_set['modified_at'] = now - self._last_modified_at = now - - if values_to_set: - stmt = update(KeyValueStoreMetadataDB).where(KeyValueStoreMetadataDB.id == self._id).values(**values_to_set) - await session.execute(stmt) - return True - - return False + def _specific_update_metadata(self, **_kwargs: dict[str, Any]) -> dict[str, Any]: + return {} diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 2462776f17..e8732ba510 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -2,17 +2,16 @@ import asyncio from collections import deque -from datetime import datetime, timezone +from hashlib import sha256 from logging import getLogger -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import delete, select, text, update +from sqlalchemy import select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import load_only from typing_extensions import override from crawlee import Request -from crawlee._utils.crypto import crypto_random_object_id from crawlee.storage_clients._base import RequestQueueClient from crawlee.storage_clients.models import ( AddRequestsResponse, @@ -50,18 +49,28 @@ class SQLRequestQueueClient(RequestQueueClient, SQLClientMixin): - `request_queue_metadata` table: Contains queue metadata (id, name, timestamps, request counts, multi-client flag) - `request` table: Contains individual requests with JSON data, unique keys for deduplication, sequence numbers for ordering, and processing status flags + - `request_queue_state` table: Maintains counters for sequence numbers to ensure proper ordering of requests. Requests are serialized to JSON for storage and maintain proper ordering through sequence numbers. The implementation provides concurrent access safety through transaction handling, locking mechanisms, and optimized database indexes for efficient querying. """ - _DEFAULT_NAME_DB = 'default' + _DEFAULT_NAME = 'default' """Default dataset name used when no name is provided.""" _MAX_REQUESTS_IN_CACHE = 1000 """Maximum number of requests to keep in cache for faster access.""" + _METADATA_TABLE = RequestQueueMetadataDB + """SQLAlchemy model for request queue metadata.""" + + _ITEM_TABLE = RequestDB + """SQLAlchemy model for request items.""" + + _CLIENT_TYPE = 'Request queue' + """Human-readable client type for error messages.""" + def __init__( self, *, @@ -72,7 +81,7 @@ def __init__( Preferably use the `SQLRequestQueueClient.open` class method to create a new instance. """ - self._id = id + super().__init__(id=id, storage_client=storage_client) self._request_cache: deque[Request] = deque() """Cache for requests: ordered by sequence number.""" @@ -86,20 +95,13 @@ def __init__( self._is_empty_cache: bool | None = None """Cache for is_empty result: None means unknown, True/False is cached state.""" - self._last_accessed_at: datetime | None = None - self._last_modified_at: datetime | None = None - self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() - - self._storage_client = storage_client - """The storage client used to access the SQL database.""" - self._lock = asyncio.Lock() async def _get_state(self, session: AsyncSession) -> RequestQueueStateDB: """Get the current state of the request queue.""" orm_state: RequestQueueStateDB | None = await session.get(RequestQueueStateDB, self._id) if not orm_state: - orm_state = RequestQueueStateDB(queue_id=self._id) + orm_state = RequestQueueStateDB(metadata_id=self._id) session.add(orm_state) await session.flush() return orm_state @@ -108,12 +110,8 @@ async def _get_state(self, session: AsyncSession) -> RequestQueueStateDB: async def get_metadata(self) -> RequestQueueMetadata: """Get the metadata for this request queue.""" # The database is a single place of truth - async with self.get_session() as session: - orm_metadata: RequestQueueMetadataDB | None = await session.get(RequestQueueMetadataDB, self._id) - if not orm_metadata: - raise ValueError(f'Request queue with ID "{self._id}" not found.') - - return RequestQueueMetadata.model_validate(orm_metadata) + metadata = await self._get_metadata(RequestQueueMetadata) + return cast('RequestQueueMetadata', metadata) @classmethod async def open( @@ -140,63 +138,18 @@ async def open( Raises: ValueError: If a queue with the specified ID is not found. """ - async with storage_client.create_session() as session: - orm_metadata: RequestQueueMetadataDB | None = None - if id: - orm_metadata = await session.get(RequestQueueMetadataDB, id) - if not orm_metadata: - raise ValueError(f'Request queue with ID "{id}" not found.') - else: - # Try to find by name - search_name = name or cls._DEFAULT_NAME_DB - stmt = select(RequestQueueMetadataDB).where(RequestQueueMetadataDB.name == search_name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - if orm_metadata: - client = cls( - id=orm_metadata.id, - storage_client=storage_client, - ) - await client._update_metadata(session, update_accessed_at=True) - else: - now = datetime.now(timezone.utc) - metadata = RequestQueueMetadata( - id=crypto_random_object_id(), - name=name, - created_at=now, - accessed_at=now, - modified_at=now, - had_multiple_clients=False, - handled_request_count=0, - pending_request_count=0, - total_request_count=0, - ) - - client = cls( - id=metadata.id, - storage_client=storage_client, - ) - session.add(RequestQueueMetadataDB(**metadata.model_dump())) - - try: - # Commit the insert or update metadata to the database - await session.commit() - except SQLAlchemyError: - # Attempt to open simultaneously by different clients. - # The commit that created the record has already been executed, make rollback and get by name. - await session.rollback() - search_name = name or cls._DEFAULT_NAME_DB - stmt = select(RequestQueueMetadataDB).where(RequestQueueMetadataDB.name == search_name) - result = await session.execute(stmt) - orm_metadata = result.scalar_one_or_none() - if not orm_metadata: - raise ValueError(f'Request queue with Name "{search_name}" not found.') from None - client = cls( - id=orm_metadata.id, - storage_client=storage_client, - ) - - return client + return await cls._safely_open( + id=id, + name=name, + storage_client=storage_client, + metadata_model=RequestQueueMetadata, + extra_metadata_fields={ + 'had_multiple_clients': False, + 'handled_request_count': 0, + 'pending_request_count': 0, + 'total_request_count': 0, + }, + ) @override async def drop(self) -> None: @@ -204,13 +157,7 @@ async def drop(self) -> None: This operation is irreversible. Uses CASCADE deletion to remove all related records. """ - stmt = delete(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id) - async with self.get_autocommit_session() as autocommit: - if self._storage_client.get_dialect_name() == 'sqlite': - # foreign_keys=ON is set at the connection level. Required for cascade deletion. - await autocommit.execute(text('PRAGMA foreign_keys=ON')) - # Delete the request queue metadata (cascade will delete requests) - await autocommit.execute(stmt) + await self._drop() self._request_cache.clear() self._request_cache_needs_refresh = True @@ -219,18 +166,14 @@ async def drop(self) -> None: @override async def purge(self) -> None: """Purge all requests from this request queue.""" - stmt = delete(RequestDB).where(RequestDB.queue_id == self._id) - async with self.get_autocommit_session() as autocommit: - # Delete all requests for this queue - await autocommit.execute(stmt) - - await self._update_metadata( - autocommit, - new_pending_request_count=0, - new_handled_request_count=0, - update_modified_at=True, - update_accessed_at=True, - ) + await self._purge( + metadata_kwargs={ + 'update_accessed_at': True, + 'update_modified_at': True, + 'new_pending_request_count': 0, + 'new_handled_request_count': 0, + } + ) self._is_empty_cache = None @@ -257,36 +200,39 @@ async def _add_batch_of_requests_optimization( # Deduplicate requests by unique_key upfront unique_requests = {} + unique_key_by_request_id = {} for req in requests: if req.unique_key not in unique_requests: - unique_requests[req.unique_key] = req + request_id = self._get_int_id_from_unique_key(req.unique_key) + unique_requests[request_id] = req + unique_key_by_request_id[request_id] = req.unique_key # Get existing requests by unique keys stmt = ( - select(RequestDB) - .where(RequestDB.queue_id == self._id, RequestDB.unique_key.in_(set(unique_requests.keys()))) + select(self._ITEM_TABLE) + .where( + self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id.in_(set(unique_requests.keys())) + ) .options( load_only( - RequestDB.request_id, - RequestDB.unique_key, - RequestDB.is_handled, + self._ITEM_TABLE.request_id, + self._ITEM_TABLE.is_handled, ) ) ) async with self.get_session() as session: result = await session.execute(stmt) - existing_requests = {req.unique_key: req for req in result.scalars()} + existing_requests = {req.request_id: req for req in result.scalars()} state = await self._get_state(session) insert_values: list[dict] = [] - for unique_key, request in unique_requests.items(): - existing_req_db = existing_requests.get(unique_key) + for request_id, request in unique_requests.items(): + existing_req_db = existing_requests.get(request_id) if existing_req_db is None or not existing_req_db.is_handled: value = { - 'request_id': request.id, - 'queue_id': self._id, + 'request_id': request_id, + 'metadata_id': self._id, 'data': request.model_dump_json(), - 'unique_key': request.unique_key, 'is_handled': False, } if forefront: @@ -303,7 +249,6 @@ async def _add_batch_of_requests_optimization( delta_pending_request_count += 1 processed_requests.append( ProcessedRequest( - id=request.id, unique_key=request.unique_key, was_already_present=False, was_already_handled=False, @@ -312,7 +257,6 @@ async def _add_batch_of_requests_optimization( else: processed_requests.append( ProcessedRequest( - id=request.id, unique_key=request.unique_key, was_already_present=True, was_already_handled=existing_req_db.is_handled, @@ -323,8 +267,7 @@ async def _add_batch_of_requests_optimization( # Already handled request, skip adding processed_requests.append( ProcessedRequest( - id=existing_req_db.request_id, - unique_key=unique_key, + unique_key=unique_key_by_request_id[request_id], was_already_present=True, was_already_handled=True, ) @@ -335,14 +278,14 @@ async def _add_batch_of_requests_optimization( # If the request already exists in the database, we update the sequence_number by shifting request # to the left. upsert_stmt = self.build_upsert_stmt( - RequestDB, + self._ITEM_TABLE, insert_values, update_columns=['sequence_number'], ) await session.execute(upsert_stmt) else: # If the request already exists in the database, we ignore this request when inserting. - insert_stmt_with_ignore = self.build_insert_stmt_with_ignore(RequestDB, insert_values) + insert_stmt_with_ignore = self.build_insert_stmt_with_ignore(self._ITEM_TABLE, insert_values) await session.execute(insert_stmt_with_ignore) await self._update_metadata( @@ -392,14 +335,17 @@ async def add_batch_of_requests( raise NotImplementedError('Batch addition is not supported for this database dialect.') @override - async def get_request(self, request_id: str) -> Request | None: - stmt = select(RequestDB).where(RequestDB.queue_id == self._id, RequestDB.request_id == request_id) + async def get_request(self, unique_key: str) -> Request | None: + request_id = self._get_int_id_from_unique_key(unique_key) + stmt = select(self._ITEM_TABLE).where( + self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id + ) async with self.get_session() as session: result = await session.execute(stmt) request_db = result.scalar_one_or_none() if request_db is None: - logger.warning(f'Request with ID "{request_id}" not found in the queue.') + logger.warning(f'Request with ID "{unique_key}" not found in the queue.') return None updated = await self._update_metadata(session, update_accessed_at=True) @@ -410,7 +356,7 @@ async def get_request(self, request_id: str) -> Request | None: request = Request.model_validate_json(request_db.data) - self.in_progress_requests.add(request.id) + self.in_progress_requests.add(request.unique_key) return request @@ -428,9 +374,9 @@ async def fetch_next_request(self) -> Request | None: candidate = self._request_cache.popleft() # Only check local state - if candidate.id not in self.in_progress_requests: + if candidate.unique_key not in self.in_progress_requests: next_request = candidate - self.in_progress_requests.add(next_request.id) + self.in_progress_requests.add(next_request.unique_key) if not self._request_cache: self._is_empty_cache = None @@ -441,14 +387,16 @@ async def fetch_next_request(self) -> Request | None: async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: self._is_empty_cache = None - if request.id not in self.in_progress_requests: - logger.warning(f'Marking request {request.id} as handled that is not in progress.') + if request.unique_key not in self.in_progress_requests: + logger.warning(f'Marking request {request.unique_key} as handled that is not in progress.') return None + request_id = self._get_int_id_from_unique_key(request.unique_key) + # Update request in DB stmt = ( - update(RequestDB) - .where(RequestDB.queue_id == self._id, RequestDB.request_id == request.id) + update(self._ITEM_TABLE) + .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id) .values(is_handled=True) ) @@ -456,7 +404,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | result = await session.execute(stmt) if result.rowcount == 0: - logger.warning(f'Request {request.id} not found in database.') + logger.warning(f'Request {request.unique_key} not found in database.') return None await self._update_metadata( @@ -473,10 +421,9 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | await session.rollback() return None - self.in_progress_requests.discard(request.id) + self.in_progress_requests.discard(request.unique_key) return ProcessedRequest( - id=request.id, unique_key=request.unique_key, was_already_present=True, was_already_handled=True, @@ -491,10 +438,12 @@ async def reclaim_request( ) -> ProcessedRequest | None: self._is_empty_cache = None - if request.id not in self.in_progress_requests: - logger.info(f'Reclaiming request {request.id} that is not in progress.') + if request.unique_key not in self.in_progress_requests: + logger.info(f'Reclaiming request {request.unique_key} that is not in progress.') return None + request_id = self._get_int_id_from_unique_key(request.unique_key) + async with self.get_autocommit_session() as autocommit: state = await self._get_state(autocommit) @@ -507,21 +456,21 @@ async def reclaim_request( state.sequence_counter += 1 stmt = ( - update(RequestDB) - .where(RequestDB.queue_id == self._id, RequestDB.request_id == request.id) + update(self._ITEM_TABLE) + .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id) .values(sequence_number=new_sequence) ) result = await autocommit.execute(stmt) if result.rowcount == 0: - logger.warning(f'Request {request.id} not found in database.') + logger.warning(f'Request {request.unique_key} not found in database.') return None await self._update_metadata(autocommit, update_modified_at=True, update_accessed_at=True) # Remove from in-progress - self.in_progress_requests.discard(request.id) + self.in_progress_requests.discard(request.unique_key) # Invalidate cache or add to cache if forefront: @@ -531,7 +480,6 @@ async def reclaim_request( self._request_cache.append(request) return ProcessedRequest( - id=request.id, unique_key=request.unique_key, was_already_present=True, was_already_handled=False, @@ -549,7 +497,7 @@ async def is_empty(self) -> bool: # Check database for unhandled requests async with self.get_session() as session: - metadata_orm = await session.get(RequestQueueMetadataDB, self._id) + metadata_orm = await session.get(self._METADATA_TABLE, self._id) if not metadata_orm: raise ValueError(f'Request queue with ID "{self._id}" not found.') @@ -569,13 +517,13 @@ async def _refresh_cache(self) -> None: async with self.get_session() as session: # Simple query - get unhandled requests not in progress stmt = ( - select(RequestDB) + select(self._ITEM_TABLE) .where( - RequestDB.queue_id == self._id, - RequestDB.is_handled == False, # noqa: E712 - RequestDB.request_id.notin_(self.in_progress_requests), + self._ITEM_TABLE.metadata_id == self._id, + self._ITEM_TABLE.is_handled == False, # noqa: E712 + self._ITEM_TABLE.request_id.notin_(self.in_progress_requests), ) - .order_by(RequestDB.sequence_number.asc()) + .order_by(self._ITEM_TABLE.sequence_number.asc()) .limit(self._MAX_REQUESTS_IN_CACHE) ) @@ -602,6 +550,7 @@ async def _update_metadata( update_had_multiple_clients: bool = False, update_accessed_at: bool = False, update_modified_at: bool = False, + **_kwargs: dict[str, Any], ) -> bool: """Update the request queue metadata in the database. @@ -617,20 +566,10 @@ async def _update_metadata( update_accessed_at: If True, update the `accessed_at` timestamp to the current time. update_modified_at: If True, update the `modified_at` timestamp to the current time. """ - now = datetime.now(timezone.utc) - values_to_set: dict[str, Any] = {} - - if update_accessed_at and ( - self._last_accessed_at is None or (now - self._last_accessed_at) > self._accessed_modified_update_interval - ): - values_to_set['accessed_at'] = now - self._last_accessed_at = now - - if update_modified_at and ( - self._last_modified_at is None or (now - self._last_modified_at) > self._accessed_modified_update_interval - ): - values_to_set['modified_at'] = now - self._last_modified_at = now + values_to_set = self._default_update_metadata( + update_accessed_at=update_accessed_at, + update_modified_at=update_modified_at, + ) if update_had_multiple_clients: values_to_set['had_multiple_clients'] = True @@ -639,26 +578,88 @@ async def _update_metadata( values_to_set['handled_request_count'] = new_handled_request_count elif delta_handled_request_count is not None: values_to_set['handled_request_count'] = ( - RequestQueueMetadataDB.handled_request_count + delta_handled_request_count + self._METADATA_TABLE.handled_request_count + delta_handled_request_count ) if new_pending_request_count is not None: values_to_set['pending_request_count'] = new_pending_request_count elif delta_pending_request_count is not None: values_to_set['pending_request_count'] = ( - RequestQueueMetadataDB.pending_request_count + delta_pending_request_count + self._METADATA_TABLE.pending_request_count + delta_pending_request_count ) if new_total_request_count is not None: values_to_set['total_request_count'] = new_total_request_count elif delta_total_request_count is not None: - values_to_set['total_request_count'] = ( - RequestQueueMetadataDB.total_request_count + delta_total_request_count - ) + values_to_set['total_request_count'] = self._METADATA_TABLE.total_request_count + delta_total_request_count if values_to_set: - stmt = update(RequestQueueMetadataDB).where(RequestQueueMetadataDB.id == self._id).values(**values_to_set) + stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(**values_to_set) await session.execute(stmt) return True return False + + def _specific_update_metadata( + self, + new_handled_request_count: int | None = None, + new_pending_request_count: int | None = None, + new_total_request_count: int | None = None, + delta_handled_request_count: int | None = None, + delta_pending_request_count: int | None = None, + delta_total_request_count: int | None = None, + *, + update_had_multiple_clients: bool = False, + **_kwargs: dict[str, Any], + ) -> dict[str, Any]: + """Update the request queue metadata in the database. + + Args: + session: The SQLAlchemy session to use for database operations. + new_handled_request_count: If provided, update the handled_request_count to this value. + new_pending_request_count: If provided, update the pending_request_count to this value. + new_total_request_count: If provided, update the total_request_count to this value. + delta_handled_request_count: If provided, add this value to the handled_request_count. + delta_pending_request_count: If provided, add this value to the pending_request_count. + delta_total_request_count: If provided, add this value to the total_request_count. + update_had_multiple_clients: If True, set had_multiple_clients to True. + """ + values_to_set: dict[str, Any] = {} + + if update_had_multiple_clients: + values_to_set['had_multiple_clients'] = True + + if new_handled_request_count is not None: + values_to_set['handled_request_count'] = new_handled_request_count + elif delta_handled_request_count is not None: + values_to_set['handled_request_count'] = ( + self._METADATA_TABLE.handled_request_count + delta_handled_request_count + ) + + if new_pending_request_count is not None: + values_to_set['pending_request_count'] = new_pending_request_count + elif delta_pending_request_count is not None: + values_to_set['pending_request_count'] = ( + self._METADATA_TABLE.pending_request_count + delta_pending_request_count + ) + + if new_total_request_count is not None: + values_to_set['total_request_count'] = new_total_request_count + elif delta_total_request_count is not None: + values_to_set['total_request_count'] = self._METADATA_TABLE.total_request_count + delta_total_request_count + + return values_to_set + + @staticmethod + def _get_int_id_from_unique_key(unique_key: str) -> int: + """Generate a deterministic integer ID for a unique_key. + + Args: + unique_key: Unique key to be used to generate ID. + + Returns: + An integer ID based on the unique_key. + """ + hashed_key = sha256(unique_key.encode('utf-8')).hexdigest() + name_length = 15 + return int(hashed_key[:name_length], 16) diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index 88f38cbfc5..98c0927e4e 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -122,7 +122,7 @@ async def test_record_and_content_verification(dataset_client: SQLDatasetClient) assert metadata.accessed_at is not None async with dataset_client.get_session() as session: - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.metadata_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 1 @@ -134,7 +134,7 @@ async def test_record_and_content_verification(dataset_client: SQLDatasetClient) await dataset_client.push_data(items) async with dataset_client.get_session() as session: - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.metadata_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 4 @@ -147,7 +147,7 @@ async def test_drop_removes_records(dataset_client: SQLDatasetClient) -> None: client_metadata = await dataset_client.get_metadata() async with dataset_client.get_session() as session: - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == client_metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.metadata_id == client_metadata.id) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 1 @@ -156,7 +156,7 @@ async def test_drop_removes_records(dataset_client: SQLDatasetClient) -> None: await dataset_client.drop() async with dataset_client.get_session() as session: - stmt = select(DatasetItemDB).where(DatasetItemDB.dataset_id == client_metadata.id) + stmt = select(DatasetItemDB).where(DatasetItemDB.metadata_id == client_metadata.id) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 0 diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py index 7e3350719c..a8274a93da 100644 --- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -123,7 +123,7 @@ async def test_request_records_persistence(rq_client: SQLRequestQueueClient) -> metadata_client = await rq_client.get_metadata() async with rq_client.get_session() as session: - stmt = select(RequestDB).where(RequestDB.queue_id == metadata_client.id) + stmt = select(RequestDB).where(RequestDB.metadata_id == metadata_client.id) result = await session.execute(stmt) db_requests = result.scalars().all() assert len(db_requests) == 3 @@ -137,7 +137,7 @@ async def test_drop_removes_records(rq_client: SQLRequestQueueClient) -> None: await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) metadata = await rq_client.get_metadata() async with rq_client.get_session() as session: - stmt = select(RequestDB).where(RequestDB.queue_id == metadata.id) + stmt = select(RequestDB).where(RequestDB.metadata_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 1 @@ -145,7 +145,7 @@ async def test_drop_removes_records(rq_client: SQLRequestQueueClient) -> None: await rq_client.drop() async with rq_client.get_session() as session: - stmt = select(RequestDB).where(RequestDB.queue_id == metadata.id) + stmt = select(RequestDB).where(RequestDB.metadata_id == metadata.id) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 0 From 9f5e640e2b745c39347c29cad76e9d5599ba2f9c Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Thu, 21 Aug 2025 21:10:01 +0000 Subject: [PATCH 18/29] fix len strict for metadata_id in kvs_record --- src/crawlee/storage_clients/_sql/_db_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index e5aee1b27a..627cada7b9 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -131,7 +131,7 @@ class KeyValueStoreRecordDB(Base): __tablename__ = 'kvs_record' metadata_id: Mapped[str] = mapped_column( - String(255), ForeignKey('kvs_metadata.id', ondelete='CASCADE'), primary_key=True, index=True + String(20), ForeignKey('kvs_metadata.id', ondelete='CASCADE'), primary_key=True, index=True ) """Foreign key to metadata key-value store record.""" @@ -141,7 +141,7 @@ class KeyValueStoreRecordDB(Base): value: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) """Value stored as binary data to support any content type.""" - content_type: Mapped[str] = mapped_column(String(100), nullable=False) + content_type: Mapped[str] = mapped_column(String(50), nullable=False) """MIME type for proper value deserialization.""" size: Mapped[int | None] = mapped_column(Integer, nullable=False, default=0) From 77c1894aa37ebc46c32dd305ede3d6eae92fb41b Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Fri, 22 Aug 2025 14:54:46 +0000 Subject: [PATCH 19/29] fix cache --- .../_sql/_request_queue_client.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index e8732ba510..4947354e3e 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -86,7 +86,7 @@ def __init__( self._request_cache: deque[Request] = deque() """Cache for requests: ordered by sequence number.""" - self.in_progress_requests: set[str] = set() + self.in_progress_requests: set[int] = set() """Set of request IDs currently being processed.""" self._request_cache_needs_refresh = True @@ -356,7 +356,9 @@ async def get_request(self, unique_key: str) -> Request | None: request = Request.model_validate_json(request_db.data) - self.in_progress_requests.add(request.unique_key) + request_id = self._get_int_id_from_unique_key(request.unique_key) + + self.in_progress_requests.add(request_id) return request @@ -372,11 +374,11 @@ async def fetch_next_request(self) -> Request | None: # Get from cache while self._request_cache and next_request is None: candidate = self._request_cache.popleft() - + request_id = self._get_int_id_from_unique_key(candidate.unique_key) # Only check local state - if candidate.unique_key not in self.in_progress_requests: + if request_id not in self.in_progress_requests: next_request = candidate - self.in_progress_requests.add(next_request.unique_key) + self.in_progress_requests.add(request_id) if not self._request_cache: self._is_empty_cache = None @@ -387,12 +389,11 @@ async def fetch_next_request(self) -> Request | None: async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: self._is_empty_cache = None - if request.unique_key not in self.in_progress_requests: + request_id = self._get_int_id_from_unique_key(request.unique_key) + if request_id not in self.in_progress_requests: logger.warning(f'Marking request {request.unique_key} as handled that is not in progress.') return None - request_id = self._get_int_id_from_unique_key(request.unique_key) - # Update request in DB stmt = ( update(self._ITEM_TABLE) @@ -421,7 +422,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | await session.rollback() return None - self.in_progress_requests.discard(request.unique_key) + self.in_progress_requests.discard(request_id) return ProcessedRequest( unique_key=request.unique_key, @@ -438,12 +439,12 @@ async def reclaim_request( ) -> ProcessedRequest | None: self._is_empty_cache = None - if request.unique_key not in self.in_progress_requests: + request_id = self._get_int_id_from_unique_key(request.unique_key) + + if request_id not in self.in_progress_requests: logger.info(f'Reclaiming request {request.unique_key} that is not in progress.') return None - request_id = self._get_int_id_from_unique_key(request.unique_key) - async with self.get_autocommit_session() as autocommit: state = await self._get_state(autocommit) @@ -470,7 +471,7 @@ async def reclaim_request( await self._update_metadata(autocommit, update_modified_at=True, update_accessed_at=True) # Remove from in-progress - self.in_progress_requests.discard(request.unique_key) + self.in_progress_requests.discard(request_id) # Invalidate cache or add to cache if forefront: From b3c1aad6cdb90ff2dcbdc95bb635ea63292a3e4f Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sat, 23 Aug 2025 02:28:52 +0000 Subject: [PATCH 20/29] update queue for support multi-clients --- .../storage_clients/_sql/_db_models.py | 6 +- .../_sql/_request_queue_client.py | 231 ++++++++---------- 2 files changed, 109 insertions(+), 128 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 627cada7b9..fface773fc 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -178,8 +178,7 @@ class RequestDB(Base): __tablename__ = 'request' __table_args__ = ( - # Index for efficient SELECT to cache - Index('idx_queue_handled_seq', 'metadata_id', 'is_handled', 'sequence_number'), + Index('idx_fetch_available', 'metadata_id', 'is_handled', 'time_blocked_until', 'sequence_number'), ) request_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) @@ -199,6 +198,9 @@ class RequestDB(Base): is_handled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) """Processing status flag.""" + time_blocked_until: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=True) + """Timestamp until which this request is considered blocked for processing by other clients.""" + # Relationship back to metadata table queue: Mapped[RequestQueueMetadataDB] = relationship(back_populates='requests') diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 4947354e3e..aefe34d637 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -1,12 +1,13 @@ from __future__ import annotations -import asyncio from collections import deque +from datetime import datetime, timedelta, timezone from hashlib import sha256 from logging import getLogger from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import select, update +from cachetools import LRUCache +from sqlalchemy import or_, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import load_only from typing_extensions import override @@ -59,8 +60,12 @@ class SQLRequestQueueClient(RequestQueueClient, SQLClientMixin): _DEFAULT_NAME = 'default' """Default dataset name used when no name is provided.""" - _MAX_REQUESTS_IN_CACHE = 1000 - """Maximum number of requests to keep in cache for faster access.""" + _MAX_BATCH_FETCH_SIZE = 10 + """Maximum number of requests to fetch from the database in a single batch operation. + + Used to limit the number of requests loaded and locked for processing at once (improves efficiency and reduces + database load). + """ _METADATA_TABLE = RequestQueueMetadataDB """SQLAlchemy model for request queue metadata.""" @@ -71,6 +76,13 @@ class SQLRequestQueueClient(RequestQueueClient, SQLClientMixin): _CLIENT_TYPE = 'Request queue' """Human-readable client type for error messages.""" + _REQUEST_ID_BY_KEY: LRUCache[str, int] = LRUCache(maxsize=10000) + """Cache mapping unique keys to integer IDs.""" + + _BLOCK_REQUEST_TIME = 300 + """Number of seconds for which a request is considered blocked in the database after being fetched for processing. + """ + def __init__( self, *, @@ -83,20 +95,9 @@ def __init__( """ super().__init__(id=id, storage_client=storage_client) - self._request_cache: deque[Request] = deque() + self._pending_fetch_cache: deque[Request] = deque() """Cache for requests: ordered by sequence number.""" - self.in_progress_requests: set[int] = set() - """Set of request IDs currently being processed.""" - - self._request_cache_needs_refresh = True - """Flag indicating whether the cache needs to be refreshed from database.""" - - self._is_empty_cache: bool | None = None - """Cache for is_empty result: None means unknown, True/False is cached state.""" - - self._lock = asyncio.Lock() - async def _get_state(self, session: AsyncSession) -> RequestQueueStateDB: """Get the current state of the request queue.""" orm_state: RequestQueueStateDB | None = await session.get(RequestQueueStateDB, self._id) @@ -159,9 +160,7 @@ async def drop(self) -> None: """ await self._drop() - self._request_cache.clear() - self._request_cache_needs_refresh = True - self._is_empty_cache = None + self._pending_fetch_cache.clear() @override async def purge(self) -> None: @@ -175,11 +174,8 @@ async def purge(self) -> None: } ) - self._is_empty_cache = None - # Clear recoverable state - self._request_cache.clear() - self._request_cache_needs_refresh = True + self._pending_fetch_cache.clear() async def _add_batch_of_requests_optimization( self, @@ -191,7 +187,6 @@ async def _add_batch_of_requests_optimization( return AddRequestsResponse(processed_requests=[], unprocessed_requests=[]) # Clear empty cache since we're adding requests - self._is_empty_cache = None processed_requests = [] unprocessed_requests = [] @@ -206,6 +201,7 @@ async def _add_batch_of_requests_optimization( request_id = self._get_int_id_from_unique_key(req.unique_key) unique_requests[request_id] = req unique_key_by_request_id[request_id] = req.unique_key + self._REQUEST_ID_BY_KEY[req.unique_key] = request_id # Get existing requests by unique keys stmt = ( @@ -314,9 +310,6 @@ async def _add_batch_of_requests_optimization( ] ) - if forefront: - self._request_cache_needs_refresh = True - return AddRequestsResponse( processed_requests=processed_requests, unprocessed_requests=unprocessed_requests, @@ -336,7 +329,10 @@ async def add_batch_of_requests( @override async def get_request(self, unique_key: str) -> Request | None: - request_id = self._get_int_id_from_unique_key(unique_key) + if not (request_id := self._REQUEST_ID_BY_KEY.get(unique_key)): + request_id = self._get_int_id_from_unique_key(unique_key) + self._REQUEST_ID_BY_KEY[unique_key] = request_id + stmt = select(self._ITEM_TABLE).where( self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id ) @@ -354,60 +350,92 @@ async def get_request(self, unique_key: str) -> Request | None: if updated: await session.commit() - request = Request.model_validate_json(request_db.data) + return Request.model_validate_json(request_db.data) + + @override + async def fetch_next_request(self) -> Request | None: + if self._pending_fetch_cache: + return self._pending_fetch_cache.popleft() + + now = datetime.now(timezone.utc) + block_until = now + timedelta(seconds=self._BLOCK_REQUEST_TIME) + dialect = self._storage_client.get_dialect_name() + + # Get available requests not blocked by another client + stmt = ( + select(self._ITEM_TABLE) + .where( + self._ITEM_TABLE.metadata_id == self._id, + self._ITEM_TABLE.is_handled.is_(False), + or_(self._ITEM_TABLE.time_blocked_until.is_(None), self._ITEM_TABLE.time_blocked_until < now), + ) + .order_by(self._ITEM_TABLE.sequence_number.asc()) + .limit(self._MAX_BATCH_FETCH_SIZE) + ) + + # We use the `skip_locked` database mechanism to prevent the “interception” of requests by another client + if dialect in {'postgresql', 'mysql'}: + stmt = stmt.with_for_update(skip_locked=True) - request_id = self._get_int_id_from_unique_key(request.unique_key) + async with self.get_session() as session: + result = await session.execute(stmt) + requests_db = result.scalars().all() + if not requests_db: + return None - self.in_progress_requests.add(request_id) + request_ids = {r.request_id for r in requests_db} - return request + # Mark the requests as blocked + update_stmt = ( + update(self._ITEM_TABLE) + .where( + self._ITEM_TABLE.metadata_id == self._id, + self._ITEM_TABLE.request_id.in_(request_ids), + self._ITEM_TABLE.is_handled.is_(False), + or_(self._ITEM_TABLE.time_blocked_until.is_(None), self._ITEM_TABLE.time_blocked_until < now), + ) + .values(time_blocked_until=block_until) + .returning(self._ITEM_TABLE.request_id) + ) - @override - async def fetch_next_request(self) -> Request | None: - # Refresh cache if needed - async with self._lock: - if self._request_cache_needs_refresh or not self._request_cache: - await self._refresh_cache() + update_result = await session.execute(update_stmt) - next_request = None + # Get IDs of successfully blocked requests + blocked_ids = {row[0] for row in update_result.fetchall()} - # Get from cache - while self._request_cache and next_request is None: - candidate = self._request_cache.popleft() - request_id = self._get_int_id_from_unique_key(candidate.unique_key) - # Only check local state - if request_id not in self.in_progress_requests: - next_request = candidate - self.in_progress_requests.add(request_id) + if not blocked_ids: + await session.rollback() + return None - if not self._request_cache: - self._is_empty_cache = None + await self._update_metadata(session, update_accessed_at=True) - return next_request + await session.commit() - @override - async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: - self._is_empty_cache = None + requests = [Request.model_validate_json(r.data) for r in requests_db if r.request_id in blocked_ids] - request_id = self._get_int_id_from_unique_key(request.unique_key) - if request_id not in self.in_progress_requests: - logger.warning(f'Marking request {request.unique_key} as handled that is not in progress.') + if not requests: return None + self._pending_fetch_cache.extend(requests[1:]) + + return requests[0] + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + if not (request_id := self._REQUEST_ID_BY_KEY.get(request.unique_key)): + request_id = self._get_int_id_from_unique_key(request.unique_key) + # Update request in DB stmt = ( update(self._ITEM_TABLE) .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id) - .values(is_handled=True) + .values(is_handled=True, time_blocked_until=None) ) - async with self.get_session() as session: result = await session.execute(stmt) - if result.rowcount == 0: logger.warning(f'Request {request.unique_key} not found in database.') return None - await self._update_metadata( session, delta_handled_request_count=1, @@ -415,15 +443,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | update_modified_at=True, update_accessed_at=True, ) - - try: - await session.commit() - except SQLAlchemyError: - await session.rollback() - return None - - self.in_progress_requests.discard(request_id) - + await session.commit() return ProcessedRequest( unique_key=request.unique_key, was_already_present=True, @@ -437,13 +457,12 @@ async def reclaim_request( *, forefront: bool = False, ) -> ProcessedRequest | None: - self._is_empty_cache = None - - request_id = self._get_int_id_from_unique_key(request.unique_key) + if not (request_id := self._REQUEST_ID_BY_KEY.get(request.unique_key)): + request_id = self._get_int_id_from_unique_key(request.unique_key) - if request_id not in self.in_progress_requests: - logger.info(f'Reclaiming request {request.unique_key} that is not in progress.') - return None + stmt = update(self._ITEM_TABLE).where( + self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id + ) async with self.get_autocommit_session() as autocommit: state = await self._get_state(autocommit) @@ -452,33 +471,24 @@ async def reclaim_request( if forefront: new_sequence = state.forefront_sequence_counter state.forefront_sequence_counter -= 1 + now = datetime.now(timezone.utc) + block_until = now + timedelta(seconds=self._BLOCK_REQUEST_TIME) + # Extend blocking for forefront request, it is considered blocked by the current client. + stmt = stmt.values(sequence_number=new_sequence, time_blocked_until=block_until) else: new_sequence = state.sequence_counter state.sequence_counter += 1 - - stmt = ( - update(self._ITEM_TABLE) - .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id) - .values(sequence_number=new_sequence) - ) + stmt = stmt.values(sequence_number=new_sequence, time_blocked_until=None) result = await autocommit.execute(stmt) - if result.rowcount == 0: logger.warning(f'Request {request.unique_key} not found in database.') return None - await self._update_metadata(autocommit, update_modified_at=True, update_accessed_at=True) - # Remove from in-progress - self.in_progress_requests.discard(request_id) - - # Invalidate cache or add to cache + # put the forefront request at the beginning of the cache if forefront: - self._request_cache_needs_refresh = True - elif len(self._request_cache) < self._MAX_REQUESTS_IN_CACHE: - # For regular requests, we can add to the end if there's space - self._request_cache.append(request) + self._pending_fetch_cache.appendleft(request) return ProcessedRequest( unique_key=request.unique_key, @@ -488,12 +498,8 @@ async def reclaim_request( @override async def is_empty(self) -> bool: - if self._is_empty_cache is not None: - return self._is_empty_cache - - # If there are in-progress requests, not empty - if len(self.in_progress_requests) > 0: - self._is_empty_cache = False + # Check in-memory cache for requests + if self._pending_fetch_cache: return False # Check database for unhandled requests @@ -502,41 +508,14 @@ async def is_empty(self) -> bool: if not metadata_orm: raise ValueError(f'Request queue with ID "{self._id}" not found.') - self._is_empty_cache = metadata_orm.pending_request_count == 0 + empty = metadata_orm.pending_request_count == 0 updated = await self._update_metadata(session, update_accessed_at=True) # Commit updates to the metadata if updated: await session.commit() - return self._is_empty_cache - - async def _refresh_cache(self) -> None: - """Refresh the request cache from database.""" - self._request_cache.clear() - - async with self.get_session() as session: - # Simple query - get unhandled requests not in progress - stmt = ( - select(self._ITEM_TABLE) - .where( - self._ITEM_TABLE.metadata_id == self._id, - self._ITEM_TABLE.is_handled == False, # noqa: E712 - self._ITEM_TABLE.request_id.notin_(self.in_progress_requests), - ) - .order_by(self._ITEM_TABLE.sequence_number.asc()) - .limit(self._MAX_REQUESTS_IN_CACHE) - ) - - result = await session.execute(stmt) - request_dbs = result.scalars().all() - - # Add to cache in order - for request_db in request_dbs: - request = Request.model_validate_json(request_db.data) - self._request_cache.append(request) - - self._request_cache_needs_refresh = False + return empty async def _update_metadata( self, From fb8ce7d12cb64e0aa70d91f923cfda5a148a8852 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sat, 23 Aug 2025 05:35:29 +0000 Subject: [PATCH 21/29] fix metadata calculate --- .../storage_clients/_sql/_client_mixin.py | 5 +- .../_sql/_request_queue_client.py | 116 +++++++----------- 2 files changed, 47 insertions(+), 74 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py index eba510a36c..4ff911d63c 100644 --- a/src/crawlee/storage_clients/_sql/_client_mixin.py +++ b/src/crawlee/storage_clients/_sql/_client_mixin.py @@ -234,7 +234,10 @@ async def _update_metadata( values_to_set.update(self._specific_update_metadata(**kwargs)) if values_to_set: - stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(**values_to_set) + if (stmt := values_to_set.pop('custom_stmt', None)) is None: + stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id) + + stmt = stmt.values(**values_to_set) await session.execute(stmt) return True diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index aefe34d637..7d5d0ae0a9 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, cast from cachetools import LRUCache -from sqlalchemy import or_, select, update +from sqlalchemy import func, or_, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import load_only from typing_extensions import override @@ -190,8 +190,7 @@ async def _add_batch_of_requests_optimization( processed_requests = [] unprocessed_requests = [] - delta_total_request_count = 0 - delta_pending_request_count = 0 + metadata_recalculate = False # Deduplicate requests by unique_key upfront unique_requests = {} @@ -222,7 +221,8 @@ async def _add_batch_of_requests_optimization( existing_requests = {req.request_id: req for req in result.scalars()} state = await self._get_state(session) insert_values: list[dict] = [] - for request_id, request in unique_requests.items(): + + for request_id, request in sorted(unique_requests.items()): existing_req_db = existing_requests.get(request_id) if existing_req_db is None or not existing_req_db.is_handled: value = { @@ -241,8 +241,7 @@ async def _add_batch_of_requests_optimization( insert_values.append(value) if existing_req_db is None: - delta_total_request_count += 1 - delta_pending_request_count += 1 + metadata_recalculate = True processed_requests.append( ProcessedRequest( unique_key=request.unique_key, @@ -277,6 +276,7 @@ async def _add_batch_of_requests_optimization( self._ITEM_TABLE, insert_values, update_columns=['sequence_number'], + conflict_cols=['request_id', 'metadata_id'], ) await session.execute(upsert_stmt) else: @@ -284,10 +284,12 @@ async def _add_batch_of_requests_optimization( insert_stmt_with_ignore = self.build_insert_stmt_with_ignore(self._ITEM_TABLE, insert_values) await session.execute(insert_stmt_with_ignore) + if metadata_recalculate: + await self._block_metadata_for_update(session) + await self._update_metadata( session, - delta_total_request_count=delta_total_request_count, - delta_pending_request_count=delta_pending_request_count, + recalculate=metadata_recalculate, update_modified_at=True, update_accessed_at=True, ) @@ -297,6 +299,13 @@ async def _add_batch_of_requests_optimization( except SQLAlchemyError as e: await session.rollback() logger.warning(f'Failed to commit session: {e}') + await self._block_metadata_for_update(session) + await self._update_metadata( + session, + recalculate=True, + update_modified_at=True, + update_accessed_at=True, + ) processed_requests.clear() unprocessed_requests.extend( [ @@ -517,68 +526,10 @@ async def is_empty(self) -> bool: return empty - async def _update_metadata( - self, - session: AsyncSession, - *, - new_handled_request_count: int | None = None, - new_pending_request_count: int | None = None, - new_total_request_count: int | None = None, - delta_handled_request_count: int | None = None, - delta_pending_request_count: int | None = None, - delta_total_request_count: int | None = None, - update_had_multiple_clients: bool = False, - update_accessed_at: bool = False, - update_modified_at: bool = False, - **_kwargs: dict[str, Any], - ) -> bool: - """Update the request queue metadata in the database. - - Args: - session: The SQLAlchemy session to use for database operations. - new_handled_request_count: If provided, update the handled_request_count to this value. - new_pending_request_count: If provided, update the pending_request_count to this value. - new_total_request_count: If provided, update the total_request_count to this value. - delta_handled_request_count: If provided, add this value to the handled_request_count. - delta_pending_request_count: If provided, add this value to the pending_request_count. - delta_total_request_count: If provided, add this value to the total_request_count. - update_had_multiple_clients: If True, set had_multiple_clients to True. - update_accessed_at: If True, update the `accessed_at` timestamp to the current time. - update_modified_at: If True, update the `modified_at` timestamp to the current time. - """ - values_to_set = self._default_update_metadata( - update_accessed_at=update_accessed_at, - update_modified_at=update_modified_at, - ) - - if update_had_multiple_clients: - values_to_set['had_multiple_clients'] = True - - if new_handled_request_count is not None: - values_to_set['handled_request_count'] = new_handled_request_count - elif delta_handled_request_count is not None: - values_to_set['handled_request_count'] = ( - self._METADATA_TABLE.handled_request_count + delta_handled_request_count - ) - - if new_pending_request_count is not None: - values_to_set['pending_request_count'] = new_pending_request_count - elif delta_pending_request_count is not None: - values_to_set['pending_request_count'] = ( - self._METADATA_TABLE.pending_request_count + delta_pending_request_count - ) - - if new_total_request_count is not None: - values_to_set['total_request_count'] = new_total_request_count - elif delta_total_request_count is not None: - values_to_set['total_request_count'] = self._METADATA_TABLE.total_request_count + delta_total_request_count - - if values_to_set: - stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(**values_to_set) + async def _block_metadata_for_update(self, session: AsyncSession) -> None: + if self._storage_client.get_dialect_name() in {'postgresql', 'mysql'}: + stmt = select(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).with_for_update() await session.execute(stmt) - return True - - return False def _specific_update_metadata( self, @@ -587,8 +538,8 @@ def _specific_update_metadata( new_total_request_count: int | None = None, delta_handled_request_count: int | None = None, delta_pending_request_count: int | None = None, - delta_total_request_count: int | None = None, *, + recalculate: bool = False, update_had_multiple_clients: bool = False, **_kwargs: dict[str, Any], ) -> dict[str, Any]: @@ -601,7 +552,7 @@ def _specific_update_metadata( new_total_request_count: If provided, update the total_request_count to this value. delta_handled_request_count: If provided, add this value to the handled_request_count. delta_pending_request_count: If provided, add this value to the pending_request_count. - delta_total_request_count: If provided, add this value to the total_request_count. + recalculate: If True, recalculate the pending_request_count, and total_request_count on request table. update_had_multiple_clients: If True, set had_multiple_clients to True. """ values_to_set: dict[str, Any] = {} @@ -625,8 +576,27 @@ def _specific_update_metadata( if new_total_request_count is not None: values_to_set['total_request_count'] = new_total_request_count - elif delta_total_request_count is not None: - values_to_set['total_request_count'] = self._METADATA_TABLE.total_request_count + delta_total_request_count + + if recalculate: + pending_count = ( + select(func.count()) + .select_from(self._ITEM_TABLE) + .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.is_handled.is_(False)) + .scalar_subquery() + ) + + total_count = ( + select(func.count()) + .select_from(self._ITEM_TABLE) + .where(self._ITEM_TABLE.metadata_id == self._id) + .scalar_subquery() + ) + + stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id) + + values_to_set['custom_stmt'] = stmt + values_to_set['pending_request_count'] = pending_count + values_to_set['total_request_count'] = total_count return values_to_set From 63249bba4efd334a84df45e5c93a4516fc6ff85b Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sat, 23 Aug 2025 05:39:39 +0000 Subject: [PATCH 22/29] Add experimental warning --- src/crawlee/storage_clients/_sql/_storage_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index 1bb8df58ad..f8189961c0 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -36,6 +36,9 @@ class SQLStorageClient(StorageClient): Database schema is automatically created during initialization. SQLite databases receive performance optimizations including WAL mode and increased cache size. + + Warning: + This is an experimental feature. The behavior and interface may change in future versions. """ _DB_NAME = 'crawlee.db' From 0d62dcf157f170c62de2f5bf154fdfe6b9afd120 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sun, 24 Aug 2025 01:24:26 +0000 Subject: [PATCH 23/29] remove mysql --- src/crawlee/storage_clients/_sql/_client_mixin.py | 9 --------- .../storage_clients/_sql/_request_queue_client.py | 6 +++--- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py index 4ff911d63c..869403cab0 100644 --- a/src/crawlee/storage_clients/_sql/_client_mixin.py +++ b/src/crawlee/storage_clients/_sql/_client_mixin.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, ClassVar from sqlalchemy import delete, select, text, update -from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as lite_insert from sqlalchemy.exc import SQLAlchemyError @@ -106,9 +105,6 @@ def build_insert_stmt_with_ignore( if dialect == 'postgresql': return pg_insert(table_model).values(insert_values).on_conflict_do_nothing() - if dialect == 'mysql': - return mysql_insert(table_model).values(insert_values).on_duplicate_key_update() - if dialect == 'sqlite': return lite_insert(table_model).values(insert_values).on_conflict_do_nothing() @@ -145,11 +141,6 @@ def build_upsert_stmt( set_ = {col: getattr(lite_stmt.excluded, col) for col in update_columns} return lite_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=set_) - if dialect == 'mysql': - mysql_stmt = mysql_insert(table_model).values(insert_values) - set_ = {col: mysql_stmt.inserted[col] for col in update_columns} - return mysql_stmt.on_duplicate_key_update(**set_) - raise NotImplementedError(f'Upsert not supported for dialect: {dialect}') async def _purge(self, metadata_kwargs: dict[str, Any]) -> None: diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index 7d5d0ae0a9..a18117a327 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -331,7 +331,7 @@ async def add_batch_of_requests( *, forefront: bool = False, ) -> AddRequestsResponse: - if self._storage_client.get_dialect_name() in {'sqlite', 'postgresql', 'mysql'}: + if self._storage_client.get_dialect_name() in {'sqlite', 'postgresql'}: return await self._add_batch_of_requests_optimization(requests, forefront=forefront) raise NotImplementedError('Batch addition is not supported for this database dialect.') @@ -383,7 +383,7 @@ async def fetch_next_request(self) -> Request | None: ) # We use the `skip_locked` database mechanism to prevent the “interception” of requests by another client - if dialect in {'postgresql', 'mysql'}: + if dialect == 'postgresql': stmt = stmt.with_for_update(skip_locked=True) async with self.get_session() as session: @@ -527,7 +527,7 @@ async def is_empty(self) -> bool: return empty async def _block_metadata_for_update(self, session: AsyncSession) -> None: - if self._storage_client.get_dialect_name() in {'postgresql', 'mysql'}: + if self._storage_client.get_dialect_name() == 'postgresql': stmt = select(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).with_for_update() await session.execute(stmt) From dffeb7619abf1e339e2c93d55dd9f850850bb6ee Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sun, 24 Aug 2025 16:28:14 +0000 Subject: [PATCH 24/29] raise Error for unsupported dialects --- pyproject.toml | 4 ++-- src/crawlee/storage_clients/_sql/_storage_client.py | 9 ++++++++- uv.lock | 10 +++++----- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 974589a2f2..025443c142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ ] [project.optional-dependencies] -all = ["crawlee[adaptive-crawler,beautifulsoup,cli,curl-impersonate,httpx,parsel,playwright,otel,sql]"] +all = ["crawlee[adaptive-crawler,beautifulsoup,cli,curl-impersonate,httpx,parsel,playwright,otel,sql_storage]"] adaptive-crawler = [ "jaro-winkler>=2.0.3", "playwright>=1.27.0", @@ -71,7 +71,7 @@ otel = [ "opentelemetry-semantic-conventions>=0.54", "wrapt>=1.17.0", ] -sql = [ +sql_storage = [ "sqlalchemy[asyncio]~=2.0.0,<3.0.0", "aiosqlite>=0.21.0", ] diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index f8189961c0..f053d764b0 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -130,8 +130,15 @@ async def initialize(self, configuration: Configuration) -> None: if not self._initialized: engine = self._get_or_create_engine(configuration) async with engine.begin() as conn: - # Set SQLite pragmas for performance and consistency self._dialect_name = engine.dialect.name + + if self._dialect_name not in ('sqlite', 'postgresql'): + raise ValueError( + f'Unsupported database dialect: {self._dialect_name}. Supported: sqlite, postgresql.\n', + 'Consider using a different database.', + ) + + # Set SQLite pragmas for performance and consistency if self._default_flag: await conn.execute(text('PRAGMA journal_mode=WAL')) # Better concurrency await conn.execute(text('PRAGMA synchronous=NORMAL')) # Balanced safety/speed diff --git a/uv.lock b/uv.lock index b8d1657093..66247da9bf 100644 --- a/uv.lock +++ b/uv.lock @@ -674,7 +674,7 @@ playwright = [ { name = "browserforge" }, { name = "playwright" }, ] -sql = [ +sql-storage = [ { name = "aiosqlite" }, { name = "sqlalchemy", extra = ["asyncio"] }, ] @@ -705,7 +705,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aiosqlite", marker = "extra == 'sql'", specifier = ">=0.21.0" }, + { name = "aiosqlite", marker = "extra == 'sql-storage'", specifier = ">=0.21.0" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'adaptive-crawler'", specifier = ">=0.0.2" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'httpx'", specifier = ">=0.0.2" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'playwright'", specifier = ">=0.0.2" }, @@ -716,7 +716,7 @@ requires-dist = [ { name = "cachetools", specifier = ">=5.5.0" }, { name = "colorama", specifier = ">=0.4.0" }, { name = "cookiecutter", marker = "extra == 'cli'", specifier = ">=2.6.0" }, - { name = "crawlee", extras = ["adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql"], marker = "extra == 'all'" }, + { name = "crawlee", extras = ["adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage"], marker = "extra == 'all'" }, { name = "curl-cffi", marker = "extra == 'curl-impersonate'", specifier = ">=0.9.0" }, { name = "html5lib", marker = "extra == 'beautifulsoup'", specifier = ">=1.0" }, { name = "httpx", extras = ["brotli", "http2", "zstd"], marker = "extra == 'httpx'", specifier = ">=0.27.0" }, @@ -740,14 +740,14 @@ requires-dist = [ { name = "pyee", specifier = ">=9.0.0" }, { name = "rich", marker = "extra == 'cli'", specifier = ">=13.9.0" }, { name = "scikit-learn", marker = "extra == 'adaptive-crawler'", specifier = ">=1.6.0" }, - { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql'", specifier = "~=2.0.0,<3.0.0" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql-storage'", specifier = "~=2.0.0,<3.0.0" }, { name = "tldextract", specifier = ">=5.1.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.0" }, { name = "typing-extensions", specifier = ">=4.1.0" }, { name = "wrapt", marker = "extra == 'otel'", specifier = ">=1.17.0" }, { name = "yarl", specifier = ">=1.18.0" }, ] -provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql"] +provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage"] [package.metadata.requires-dev] dev = [ From 61ba5129b7e58e7851d169e22db90b8895ca68d6 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sun, 24 Aug 2025 17:45:05 +0000 Subject: [PATCH 25/29] optimize update timestamps in metadata --- pyproject.toml | 4 ++ .../storage_clients/_sql/_client_mixin.py | 36 ++++++++--- .../storage_clients/_sql/_dataset_client.py | 6 +- .../_sql/_request_queue_client.py | 8 +-- uv.lock | 60 ++++++++++++++++++- 5 files changed, 99 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 025443c142..484a76fc57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,10 @@ otel = [ "opentelemetry-semantic-conventions>=0.54", "wrapt>=1.17.0", ] +sql_storage_posgres = [ + "sqlalchemy[asyncio]~=2.0.0,<3.0.0", + "asyncpg>=0.24.0" +] sql_storage = [ "sqlalchemy[asyncio]~=2.0.0,<3.0.0", "aiosqlite>=0.21.0", diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py index 869403cab0..a4d03aad8b 100644 --- a/src/crawlee/storage_clients/_sql/_client_mixin.py +++ b/src/crawlee/storage_clients/_sql/_client_mixin.py @@ -60,6 +60,7 @@ def __init__(self, *, id: str, storage_client: SQLStorageClient) -> None: # Time tracking to reduce database writes during frequent operation self._accessed_at_allow_update_after: datetime | None = None + self._modified_at_allow_update_after: datetime | None = None self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() def get_session(self) -> AsyncSession: @@ -167,26 +168,43 @@ async def _drop(self) -> None: await autocommit.execute(stmt) def _default_update_metadata( - self, *, update_accessed_at: bool = False, update_modified_at: bool = False + self, *, update_accessed_at: bool = False, update_modified_at: bool = False, force: bool = False ) -> dict[str, Any]: """Prepare common metadata updates with rate limiting. Args: update_accessed_at: Whether to update accessed_at timestamp update_modified_at: Whether to update modified_at timestamp + force: Whether to force the update regardless of rate limiting """ - now = datetime.now(timezone.utc) values_to_set: dict[str, Any] = {} + now = datetime.now(timezone.utc) - if update_accessed_at and ( - self._accessed_at_allow_update_after is None or (now >= self._accessed_at_allow_update_after) + # If the record must be updated (for example, when updating counters), we update timestamps and shift the time. + if force: + if update_modified_at: + values_to_set['modified_at'] = now + self._modified_at_allow_update_after = now + self._accessed_modified_update_interval + if update_accessed_at: + values_to_set['accessed_at'] = now + self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval + + elif update_modified_at and ( + self._modified_at_allow_update_after is None or now >= self._modified_at_allow_update_after + ): + values_to_set['modified_at'] = now + self._modified_at_allow_update_after = now + self._accessed_modified_update_interval + # The record will be updated, we can update `accessed_at` and shift the time. + if update_accessed_at: + values_to_set['accessed_at'] = now + self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval + + elif update_accessed_at and ( + self._accessed_at_allow_update_after is None or now >= self._accessed_at_allow_update_after ): values_to_set['accessed_at'] = now self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval - if update_modified_at: - values_to_set['modified_at'] = now - return values_to_set def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]: @@ -205,6 +223,7 @@ async def _update_metadata( *, update_accessed_at: bool = False, update_modified_at: bool = False, + force: bool = False, **kwargs: Any, ) -> bool: """Update storage metadata combining common and specific fields. @@ -213,13 +232,14 @@ async def _update_metadata( session: Active database session update_accessed_at: Whether to update accessed_at timestamp update_modified_at: Whether to update modified_at timestamp + force: Whether to force the update timestamps regardless of rate limiting **kwargs: Additional arguments for _specific_update_metadata Returns: True if any updates were made, False otherwise """ values_to_set = self._default_update_metadata( - update_accessed_at=update_accessed_at, update_modified_at=update_modified_at + update_accessed_at=update_accessed_at, update_modified_at=update_modified_at, force=force ) values_to_set.update(self._specific_update_metadata(**kwargs)) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 171cdadbc9..554f656215 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -113,7 +113,9 @@ async def purge(self) -> None: Resets item_count to 0 and deletes all records from dataset_item table. """ - await self._purge(metadata_kwargs={'new_item_count': 0, 'update_accessed_at': True, 'update_modified_at': True}) + await self._purge( + metadata_kwargs={'new_item_count': 0, 'update_accessed_at': True, 'update_modified_at': True, 'force': True} + ) @override async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: @@ -139,7 +141,7 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: await autocommit.execute(stmt) await self._update_metadata( - autocommit, update_accessed_at=True, update_modified_at=True, delta_item_count=len(data) + autocommit, update_accessed_at=True, update_modified_at=True, delta_item_count=len(data), force=True ) def _prepare_get_stmt( diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index a18117a327..cce4bc9e10 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -171,6 +171,7 @@ async def purge(self) -> None: 'update_modified_at': True, 'new_pending_request_count': 0, 'new_handled_request_count': 0, + 'force': True, } ) @@ -292,6 +293,7 @@ async def _add_batch_of_requests_optimization( recalculate=metadata_recalculate, update_modified_at=True, update_accessed_at=True, + force=metadata_recalculate, ) try: @@ -301,10 +303,7 @@ async def _add_batch_of_requests_optimization( logger.warning(f'Failed to commit session: {e}') await self._block_metadata_for_update(session) await self._update_metadata( - session, - recalculate=True, - update_modified_at=True, - update_accessed_at=True, + session, recalculate=True, update_modified_at=True, update_accessed_at=True, force=True ) processed_requests.clear() unprocessed_requests.extend( @@ -451,6 +450,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | delta_pending_request_count=-1, update_modified_at=True, update_accessed_at=True, + force=True, ) await session.commit() return ProcessedRequest( diff --git a/uv.lock b/uv.lock index 66247da9bf..6160a3b296 100644 --- a/uv.lock +++ b/uv.lock @@ -98,6 +98,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/ed/e97229a566617f2ae958a6b13e7cc0f585470eac730a73e9e82c32a3cdd2/arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80", size = 66419, upload-time = "2023-09-30T22:11:16.072Z" }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274, upload-time = "2024-11-06T16:41:39.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, +] + +[[package]] +name = "asyncpg" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746, upload-time = "2024-10-20T00:30:41.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/07/1650a8c30e3a5c625478fa8aafd89a8dd7d85999bf7169b16f54973ebf2c/asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e", size = 673143, upload-time = "2024-10-20T00:29:08.846Z" }, + { url = "https://files.pythonhosted.org/packages/a0/9a/568ff9b590d0954553c56806766914c149609b828c426c5118d4869111d3/asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0", size = 645035, upload-time = "2024-10-20T00:29:12.02Z" }, + { url = "https://files.pythonhosted.org/packages/de/11/6f2fa6c902f341ca10403743701ea952bca896fc5b07cc1f4705d2bb0593/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f", size = 2912384, upload-time = "2024-10-20T00:29:13.644Z" }, + { url = "https://files.pythonhosted.org/packages/83/83/44bd393919c504ffe4a82d0aed8ea0e55eb1571a1dea6a4922b723f0a03b/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af", size = 2947526, upload-time = "2024-10-20T00:29:15.871Z" }, + { url = "https://files.pythonhosted.org/packages/08/85/e23dd3a2b55536eb0ded80c457b0693352262dc70426ef4d4a6fc994fa51/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75", size = 2895390, upload-time = "2024-10-20T00:29:19.346Z" }, + { url = "https://files.pythonhosted.org/packages/9b/26/fa96c8f4877d47dc6c1864fef5500b446522365da3d3d0ee89a5cce71a3f/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f", size = 3015630, upload-time = "2024-10-20T00:29:21.186Z" }, + { url = "https://files.pythonhosted.org/packages/34/00/814514eb9287614188a5179a8b6e588a3611ca47d41937af0f3a844b1b4b/asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf", size = 568760, upload-time = "2024-10-20T00:29:22.769Z" }, + { url = "https://files.pythonhosted.org/packages/f0/28/869a7a279400f8b06dd237266fdd7220bc5f7c975348fea5d1e6909588e9/asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50", size = 625764, upload-time = "2024-10-20T00:29:25.882Z" }, + { url = "https://files.pythonhosted.org/packages/4c/0e/f5d708add0d0b97446c402db7e8dd4c4183c13edaabe8a8500b411e7b495/asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a", size = 674506, upload-time = "2024-10-20T00:29:27.988Z" }, + { url = "https://files.pythonhosted.org/packages/6a/a0/67ec9a75cb24a1d99f97b8437c8d56da40e6f6bd23b04e2f4ea5d5ad82ac/asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed", size = 645922, upload-time = "2024-10-20T00:29:29.391Z" }, + { url = "https://files.pythonhosted.org/packages/5c/d9/a7584f24174bd86ff1053b14bb841f9e714380c672f61c906eb01d8ec433/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a", size = 3079565, upload-time = "2024-10-20T00:29:30.832Z" }, + { url = "https://files.pythonhosted.org/packages/a0/d7/a4c0f9660e333114bdb04d1a9ac70db690dd4ae003f34f691139a5cbdae3/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956", size = 3109962, upload-time = "2024-10-20T00:29:33.114Z" }, + { url = "https://files.pythonhosted.org/packages/3c/21/199fd16b5a981b1575923cbb5d9cf916fdc936b377e0423099f209e7e73d/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056", size = 3064791, upload-time = "2024-10-20T00:29:34.677Z" }, + { url = "https://files.pythonhosted.org/packages/77/52/0004809b3427534a0c9139c08c87b515f1c77a8376a50ae29f001e53962f/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454", size = 3188696, upload-time = "2024-10-20T00:29:36.389Z" }, + { url = "https://files.pythonhosted.org/packages/52/cb/fbad941cd466117be58b774a3f1cc9ecc659af625f028b163b1e646a55fe/asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d", size = 567358, upload-time = "2024-10-20T00:29:37.915Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0a/0a32307cf166d50e1ad120d9b81a33a948a1a5463ebfa5a96cc5606c0863/asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f", size = 629375, upload-time = "2024-10-20T00:29:39.987Z" }, + { url = "https://files.pythonhosted.org/packages/4b/64/9d3e887bb7b01535fdbc45fbd5f0a8447539833b97ee69ecdbb7a79d0cb4/asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e", size = 673162, upload-time = "2024-10-20T00:29:41.88Z" }, + { url = "https://files.pythonhosted.org/packages/6e/eb/8b236663f06984f212a087b3e849731f917ab80f84450e943900e8ca4052/asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a", size = 637025, upload-time = "2024-10-20T00:29:43.352Z" }, + { url = "https://files.pythonhosted.org/packages/cc/57/2dc240bb263d58786cfaa60920779af6e8d32da63ab9ffc09f8312bd7a14/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3", size = 3496243, upload-time = "2024-10-20T00:29:44.922Z" }, + { url = "https://files.pythonhosted.org/packages/f4/40/0ae9d061d278b10713ea9021ef6b703ec44698fe32178715a501ac696c6b/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737", size = 3575059, upload-time = "2024-10-20T00:29:46.891Z" }, + { url = "https://files.pythonhosted.org/packages/c3/75/d6b895a35a2c6506952247640178e5f768eeb28b2e20299b6a6f1d743ba0/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a", size = 3473596, upload-time = "2024-10-20T00:29:49.201Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e7/3693392d3e168ab0aebb2d361431375bd22ffc7b4a586a0fc060d519fae7/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af", size = 3641632, upload-time = "2024-10-20T00:29:50.768Z" }, + { url = "https://files.pythonhosted.org/packages/32/ea/15670cea95745bba3f0352341db55f506a820b21c619ee66b7d12ea7867d/asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e", size = 560186, upload-time = "2024-10-20T00:29:52.394Z" }, + { url = "https://files.pythonhosted.org/packages/7e/6b/fe1fad5cee79ca5f5c27aed7bd95baee529c1bf8a387435c8ba4fe53d5c1/asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305", size = 621064, upload-time = "2024-10-20T00:29:53.757Z" }, + { url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373, upload-time = "2024-10-20T00:29:55.165Z" }, + { url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745, upload-time = "2024-10-20T00:29:57.14Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103, upload-time = "2024-10-20T00:29:58.499Z" }, + { url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471, upload-time = "2024-10-20T00:30:00.354Z" }, + { url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253, upload-time = "2024-10-20T00:30:02.794Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720, upload-time = "2024-10-20T00:30:04.501Z" }, + { url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404, upload-time = "2024-10-20T00:30:06.537Z" }, + { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" }, +] + [[package]] name = "backports-asyncio-runner" version = "1.2.0" @@ -678,6 +730,10 @@ sql-storage = [ { name = "aiosqlite" }, { name = "sqlalchemy", extra = ["asyncio"] }, ] +sql-storage-posgres = [ + { name = "asyncpg" }, + { name = "sqlalchemy", extra = ["asyncio"] }, +] [package.dev-dependencies] dev = [ @@ -709,6 +765,7 @@ requires-dist = [ { name = "apify-fingerprint-datapoints", marker = "extra == 'adaptive-crawler'", specifier = ">=0.0.2" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'httpx'", specifier = ">=0.0.2" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'playwright'", specifier = ">=0.0.2" }, + { name = "asyncpg", marker = "extra == 'sql-storage-posgres'", specifier = ">=0.24.0" }, { name = "beautifulsoup4", extras = ["lxml"], marker = "extra == 'beautifulsoup'", specifier = ">=4.12.0" }, { name = "browserforge", marker = "extra == 'adaptive-crawler'", specifier = ">=1.2.3" }, { name = "browserforge", marker = "extra == 'httpx'", specifier = ">=1.2.3" }, @@ -741,13 +798,14 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13.9.0" }, { name = "scikit-learn", marker = "extra == 'adaptive-crawler'", specifier = ">=1.6.0" }, { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql-storage'", specifier = "~=2.0.0,<3.0.0" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql-storage-posgres'", specifier = "~=2.0.0,<3.0.0" }, { name = "tldextract", specifier = ">=5.1.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.0" }, { name = "typing-extensions", specifier = ">=4.1.0" }, { name = "wrapt", marker = "extra == 'otel'", specifier = ">=1.17.0" }, { name = "yarl", specifier = ">=1.18.0" }, ] -provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage"] +provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage-posgres", "sql-storage"] [package.metadata.requires-dev] dev = [ From 46e12b42686f81119a6bd0e9e98956aa409ee8c9 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Sun, 24 Aug 2025 18:36:20 +0000 Subject: [PATCH 26/29] add docs --- .../sql_storage_client_basic_example.py | 10 ++ ...ql_storage_client_configuration_example.py | 32 +++++++ docs/guides/storage_clients.mdx | 93 +++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py create mode 100644 docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py diff --git a/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py b/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py new file mode 100644 index 0000000000..12e4ac5aa5 --- /dev/null +++ b/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py @@ -0,0 +1,10 @@ +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import SQLStorageClient + + +async def main() -> None: + # Create a new instance of storage client. + # Use the context manager to ensure that connections are properly cleaned up. + async with SQLStorageClient() as storage_client: + # And pass it to the crawler. + crawler = ParselCrawler(storage_client=storage_client) diff --git a/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py b/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py new file mode 100644 index 0000000000..ca4d0a6e06 --- /dev/null +++ b/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py @@ -0,0 +1,32 @@ +from sqlalchemy.ext.asyncio import create_async_engine + +from crawlee.configuration import Configuration +from crawlee.crawlers import ParselCrawler +from crawlee.storage_clients import SQLStorageClient + + +async def main() -> None: + # Create a new instance of storage client. + # Use the context manager to ensure that connections are properly cleaned up. + async with SQLStorageClient( + # Create an `engine` with the desired configuration + engine=create_async_engine( + 'postgresql+asyncpg://myuser:mypassword@localhost:5432/postgres', + future=True, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + pool_pre_ping=True, + echo=False, + ) + ) as storage_client: + # Create a configuration with custom settings. + configuration = Configuration( + purge_on_start=False, + ) + + # And pass them to the crawler. + crawler = ParselCrawler( + storage_client=storage_client, + configuration=configuration, + ) diff --git a/docs/guides/storage_clients.mdx b/docs/guides/storage_clients.mdx index 0c2a14ffe9..fd592b7a1f 100644 --- a/docs/guides/storage_clients.mdx +++ b/docs/guides/storage_clients.mdx @@ -14,6 +14,8 @@ import FileSystemStorageClientBasicExample from '!!raw-loader!roa-loader!./code_ import FileSystemStorageClientConfigurationExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/file_system_storage_client_configuration_example.py'; import CustomStorageClientExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/custom_storage_client_example.py'; import RegisteringStorageClientsExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/registering_storage_clients_example.py'; +import SQLStorageClientBasicExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/sql_storage_client_basic_example.py'; +import SQLStorageClientConfigurationExample from '!!raw-loader!roa-loader!./code_examples/storage_clients/sql_storage_client_configuration_example.py'; Storage clients provide a unified interface for interacting with `Dataset`, `KeyValueStore`, and `RequestQueue`, regardless of the underlying implementation. They handle operations like creating, reading, updating, and deleting storage instances, as well as managing data persistence and cleanup. This abstraction makes it easy to switch between different environments, such as local development and cloud production setups. @@ -50,6 +52,8 @@ class FileSystemStorageClient class MemoryStorageClient +class SQLStorageClient + class ApifyStorageClient %% ======================== @@ -58,6 +62,7 @@ class ApifyStorageClient StorageClient --|> FileSystemStorageClient StorageClient --|> MemoryStorageClient +StorageClient --|> SQLStorageClient StorageClient --|> ApifyStorageClient ``` @@ -125,6 +130,94 @@ The `MemoryStorageClient` does not persist data between runs. All data is lost w {MemoryStorageClientBasicExample} +### SQL storage client + +:::warning Experimental feature +The `SQLStorageClient` is experimental. Its API and behavior may change in future releases. +::: + +The `SQLStorageClient` provides persistent storage using a SQL database ([SQLite](https://sqlite.org/) by default, or [PostgreSQL](https://www.postgresql.org/)). It supports all Crawlee storage types and enables concurrent access from multiple independent clients or processes. + +:::note dependencies +Use crawlee['sql_storage'] for SQLite and crawlee['sql_storage_posgres'] for PostgreSQL +::: + + + {SQLStorageClientBasicExample} + + +Data is organized in relational tables. Below are the main tables and columns used for each storage type: + +```text +{DATABASE} +├── dataset_metadata +│ ├── id (PK) +│ ├── name +│ ├── accessed_at +│ ├── created_at +│ ├── modified_at +│ └── item_count +├── dataset_item +│ ├── order_id (PK) +│ ├── metadata_id (FK) +│ └── data +├── kvs_metadata +│ ├── id (PK) +│ ├── name +│ ├── accessed_at +│ ├── created_at +│ └── modified_at +├── kvs_record +│ ├── metadata_id (FK, PK) +│ ├── key (PK) +│ ├── value +│ ├── content_type +│ └── size +├── request_queue_metadata +│ ├── id (PK) +│ ├── name +│ ├── accessed_at +│ ├── created_at +│ ├── modified_at +│ ├── had_multiple_clients +│ ├── handled_request_count +│ ├── pending_request_count +│ └── total_request_count +├── request +│ ├── request_id (PK) +│ ├── metadata_id (FK, PK) +│ ├── data +│ ├── sequence_number +│ ├── is_handled +│ └── time_blocked_until +└── request_queue_state + ├── metadata_id (PK, FK) + ├── sequence_counter + └── forefront_sequence_counter +``` + +Where: + +- Each *_metadata table stores metadata for the storage instance. +- Data tables (dataset_item, kvs_record, request) store the actual items, records, or requests. +- request_queue_state is technical table for queue. +- Foreign keys (`metadata_id`) link data to the corresponding metadata record. + +Configuration options for the `SQLStorageClient` can be set through environment variables or the `Configuration` class: + +- **`storage_dir`** (env: `CRAWLEE_STORAGE_DIR`, default: `'./storage'`) - The root directory where the default SQLite database will be created if no connection string is provided. +- **`purge_on_start`** (env: `CRAWLEE_PURGE_ON_START`, default: `True`) - Whether to purge default storages on start. + +Configuration options for the SQLStorageClient can be set via constructor arguments: + +- **`connection_string`** (default: SQLite in Configuration storage dir) – SQLAlchemy connection string, e.g. 'sqlite+aiosqlite:///my.db' or 'postgresql+asyncpg://user:pass@host/db'. +- **`engine`** – Pre-configured SQLAlchemy AsyncEngine (optional). +- **`accessed_modified_update_interval`** – Minimum interval between metadata timestamp updates (default: 1 second). Reducing this parameter can significantly increase the load on the database. + + + {SQLStorageClientConfigurationExample} + + ## Creating a custom storage client A storage client consists of two parts: the storage client factory and individual storage type clients. The `StorageClient` acts as a factory that creates specific clients (`DatasetClient`, `KeyValueStoreClient`, `RequestQueueClient`) where the actual storage logic is implemented. From b92e385ef18c53335985519567063c19530da8c9 Mon Sep 17 00:00:00 2001 From: Max Bohomolov <34358312+Mantisus@users.noreply.github.com> Date: Tue, 26 Aug 2025 00:57:27 +0300 Subject: [PATCH 27/29] Update pyproject.toml Co-authored-by: Jan Buchar --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 484a76fc57..52ef8c0701 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ otel = [ "opentelemetry-semantic-conventions>=0.54", "wrapt>=1.17.0", ] -sql_storage_posgres = [ +sql_storage_postgres = [ "sqlalchemy[asyncio]~=2.0.0,<3.0.0", "asyncpg>=0.24.0" ] From 045fe9cda4925ae09c617678ebcbde3701376243 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Mon, 25 Aug 2025 23:19:49 +0000 Subject: [PATCH 28/29] up docs --- .../sql_storage_client_basic_example.py | 1 + ...ql_storage_client_configuration_example.py | 1 + docs/guides/storage_clients.mdx | 181 +++++++++++++----- pyproject.toml | 2 +- uv.lock | 11 +- 5 files changed, 139 insertions(+), 57 deletions(-) diff --git a/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py b/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py index 12e4ac5aa5..17b23ca96f 100644 --- a/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py +++ b/docs/guides/code_examples/storage_clients/sql_storage_client_basic_example.py @@ -4,6 +4,7 @@ async def main() -> None: # Create a new instance of storage client. + # This will also create an SQLite database file crawlee.db. # Use the context manager to ensure that connections are properly cleaned up. async with SQLStorageClient() as storage_client: # And pass it to the crawler. diff --git a/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py b/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py index ca4d0a6e06..6cdc7a3919 100644 --- a/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py +++ b/docs/guides/code_examples/storage_clients/sql_storage_client_configuration_example.py @@ -7,6 +7,7 @@ async def main() -> None: # Create a new instance of storage client. + # On first run, also creates tables in your PostgreSQL database. # Use the context manager to ensure that connections are properly cleaned up. async with SQLStorageClient( # Create an `engine` with the desired configuration diff --git a/docs/guides/storage_clients.mdx b/docs/guides/storage_clients.mdx index fd592b7a1f..403a6022dc 100644 --- a/docs/guides/storage_clients.mdx +++ b/docs/guides/storage_clients.mdx @@ -148,60 +148,139 @@ Use crawlee['sql_storage'] for SQLite and crawlee['sql_storage_posgres'] for Pos Data is organized in relational tables. Below are the main tables and columns used for each storage type: -```text -{DATABASE} -├── dataset_metadata -│ ├── id (PK) -│ ├── name -│ ├── accessed_at -│ ├── created_at -│ ├── modified_at -│ └── item_count -├── dataset_item -│ ├── order_id (PK) -│ ├── metadata_id (FK) -│ └── data -├── kvs_metadata -│ ├── id (PK) -│ ├── name -│ ├── accessed_at -│ ├── created_at -│ └── modified_at -├── kvs_record -│ ├── metadata_id (FK, PK) -│ ├── key (PK) -│ ├── value -│ ├── content_type -│ └── size -├── request_queue_metadata -│ ├── id (PK) -│ ├── name -│ ├── accessed_at -│ ├── created_at -│ ├── modified_at -│ ├── had_multiple_clients -│ ├── handled_request_count -│ ├── pending_request_count -│ └── total_request_count -├── request -│ ├── request_id (PK) -│ ├── metadata_id (FK, PK) -│ ├── data -│ ├── sequence_number -│ ├── is_handled -│ └── time_blocked_until -└── request_queue_state - ├── metadata_id (PK, FK) - ├── sequence_counter - └── forefront_sequence_counter +```mermaid +--- +config: + class: + hideEmptyMembersBox: true +--- + +classDiagram + +%% ======================== +%% Storage Clients +%% ======================== + +class SQLDatasetClient { + <> +} + +class SQLKeyValueStoreClient { + <> +} + +%% ======================== +%% Dataset Tables +%% ======================== + +class dataset_metadata { + <> + + id (PK) + + name + + accessed_at + + created_at + + modified_at + + item_count +} + +class dataset_item { + <
> + + order_id (PK) + + metadata_id (FK) + + data +} + +%% ======================== +%% Key-Value Store Tables +%% ======================== + +class kvs_metadata { + <
> + + id (PK) + + name + + accessed_at + + created_at + + modified_at +} + +class kvs_record { + <
> + + metadata_id (FK, PK) + + key (PK) + + value + + content_type + + size +} + +%% ======================== +%% Client to Table arrows +%% ======================== + +SQLDatasetClient --> dataset_metadata +SQLDatasetClient --> dataset_item + +SQLKeyValueStoreClient --> kvs_metadata +SQLKeyValueStoreClient --> kvs_record ``` +```mermaid +--- +config: + class: + hideEmptyMembersBox: true +--- -Where: +classDiagram + +%% ======================== +%% Storage Clients +%% ======================== -- Each *_metadata table stores metadata for the storage instance. -- Data tables (dataset_item, kvs_record, request) store the actual items, records, or requests. -- request_queue_state is technical table for queue. -- Foreign keys (`metadata_id`) link data to the corresponding metadata record. +class SQLRequestQueueClient { + <> +} + +%% ======================== +%% Request Queue Tables +%% ======================== + +class request_queue_metadata { + <
> + + id (PK) + + name + + accessed_at + + created_at + + modified_at + + had_multiple_clients + + handled_request_count + + pending_request_count + + total_request_count +} + +class request { + <
> + + request_id (PK) + + metadata_id (FK, PK) + + data + + sequence_number + + is_handled + + time_blocked_until +} + +class request_queue_state { + <
> + + metadata_id (FK, PK) + + sequence_counter + + forefront_sequence_counter +} + +%% ======================== +%% Client to Table arrows +%% ======================== + +SQLRequestQueueClient --> request_queue_metadata +SQLRequestQueueClient --> request +SQLRequestQueueClient --> request_queue_state +``` Configuration options for the `SQLStorageClient` can be set through environment variables or the `Configuration` class: diff --git a/pyproject.toml b/pyproject.toml index 52ef8c0701..3494edba7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ ] [project.optional-dependencies] -all = ["crawlee[adaptive-crawler,beautifulsoup,cli,curl-impersonate,httpx,parsel,playwright,otel,sql_storage]"] +all = ["crawlee[adaptive-crawler,beautifulsoup,cli,curl-impersonate,httpx,parsel,playwright,otel,sql_storage,sql_storage_postgres]"] adaptive-crawler = [ "jaro-winkler>=2.0.3", "playwright>=1.27.0", diff --git a/uv.lock b/uv.lock index 6160a3b296..f1228baa24 100644 --- a/uv.lock +++ b/uv.lock @@ -669,6 +669,7 @@ adaptive-crawler = [ all = [ { name = "aiosqlite" }, { name = "apify-fingerprint-datapoints" }, + { name = "asyncpg" }, { name = "beautifulsoup4", extra = ["lxml"] }, { name = "browserforge" }, { name = "cookiecutter" }, @@ -730,7 +731,7 @@ sql-storage = [ { name = "aiosqlite" }, { name = "sqlalchemy", extra = ["asyncio"] }, ] -sql-storage-posgres = [ +sql-storage-postgres = [ { name = "asyncpg" }, { name = "sqlalchemy", extra = ["asyncio"] }, ] @@ -765,7 +766,7 @@ requires-dist = [ { name = "apify-fingerprint-datapoints", marker = "extra == 'adaptive-crawler'", specifier = ">=0.0.2" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'httpx'", specifier = ">=0.0.2" }, { name = "apify-fingerprint-datapoints", marker = "extra == 'playwright'", specifier = ">=0.0.2" }, - { name = "asyncpg", marker = "extra == 'sql-storage-posgres'", specifier = ">=0.24.0" }, + { name = "asyncpg", marker = "extra == 'sql-storage-postgres'", specifier = ">=0.24.0" }, { name = "beautifulsoup4", extras = ["lxml"], marker = "extra == 'beautifulsoup'", specifier = ">=4.12.0" }, { name = "browserforge", marker = "extra == 'adaptive-crawler'", specifier = ">=1.2.3" }, { name = "browserforge", marker = "extra == 'httpx'", specifier = ">=1.2.3" }, @@ -773,7 +774,7 @@ requires-dist = [ { name = "cachetools", specifier = ">=5.5.0" }, { name = "colorama", specifier = ">=0.4.0" }, { name = "cookiecutter", marker = "extra == 'cli'", specifier = ">=2.6.0" }, - { name = "crawlee", extras = ["adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage"], marker = "extra == 'all'" }, + { name = "crawlee", extras = ["adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage", "sql-storage-postgres"], marker = "extra == 'all'" }, { name = "curl-cffi", marker = "extra == 'curl-impersonate'", specifier = ">=0.9.0" }, { name = "html5lib", marker = "extra == 'beautifulsoup'", specifier = ">=1.0" }, { name = "httpx", extras = ["brotli", "http2", "zstd"], marker = "extra == 'httpx'", specifier = ">=0.27.0" }, @@ -798,14 +799,14 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13.9.0" }, { name = "scikit-learn", marker = "extra == 'adaptive-crawler'", specifier = ">=1.6.0" }, { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql-storage'", specifier = "~=2.0.0,<3.0.0" }, - { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql-storage-posgres'", specifier = "~=2.0.0,<3.0.0" }, + { name = "sqlalchemy", extras = ["asyncio"], marker = "extra == 'sql-storage-postgres'", specifier = "~=2.0.0,<3.0.0" }, { name = "tldextract", specifier = ">=5.1.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.0" }, { name = "typing-extensions", specifier = ">=4.1.0" }, { name = "wrapt", marker = "extra == 'otel'", specifier = ">=1.17.0" }, { name = "yarl", specifier = ">=1.18.0" }, ] -provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage-posgres", "sql-storage"] +provides-extras = ["all", "adaptive-crawler", "beautifulsoup", "cli", "curl-impersonate", "httpx", "parsel", "playwright", "otel", "sql-storage-postgres", "sql-storage"] [package.metadata.requires-dev] dev = [ From 1a7618eb6ec0c349bad94964b1536400798012c6 Mon Sep 17 00:00:00 2001 From: Max Bohomolov Date: Tue, 26 Aug 2025 01:17:28 +0000 Subject: [PATCH 29/29] up database types --- .../storage_clients/_sql/_dataset_client.py | 24 +++++-------------- .../storage_clients/_sql/_db_models.py | 23 ++++++++++++++---- .../_sql/_request_queue_client.py | 6 ++++- .../_sql/test_sql_dataset_client.py | 3 +-- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 554f656215..a1a5dcdcb0 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json from logging import getLogger from typing import TYPE_CHECKING, Any, cast @@ -34,8 +33,8 @@ class SQLDatasetClient(DatasetClient, SQLClientMixin): - `dataset_metadata` table: Contains dataset metadata (id, name, timestamps, item_count) - `dataset_item` table: Contains individual items with JSON data and auto-increment ordering - Items are serialized to JSON with `default=str` to handle non-serializable types like datetime - objects. The `order_id` auto-increment primary key ensures insertion order is preserved. + Items are stored as a JSON object in SQLite and as JSONB in PostgreSQL. These objects must be JSON-serializable. + The `order_id` auto-increment primary key ensures insertion order is preserved. All operations are wrapped in database transactions with CASCADE deletion support. """ @@ -124,17 +123,7 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: data = [data] db_items: list[dict[str, Any]] = [] - - for item in data: - # Serialize with default=str to handle non-serializable types like datetime - json_item = json.dumps(item, default=str, ensure_ascii=False) - db_items.append( - { - 'metadata_id': self._id, - 'data': json_item, - } - ) - + db_items = [{'metadata_id': self._id, 'data': item} for item in data] stmt = insert(self._ITEM_TABLE).values(db_items) async with self.get_autocommit_session() as autocommit: @@ -181,7 +170,7 @@ def _prepare_get_stmt( if skip_empty: # Skip items that are empty JSON objects - stmt = stmt.where(self._ITEM_TABLE.data != '"{}"') + stmt = stmt.where(self._ITEM_TABLE.data != {}) # Apply ordering by insertion order (order_id) stmt = ( @@ -230,8 +219,7 @@ async def get_data( if updated: await session.commit() - # Deserialize JSON items - items = [json.loads(db_item.data) for db_item in db_items] + items = [db_item.data for db_item in db_items] metadata = await self.get_metadata() return DatasetItemsListPage( items=items, @@ -273,7 +261,7 @@ async def iterate_items( db_items = await session.stream_scalars(stmt) async for db_item in db_items: - yield json.loads(db_item.data) + yield db_item.data updated = await self._update_metadata(session, update_accessed_at=True) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index fface773fc..e681c06744 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -1,15 +1,17 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from sqlalchemy import JSON, BigInteger, Boolean, ForeignKey, Index, Integer, LargeBinary, String +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.types import DateTime, TypeDecorator from typing_extensions import override if TYPE_CHECKING: from sqlalchemy.engine import Dialect + from sqlalchemy.types import TypeEngine # This is necessary because unique constraints don't apply to NULL values in SQL. @@ -50,6 +52,19 @@ def process_result_value(self, value: datetime | None, _dialect: Dialect) -> dat return value +class JSONField(TypeDecorator): + """Uses JSONB for PostgreSQL and JSON for other databases.""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[JSON | JSONB]: + """Load the appropriate dialect implementation for the JSON type.""" + if dialect.name == 'postgresql': + return dialect.type_descriptor(JSONB()) + return dialect.type_descriptor(JSON()) + + class Base(DeclarativeBase): """Base class for all database models for correct type annotations.""" @@ -166,8 +181,8 @@ class DatasetItemDB(Base): ) """Foreign key to metadata dataset record.""" - data: Mapped[str] = mapped_column(JSON, nullable=False) - """JSON-serialized item data.""" + data: Mapped[list[dict[str, Any]] | dict[str, Any]] = mapped_column(JSONField, nullable=False) + """JSON serializable item data.""" # Relationship back to parent dataset dataset: Mapped[DatasetMetadataDB] = relationship(back_populates='items') @@ -189,7 +204,7 @@ class RequestDB(Base): ) """Foreign key to metadata request queue record.""" - data: Mapped[str] = mapped_column(JSON, nullable=False) + data: Mapped[str] = mapped_column(String, nullable=False) """JSON-serialized Request object.""" sequence_number: Mapped[int] = mapped_column(Integer, nullable=False) diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index cce4bc9e10..08fe1676e5 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -433,11 +433,15 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | if not (request_id := self._REQUEST_ID_BY_KEY.get(request.unique_key)): request_id = self._get_int_id_from_unique_key(request.unique_key) + # Update the request's handled_at timestamp. + if request.handled_at is None: + request.handled_at = datetime.now(timezone.utc) + # Update request in DB stmt = ( update(self._ITEM_TABLE) .where(self._ITEM_TABLE.metadata_id == self._id, self._ITEM_TABLE.request_id == request_id) - .values(is_handled=True, time_blocked_until=None) + .values(is_handled=True, time_blocked_until=None, data=request.model_dump_json()) ) async with self.get_session() as session: result = await session.execute(stmt) diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index 98c0927e4e..79c302060a 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import json from datetime import timedelta from typing import TYPE_CHECKING @@ -126,7 +125,7 @@ async def test_record_and_content_verification(dataset_client: SQLDatasetClient) result = await session.execute(stmt) records = result.scalars().all() assert len(records) == 1 - saved_item = json.loads(records[0].data) + saved_item = records[0].data assert saved_item == item # Test pushing multiple items and verify total count