439 lines
17 KiB
Python
439 lines
17 KiB
Python
"""Handles state variable access and change notification."""
|
|
|
|
import asyncio
|
|
import logging
|
|
|
|
from homeassistant.core import Context
|
|
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
|
|
from homeassistant.helpers.service import async_get_all_descriptions
|
|
|
|
from .const import LOGGER_PATH
|
|
from .entity import PyscriptEntity
|
|
from .function import Function
|
|
|
|
_LOGGER = logging.getLogger(LOGGER_PATH + ".state")
|
|
|
|
STATE_VIRTUAL_ATTRS = {"entity_id", "last_changed", "last_updated"}
|
|
|
|
|
|
class StateVal(str):
|
|
"""Class for representing the value and attributes of a state variable."""
|
|
|
|
def __new__(cls, state):
|
|
"""Create a new instance given a state variable."""
|
|
new_var = super().__new__(cls, state.state)
|
|
new_var.__dict__ = state.attributes.copy()
|
|
new_var.entity_id = state.entity_id
|
|
new_var.last_updated = state.last_updated
|
|
new_var.last_changed = state.last_changed
|
|
return new_var
|
|
|
|
|
|
class State:
|
|
"""Class for state functions."""
|
|
|
|
#
|
|
# Global hass instance
|
|
#
|
|
hass = None
|
|
|
|
#
|
|
# notify message queues by variable
|
|
#
|
|
notify = {}
|
|
|
|
#
|
|
# Last value of state variable notifications. We maintain this
|
|
# so that trigger evaluation can use the last notified value,
|
|
# rather than fetching the current value, which is subject to
|
|
# race conditions when multiple state variables are set quickly.
|
|
#
|
|
notify_var_last = {}
|
|
|
|
#
|
|
# pyscript yaml configuration
|
|
#
|
|
pyscript_config = {}
|
|
|
|
#
|
|
# pyscript vars which have already been registered as persisted
|
|
#
|
|
persisted_vars = {}
|
|
|
|
#
|
|
# other parameters of all services that have "entity_id" as a parameter
|
|
#
|
|
service2args = {}
|
|
|
|
def __init__(self):
|
|
"""Warn on State instantiation."""
|
|
_LOGGER.error("State class is not meant to be instantiated")
|
|
|
|
@classmethod
|
|
def init(cls, hass):
|
|
"""Initialize State."""
|
|
cls.hass = hass
|
|
|
|
@classmethod
|
|
async def get_service_params(cls):
|
|
"""Get parameters for all services."""
|
|
cls.service2args = {}
|
|
all_services = await async_get_all_descriptions(cls.hass)
|
|
for domain in all_services:
|
|
cls.service2args[domain] = {}
|
|
for service, desc in all_services[domain].items():
|
|
if "entity_id" not in desc["fields"] and "target" not in desc:
|
|
continue
|
|
cls.service2args[domain][service] = set(desc["fields"].keys())
|
|
cls.service2args[domain][service].discard("entity_id")
|
|
|
|
@classmethod
|
|
async def notify_add(cls, var_names, queue):
|
|
"""Register to notify state variables changes to be sent to queue."""
|
|
|
|
added = False
|
|
for var_name in var_names if isinstance(var_names, set) else {var_names}:
|
|
parts = var_name.split(".")
|
|
if len(parts) != 2 and len(parts) != 3:
|
|
continue
|
|
state_var_name = f"{parts[0]}.{parts[1]}"
|
|
if state_var_name not in cls.notify:
|
|
cls.notify[state_var_name] = {}
|
|
cls.notify[state_var_name][queue] = var_names
|
|
added = True
|
|
return added
|
|
|
|
@classmethod
|
|
def notify_del(cls, var_names, queue):
|
|
"""Unregister notify of state variables changes for given queue."""
|
|
|
|
for var_name in var_names if isinstance(var_names, set) else {var_names}:
|
|
parts = var_name.split(".")
|
|
if len(parts) != 2 and len(parts) != 3:
|
|
continue
|
|
state_var_name = f"{parts[0]}.{parts[1]}"
|
|
if state_var_name not in cls.notify or queue not in cls.notify[state_var_name]:
|
|
return
|
|
del cls.notify[state_var_name][queue]
|
|
|
|
@classmethod
|
|
async def update(cls, new_vars, func_args):
|
|
"""Deliver all notifications for state variable changes."""
|
|
|
|
notify = {}
|
|
for var_name, var_val in new_vars.items():
|
|
if var_name in cls.notify:
|
|
cls.notify_var_last[var_name] = var_val
|
|
notify.update(cls.notify[var_name])
|
|
|
|
if notify:
|
|
_LOGGER.debug("state.update(%s, %s)", new_vars, func_args)
|
|
for queue, var_names in notify.items():
|
|
await queue.put(["state", [cls.notify_var_get(var_names, new_vars), func_args.copy()]])
|
|
|
|
@classmethod
|
|
def notify_var_get(cls, var_names, new_vars):
|
|
"""Add values of var_names to new_vars, or default to None."""
|
|
notify_vars = new_vars.copy()
|
|
for var_name in var_names if var_names is not None else []:
|
|
if var_name in notify_vars:
|
|
continue
|
|
parts = var_name.split(".")
|
|
if var_name in cls.notify_var_last:
|
|
notify_vars[var_name] = cls.notify_var_last[var_name]
|
|
elif len(parts) == 3 and f"{parts[0]}.{parts[1]}" in cls.notify_var_last:
|
|
notify_vars[var_name] = getattr(
|
|
cls.notify_var_last[f"{parts[0]}.{parts[1]}"], parts[2], None
|
|
)
|
|
elif len(parts) == 4 and parts[2] == "old" and f"{parts[0]}.{parts[1]}.old" in notify_vars:
|
|
notify_vars[var_name] = getattr(notify_vars[f"{parts[0]}.{parts[1]}.old"], parts[3], None)
|
|
elif 1 <= var_name.count(".") <= 3 and not cls.exist(var_name):
|
|
notify_vars[var_name] = None
|
|
return notify_vars
|
|
|
|
@classmethod
|
|
def set(cls, var_name, value=None, new_attributes=None, **kwargs):
|
|
"""Set a state variable and optional attributes in hass."""
|
|
if var_name.count(".") != 1:
|
|
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
|
|
|
|
if isinstance(value, StateVal):
|
|
if new_attributes is None:
|
|
#
|
|
# value is a StateVal, so extract the attributes and value
|
|
#
|
|
new_attributes = value.__dict__.copy()
|
|
for discard in STATE_VIRTUAL_ATTRS:
|
|
new_attributes.pop(discard, None)
|
|
value = str(value)
|
|
|
|
state_value = None
|
|
if value is None or new_attributes is None:
|
|
state_value = cls.hass.states.get(var_name)
|
|
|
|
if value is None and state_value:
|
|
value = state_value.state
|
|
|
|
if new_attributes is None:
|
|
if state_value:
|
|
new_attributes = state_value.attributes.copy()
|
|
else:
|
|
new_attributes = {}
|
|
|
|
curr_task = asyncio.current_task()
|
|
if "context" in kwargs and isinstance(kwargs["context"], Context):
|
|
context = kwargs["context"]
|
|
del kwargs["context"]
|
|
else:
|
|
context = Function.task2context.get(curr_task, None)
|
|
|
|
if kwargs:
|
|
new_attributes = new_attributes.copy()
|
|
new_attributes.update(kwargs)
|
|
|
|
_LOGGER.debug("setting %s = %s, attr = %s", var_name, value, new_attributes)
|
|
cls.hass.states.async_set(var_name, value, new_attributes, context=context)
|
|
if var_name in cls.notify_var_last or var_name in cls.notify:
|
|
#
|
|
# immediately update a variable we are monitoring since it could take a while
|
|
# for the state changed event to propagate
|
|
#
|
|
cls.notify_var_last[var_name] = StateVal(cls.hass.states.get(var_name))
|
|
|
|
if var_name in cls.persisted_vars:
|
|
cls.persisted_vars[var_name].set_state(value)
|
|
cls.persisted_vars[var_name].set_attributes(new_attributes)
|
|
|
|
@classmethod
|
|
def setattr(cls, var_attr_name, value):
|
|
"""Set a state variable's attribute in hass."""
|
|
parts = var_attr_name.split(".")
|
|
if len(parts) != 3:
|
|
raise NameError(f"invalid name {var_attr_name} (should be 'domain.entity.attr')")
|
|
if not cls.exist(f"{parts[0]}.{parts[1]}"):
|
|
raise NameError(f"state {parts[0]}.{parts[1]} doesn't exist")
|
|
cls.set(f"{parts[0]}.{parts[1]}", **{parts[2]: value})
|
|
|
|
@classmethod
|
|
async def register_persist(cls, var_name):
|
|
"""Register pyscript state variable to be persisted with RestoreState."""
|
|
if var_name.startswith("pyscript.") and var_name not in cls.persisted_vars:
|
|
# this is a hack accessing hass internals; should re-implement using RestoreEntity
|
|
restore_data = cls.hass.data[DATA_RESTORE_STATE]
|
|
this_entity = PyscriptEntity()
|
|
this_entity.entity_id = var_name
|
|
cls.persisted_vars[var_name] = this_entity
|
|
try:
|
|
restore_data.async_restore_entity_added(this_entity)
|
|
except TypeError:
|
|
restore_data.async_restore_entity_added(var_name)
|
|
|
|
@classmethod
|
|
async def persist(cls, var_name, default_value=None, default_attributes=None):
|
|
"""Persist a pyscript domain state variable, and update with optional defaults."""
|
|
if var_name.count(".") != 1 or not var_name.startswith("pyscript."):
|
|
raise NameError(f"invalid name {var_name} (should be 'pyscript.entity')")
|
|
|
|
await cls.register_persist(var_name)
|
|
exists = cls.exist(var_name)
|
|
|
|
if not exists and default_value is not None:
|
|
cls.set(var_name, default_value, default_attributes)
|
|
elif exists and default_attributes is not None:
|
|
# Patch the attributes with new values if necessary
|
|
current = cls.hass.states.get(var_name)
|
|
new_attributes = {k: v for (k, v) in default_attributes.items() if k not in current.attributes}
|
|
cls.set(var_name, current.state, **new_attributes)
|
|
|
|
@classmethod
|
|
def exist(cls, var_name):
|
|
"""Check if a state variable value or attribute exists in hass."""
|
|
parts = var_name.split(".")
|
|
if len(parts) != 2 and len(parts) != 3:
|
|
return False
|
|
value = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
|
|
if value is None:
|
|
return False
|
|
if (
|
|
len(parts) == 2
|
|
or (parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]])
|
|
or parts[2] in value.attributes
|
|
or parts[2] in STATE_VIRTUAL_ATTRS
|
|
):
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def get(cls, var_name):
|
|
"""Get a state variable value or attribute from hass."""
|
|
parts = var_name.split(".")
|
|
if len(parts) != 2 and len(parts) != 3:
|
|
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
|
|
state = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
|
|
if not state:
|
|
raise NameError(f"name '{parts[0]}.{parts[1]}' is not defined")
|
|
#
|
|
# simplest case is just the state value
|
|
#
|
|
state = StateVal(state)
|
|
if len(parts) == 2:
|
|
return state
|
|
#
|
|
# see if this is a service that has an entity_id parameter
|
|
#
|
|
if parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]]:
|
|
params = cls.service2args[parts[0]][parts[2]]
|
|
|
|
def service_call_factory(domain, service, entity_id, params):
|
|
async def service_call(*args, **kwargs):
|
|
curr_task = asyncio.current_task()
|
|
hass_args = {}
|
|
for keyword, typ, default in [
|
|
("context", [Context], Function.task2context.get(curr_task, None)),
|
|
("blocking", [bool], None),
|
|
("return_response", [bool], None),
|
|
("limit", [float, int], None),
|
|
]:
|
|
if keyword in kwargs and type(kwargs[keyword]) in typ:
|
|
hass_args[keyword] = kwargs.pop(keyword)
|
|
elif default:
|
|
hass_args[keyword] = default
|
|
|
|
kwargs["entity_id"] = entity_id
|
|
if len(args) == 1 and len(params) == 1:
|
|
#
|
|
# with just a single parameter and positional argument, create the keyword setting
|
|
#
|
|
[param_name] = params
|
|
kwargs[param_name] = args[0]
|
|
elif len(args) != 0:
|
|
raise TypeError(f"service {domain}.{service} takes no positional arguments")
|
|
|
|
# return await Function.hass_services_async_call(domain, service, kwargs, **hass_args)
|
|
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
|
|
|
|
return service_call
|
|
|
|
return service_call_factory(parts[0], parts[2], f"{parts[0]}.{parts[1]}", params)
|
|
#
|
|
# finally see if it is an attribute
|
|
#
|
|
try:
|
|
return getattr(state, parts[2])
|
|
except AttributeError:
|
|
raise AttributeError( # pylint: disable=raise-missing-from
|
|
f"state '{parts[0]}.{parts[1]}' has no attribute '{parts[2]}'"
|
|
)
|
|
|
|
@classmethod
|
|
def delete(cls, var_name, context=None):
|
|
"""Delete a state variable or attribute from hass."""
|
|
parts = var_name.split(".")
|
|
if not context:
|
|
context = Function.task2context.get(asyncio.current_task(), None)
|
|
context_arg = {"context": context} if context else {}
|
|
if len(parts) == 2:
|
|
if var_name in cls.notify_var_last or var_name in cls.notify:
|
|
#
|
|
# immediately update a variable we are monitoring since it could take a while
|
|
# for the state changed event to propagate
|
|
#
|
|
cls.notify_var_last[var_name] = None
|
|
if not cls.hass.states.async_remove(var_name, **context_arg):
|
|
raise NameError(f"name '{var_name}' not defined")
|
|
return
|
|
if len(parts) == 3:
|
|
var_name = f"{parts[0]}.{parts[1]}"
|
|
value = cls.hass.states.get(var_name)
|
|
if value is None:
|
|
raise NameError(f"state {var_name} doesn't exist")
|
|
new_attr = value.attributes.copy()
|
|
if parts[2] not in new_attr:
|
|
raise AttributeError(f"state '{var_name}' has no attribute '{parts[2]}'")
|
|
del new_attr[parts[2]]
|
|
cls.set(f"{var_name}", value.state, new_attributes=new_attr, **context_arg)
|
|
return
|
|
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
|
|
|
|
@classmethod
|
|
def getattr(cls, var_name):
|
|
"""Return a dict of attributes for a state variable."""
|
|
if isinstance(var_name, StateVal):
|
|
attrs = var_name.__dict__.copy()
|
|
for discard in STATE_VIRTUAL_ATTRS:
|
|
attrs.pop(discard, None)
|
|
return attrs
|
|
if var_name.count(".") != 1:
|
|
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
|
|
value = cls.hass.states.get(var_name)
|
|
if not value:
|
|
return None
|
|
return value.attributes.copy()
|
|
|
|
@classmethod
|
|
def get_attr(cls, var_name):
|
|
"""Return a dict of attributes for a state variable - deprecated."""
|
|
_LOGGER.warning("state.get_attr() is deprecated: use state.getattr() instead")
|
|
return cls.getattr(var_name)
|
|
|
|
@classmethod
|
|
def completions(cls, root):
|
|
"""Return possible completions of state variables."""
|
|
words = set()
|
|
parts = root.split(".")
|
|
num_period = len(parts) - 1
|
|
if num_period == 2:
|
|
#
|
|
# complete state attributes
|
|
#
|
|
last_period = root.rfind(".")
|
|
name = root[0:last_period]
|
|
value = cls.hass.states.get(name)
|
|
if value:
|
|
attr_root = root[last_period + 1 :]
|
|
attrs = set(value.attributes.keys()).union(STATE_VIRTUAL_ATTRS)
|
|
if parts[0] in cls.service2args:
|
|
attrs.update(set(cls.service2args[parts[0]].keys()))
|
|
for attr_name in attrs:
|
|
if attr_name.lower().startswith(attr_root):
|
|
words.add(f"{name}.{attr_name}")
|
|
elif num_period < 2:
|
|
#
|
|
# complete among all state names
|
|
#
|
|
for name in cls.hass.states.async_all():
|
|
if name.entity_id.lower().startswith(root):
|
|
words.add(name.entity_id)
|
|
return words
|
|
|
|
@classmethod
|
|
async def names(cls, domain=None):
|
|
"""Implement names, which returns all entity_ids."""
|
|
return cls.hass.states.async_entity_ids(domain)
|
|
|
|
@classmethod
|
|
def register_functions(cls):
|
|
"""Register state functions and config variable."""
|
|
functions = {
|
|
"state.get": cls.get,
|
|
"state.set": cls.set,
|
|
"state.setattr": cls.setattr,
|
|
"state.names": cls.names,
|
|
"state.getattr": cls.getattr,
|
|
"state.get_attr": cls.get_attr, # deprecated form; to be removed
|
|
"state.persist": cls.persist,
|
|
"state.delete": cls.delete,
|
|
"pyscript.config": cls.pyscript_config,
|
|
}
|
|
Function.register(functions)
|
|
|
|
@classmethod
|
|
def set_pyscript_config(cls, config):
|
|
"""Set pyscript yaml config."""
|
|
#
|
|
# have to update inplace, since dest is already used as value
|
|
#
|
|
cls.pyscript_config.clear()
|
|
for name, value in config.items():
|
|
cls.pyscript_config[name] = value
|