first
This commit is contained in:
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
13
backend/app/config.py
Normal file
13
backend/app/config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
ha_base_url: str = "http://10.0.0.2:8123"
|
||||
ha_token: str = ""
|
||||
database_url: str = "sqlite:///./data/ha_explorer.db"
|
||||
cors_origins: list[str] = ["http://localhost:5173"]
|
||||
|
||||
model_config = {"env_prefix": ""}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
35
backend/app/database.py
Normal file
35
backend/app/database.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from app.config import settings
|
||||
|
||||
_db_path = settings.database_url.replace("sqlite:///", "")
|
||||
Path(_db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_default_engine = create_engine(
|
||||
settings.database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Engine mutable pour permettre le remplacement en tests
|
||||
_engine_holder: dict = {"engine": _default_engine}
|
||||
|
||||
|
||||
def get_engine():
|
||||
return _engine_holder["engine"]
|
||||
|
||||
|
||||
def set_engine(engine):
|
||||
_engine_holder["engine"] = engine
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
SQLModel.metadata.create_all(get_engine())
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
with Session(get_engine()) as session:
|
||||
yield session
|
||||
135
backend/app/ha_client.py
Normal file
135
backend/app/ha_client.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class HAClient:
|
||||
def __init__(self):
|
||||
self.base_url = settings.ha_base_url.rstrip("/")
|
||||
self.token = settings.ha_token
|
||||
self._headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self._ws_id_counter = 0
|
||||
|
||||
def _next_ws_id(self) -> int:
|
||||
self._ws_id_counter += 1
|
||||
return self._ws_id_counter
|
||||
|
||||
async def check_connection(self) -> tuple[bool, str]:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/api/",
|
||||
headers=self._headers,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return True, "Connecté"
|
||||
elif resp.status == 401:
|
||||
return False, "Token invalide (401)"
|
||||
else:
|
||||
return False, f"Erreur HTTP {resp.status}"
|
||||
except aiohttp.ClientError as e:
|
||||
return False, f"Connexion impossible : {e}"
|
||||
except asyncio.TimeoutError:
|
||||
return False, "Timeout de connexion"
|
||||
|
||||
async def fetch_all_states(self) -> list[dict[str, Any]]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/api/states",
|
||||
headers=self._headers,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
async def _ws_command(self, command: dict[str, Any]) -> dict[str, Any]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect(
|
||||
f"{self.base_url}/api/websocket",
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as ws:
|
||||
# Attendre auth_required
|
||||
msg = await ws.receive_json()
|
||||
|
||||
# Authentification
|
||||
await ws.send_json({"type": "auth", "access_token": self.token})
|
||||
msg = await ws.receive_json()
|
||||
if msg.get("type") != "auth_ok":
|
||||
raise ConnectionError(f"Authentification WS échouée : {msg}")
|
||||
|
||||
# Envoyer la commande
|
||||
cmd_id = self._next_ws_id()
|
||||
command["id"] = cmd_id
|
||||
await ws.send_json(command)
|
||||
|
||||
# Attendre la réponse
|
||||
msg = await ws.receive_json()
|
||||
if not msg.get("success"):
|
||||
raise RuntimeError(
|
||||
f"Commande WS échouée : {msg.get('error', {}).get('message', 'Erreur inconnue')}"
|
||||
)
|
||||
return msg.get("result", {})
|
||||
|
||||
async def fetch_entity_registry(self) -> list[dict[str, Any]]:
|
||||
return await self._ws_command({"type": "config/entity_registry/list"})
|
||||
|
||||
async def update_entity_registry(
|
||||
self, entity_id: str, **updates: Any
|
||||
) -> dict[str, Any]:
|
||||
return await self._ws_command(
|
||||
{
|
||||
"type": "config/entity_registry/update",
|
||||
"entity_id": entity_id,
|
||||
**updates,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _parse_dt(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def normalize_entity(
|
||||
state: dict[str, Any],
|
||||
registry_entry: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
attrs = state.get("attributes", {})
|
||||
entity_id = state.get("entity_id", "")
|
||||
domain = entity_id.split(".")[0] if "." in entity_id else ""
|
||||
|
||||
reg = registry_entry or {}
|
||||
|
||||
return {
|
||||
"entity_id": entity_id,
|
||||
"domain": domain,
|
||||
"friendly_name": attrs.get("friendly_name", ""),
|
||||
"state": state.get("state", ""),
|
||||
"attrs_json": json.dumps(attrs, ensure_ascii=False),
|
||||
"device_class": attrs.get("device_class"),
|
||||
"unit_of_measurement": attrs.get("unit_of_measurement"),
|
||||
"area_id": reg.get("area_id"),
|
||||
"device_id": reg.get("device_id"),
|
||||
"integration": reg.get("platform"),
|
||||
"is_disabled": reg.get("disabled_by") is not None,
|
||||
"is_hidden": reg.get("hidden_by") is not None,
|
||||
"is_available": state.get("state") not in ("unavailable", "unknown"),
|
||||
"last_changed": _parse_dt(state.get("last_changed")),
|
||||
"last_updated": _parse_dt(state.get("last_updated")),
|
||||
}
|
||||
|
||||
|
||||
ha_client = HAClient()
|
||||
31
backend/app/main.py
Normal file
31
backend/app/main.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.database import create_db_and_tables
|
||||
from app.routers import health, scan, entities, actions, audit
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
create_db_and_tables()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="HA Entity Scanner", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(health.router, prefix="/api")
|
||||
app.include_router(scan.router, prefix="/api")
|
||||
app.include_router(entities.router, prefix="/api")
|
||||
app.include_router(actions.router, prefix="/api")
|
||||
app.include_router(audit.router, prefix="/api")
|
||||
47
backend/app/models.py
Normal file
47
backend/app/models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class EntityCache(SQLModel, table=True):
|
||||
__tablename__ = "entities_cache"
|
||||
|
||||
entity_id: str = Field(primary_key=True)
|
||||
domain: str = ""
|
||||
friendly_name: str = ""
|
||||
state: str = ""
|
||||
attrs_json: str = "{}"
|
||||
device_class: Optional[str] = None
|
||||
unit_of_measurement: Optional[str] = None
|
||||
area_id: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
integration: Optional[str] = None
|
||||
is_disabled: bool = False
|
||||
is_hidden: bool = False
|
||||
is_available: bool = True
|
||||
last_changed: Optional[datetime] = None
|
||||
last_updated: Optional[datetime] = None
|
||||
fetched_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class EntityFlag(SQLModel, table=True):
|
||||
__tablename__ = "entity_flags"
|
||||
|
||||
entity_id: str = Field(primary_key=True)
|
||||
ignored_local: bool = False
|
||||
favorite: bool = False
|
||||
notes: str = ""
|
||||
original_state: Optional[str] = None
|
||||
disabled_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class AuditLog(SQLModel, table=True):
|
||||
__tablename__ = "audit_log"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
ts: datetime = Field(default_factory=datetime.utcnow)
|
||||
action: str = ""
|
||||
entity_ids_json: str = "[]"
|
||||
result: str = ""
|
||||
error: str = ""
|
||||
0
backend/app/routers/__init__.py
Normal file
0
backend/app/routers/__init__.py
Normal file
33
backend/app/routers/actions.py
Normal file
33
backend/app/routers/actions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.entity_actions import disable_entity, enable_entity, set_flag
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class BulkActionRequest(BaseModel):
|
||||
action: str # disable, enable, favorite, unfavorite, ignore, unignore
|
||||
entity_ids: list[str]
|
||||
|
||||
|
||||
@router.post("/entities/actions")
|
||||
async def bulk_action(req: BulkActionRequest):
|
||||
results = []
|
||||
|
||||
if req.action in ("favorite", "unfavorite", "ignore", "unignore"):
|
||||
results = set_flag(req.entity_ids, req.action)
|
||||
elif req.action == "disable":
|
||||
for eid in req.entity_ids:
|
||||
r = await disable_entity(eid)
|
||||
results.append(r)
|
||||
elif req.action == "enable":
|
||||
for eid in req.entity_ids:
|
||||
r = await enable_entity(eid)
|
||||
results.append(r)
|
||||
else:
|
||||
return {"error": f"Action inconnue : {req.action}"}
|
||||
|
||||
return {"action": req.action, "results": results}
|
||||
38
backend/app/routers/audit.py
Normal file
38
backend/app/routers/audit.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlmodel import Session, col, func, select
|
||||
|
||||
from app.database import get_session
|
||||
from app.models import AuditLog
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/audit")
|
||||
def list_audit(
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(50, ge=1, le=200),
|
||||
action: Optional[str] = None,
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
query = select(AuditLog)
|
||||
|
||||
if action:
|
||||
query = query.where(AuditLog.action == action)
|
||||
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = session.exec(count_query).one()
|
||||
|
||||
query = query.order_by(col(AuditLog.ts).desc())
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
logs = session.exec(query).all()
|
||||
|
||||
return {
|
||||
"items": [log.model_dump() for log in logs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
150
backend/app/routers/entities.py
Normal file
150
backend/app/routers/entities.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel import Session, col, func, or_, select
|
||||
|
||||
from app.database import get_session
|
||||
from app.models import EntityCache, EntityFlag
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/entities")
|
||||
def list_entities(
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(50, ge=1, le=500),
|
||||
domain: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
available: Optional[bool] = None,
|
||||
sort_by: str = Query("entity_id"),
|
||||
sort_dir: str = Query("asc", pattern="^(asc|desc)$"),
|
||||
favorite: Optional[bool] = None,
|
||||
ignored: Optional[bool] = None,
|
||||
device_class: Optional[str] = None,
|
||||
integration: Optional[str] = None,
|
||||
area_id: Optional[str] = None,
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
query = select(EntityCache)
|
||||
|
||||
# Filtres
|
||||
if domain:
|
||||
domains = [d.strip() for d in domain.split(",")]
|
||||
query = query.where(col(EntityCache.domain).in_(domains))
|
||||
|
||||
if state:
|
||||
states = [s.strip() for s in state.split(",")]
|
||||
query = query.where(col(EntityCache.state).in_(states))
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
col(EntityCache.entity_id).ilike(pattern),
|
||||
col(EntityCache.friendly_name).ilike(pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if available is not None:
|
||||
query = query.where(EntityCache.is_available == available)
|
||||
|
||||
if device_class:
|
||||
query = query.where(EntityCache.device_class == device_class)
|
||||
|
||||
if integration:
|
||||
query = query.where(EntityCache.integration == integration)
|
||||
|
||||
if area_id:
|
||||
query = query.where(EntityCache.area_id == area_id)
|
||||
|
||||
# Filtres flags (nécessite jointure)
|
||||
if favorite is not None or ignored is not None:
|
||||
query = query.outerjoin(
|
||||
EntityFlag, EntityCache.entity_id == EntityFlag.entity_id
|
||||
)
|
||||
if favorite is not None:
|
||||
query = query.where(EntityFlag.favorite == favorite)
|
||||
if ignored is not None:
|
||||
query = query.where(EntityFlag.ignored_local == ignored)
|
||||
|
||||
# Compteur total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = session.exec(count_query).one()
|
||||
|
||||
# Tri
|
||||
sort_column = getattr(EntityCache, sort_by, EntityCache.entity_id)
|
||||
if sort_dir == "desc":
|
||||
query = query.order_by(col(sort_column).desc())
|
||||
else:
|
||||
query = query.order_by(col(sort_column).asc())
|
||||
|
||||
# Pagination
|
||||
offset = (page - 1) * per_page
|
||||
query = query.offset(offset).limit(per_page)
|
||||
|
||||
entities = session.exec(query).all()
|
||||
|
||||
# Récupérer les flags pour chaque entité
|
||||
entity_ids = [e.entity_id for e in entities]
|
||||
flags_query = select(EntityFlag).where(col(EntityFlag.entity_id).in_(entity_ids))
|
||||
flags = {f.entity_id: f for f in session.exec(flags_query).all()}
|
||||
|
||||
results = []
|
||||
for e in entities:
|
||||
d = e.model_dump()
|
||||
flag = flags.get(e.entity_id)
|
||||
d["favorite"] = flag.favorite if flag else False
|
||||
d["ignored_local"] = flag.ignored_local if flag else False
|
||||
d["notes"] = flag.notes if flag else ""
|
||||
d["original_state"] = flag.original_state if flag else None
|
||||
d["disabled_at"] = flag.disabled_at.isoformat() if flag and flag.disabled_at else None
|
||||
results.append(d)
|
||||
|
||||
return {
|
||||
"items": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/entities/filters")
|
||||
def get_filter_values(session: Session = Depends(get_session)):
|
||||
"""Retourne les valeurs disponibles pour les filtres."""
|
||||
domains = session.exec(
|
||||
select(EntityCache.domain).distinct().order_by(EntityCache.domain)
|
||||
).all()
|
||||
areas = session.exec(
|
||||
select(EntityCache.area_id).where(EntityCache.area_id.is_not(None)).distinct().order_by(EntityCache.area_id) # type: ignore
|
||||
).all()
|
||||
integrations = session.exec(
|
||||
select(EntityCache.integration).where(EntityCache.integration.is_not(None)).distinct().order_by(EntityCache.integration) # type: ignore
|
||||
).all()
|
||||
device_classes = session.exec(
|
||||
select(EntityCache.device_class).where(EntityCache.device_class.is_not(None)).distinct().order_by(EntityCache.device_class) # type: ignore
|
||||
).all()
|
||||
return {
|
||||
"domains": domains,
|
||||
"areas": areas,
|
||||
"integrations": integrations,
|
||||
"device_classes": device_classes,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/entities/{entity_id}")
|
||||
def get_entity(entity_id: str, session: Session = Depends(get_session)):
|
||||
entity = session.get(EntityCache, entity_id)
|
||||
if not entity:
|
||||
raise HTTPException(status_code=404, detail="Entité non trouvée")
|
||||
|
||||
d = entity.model_dump()
|
||||
flag = session.get(EntityFlag, entity_id)
|
||||
d["favorite"] = flag.favorite if flag else False
|
||||
d["ignored_local"] = flag.ignored_local if flag else False
|
||||
d["notes"] = flag.notes if flag else ""
|
||||
d["original_state"] = flag.original_state if flag else None
|
||||
d["disabled_at"] = flag.disabled_at.isoformat() if flag and flag.disabled_at else None
|
||||
|
||||
return d
|
||||
23
backend/app/routers/health.py
Normal file
23
backend/app/routers/health.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlmodel import Session, func, select
|
||||
|
||||
from app.database import get_session
|
||||
from app.ha_client import ha_client
|
||||
from app.models import EntityCache
|
||||
from app.scan_state import scan_state
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(session: Session = Depends(get_session)):
|
||||
connected, message = await ha_client.check_connection()
|
||||
count = session.exec(select(func.count()).select_from(EntityCache)).one()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"ha_connected": connected,
|
||||
"ha_message": message,
|
||||
"entity_count": count,
|
||||
**scan_state.to_dict(),
|
||||
}
|
||||
20
backend/app/routers/scan.py
Normal file
20
backend/app/routers/scan.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
|
||||
from app.scan_state import scan_state
|
||||
from app.services.scanner import run_scan
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _run_scan_sync():
|
||||
import asyncio
|
||||
asyncio.run(run_scan())
|
||||
|
||||
|
||||
@router.post("/scan", status_code=202)
|
||||
async def trigger_scan(background_tasks: BackgroundTasks):
|
||||
if scan_state.status == "scanning":
|
||||
return {"message": "Scan déjà en cours", **scan_state.to_dict()}
|
||||
|
||||
background_tasks.add_task(run_scan)
|
||||
return {"message": "Scan lancé", **scan_state.to_dict()}
|
||||
39
backend/app/scan_state.py
Normal file
39
backend/app/scan_state.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ScanState:
|
||||
def __init__(self):
|
||||
self.status: str = "idle" # idle, scanning, done, error
|
||||
self.last_scan: Optional[datetime] = None
|
||||
self.progress: int = 0
|
||||
self.total: int = 0
|
||||
self.error: str = ""
|
||||
|
||||
def start(self):
|
||||
self.status = "scanning"
|
||||
self.progress = 0
|
||||
self.total = 0
|
||||
self.error = ""
|
||||
|
||||
def finish(self, count: int):
|
||||
self.status = "done"
|
||||
self.progress = count
|
||||
self.total = count
|
||||
self.last_scan = datetime.utcnow()
|
||||
|
||||
def fail(self, error: str):
|
||||
self.status = "error"
|
||||
self.error = error
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"scan_status": self.status,
|
||||
"last_scan": self.last_scan.isoformat() if self.last_scan else None,
|
||||
"progress": self.progress,
|
||||
"total": self.total,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
scan_state = ScanState()
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
134
backend/app/services/entity_actions.py
Normal file
134
backend/app/services/entity_actions.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.database import get_engine
|
||||
from app.ha_client import ha_client
|
||||
from app.models import EntityCache, EntityFlag, AuditLog
|
||||
|
||||
|
||||
def _get_current_state(session: Session, entity_id: str) -> str | None:
|
||||
"""Récupère l'état actuel d'une entité depuis le cache."""
|
||||
entity = session.get(EntityCache, entity_id)
|
||||
return entity.state if entity else None
|
||||
|
||||
|
||||
def _save_original_state(session: Session, entity_id: str):
|
||||
"""Sauvegarde l'état original avant désactivation."""
|
||||
flag = session.get(EntityFlag, entity_id)
|
||||
if not flag:
|
||||
flag = EntityFlag(entity_id=entity_id)
|
||||
# Ne sauvegarder que si pas déjà désactivé (garder le vrai état original)
|
||||
if not flag.original_state:
|
||||
flag.original_state = _get_current_state(session, entity_id)
|
||||
flag.disabled_at = datetime.utcnow()
|
||||
session.add(flag)
|
||||
return flag
|
||||
|
||||
|
||||
def _clear_original_state(session: Session, entity_id: str):
|
||||
"""Efface l'état original lors de la réactivation."""
|
||||
flag = session.get(EntityFlag, entity_id)
|
||||
if flag:
|
||||
flag.original_state = None
|
||||
flag.disabled_at = None
|
||||
session.add(flag)
|
||||
|
||||
|
||||
async def disable_entity(entity_id: str) -> dict:
|
||||
mode = "local_flag"
|
||||
error = ""
|
||||
|
||||
# Sauvegarder l'état original
|
||||
with Session(get_engine()) as session:
|
||||
_save_original_state(session, entity_id)
|
||||
session.commit()
|
||||
|
||||
# Tenter désactivation via HA registry
|
||||
try:
|
||||
await ha_client.update_entity_registry(entity_id, disabled_by="user")
|
||||
mode = "ha_registry"
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
# Fallback : flag local
|
||||
with Session(get_engine()) as session:
|
||||
flag = session.get(EntityFlag, entity_id)
|
||||
if not flag:
|
||||
flag = EntityFlag(entity_id=entity_id)
|
||||
flag.ignored_local = True
|
||||
session.add(flag)
|
||||
session.commit()
|
||||
|
||||
_log_action("disable", [entity_id], mode, error)
|
||||
return {"entity_id": entity_id, "mode": mode, "error": error}
|
||||
|
||||
|
||||
async def enable_entity(entity_id: str) -> dict:
|
||||
mode = "local_flag"
|
||||
error = ""
|
||||
|
||||
try:
|
||||
await ha_client.update_entity_registry(entity_id, disabled_by=None)
|
||||
mode = "ha_registry"
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
with Session(get_engine()) as session:
|
||||
flag = session.get(EntityFlag, entity_id)
|
||||
if flag:
|
||||
flag.ignored_local = False
|
||||
session.add(flag)
|
||||
session.commit()
|
||||
|
||||
# Effacer l'état original
|
||||
with Session(get_engine()) as session:
|
||||
_clear_original_state(session, entity_id)
|
||||
session.commit()
|
||||
|
||||
_log_action("enable", [entity_id], mode, error)
|
||||
return {"entity_id": entity_id, "mode": mode, "error": error}
|
||||
|
||||
|
||||
def set_flag(entity_ids: list[str], action: str) -> list[dict]:
|
||||
results = []
|
||||
with Session(get_engine()) as session:
|
||||
for eid in entity_ids:
|
||||
flag = session.get(EntityFlag, eid)
|
||||
if not flag:
|
||||
flag = EntityFlag(entity_id=eid)
|
||||
|
||||
if action == "favorite":
|
||||
flag.favorite = True
|
||||
elif action == "unfavorite":
|
||||
flag.favorite = False
|
||||
elif action == "ignore":
|
||||
# Sauvegarder l'état original avant ignore
|
||||
if not flag.original_state:
|
||||
flag.original_state = _get_current_state(session, eid)
|
||||
flag.disabled_at = datetime.utcnow()
|
||||
flag.ignored_local = True
|
||||
elif action == "unignore":
|
||||
flag.ignored_local = False
|
||||
flag.original_state = None
|
||||
flag.disabled_at = None
|
||||
|
||||
session.add(flag)
|
||||
results.append({"entity_id": eid, "action": action, "ok": True})
|
||||
|
||||
session.commit()
|
||||
|
||||
_log_action(action, entity_ids, "ok", "")
|
||||
return results
|
||||
|
||||
|
||||
def _log_action(action: str, entity_ids: list[str], result: str, error: str):
|
||||
with Session(get_engine()) as session:
|
||||
log = AuditLog(
|
||||
ts=datetime.utcnow(),
|
||||
action=action,
|
||||
entity_ids_json=json.dumps(entity_ids),
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
session.add(log)
|
||||
session.commit()
|
||||
53
backend/app/services/scanner.py
Normal file
53
backend/app/services/scanner.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.database import get_engine
|
||||
from app.ha_client import ha_client, normalize_entity
|
||||
from app.models import EntityCache
|
||||
from app.scan_state import scan_state
|
||||
|
||||
|
||||
async def run_scan():
|
||||
scan_state.start()
|
||||
try:
|
||||
states = await ha_client.fetch_all_states()
|
||||
scan_state.total = len(states)
|
||||
|
||||
# Tenter de récupérer le registry (peut échouer si WS non dispo)
|
||||
registry_map: dict[str, dict] = {}
|
||||
try:
|
||||
registry = await ha_client.fetch_entity_registry()
|
||||
registry_map = {e["entity_id"]: e for e in registry}
|
||||
except Exception:
|
||||
pass # On continue sans registry
|
||||
|
||||
with Session(get_engine()) as session:
|
||||
for i, state in enumerate(states):
|
||||
entity_id = state.get("entity_id", "")
|
||||
reg_entry = registry_map.get(entity_id)
|
||||
normalized = normalize_entity(state, reg_entry)
|
||||
|
||||
existing = session.get(EntityCache, entity_id)
|
||||
if existing:
|
||||
for key, value in normalized.items():
|
||||
if key != "entity_id":
|
||||
setattr(existing, key, value)
|
||||
existing.fetched_at = datetime.utcnow()
|
||||
else:
|
||||
entity = EntityCache(
|
||||
**normalized,
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
session.add(entity)
|
||||
|
||||
scan_state.progress = i + 1
|
||||
|
||||
session.commit()
|
||||
|
||||
scan_state.finish(len(states))
|
||||
|
||||
except Exception as e:
|
||||
scan_state.fail(str(e))
|
||||
raise
|
||||
Reference in New Issue
Block a user