Files
homeassistant_config/config/custom_components/pyscript/function.py
2024-08-09 06:45:02 +02:00

520 lines
18 KiB
Python

"""Function call handling."""
import asyncio
import logging
import traceback
from homeassistant.core import Context
from .const import LOGGER_PATH, SERVICE_RESPONSE_NONE, SERVICE_RESPONSE_ONLY
_LOGGER = logging.getLogger(LOGGER_PATH + ".function")
class Function:
"""Define function handler functions."""
#
# Global hass instance
#
hass = None
#
# Mappings of tasks ids <-> task names
#
unique_task2name = {}
unique_name2task = {}
#
# Mappings of task id to hass contexts
task2context = {}
#
# Set of tasks that are running
#
our_tasks = set()
#
# Done callbacks for each task
#
task2cb = {}
#
# initial list of available functions
#
functions = {}
#
# Functions that take the AstEval context as a first argument,
# which is needed by a handful of special functions that need the
# ast context
#
ast_functions = {}
#
# task id of the task that cancels and waits for other tasks,
#
task_reaper = None
task_reaper_q = None
#
# task id of the task that awaits for coros (used by shutdown triggers)
#
task_waiter = None
task_waiter_q = None
#
# reference counting for service registrations; the new @service trigger
# registers the service call before the old one is removed, so we only
# remove the service registration when the reference count goes to zero
#
service_cnt = {}
#
# save the global_ctx name where a service is registered so we can raise
# an exception if it gets registered by a different global_ctx.
#
service2global_ctx = {}
def __init__(self):
"""Warn on Function instantiation."""
_LOGGER.error("Function class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize Function."""
cls.hass = hass
cls.functions.update(
{
"event.fire": cls.event_fire,
"service.call": cls.service_call,
"service.has_service": cls.service_has_service,
"task.cancel": cls.user_task_cancel,
"task.current_task": cls.user_task_current_task,
"task.remove_done_callback": cls.user_task_remove_done_callback,
"task.sleep": cls.async_sleep,
"task.wait": cls.user_task_wait,
}
)
cls.ast_functions.update(
{
"log.debug": lambda ast_ctx: ast_ctx.get_logger().debug,
"log.error": lambda ast_ctx: ast_ctx.get_logger().error,
"log.info": lambda ast_ctx: ast_ctx.get_logger().info,
"log.warning": lambda ast_ctx: ast_ctx.get_logger().warning,
"print": lambda ast_ctx: ast_ctx.get_logger().debug,
"task.name2id": cls.task_name2id_factory,
"task.unique": cls.task_unique_factory,
}
)
#
# start a task which is a reaper for canceled tasks, since some # functions
# like TrigInfo.stop() can't be async (it's called from a __del__ method)
#
async def task_reaper(reaper_q):
while True:
try:
cmd = await reaper_q.get()
if cmd[0] == "exit":
return
if cmd[0] == "cancel":
try:
cmd[1].cancel()
await cmd[1]
except asyncio.CancelledError:
pass
else:
_LOGGER.error("task_reaper: unknown command %s", cmd[0])
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("task_reaper: got exception %s", traceback.format_exc(-1))
if not cls.task_reaper:
cls.task_reaper_q = asyncio.Queue(0)
cls.task_reaper = cls.create_task(task_reaper(cls.task_reaper_q))
#
# start a task which creates tasks to run coros, and then syncs on their completion;
# this is used by the shutdown trigger
#
async def task_waiter(waiter_q):
aws = []
while True:
try:
cmd = await waiter_q.get()
if cmd[0] == "exit":
return
if cmd[0] == "await":
aws.append(cls.create_task(cmd[1]))
elif cmd[0] == "sync":
if len(aws) > 0:
await asyncio.gather(*aws)
aws = []
await cmd[1].put(0)
else:
_LOGGER.error("task_waiter: unknown command %s", cmd[0])
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("task_waiter: got exception %s", traceback.format_exc(-1))
if not cls.task_waiter:
cls.task_waiter_q = asyncio.Queue(0)
cls.task_waiter = cls.create_task(task_waiter(cls.task_waiter_q))
@classmethod
def reaper_cancel(cls, task):
"""Send a task to be canceled by the reaper."""
cls.task_reaper_q.put_nowait(["cancel", task])
@classmethod
async def reaper_stop(cls):
"""Tell the reaper task to exit."""
if cls.task_reaper:
cls.task_reaper_q.put_nowait(["exit"])
await cls.task_reaper
cls.task_reaper = None
cls.task_reaper_q = None
@classmethod
def waiter_await(cls, coro):
"""Send a coro to be awaited by the waiter task."""
cls.task_waiter_q.put_nowait(["await", coro])
@classmethod
async def waiter_sync(cls):
"""Wait until the waiter queue is empty."""
if cls.task_waiter:
sync_q = asyncio.Queue(0)
cls.task_waiter_q.put_nowait(["sync", sync_q])
await sync_q.get()
@classmethod
async def waiter_stop(cls):
"""Tell the waiter task to exit."""
if cls.task_waiter:
cls.task_waiter_q.put_nowait(["exit"])
await cls.task_waiter
cls.task_waiter = None
cls.task_waiter_q = None
@classmethod
async def async_sleep(cls, duration):
"""Implement task.sleep()."""
await asyncio.sleep(float(duration))
@classmethod
async def event_fire(cls, event_type, **kwargs):
"""Implement event.fire()."""
curr_task = asyncio.current_task()
if "context" in kwargs and isinstance(kwargs["context"], Context):
context = kwargs["context"]
del kwargs["context"]
else:
context = cls.task2context.get(curr_task, None)
cls.hass.bus.async_fire(event_type, kwargs, context=context)
@classmethod
def store_hass_context(cls, hass_context):
"""Store a context against the running task."""
curr_task = asyncio.current_task()
cls.task2context[curr_task] = hass_context
@classmethod
def task_unique_factory(cls, ctx):
"""Define and return task.unique() for this context."""
async def task_unique(name, kill_me=False):
"""Implement task.unique()."""
name = f"{ctx.get_global_ctx_name()}.{name}"
curr_task = asyncio.current_task()
if name in cls.unique_name2task:
task = cls.unique_name2task[name]
if kill_me:
if task != curr_task:
#
# it seems we can't cancel ourselves, so we
# tell the reaper task to cancel us
#
cls.reaper_cancel(curr_task)
# wait to be canceled
await asyncio.sleep(100000)
elif task != curr_task and task in cls.our_tasks:
# only cancel tasks if they are ones we started
cls.reaper_cancel(task)
if curr_task in cls.our_tasks:
if name in cls.unique_name2task:
task = cls.unique_name2task[name]
if task in cls.unique_task2name:
cls.unique_task2name[task].discard(name)
cls.unique_name2task[name] = curr_task
if curr_task not in cls.unique_task2name:
cls.unique_task2name[curr_task] = set()
cls.unique_task2name[curr_task].add(name)
return task_unique
@classmethod
async def user_task_cancel(cls, task=None):
"""Implement task.cancel()."""
do_sleep = False
if not task:
task = asyncio.current_task()
do_sleep = True
if task not in cls.our_tasks:
raise TypeError(f"{task} is not a user-started task")
cls.reaper_cancel(task)
if do_sleep:
# wait to be canceled
await asyncio.sleep(100000)
@classmethod
async def user_task_current_task(cls):
"""Implement task.current_task()."""
return asyncio.current_task()
@classmethod
def task_name2id_factory(cls, ctx):
"""Define and return task.name2id() for this context."""
def user_task_name2id(name=None):
"""Implement task.name2id()."""
prefix = f"{ctx.get_global_ctx_name()}."
if name is None:
ret = {}
for task_name, task_id in cls.unique_name2task.items():
if task_name.startswith(prefix):
ret[task_name[len(prefix) :]] = task_id
return ret
if prefix + name in cls.unique_name2task:
return cls.unique_name2task[prefix + name]
raise NameError(f"task name '{name}' is unknown")
return user_task_name2id
@classmethod
async def user_task_wait(cls, aws, **kwargs):
"""Implement task.wait()."""
return await asyncio.wait(aws, **kwargs)
@classmethod
def user_task_remove_done_callback(cls, task, callback):
"""Implement task.remove_done_callback()."""
cls.task2cb[task]["cb"].pop(callback, None)
@classmethod
def unique_name_used(cls, ctx, name):
"""Return whether the current unique name is in use."""
name = f"{ctx.get_global_ctx_name()}.{name}"
return name in cls.unique_name2task
@classmethod
def service_has_service(cls, domain, name):
"""Implement service.has_service()."""
return cls.hass.services.has_service(domain, name)
@classmethod
async def service_call(cls, domain, name, **kwargs):
"""Implement service.call()."""
curr_task = asyncio.current_task()
hass_args = {}
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default
return await cls.hass_services_async_call(domain, name, kwargs, **hass_args)
@classmethod
async def service_completions(cls, root):
"""Return possible completions of HASS services."""
words = set()
services = cls.hass.services.async_services()
num_period = root.count(".")
if num_period == 1:
domain, svc_root = root.split(".")
if domain in services:
words |= {f"{domain}.{svc}" for svc in services[domain] if svc.lower().startswith(svc_root)}
elif num_period == 0:
words |= {domain for domain in services if domain.lower().startswith(root)}
return words
@classmethod
async def func_completions(cls, root):
"""Return possible completions of functions."""
funcs = {**cls.functions, **cls.ast_functions}
words = {name for name in funcs if name.lower().startswith(root)}
return words
@classmethod
def register(cls, funcs):
"""Register functions to be available for calling."""
cls.functions.update(funcs)
@classmethod
def register_ast(cls, funcs):
"""Register functions that need ast context to be available for calling."""
cls.ast_functions.update(funcs)
@classmethod
def install_ast_funcs(cls, ast_ctx):
"""Install ast functions into the local symbol table."""
sym_table = {name: func(ast_ctx) for name, func in cls.ast_functions.items()}
ast_ctx.set_local_sym_table(sym_table)
@classmethod
def get(cls, name):
"""Lookup a function locally and then as a service."""
func = cls.functions.get(name, None)
if func:
return func
name_parts = name.split(".")
if len(name_parts) != 2:
return None
domain, service = name_parts
if not cls.service_has_service(domain, service):
return None
def service_call_factory(domain, service):
async def service_call(*args, **kwargs):
curr_task = asyncio.current_task()
hass_args = {}
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default
if len(args) != 0:
raise TypeError(f"service {domain}.{service} takes only keyword arguments")
return await cls.hass_services_async_call(domain, service, kwargs, **hass_args)
return service_call
return service_call_factory(domain, service)
@classmethod
async def hass_services_async_call(cls, domain, service, kwargs, **hass_args):
"""Call a hass async service."""
if SERVICE_RESPONSE_ONLY is None:
# backwards compatibility < 2023.7
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
else:
# allow service responses >= 2023.7
if (
"return_response" in hass_args
and hass_args["return_response"]
and "blocking" not in hass_args
):
hass_args["blocking"] = True
elif (
"return_response" not in hass_args
and cls.hass.services.supports_response(domain, service) == SERVICE_RESPONSE_ONLY
):
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
@classmethod
async def run_coro(cls, coro, ast_ctx=None):
"""Run coroutine task and update unique task on start and exit."""
#
# Add a placeholder for the new task so we know it's one we started
#
task: asyncio.Task = None
try:
task = asyncio.current_task()
cls.our_tasks.add(task)
if ast_ctx is not None:
cls.task_done_callback_ctx(task, ast_ctx)
result = await coro
return result
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("run_coro: got exception %s", traceback.format_exc(-1))
finally:
if task in cls.task2cb:
for callback, info in cls.task2cb[task]["cb"].items():
ast_ctx, args, kwargs = info
await ast_ctx.call_func(callback, None, *args, **kwargs)
if ast_ctx.get_exception_obj():
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
break
if task in cls.unique_task2name:
for name in cls.unique_task2name[task]:
del cls.unique_name2task[name]
del cls.unique_task2name[task]
cls.task2context.pop(task, None)
cls.task2cb.pop(task, None)
cls.our_tasks.discard(task)
@classmethod
def create_task(cls, coro, ast_ctx=None):
"""Create a new task that runs a coroutine."""
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))
@classmethod
def service_register(
cls, global_ctx_name, domain, service, callback, supports_response=SERVICE_RESPONSE_NONE
):
"""Register a new service callback."""
key = f"{domain}.{service}"
if key not in cls.service_cnt:
cls.service_cnt[key] = 0
if key not in cls.service2global_ctx:
cls.service2global_ctx[key] = global_ctx_name
if cls.service2global_ctx[key] != global_ctx_name:
raise ValueError(
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
)
cls.service_cnt[key] += 1
if SERVICE_RESPONSE_ONLY is None:
# backwards compatibility < 2023.7
cls.hass.services.async_register(domain, service, callback)
else:
# allow service responses >= 2023.7
cls.hass.services.async_register(domain, service, callback, supports_response=supports_response)
@classmethod
def service_remove(cls, global_ctx_name, domain, service):
"""Remove a service callback."""
key = f"{domain}.{service}"
if cls.service_cnt.get(key, 0) > 1:
cls.service_cnt[key] -= 1
return
cls.service_cnt[key] = 0
cls.hass.services.async_remove(domain, service)
cls.service2global_ctx.pop(key, None)
@classmethod
def task_done_callback_ctx(cls, task, ast_ctx):
"""Set the ast_ctx for a task, which is needed for done callbacks."""
if task not in cls.task2cb or "ctx" not in cls.task2cb[task]:
cls.task2cb[task] = {"ctx": ast_ctx, "cb": {}}
@classmethod
def task_add_done_callback(cls, task, ast_ctx, callback, *args, **kwargs):
"""Add a done callback to the given task."""
if ast_ctx is None:
ast_ctx = cls.task2cb[task]["ctx"]
cls.task2cb[task]["cb"][callback] = [ast_ctx, args, kwargs]