Files
homeassistant_config/config/custom_components/pyscript/trigger.py
2024-08-09 06:45:02 +02:00

1397 lines
60 KiB
Python

"""Implements all the trigger logic."""
import asyncio
import datetime as dt
import functools
import locale
import logging
import math
import re
import time
from croniter import croniter
from homeassistant.core import Context
from homeassistant.helpers import sun
from homeassistant.util import dt as dt_util
from .const import LOGGER_PATH
from .eval import AstEval, EvalFunc, EvalFuncVar
from .event import Event
from .function import Function
from .mqtt import Mqtt
from .state import STATE_VIRTUAL_ATTRS, State
from .webhook import Webhook
_LOGGER = logging.getLogger(LOGGER_PATH + ".trigger")
STATE_RE = re.compile(r"\w+\.\w+(\.((\w+)|\*))?$")
def dt_now():
"""Return current time."""
return dt.datetime.now()
def parse_time_offset(offset_str):
"""Parse a time offset."""
match = re.split(r"([-+]?\s*\d*\.?\d+(?:[eE][-+]?\d+)?)\s*(\w*)", offset_str)
scale = 1
value = 0
if len(match) == 4:
value = float(match[1].replace(" ", ""))
if match[2] in {"m", "min", "mins", "minute", "minutes"}:
scale = 60
elif match[2] in {"h", "hr", "hour", "hours"}:
scale = 60 * 60
elif match[2] in {"d", "day", "days"}:
scale = 60 * 60 * 24
elif match[2] in {"w", "week", "weeks"}:
scale = 60 * 60 * 24 * 7
elif match[2] not in {"", "s", "sec", "second", "seconds"}:
_LOGGER.error("can't parse time offset %s", offset_str)
else:
_LOGGER.error("can't parse time offset %s", offset_str)
return value * scale
def ident_any_values_changed(func_args, ident):
"""Check for any changes to state or attributes on ident vars."""
var_name = func_args.get("var_name", None)
if var_name is None:
return False
value = func_args["value"]
old_value = func_args["old_value"]
for check_var in ident:
if check_var == var_name and old_value != value:
return True
if check_var.startswith(f"{var_name}."):
var_pieces = check_var.split(".")
if len(var_pieces) == 3 and f"{var_pieces[0]}.{var_pieces[1]}" == var_name:
if var_pieces[2] == "*":
# catch all has been requested, check all attributes for change
all_attrs = set()
if value is not None:
all_attrs |= set(value.__dict__.keys())
if old_value is not None:
all_attrs |= set(old_value.__dict__.keys())
for attr in all_attrs - STATE_VIRTUAL_ATTRS:
if getattr(value, attr, None) != getattr(old_value, attr, None):
return True
elif getattr(value, var_pieces[2], None) != getattr(old_value, var_pieces[2], None):
return True
return False
def ident_values_changed(func_args, ident):
"""Check for changes to state or attributes on ident vars."""
var_name = func_args.get("var_name", None)
if var_name is None:
return False
value = func_args["value"]
old_value = func_args["old_value"]
for check_var in ident:
var_pieces = check_var.split(".")
if len(var_pieces) < 2 or len(var_pieces) > 3:
continue
var_root = f"{var_pieces[0]}.{var_pieces[1]}"
if var_root == var_name and (len(var_pieces) == 2 or var_pieces[2] == "old"):
if value != old_value:
return True
elif len(var_pieces) == 3 and var_root == var_name:
if getattr(value, var_pieces[2], None) != getattr(old_value, var_pieces[2], None):
return True
return False
class TrigTime:
"""Class for trigger time functions."""
#
# Global hass instance
#
hass = None
#
# Mappings of day of week name to number, using US convention of sunday is 0.
# Initialized based on locale at startup.
#
dow2int = {}
def __init__(self):
"""Warn on TrigTime instantiation."""
_LOGGER.error("TrigTime class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize TrigTime."""
cls.hass = hass
def wait_until_factory(ast_ctx):
"""Return wapper to call to astFunction with the ast context."""
async def wait_until_call(*arg, **kw):
return await cls.wait_until(ast_ctx, *arg, **kw)
return wait_until_call
def user_task_create_factory(ast_ctx):
"""Return wapper to call to astFunction with the ast context."""
async def user_task_create(func, *args, **kwargs):
"""Implement task.create()."""
async def func_call(func, func_name, new_ast_ctx, *args, **kwargs):
"""Call user function inside task.create()."""
ret = await new_ast_ctx.call_func(func, func_name, *args, **kwargs)
if new_ast_ctx.get_exception_obj():
new_ast_ctx.get_logger().error(new_ast_ctx.get_exception_long())
return ret
try:
if isinstance(func, (EvalFunc, EvalFuncVar)):
func_name = func.get_name()
else:
func_name = func.__name__
except Exception:
func_name = "<function>"
new_ast_ctx = AstEval(
f"{ast_ctx.get_global_ctx_name()}.{func_name}", ast_ctx.get_global_ctx()
)
Function.install_ast_funcs(new_ast_ctx)
task = Function.create_task(
func_call(func, func_name, new_ast_ctx, *args, **kwargs), ast_ctx=new_ast_ctx
)
Function.task_done_callback_ctx(task, new_ast_ctx)
return task
return user_task_create
ast_funcs = {
"task.wait_until": wait_until_factory,
"task.create": user_task_create_factory,
}
Function.register_ast(ast_funcs)
async def user_task_add_done_callback(task, callback, *args, **kwargs):
"""Implement task.add_done_callback()."""
ast_ctx = None
if type(callback) is EvalFuncVar:
ast_ctx = callback.get_ast_ctx()
Function.task_add_done_callback(task, ast_ctx, callback, *args, **kwargs)
funcs = {
"task.add_done_callback": user_task_add_done_callback,
"task.executor": cls.user_task_executor,
}
Function.register(funcs)
try:
for i in range(0, 7):
cls.dow2int[locale.nl_langinfo(getattr(locale, f"ABDAY_{i + 1}")).lower()] = i
cls.dow2int[locale.nl_langinfo(getattr(locale, f"DAY_{i + 1}")).lower()] = i
except AttributeError:
# Win10 Python doesn't have locale.nl_langinfo, so default to English days of week
dow = [
"sunday",
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
]
for idx, name in enumerate(dow):
cls.dow2int[name] = idx
cls.dow2int[name[0:3]] = idx
@classmethod
async def wait_until(
cls,
ast_ctx,
state_trigger=None,
state_check_now=True,
time_trigger=None,
event_trigger=None,
mqtt_trigger=None,
webhook_trigger=None,
webhook_local_only=True,
webhook_methods=None,
timeout=None,
state_hold=None,
state_hold_false=None,
__test_handshake__=None,
):
"""Wait for zero or more triggers, until an optional timeout."""
if (
state_trigger is None
and time_trigger is None
and event_trigger is None
and mqtt_trigger is None
and webhook_trigger is None
):
if timeout is not None:
await asyncio.sleep(timeout)
return {"trigger_type": "timeout"}
return {"trigger_type": "none"}
state_trig_ident = set()
state_trig_ident_any = set()
state_trig_eval = None
event_trig_expr = None
mqtt_trig_expr = None
webhook_trig_expr = None
exc = None
notify_q = asyncio.Queue(0)
last_state_trig_time = None
state_trig_waiting = False
state_trig_notify_info = [None, None]
state_false_time = None
check_state_expr_on_start = state_check_now or state_hold_false is not None
if state_trigger is not None:
state_trig = []
if isinstance(state_trigger, str):
state_trigger = [state_trigger]
elif isinstance(state_trigger, set):
state_trigger = list(state_trigger)
#
# separate out the entries that are just state var names, which mean trigger
# on any change (no expr)
#
for trig in state_trigger:
if STATE_RE.match(trig):
state_trig_ident_any.add(trig)
else:
state_trig.append(trig)
if len(state_trig) > 0:
if len(state_trig) == 1:
state_trig_expr = state_trig[0]
else:
state_trig_expr = f"any([{', '.join(state_trig)}])"
state_trig_eval = AstEval(
f"{ast_ctx.name} state_trigger",
ast_ctx.get_global_ctx(),
logger_name=ast_ctx.get_logger_name(),
)
Function.install_ast_funcs(state_trig_eval)
state_trig_eval.parse(state_trig_expr, mode="eval")
state_trig_ident = await state_trig_eval.get_names()
exc = state_trig_eval.get_exception_obj()
if exc is not None:
raise exc
state_trig_ident.update(state_trig_ident_any)
if check_state_expr_on_start and state_trig_eval:
#
# check straight away to see if the condition is met
#
new_vars = State.notify_var_get(state_trig_ident, {})
state_trig_ok = await state_trig_eval.eval(new_vars)
exc = state_trig_eval.get_exception_obj()
if exc is not None:
raise exc
if state_hold_false is not None and not state_check_now:
#
# if state_trig_ok we wait until it is false;
# otherwise we consider now to be the start of the false hold time
#
state_false_time = None if state_trig_ok else time.monotonic()
elif state_hold is not None and state_trig_ok:
state_trig_waiting = True
state_trig_notify_info = [None, {"trigger_type": "state"}]
last_state_trig_time = time.monotonic()
_LOGGER.debug(
"trigger %s wait_until: state trigger immediately true; now waiting for state_hold of %g seconds",
ast_ctx.name,
state_hold,
)
elif state_trig_ok:
return {"trigger_type": "state"}
_LOGGER.debug(
"trigger %s wait_until: watching vars %s",
ast_ctx.name,
state_trig_ident,
)
if len(state_trig_ident) > 0:
await State.notify_add(state_trig_ident, notify_q)
if event_trigger is not None:
if isinstance(event_trigger, str):
event_trigger = [event_trigger]
if len(event_trigger) > 1:
event_trig_expr = AstEval(
f"{ast_ctx.name} event_trigger",
ast_ctx.get_global_ctx(),
logger_name=ast_ctx.get_logger_name(),
)
Function.install_ast_funcs(event_trig_expr)
event_trig_expr.parse(event_trigger[1], mode="eval")
exc = event_trig_expr.get_exception_obj()
if exc is not None:
if len(state_trig_ident) > 0:
State.notify_del(state_trig_ident, notify_q)
raise exc
Event.notify_add(event_trigger[0], notify_q)
if mqtt_trigger is not None:
if isinstance(mqtt_trigger, str):
mqtt_trigger = [mqtt_trigger]
if len(mqtt_trigger) > 1:
mqtt_trig_expr = AstEval(
f"{ast_ctx.name} mqtt_trigger",
ast_ctx.get_global_ctx(),
logger_name=ast_ctx.get_logger_name(),
)
Function.install_ast_funcs(mqtt_trig_expr)
mqtt_trig_expr.parse(mqtt_trigger[1], mode="eval")
exc = mqtt_trig_expr.get_exception_obj()
if exc is not None:
if len(state_trig_ident) > 0:
State.notify_del(state_trig_ident, notify_q)
raise exc
await Mqtt.notify_add(mqtt_trigger[0], notify_q)
if webhook_trigger is not None:
if isinstance(webhook_trigger, str):
webhook_trigger = [webhook_trigger]
if len(webhook_trigger) > 1:
webhook_trig_expr = AstEval(
f"{ast_ctx.name} webhook_trigger",
ast_ctx.get_global_ctx(),
logger_name=ast_ctx.get_logger_name(),
)
Function.install_ast_funcs(webhook_trig_expr)
webhook_trig_expr.parse(webhook_trigger[1], mode="eval")
exc = webhook_trig_expr.get_exception_obj()
if exc is not None:
if len(state_trig_ident) > 0:
State.notify_del(state_trig_ident, notify_q)
raise exc
if webhook_methods is None:
webhook_methods = {"POST", "PUT"}
Webhook.notify_add(webhook_trigger[0], webhook_local_only, webhook_methods, notify_q)
time0 = time.monotonic()
if __test_handshake__:
#
# used for testing to avoid race conditions
# we use this as a handshake that we are about to
# listen to the queue
#
State.set(__test_handshake__[0], __test_handshake__[1])
while True:
ret = None
this_timeout = None
state_trig_timeout = False
time_next = None
startup_time = None
now = dt_now()
if startup_time is None:
startup_time = now
if time_trigger is not None:
time_next, time_next_adj = await cls.timer_trigger_next(time_trigger, now, startup_time)
_LOGGER.debug(
"trigger %s wait_until time_next = %s, now = %s",
ast_ctx.name,
time_next,
now,
)
if time_next is not None:
this_timeout = (time_next_adj - now).total_seconds()
if timeout is not None:
time_left = time0 + timeout - time.monotonic()
if time_left <= 0:
ret = {"trigger_type": "timeout"}
break
if this_timeout is None or this_timeout > time_left:
ret = {"trigger_type": "timeout"}
this_timeout = time_left
time_next = now + dt.timedelta(seconds=this_timeout)
if state_trig_waiting:
time_left = last_state_trig_time + state_hold - time.monotonic()
if this_timeout is None or time_left < this_timeout:
this_timeout = time_left
state_trig_timeout = True
time_next = now + dt.timedelta(seconds=this_timeout)
if this_timeout is None:
if (
state_trigger is None
and event_trigger is None
and mqtt_trigger is None
and webhook_trigger is None
):
_LOGGER.debug(
"trigger %s wait_until no next time - returning with none",
ast_ctx.name,
)
ret = {"trigger_type": "none"}
break
_LOGGER.debug("trigger %s wait_until no timeout", ast_ctx.name)
notify_type, notify_info = await notify_q.get()
else:
timeout_occured = False
while True:
try:
this_timeout = max(0, this_timeout)
_LOGGER.debug("trigger %s wait_until %.6g secs", ast_ctx.name, this_timeout)
notify_type, notify_info = await asyncio.wait_for(
notify_q.get(), timeout=this_timeout
)
state_trig_timeout = False
except asyncio.TimeoutError:
actual_now = dt_now()
if actual_now < time_next:
this_timeout = (time_next - actual_now).total_seconds()
# tests/tests_function's simple now() requires us to ignore
# timeouts that are up to 1us too early; otherwise wait for
# longer until we are sure we are at or past time_next
if this_timeout > 1e-6:
continue
if not state_trig_timeout:
if not ret:
ret = {"trigger_type": "time"}
if time_next is not None:
ret["trigger_time"] = time_next
timeout_occured = True
break
if timeout_occured:
break
if state_trig_timeout:
ret = state_trig_notify_info[1]
state_trig_waiting = False
break
if notify_type == "state":
if notify_info:
new_vars, func_args = notify_info
else:
new_vars, func_args = None, {}
state_trig_ok = True
if not ident_any_values_changed(func_args, state_trig_ident_any):
# if var_name not in func_args we are state_check_now
if "var_name" in func_args and not ident_values_changed(func_args, state_trig_ident):
continue
if state_trig_eval:
state_trig_ok = await state_trig_eval.eval(new_vars)
exc = state_trig_eval.get_exception_obj()
if exc is not None:
break
if state_hold_false is not None:
if state_false_time is None:
if state_trig_ok:
#
# wasn't False, so ignore
#
continue
#
# first False, so remember when it is
#
state_false_time = time.monotonic()
elif state_trig_ok:
too_soon = time.monotonic() - state_false_time < state_hold_false
state_false_time = None
if too_soon:
#
# was False but not for long enough, so start over
#
continue
if state_hold is not None:
if state_trig_ok:
if not state_trig_waiting:
state_trig_waiting = True
state_trig_notify_info = notify_info
last_state_trig_time = time.monotonic()
_LOGGER.debug(
"trigger %s wait_until: got %s trigger; now waiting for state_hold of %g seconds",
notify_type,
ast_ctx.name,
state_hold,
)
else:
_LOGGER.debug(
"trigger %s wait_until: got %s trigger; still waiting for state_hold of %g seconds",
notify_type,
ast_ctx.name,
state_hold,
)
continue
if state_trig_waiting:
state_trig_waiting = False
_LOGGER.debug(
"trigger %s wait_until: %s trigger now false during state_hold; waiting for new trigger",
notify_type,
ast_ctx.name,
)
continue
if state_trig_ok:
ret = notify_info[1] if notify_info else None
break
elif notify_type == "event":
if event_trig_expr is None:
ret = notify_info
break
event_trig_ok = await event_trig_expr.eval(notify_info)
exc = event_trig_expr.get_exception_obj()
if exc is not None:
break
if event_trig_ok:
ret = notify_info
break
elif notify_type == "mqtt":
if mqtt_trig_expr is None:
ret = notify_info
break
mqtt_trig_ok = await mqtt_trig_expr.eval(notify_info)
exc = mqtt_trig_expr.get_exception_obj()
if exc is not None:
break
if mqtt_trig_ok:
ret = notify_info
break
elif notify_type == "webhook":
if webhook_trig_expr is None:
ret = notify_info
break
webhook_trig_ok = await webhook_trig_expr.eval(notify_info)
exc = webhook_trig_expr.get_exception_obj()
if exc is not None:
break
if webhook_trig_ok:
ret = notify_info
break
else:
_LOGGER.error(
"trigger %s wait_until got unexpected queue message %s",
ast_ctx.name,
notify_type,
)
if len(state_trig_ident) > 0:
State.notify_del(state_trig_ident, notify_q)
if event_trigger is not None:
Event.notify_del(event_trigger[0], notify_q)
if mqtt_trigger is not None:
Mqtt.notify_del(mqtt_trigger[0], notify_q)
if webhook_trigger is not None:
Webhook.notify_del(webhook_trigger[0], notify_q)
if exc:
raise exc
return ret
@classmethod
async def user_task_executor(cls, func, *args, **kwargs):
"""Implement task.executor()."""
if asyncio.iscoroutinefunction(func) or not callable(func):
raise TypeError(f"function {func} is not callable by task.executor")
if isinstance(func, EvalFuncVar):
raise TypeError(
"pyscript functions can't be called from task.executor - must be a regular python function"
)
return await cls.hass.async_add_executor_job(functools.partial(func, **kwargs), *args)
@classmethod
async def parse_date_time(cls, date_time_str, day_offset, now, startup_time):
"""Parse a date time string, returning datetime."""
year = now.year
month = now.month
day = now.day
dt_str_orig = dt_str = date_time_str.strip().lower()
#
# parse the date
#
match0 = re.match(r"0*(\d+)[-/]0*(\d+)(?:[-/]0*(\d+))?", dt_str)
match1 = re.match(r"(\w+)", dt_str)
if match0:
if match0[3]:
year, month, day = int(match0[1]), int(match0[2]), int(match0[3])
else:
month, day = int(match0[1]), int(match0[2])
day_offset = 0 # explicit date means no offset
dt_str = dt_str[len(match0.group(0)) :]
elif match1:
skip = True
if match1[1] in cls.dow2int:
dow = cls.dow2int[match1[1]]
if dow >= (now.isoweekday() % 7):
day_offset = dow - (now.isoweekday() % 7)
else:
day_offset = 7 + dow - (now.isoweekday() % 7)
elif match1[1] == "today":
day_offset = 0
elif match1[1] == "tomorrow":
day_offset = 1
else:
skip = False
if skip:
dt_str = dt_str[len(match1.group(0)) :]
if day_offset != 0:
now = dt.datetime(year, month, day) + dt.timedelta(days=day_offset)
year = now.year
month = now.month
day = now.day
else:
now = dt.datetime(year, month, day)
dt_str = dt_str.strip()
if len(dt_str) == 0:
return now
#
# parse the time
#
match0 = re.match(r"0*(\d+):0*(\d+)(?::0*(\d*\.?\d+(?:[eE][-+]?\d+)?))?", dt_str)
if match0:
if match0[3]:
hour, mins, sec = int(match0[1]), int(match0[2]), float(match0[3])
else:
hour, mins, sec = int(match0[1]), int(match0[2]), 0
dt_str = dt_str[len(match0.group(0)) :]
elif dt_str.startswith("sunrise") or dt_str.startswith("sunset"):
location = sun.get_astral_location(cls.hass)
if isinstance(location, tuple):
# HA core-2021.5.0 included this breaking change: https://github.com/home-assistant/core/pull/48573.
# As part of the upgrade to astral 2.2, sun.get_astral_location() now returns a tuple including the
# elevation. We just want the astral.location.Location object.
location = location[0]
try:
if dt_str.startswith("sunrise"):
time_sun = await cls.hass.async_add_executor_job(
location.sunrise, dt.date(year, month, day)
)
dt_str = dt_str[7:]
else:
time_sun = await cls.hass.async_add_executor_job(
location.sunset, dt.date(year, month, day)
)
dt_str = dt_str[6:]
except Exception:
_LOGGER.warning("'%s' not defined at this latitude", dt_str)
# return something in the past so it is ignored
return now - dt.timedelta(days=100)
now += time_sun.date() - now.date()
hour, mins, sec = time_sun.hour, time_sun.minute, time_sun.second
elif dt_str.startswith("noon"):
hour, mins, sec = 12, 0, 0
dt_str = dt_str[4:]
elif dt_str.startswith("midnight"):
hour, mins, sec = 0, 0, 0
dt_str = dt_str[8:]
elif dt_str.startswith("now") and dt_str_orig == dt_str:
#
# "now" means the first time, and only matches if there was no date specification
#
hour, mins, sec = 0, 0, 0
now = startup_time
dt_str = dt_str[3:]
else:
hour, mins, sec = 0, 0, 0
now += dt.timedelta(seconds=sec + 60 * (mins + 60 * hour))
#
# parse the offset
#
dt_str = dt_str.strip()
if len(dt_str) > 0:
now = now + dt.timedelta(seconds=parse_time_offset(dt_str))
return now
@classmethod
async def timer_active_check(cls, time_spec, now, startup_time):
"""Check if the given time matches the time specification."""
results = {"+": [], "-": []}
for entry in time_spec if isinstance(time_spec, list) else [time_spec]:
this_match = False
negate = False
active_str = entry.strip()
if active_str.startswith("not"):
negate = True
active_str = active_str.replace("not ", "")
cron_match = re.match(r"cron\((?P<cron_expr>.*)\)", active_str)
range_expr = re.match(r"range\(([^,]+),\s?([^,]+)\)", active_str)
if cron_match:
if not croniter.is_valid(cron_match.group("cron_expr")):
_LOGGER.error("Invalid cron expression: %s", cron_match)
return False
this_match = croniter.match(cron_match.group("cron_expr"), now)
elif range_expr:
try:
dt_start, dt_end = range_expr.groups()
except ValueError as exc:
_LOGGER.error("Invalid range expression: %s", exc)
return False
start = await cls.parse_date_time(dt_start.strip(), 0, now, startup_time)
end = await cls.parse_date_time(dt_end.strip(), 0, start, startup_time)
if start <= end:
this_match = start <= now <= end
else: # Over midnight
this_match = now >= start or now <= end
else:
_LOGGER.error("Invalid time_active expression: %s", active_str)
return False
if negate:
results["-"].append(not this_match)
else:
results["+"].append(this_match)
# An empty spec, or only neg specs, is True
result = (any(results["+"]) if results["+"] else True) and all(results["-"])
return result
@classmethod
async def timer_trigger_next(cls, time_spec, now, startup_time):
"""Return the next trigger time based on the given time and time specification."""
next_time = None
next_time_adj = None
if not isinstance(time_spec, list):
time_spec = [time_spec]
for spec in time_spec:
cron_match = re.search(r"cron\((?P<cron_expr>.*)\)", spec)
match1 = re.split(r"once\((.*)\)", spec)
match2 = re.split(r"period\(([^,]*),([^,]*)(?:,([^,]*))?\)", spec)
if cron_match:
if not croniter.is_valid(cron_match.group("cron_expr")):
_LOGGER.error("Invalid cron expression: %s", cron_match)
continue
#
# Handling DST changes is tricky; all times in pyscript are naive (no timezone). This is the
# one part of the code where we do check timezones, in case now and next_time bracket a DST
# change. We return next_time as the local time of the next trigger according to the cron
# spec, and next_time_adj is potentially adjusted so that (next_time_adj - now) is the correct
# timedelta to wait (eg: if cron is a daily trigger at 6am, next_time will always be 6am
# tomorrow, and next_time_adj will also by 6am, except on the day of a DST change, when it
# will be 5am or 7am, such that (next_time_adj - now) is 23 hours or 25 hours.
#
# We might have to fetch multiple croniter times, in case (next_time_adj - now) is non-positive
# after a DST change.
#
# Also, datetime doesn't correctly subtract datetimes in different timezones, so we need to compute
# the different in UTC. See https://blog.ganssle.io/articles/2018/02/aware-datetime-arithmetic.html.
#
cron_iter = croniter(cron_match.group("cron_expr"), now, dt.datetime)
delta = None
while delta is None or delta.total_seconds() <= 0:
val = cron_iter.get_next()
delta = dt_util.as_local(val).astimezone(dt_util.UTC) - dt_util.as_local(now).astimezone(
dt_util.UTC
)
if next_time is None or val < next_time:
next_time = val
next_time_adj = now + delta
elif len(match1) == 3:
this_t = await cls.parse_date_time(match1[1].strip(), 0, now, startup_time)
day_offset = (now - this_t).days + 1
if day_offset != 0 and this_t != startup_time:
#
# Try a day offset (won't make a difference if spec has full date)
#
this_t = await cls.parse_date_time(match1[1].strip(), day_offset, now, startup_time)
startup = now == this_t and now == startup_time
if (now < this_t or startup) and (next_time is None or this_t < next_time):
next_time_adj = next_time = this_t
elif len(match2) == 5:
start_str, period_str = match2[1].strip(), match2[2].strip()
start = await cls.parse_date_time(start_str, 0, now, startup_time)
period = parse_time_offset(period_str)
if period <= 0:
_LOGGER.error("Invalid non-positive period %s in period(): %s", period, time_spec)
continue
if match2[3] is None:
startup = now == start and now == startup_time
if (now < start or startup) and (next_time is None or start < next_time):
next_time_adj = next_time = start
if now >= start and not startup:
secs = period * (1.0 + math.floor((now - start).total_seconds() / period))
this_t = start + dt.timedelta(seconds=secs)
if now < this_t and (next_time is None or this_t < next_time):
next_time_adj = next_time = this_t
continue
end_str = match2[3].strip()
end = await cls.parse_date_time(end_str, 0, now, startup_time)
end_offset = 1 if end < start else 0
for day in [-1, 0, 1]:
start = await cls.parse_date_time(start_str, day, now, startup_time)
end = await cls.parse_date_time(end_str, day + end_offset, now, startup_time)
if now < start or (now == start and now == startup_time):
if next_time is None or start < next_time:
next_time_adj = next_time = start
break
secs = period * (1.0 + math.floor((now - start).total_seconds() / period))
this_t = start + dt.timedelta(seconds=secs)
if start <= this_t <= end:
if next_time is None or this_t < next_time:
next_time_adj = next_time = this_t
break
else:
_LOGGER.warning("Can't parse %s in time_trigger check", spec)
return next_time, next_time_adj
class TrigInfo:
"""Class for all trigger-decorated functions."""
def __init__(
self,
name,
trig_cfg,
global_ctx=None,
):
"""Create a new TrigInfo."""
self.name = name
self.task = None
self.global_ctx = global_ctx
self.trig_cfg = trig_cfg
self.state_trigger = trig_cfg.get("state_trigger", {}).get("args", None)
self.state_trigger_kwargs = trig_cfg.get("state_trigger", {}).get("kwargs", {})
self.state_hold = self.state_trigger_kwargs.get("state_hold", None)
self.state_hold_false = self.state_trigger_kwargs.get("state_hold_false", None)
self.state_check_now = self.state_trigger_kwargs.get("state_check_now", False)
self.state_user_watch = self.state_trigger_kwargs.get("watch", None)
self.time_trigger = trig_cfg.get("time_trigger", {}).get("args", None)
self.time_trigger_kwargs = trig_cfg.get("time_trigger", {}).get("kwargs", {})
self.event_trigger = trig_cfg.get("event_trigger", {}).get("args", None)
self.event_trigger_kwargs = trig_cfg.get("event_trigger", {}).get("kwargs", {})
self.mqtt_trigger = trig_cfg.get("mqtt_trigger", {}).get("args", None)
self.mqtt_trigger_kwargs = trig_cfg.get("mqtt_trigger", {}).get("kwargs", {})
self.webhook_trigger = trig_cfg.get("webhook_trigger", {}).get("args", None)
self.webhook_trigger_kwargs = trig_cfg.get("webhook_trigger", {}).get("kwargs", {})
self.webhook_local_only = self.webhook_trigger_kwargs.get("local_only", True)
self.webhook_methods = self.webhook_trigger_kwargs.get("methods", {"POST", "PUT"})
self.state_active = trig_cfg.get("state_active", {}).get("args", None)
self.time_active = trig_cfg.get("time_active", {}).get("args", None)
self.time_active_hold_off = trig_cfg.get("time_active", {}).get("kwargs", {}).get("hold_off", None)
self.task_unique = trig_cfg.get("task_unique", {}).get("args", None)
self.task_unique_kwargs = trig_cfg.get("task_unique", {}).get("kwargs", None)
self.action = trig_cfg.get("action")
self.global_sym_table = trig_cfg.get("global_sym_table", {})
self.notify_q = asyncio.Queue(0)
self.active_expr = None
self.state_active_ident = None
self.state_trig_expr = None
self.state_trig_eval = None
self.state_trig_ident = None
self.state_trig_ident_any = set()
self.event_trig_expr = None
self.mqtt_trig_expr = None
self.webhook_trig_expr = None
self.have_trigger = False
self.setup_ok = False
self.run_on_startup = False
self.run_on_shutdown = False
if self.state_active is not None:
self.active_expr = AstEval(
f"{self.name} @state_active()", self.global_ctx, logger_name=self.name
)
Function.install_ast_funcs(self.active_expr)
self.active_expr.parse(self.state_active, mode="eval")
exc = self.active_expr.get_exception_long()
if exc is not None:
self.active_expr.get_logger().error(exc)
return
if "time_trigger" in trig_cfg and self.time_trigger is None:
self.run_on_startup = True
if self.time_trigger is not None:
while "startup" in self.time_trigger:
self.run_on_startup = True
self.time_trigger.remove("startup")
while "shutdown" in self.time_trigger:
self.run_on_shutdown = True
self.time_trigger.remove("shutdown")
if len(self.time_trigger) == 0:
self.time_trigger = None
if self.state_trigger is not None:
state_trig = []
for triggers in self.state_trigger:
if isinstance(triggers, str):
triggers = [triggers]
elif isinstance(triggers, set):
triggers = list(triggers)
#
# separate out the entries that are just state var names, which mean trigger
# on any change (no expr)
#
for trig in triggers:
if STATE_RE.match(trig):
self.state_trig_ident_any.add(trig)
else:
state_trig.append(trig)
if len(state_trig) > 0:
if len(state_trig) == 1:
self.state_trig_expr = state_trig[0]
else:
self.state_trig_expr = f"any([{', '.join(state_trig)}])"
self.state_trig_eval = AstEval(
f"{self.name} @state_trigger()", self.global_ctx, logger_name=self.name
)
Function.install_ast_funcs(self.state_trig_eval)
self.state_trig_eval.parse(self.state_trig_expr, mode="eval")
exc = self.state_trig_eval.get_exception_long()
if exc is not None:
self.state_trig_eval.get_logger().error(exc)
return
self.have_trigger = True
if self.event_trigger is not None:
if len(self.event_trigger) == 2:
self.event_trig_expr = AstEval(
f"{self.name} @event_trigger()",
self.global_ctx,
logger_name=self.name,
)
Function.install_ast_funcs(self.event_trig_expr)
self.event_trig_expr.parse(self.event_trigger[1], mode="eval")
exc = self.event_trig_expr.get_exception_long()
if exc is not None:
self.event_trig_expr.get_logger().error(exc)
return
self.have_trigger = True
if self.mqtt_trigger is not None:
if len(self.mqtt_trigger) == 2:
self.mqtt_trig_expr = AstEval(
f"{self.name} @mqtt_trigger()",
self.global_ctx,
logger_name=self.name,
)
Function.install_ast_funcs(self.mqtt_trig_expr)
self.mqtt_trig_expr.parse(self.mqtt_trigger[1], mode="eval")
exc = self.mqtt_trig_expr.get_exception_long()
if exc is not None:
self.mqtt_trig_expr.get_logger().error(exc)
return
self.have_trigger = True
if self.webhook_trigger is not None:
if len(self.webhook_trigger) == 2:
self.webhook_trig_expr = AstEval(
f"{self.name} @webhook_trigger()",
self.global_ctx,
logger_name=self.name,
)
Function.install_ast_funcs(self.webhook_trig_expr)
self.webhook_trig_expr.parse(self.webhook_trigger[1], mode="eval")
exc = self.webhook_trig_expr.get_exception_long()
if exc is not None:
self.webhook_trig_expr.get_logger().error(exc)
return
self.have_trigger = True
self.setup_ok = True
def stop(self):
"""Stop this trigger task."""
if self.task:
if self.state_trig_ident:
State.notify_del(self.state_trig_ident, self.notify_q)
if self.event_trigger is not None:
Event.notify_del(self.event_trigger[0], self.notify_q)
if self.mqtt_trigger is not None:
Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q)
if self.webhook_trigger is not None:
Webhook.notify_del(self.webhook_trigger[0], self.notify_q)
if self.task:
Function.reaper_cancel(self.task)
self.task = None
if self.run_on_shutdown:
notify_type = "shutdown"
notify_info = {"trigger_type": "time", "trigger_time": "shutdown"}
notify_info.update(self.time_trigger_kwargs.get("kwargs", {}))
action_future = self.call_action(notify_type, notify_info, run_task=False)
Function.waiter_await(action_future)
def start(self):
"""Start this trigger task."""
if not self.task and self.setup_ok:
self.task = Function.create_task(self.trigger_watch())
_LOGGER.debug("trigger %s is active", self.name)
async def trigger_watch(self):
"""Task that runs for each trigger, waiting for the next trigger and calling the function."""
try:
if self.state_trigger is not None:
self.state_trig_ident = set()
if self.state_user_watch:
if isinstance(self.state_user_watch, list):
self.state_trig_ident = set(self.state_user_watch)
else:
self.state_trig_ident = self.state_user_watch
else:
if self.state_trig_eval:
self.state_trig_ident = await self.state_trig_eval.get_names()
self.state_trig_ident.update(self.state_trig_ident_any)
_LOGGER.debug("trigger %s: watching vars %s", self.name, self.state_trig_ident)
if len(self.state_trig_ident) == 0 or not await State.notify_add(
self.state_trig_ident, self.notify_q
):
_LOGGER.error(
"trigger %s: @state_trigger is not watching any variables; will never trigger",
self.name,
)
if self.active_expr:
self.state_active_ident = await self.active_expr.get_names()
if self.event_trigger is not None:
_LOGGER.debug("trigger %s adding event_trigger %s", self.name, self.event_trigger[0])
Event.notify_add(self.event_trigger[0], self.notify_q)
if self.mqtt_trigger is not None:
_LOGGER.debug("trigger %s adding mqtt_trigger %s", self.name, self.mqtt_trigger[0])
await Mqtt.notify_add(self.mqtt_trigger[0], self.notify_q)
if self.webhook_trigger is not None:
_LOGGER.debug("trigger %s adding webhook_trigger %s", self.name, self.webhook_trigger[0])
Webhook.notify_add(
self.webhook_trigger[0], self.webhook_local_only, self.webhook_methods, self.notify_q
)
last_trig_time = None
last_state_trig_time = None
state_trig_waiting = False
state_trig_notify_info = [None, None]
state_false_time = None
now = startup_time = None
check_state_expr_on_start = self.state_check_now or self.state_hold_false is not None
while True:
timeout = None
state_trig_timeout = False
notify_info = None
notify_type = None
now = dt_now()
if startup_time is None:
startup_time = now
if self.run_on_startup:
#
# first time only - skip waiting for other triggers
#
notify_type = "startup"
notify_info = {"trigger_type": "time", "trigger_time": "startup"}
self.run_on_startup = False
elif check_state_expr_on_start:
#
# first time only - skip wait and check state trigger
#
notify_type = "state"
if self.state_trig_ident:
notify_vars = State.notify_var_get(self.state_trig_ident, {})
else:
notify_vars = {}
notify_info = [notify_vars, {"trigger_type": notify_type}]
check_state_expr_on_start = False
else:
if self.time_trigger:
time_next, time_next_adj = await TrigTime.timer_trigger_next(
self.time_trigger, now, startup_time
)
_LOGGER.debug(
"trigger %s time_next = %s, now = %s",
self.name,
time_next,
now,
)
if time_next is not None:
timeout = (time_next_adj - now).total_seconds()
if state_trig_waiting:
time_left = last_state_trig_time + self.state_hold - time.monotonic()
if timeout is None or time_left < timeout:
timeout = time_left
time_next = now + dt.timedelta(seconds=timeout)
state_trig_timeout = True
if timeout is not None:
while True:
try:
timeout = max(0, timeout)
_LOGGER.debug("trigger %s waiting for %.6g secs", self.name, timeout)
notify_type, notify_info = await asyncio.wait_for(
self.notify_q.get(), timeout=timeout
)
state_trig_timeout = False
now = dt_now()
except asyncio.TimeoutError:
actual_now = dt_now()
if actual_now < time_next:
timeout = (time_next - actual_now).total_seconds()
continue
now = time_next
if not state_trig_timeout:
notify_type = "time"
notify_info = {
"trigger_type": "time",
"trigger_time": time_next,
}
break
elif self.have_trigger:
_LOGGER.debug("trigger %s waiting for state change or event", self.name)
notify_type, notify_info = await self.notify_q.get()
now = dt_now()
else:
_LOGGER.debug("trigger %s finished", self.name)
return
#
# check the trigger-specific expressions
#
trig_ok = True
new_vars = {}
user_kwargs = {}
if state_trig_timeout:
new_vars, func_args = state_trig_notify_info
state_trig_waiting = False
elif notify_type == "state":
new_vars, func_args = notify_info
user_kwargs = self.state_trigger_kwargs.get("kwargs", {})
if not ident_any_values_changed(func_args, self.state_trig_ident_any):
#
# if var_name not in func_args we are check_state_expr_on_start
#
if "var_name" in func_args and not ident_values_changed(
func_args, self.state_trig_ident
):
continue
if self.state_trig_eval:
trig_ok = await self.state_trig_eval.eval(new_vars)
exc = self.state_trig_eval.get_exception_long()
if exc is not None:
self.state_trig_eval.get_logger().error(exc)
trig_ok = False
if self.state_hold_false is not None:
if "var_name" not in func_args:
#
# this is check_state_expr_on_start check
# if immediately true, force wait until False
# otherwise start False wait now
#
state_false_time = None if trig_ok else time.monotonic()
if not self.state_check_now:
continue
if state_false_time is None:
if trig_ok:
#
# wasn't False, so ignore after initial check
#
if "var_name" in func_args:
continue
else:
#
# first False, so remember when it is
#
state_false_time = time.monotonic()
elif trig_ok and "var_name" in func_args:
too_soon = time.monotonic() - state_false_time < self.state_hold_false
state_false_time = None
if too_soon:
#
# was False but not for long enough, so start over
#
continue
else:
trig_ok = False
if self.state_hold is not None:
if trig_ok:
if not state_trig_waiting:
state_trig_waiting = True
state_trig_notify_info = notify_info
last_state_trig_time = time.monotonic()
_LOGGER.debug(
"trigger %s got %s trigger; now waiting for state_hold of %g seconds",
notify_type,
self.name,
self.state_hold,
)
else:
_LOGGER.debug(
"trigger %s got %s trigger; still waiting for state_hold of %g seconds",
notify_type,
self.name,
self.state_hold,
)
func_args.update(user_kwargs)
continue
if state_trig_waiting:
state_trig_waiting = False
_LOGGER.debug(
"trigger %s %s trigger now false during state_hold; waiting for new trigger",
notify_type,
self.name,
)
continue
elif notify_type == "event":
func_args = notify_info
user_kwargs = self.event_trigger_kwargs.get("kwargs", {})
if self.event_trig_expr:
trig_ok = await self.event_trig_expr.eval(notify_info)
elif notify_type == "mqtt":
func_args = notify_info
user_kwargs = self.mqtt_trigger_kwargs.get("kwargs", {})
if self.mqtt_trig_expr:
trig_ok = await self.mqtt_trig_expr.eval(notify_info)
elif notify_type == "webhook":
func_args = notify_info
user_kwargs = self.webhook_trigger_kwargs.get("kwargs", {})
if self.webhook_trig_expr:
trig_ok = await self.webhook_trig_expr.eval(notify_info)
else:
user_kwargs = self.time_trigger_kwargs.get("kwargs", {})
func_args = notify_info
#
# now check the state and time active expressions
#
if trig_ok and self.active_expr:
active_vars = State.notify_var_get(self.state_active_ident, new_vars)
trig_ok = await self.active_expr.eval(active_vars)
exc = self.active_expr.get_exception_long()
if exc is not None:
self.active_expr.get_logger().error(exc)
trig_ok = False
if trig_ok and self.time_active:
trig_ok = await TrigTime.timer_active_check(self.time_active, now, startup_time)
if not trig_ok:
_LOGGER.debug(
"trigger %s got %s trigger, but not active",
self.name,
notify_type,
)
continue
if (
self.time_active_hold_off is not None
and last_trig_time is not None
and time.monotonic() < last_trig_time + self.time_active_hold_off
):
_LOGGER.debug(
"trigger %s got %s trigger, but less than %s seconds since last trigger, so skipping",
notify_type,
self.name,
self.time_active_hold_off,
)
continue
func_args.update(user_kwargs)
if self.call_action(notify_type, func_args):
last_trig_time = time.monotonic()
except asyncio.CancelledError:
raise
except Exception as exc:
# _LOGGER.error(f"{self.name}: " + traceback.format_exc(-1))
_LOGGER.error("%s: %s", self.name, exc)
if self.state_trig_ident:
State.notify_del(self.state_trig_ident, self.notify_q)
if self.event_trigger is not None:
Event.notify_del(self.event_trigger[0], self.notify_q)
if self.mqtt_trigger is not None:
Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q)
if self.webhook_trigger is not None:
Webhook.notify_del(self.webhook_trigger[0], self.notify_q)
return
def call_action(self, notify_type, func_args, run_task=True):
"""Call the trigger action function."""
action_ast_ctx = AstEval(f"{self.action.global_ctx_name}.{self.action.name}", self.action.global_ctx)
Function.install_ast_funcs(action_ast_ctx)
task_unique_func = None
if self.task_unique is not None:
task_unique_func = Function.task_unique_factory(action_ast_ctx)
#
# check for @task_unique with kill_me=True
#
if (
self.task_unique is not None
and self.task_unique_kwargs
and self.task_unique_kwargs["kill_me"]
and Function.unique_name_used(action_ast_ctx, self.task_unique)
):
_LOGGER.debug(
"trigger %s got %s trigger, @task_unique kill_me=True prevented new action",
notify_type,
self.name,
)
return False
# Create new HASS Context with incoming as parent
if "context" in func_args and isinstance(func_args["context"], Context):
hass_context = Context(parent_id=func_args["context"].id)
else:
hass_context = Context()
# Fire an event indicating that pyscript is running
# Note: the event must have an entity_id for logbook to work correctly.
ev_name = self.name.replace(".", "_")
ev_entity_id = f"pyscript.{ev_name}"
event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": func_args}
Function.hass.bus.async_fire("pyscript_running", event_data, context=hass_context)
_LOGGER.debug(
"trigger %s got %s trigger, running action (kwargs = %s)",
self.name,
notify_type,
func_args,
)
async def do_func_call(func, ast_ctx, task_unique, task_unique_func, hass_context, **kwargs):
# Store HASS Context for this Task
Function.store_hass_context(hass_context)
if task_unique and task_unique_func:
await task_unique_func(task_unique)
await ast_ctx.call_func(func, None, **kwargs)
if ast_ctx.get_exception_obj():
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
func = do_func_call(
self.action,
action_ast_ctx,
self.task_unique,
task_unique_func,
hass_context,
**func_args,
)
if run_task:
task = Function.create_task(func, ast_ctx=action_ast_ctx)
Function.task_done_callback_ctx(task, action_ast_ctx)
return True
return func