Files
home_stock/backend/app/repositories/base.py
2026-01-28 19:22:30 +01:00

155 lines
4.3 KiB
Python

"""Repository de base générique.
Fournit les opérations CRUD de base pour tous les modèles.
"""
from typing import Any, Generic, TypeVar
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import Base
ModelType = TypeVar("ModelType", bound=Base)
class BaseRepository(Generic[ModelType]):
"""Repository générique avec opérations CRUD de base.
Attributes:
model: Classe du modèle SQLAlchemy
db: Session de base de données
"""
def __init__(self, model: type[ModelType], db: AsyncSession) -> None:
"""Initialise le repository.
Args:
model: Classe du modèle SQLAlchemy
db: Session de base de données async
"""
self.model = model
self.db = db
async def get(self, id: int) -> ModelType | None:
"""Récupère un élément par son ID.
Args:
id: Identifiant de l'élément
Returns:
L'élément trouvé ou None
"""
result = await self.db.execute(select(self.model).where(self.model.id == id))
return result.scalar_one_or_none()
async def get_all(
self, skip: int = 0, limit: int = 100, **filters: Any
) -> list[ModelType]:
"""Récupère tous les éléments avec pagination et filtres optionnels.
Args:
skip: Nombre d'éléments à sauter
limit: Nombre max d'éléments à retourner
**filters: Filtres additionnels (ex: status="active")
Returns:
Liste des éléments
"""
query = select(self.model)
# Appliquer les filtres
for field, value in filters.items():
if value is not None and hasattr(self.model, field):
query = query.where(getattr(self.model, field) == value)
query = query.offset(skip).limit(limit)
result = await self.db.execute(query)
return list(result.scalars().all())
async def count(self, **filters: Any) -> int:
"""Compte le nombre d'éléments avec filtres optionnels.
Args:
**filters: Filtres additionnels
Returns:
Nombre total d'éléments
"""
query = select(func.count(self.model.id))
for field, value in filters.items():
if value is not None and hasattr(self.model, field):
query = query.where(getattr(self.model, field) == value)
result = await self.db.execute(query)
return result.scalar_one()
async def create(self, **data: Any) -> ModelType:
"""Crée un nouvel élément.
Args:
**data: Données de l'élément
Returns:
L'élément créé
"""
instance = self.model(**data)
self.db.add(instance)
await self.db.flush()
await self.db.refresh(instance)
return instance
async def update(self, id: int, **data: Any) -> ModelType | None:
"""Met à jour un élément existant.
Args:
id: Identifiant de l'élément
**data: Données à mettre à jour (seules les valeurs non-None)
Returns:
L'élément mis à jour ou None si non trouvé
"""
instance = await self.get(id)
if instance is None:
return None
for field, value in data.items():
if value is not None and hasattr(instance, field):
setattr(instance, field, value)
await self.db.flush()
await self.db.refresh(instance)
return instance
async def delete(self, id: int) -> bool:
"""Supprime un élément.
Args:
id: Identifiant de l'élément
Returns:
True si supprimé, False si non trouvé
"""
instance = await self.get(id)
if instance is None:
return False
await self.db.delete(instance)
await self.db.flush()
return True
async def exists(self, id: int) -> bool:
"""Vérifie si un élément existe.
Args:
id: Identifiant de l'élément
Returns:
True si existe, False sinon
"""
result = await self.db.execute(
select(func.count(self.model.id)).where(self.model.id == id)
)
return result.scalar_one() > 0