Home Assistant Git Exporter

This commit is contained in:
root
2024-08-09 06:45:02 +02:00
parent 60abdd866c
commit 80fc630f5e
624 changed files with 27739 additions and 4497 deletions

View File

@@ -0,0 +1,682 @@
"""Component to allow running Python scripts."""
import asyncio
import glob
import json
import logging
import os
import time
import traceback
from typing import Any, Callable, Dict, List, Set, Union
import voluptuous as vol
from watchdog.events import DirModifiedEvent, FileSystemEvent, FileSystemEventHandler
import watchdog.observers
from homeassistant.config import async_hass_config_yaml
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED,
SERVICE_RELOAD,
)
from homeassistant.core import Config, Event as HAEvent, HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
from homeassistant.loader import bind_hass
from .const import (
CONF_ALLOW_ALL_IMPORTS,
CONF_HASS_IS_GLOBAL,
CONFIG_ENTRY,
CONFIG_ENTRY_OLD,
DOMAIN,
FOLDER,
LOGGER_PATH,
REQUIREMENTS_FILE,
SERVICE_JUPYTER_KERNEL_START,
UNSUB_LISTENERS,
WATCHDOG_TASK,
)
from .eval import AstEval
from .event import Event
from .function import Function
from .global_ctx import GlobalContext, GlobalContextMgr
from .jupyter_kernel import Kernel
from .mqtt import Mqtt
from .requirements import install_requirements
from .state import State, StateVal
from .trigger import TrigTime
from .webhook import Webhook
_LOGGER = logging.getLogger(LOGGER_PATH)
PYSCRIPT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): cv.boolean,
vol.Optional(CONF_HASS_IS_GLOBAL, default=False): cv.boolean,
},
extra=vol.ALLOW_EXTRA,
)
CONFIG_SCHEMA = vol.Schema({DOMAIN: PYSCRIPT_SCHEMA}, extra=vol.ALLOW_EXTRA)
async def async_setup(hass: HomeAssistant, config: Config) -> bool:
"""Component setup, run import config flow for each entry in config."""
await restore_state(hass)
if DOMAIN in config:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=config[DOMAIN]
)
)
return True
async def restore_state(hass: HomeAssistant) -> None:
"""Restores the persisted pyscript state."""
# this is a hack accessing hass internals; should re-implement using RestoreEntity
restore_data = hass.data[DATA_RESTORE_STATE]
for entity_id, value in restore_data.last_states.items():
if entity_id.startswith("pyscript."):
last_state = value.state
hass.states.async_set(entity_id, last_state.state, last_state.attributes)
async def update_yaml_config(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Update the yaml config."""
try:
conf = await async_hass_config_yaml(hass)
except HomeAssistantError as err:
_LOGGER.error(err)
return
config = PYSCRIPT_SCHEMA(conf.get(DOMAIN, {}))
#
# If data in config doesn't match config entry, trigger a config import
# so that the config entry can get updated
#
if config != config_entry.data:
await hass.config_entries.flow.async_init(DOMAIN, context={"source": SOURCE_IMPORT}, data=config)
#
# if hass_is_global or allow_all_imports have changed, we need to reload all scripts
# since they affect all scripts
#
config_save = {
param: config_entry.data.get(param, False) for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS]
}
if DOMAIN not in hass.data:
hass.data.setdefault(DOMAIN, {})
if CONFIG_ENTRY_OLD in hass.data[DOMAIN]:
old_entry = hass.data[DOMAIN][CONFIG_ENTRY_OLD]
hass.data[DOMAIN][CONFIG_ENTRY_OLD] = config_save
for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS]:
if old_entry.get(param, False) != config_entry.data.get(param, False):
return True
hass.data[DOMAIN][CONFIG_ENTRY_OLD] = config_save
return False
def start_global_contexts(global_ctx_only: str = None) -> None:
"""Start all the file and apps global contexts."""
start_list = []
for global_ctx_name, global_ctx in GlobalContextMgr.items():
idx = global_ctx_name.find(".")
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "scripts"}:
continue
if global_ctx_only is not None and global_ctx_only != "*":
if global_ctx_name != global_ctx_only and not global_ctx_name.startswith(global_ctx_only + "."):
continue
global_ctx.set_auto_start(True)
start_list.append(global_ctx)
for global_ctx in start_list:
global_ctx.start()
async def watchdog_start(
hass: HomeAssistant, pyscript_folder: str, reload_scripts_handler: Callable[[None], None]
) -> None:
"""Start watchdog thread to look for changed files in pyscript_folder."""
if WATCHDOG_TASK in hass.data[DOMAIN]:
return
class WatchDogHandler(FileSystemEventHandler):
"""Class for handling watchdog events."""
def __init__(
self, watchdog_q: asyncio.Queue, observer: watchdog.observers.Observer, path: str
) -> None:
self.watchdog_q = watchdog_q
self._observer = observer
self._observer.schedule(self, path, recursive=True)
if not hass.is_running:
hass.bus.listen_once(EVENT_HOMEASSISTANT_STARTED, self.startup)
else:
self.startup(None)
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown)
_LOGGER.debug("watchdog init path=%s", path)
def startup(self, event: Event | None) -> None:
"""Start the observer."""
_LOGGER.debug("watchdog startup")
self._observer.start()
def shutdown(self, event: Event | None) -> None:
"""Stop the observer."""
self._observer.stop()
self._observer.join()
_LOGGER.debug("watchdog shutdown")
def process(self, event: FileSystemEvent) -> None:
"""Send watchdog events to main loop task."""
_LOGGER.debug("watchdog process(%s)", event)
hass.loop.call_soon_threadsafe(self.watchdog_q.put_nowait, event)
def on_modified(self, event: FileSystemEvent) -> None:
"""File modified."""
self.process(event)
def on_moved(self, event: FileSystemEvent) -> None:
"""File moved."""
self.process(event)
def on_created(self, event: FileSystemEvent) -> None:
"""File created."""
self.process(event)
def on_deleted(self, event: FileSystemEvent) -> None:
"""File deleted."""
self.process(event)
async def task_watchdog(watchdog_q: asyncio.Queue) -> None:
def check_event(event, do_reload: bool) -> bool:
"""Check if event should trigger a reload."""
if event.is_directory:
# don't reload if it's just a directory modified
if isinstance(event, DirModifiedEvent):
return do_reload
return True
# only reload if it's a script, yaml, or requirements.txt file
for valid_suffix in [".py", ".yaml", "/" + REQUIREMENTS_FILE]:
if event.src_path.endswith(valid_suffix):
return True
return do_reload
while True:
try:
#
# since some file/dir changes create multiple events, we consume all
# events in a small window; first wait indefinitely for next event
#
do_reload = check_event(await watchdog_q.get(), False)
#
# now consume all additional events with 50ms timeout or 500ms elapsed
#
t_start = time.monotonic()
while time.monotonic() - t_start < 0.5:
try:
do_reload = check_event(
await asyncio.wait_for(watchdog_q.get(), timeout=0.05), do_reload
)
except asyncio.TimeoutError:
break
if do_reload:
await reload_scripts_handler(None)
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("task_watchdog: got exception %s", traceback.format_exc(-1))
watchdog_q = asyncio.Queue(0)
observer = watchdog.observers.Observer()
if observer is not None:
# don't run watchdog when we are testing (Observer() patches to None)
hass.data[DOMAIN][WATCHDOG_TASK] = Function.create_task(task_watchdog(watchdog_q))
await hass.async_add_executor_job(WatchDogHandler, watchdog_q, observer, pyscript_folder)
_LOGGER.debug("watchdog started job and task folder=%s", pyscript_folder)
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Initialize the pyscript config entry."""
global_ctx_only = None
doing_reload = False
if Function.hass:
#
# reload yaml if this isn't the first time (ie, on reload)
#
doing_reload = True
if await update_yaml_config(hass, config_entry):
global_ctx_only = "*"
Function.init(hass)
Event.init(hass)
Mqtt.init(hass)
TrigTime.init(hass)
State.init(hass)
Webhook.init(hass)
State.register_functions()
GlobalContextMgr.init()
pyscript_folder = hass.config.path(FOLDER)
if not await hass.async_add_executor_job(os.path.isdir, pyscript_folder):
_LOGGER.debug("Folder %s not found in configuration folder, creating it", FOLDER)
await hass.async_add_executor_job(os.makedirs, pyscript_folder)
hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN][CONFIG_ENTRY] = config_entry
hass.data[DOMAIN][UNSUB_LISTENERS] = []
State.set_pyscript_config(config_entry.data)
await install_requirements(hass, config_entry, pyscript_folder)
await load_scripts(hass, config_entry.data, global_ctx_only=global_ctx_only)
async def reload_scripts_handler(call: ServiceCall) -> None:
"""Handle reload service calls."""
_LOGGER.debug("reload: yaml, reloading scripts, and restarting")
global_ctx_only = call.data.get("global_ctx", None) if call else None
if await update_yaml_config(hass, config_entry):
global_ctx_only = "*"
State.set_pyscript_config(config_entry.data)
await State.get_service_params()
await install_requirements(hass, config_entry, pyscript_folder)
await load_scripts(hass, config_entry.data, global_ctx_only=global_ctx_only)
start_global_contexts(global_ctx_only=global_ctx_only)
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_scripts_handler)
async def jupyter_kernel_start(call: ServiceCall) -> None:
"""Handle Jupyter kernel start call."""
_LOGGER.debug("service call to jupyter_kernel_start: %s", call.data)
global_ctx_name = GlobalContextMgr.new_name("jupyter_")
global_ctx = GlobalContext(
global_ctx_name, global_sym_table={"__name__": global_ctx_name}, manager=GlobalContextMgr
)
global_ctx.set_auto_start(True)
GlobalContextMgr.set(global_ctx_name, global_ctx)
ast_ctx = AstEval(global_ctx_name, global_ctx)
Function.install_ast_funcs(ast_ctx)
kernel = Kernel(call.data, ast_ctx, global_ctx, global_ctx_name)
await kernel.session_start()
hass.states.async_set(call.data["state_var"], json.dumps(kernel.get_ports()))
def state_var_remove():
hass.states.async_remove(call.data["state_var"])
kernel.set_session_cleanup_callback(state_var_remove)
hass.services.async_register(DOMAIN, SERVICE_JUPYTER_KERNEL_START, jupyter_kernel_start)
async def state_changed(event: HAEvent) -> None:
var_name = event.data["entity_id"]
if event.data.get("new_state", None):
new_val = StateVal(event.data["new_state"])
else:
# state variable has been deleted
new_val = None
if event.data.get("old_state", None):
old_val = StateVal(event.data["old_state"])
else:
# no previous state
old_val = None
new_vars = {var_name: new_val, f"{var_name}.old": old_val}
func_args = {
"trigger_type": "state",
"var_name": var_name,
"value": new_val,
"old_value": old_val,
"context": event.context,
}
await State.update(new_vars, func_args)
async def hass_started(event: HAEvent) -> None:
_LOGGER.debug("adding state changed listener and starting global contexts")
await State.get_service_params()
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed))
start_global_contexts()
async def hass_stop(event: HAEvent) -> None:
if WATCHDOG_TASK in hass.data[DOMAIN]:
Function.reaper_cancel(hass.data[DOMAIN][WATCHDOG_TASK])
del hass.data[DOMAIN][WATCHDOG_TASK]
_LOGGER.debug("stopping global contexts")
await unload_scripts(unload_all=True)
# sync with waiter, and then tell waiter and reaper tasks to exit
await Function.waiter_sync()
await Function.waiter_stop()
await Function.reaper_stop()
# Store callbacks to event listeners so we can unsubscribe on unload
hass.data[DOMAIN][UNSUB_LISTENERS].append(
hass.bus.async_listen(EVENT_HOMEASSISTANT_STARTED, hass_started)
)
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, hass_stop))
await watchdog_start(hass, pyscript_folder, reload_scripts_handler)
if doing_reload:
start_global_contexts(global_ctx_only="*")
return True
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload a config entry."""
_LOGGER.info("Unloading all scripts")
await unload_scripts(unload_all=True)
for unsub_listener in hass.data[DOMAIN][UNSUB_LISTENERS]:
unsub_listener()
hass.data[DOMAIN][UNSUB_LISTENERS] = []
# sync with waiter, and then tell waiter and reaper tasks to exit
await Function.waiter_sync()
await Function.waiter_stop()
await Function.reaper_stop()
return True
async def unload_scripts(global_ctx_only: str = None, unload_all: bool = False) -> None:
"""Unload all scripts from GlobalContextMgr with given name prefixes."""
ctx_delete = {}
for global_ctx_name, global_ctx in GlobalContextMgr.items():
if not unload_all:
idx = global_ctx_name.find(".")
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules", "scripts"}:
continue
if global_ctx_only is not None:
if global_ctx_name != global_ctx_only and not global_ctx_name.startswith(global_ctx_only + "."):
continue
global_ctx.stop()
ctx_delete[global_ctx_name] = global_ctx
for global_ctx_name, global_ctx in ctx_delete.items():
GlobalContextMgr.delete(global_ctx_name)
await Function.waiter_sync()
@bind_hass
async def load_scripts(hass: HomeAssistant, config_data: Dict[str, Any], global_ctx_only: str = None):
"""Load all python scripts in FOLDER."""
class SourceFile:
"""Class for information about a source file."""
def __init__(
self,
global_ctx_name=None,
file_path=None,
rel_path=None,
rel_import_path=None,
fq_mod_name=None,
check_config=None,
app_config=None,
source=None,
mtime=None,
autoload=None,
):
self.global_ctx_name = global_ctx_name
self.file_path = file_path
self.rel_path = rel_path
self.rel_import_path = rel_import_path
self.fq_mod_name = fq_mod_name
self.check_config = check_config
self.app_config = app_config
self.source = source
self.mtime = mtime
self.autoload = autoload
self.force = False
pyscript_dir = hass.config.path(FOLDER)
def glob_read_files(
load_paths: List[Set[Union[str, bool]]], apps_config: Dict[str, Any]
) -> Dict[str, SourceFile]:
"""Expand globs and read all the source files."""
ctx2source = {}
for path, match, check_config, autoload in load_paths:
for this_path in sorted(glob.glob(os.path.join(pyscript_dir, path, match), recursive=True)):
rel_import_path = None
rel_path = this_path
if rel_path.startswith(pyscript_dir):
rel_path = rel_path[len(pyscript_dir) :]
if rel_path.startswith("/"):
rel_path = rel_path[1:]
if rel_path[0] == "#" or rel_path.find("/#") >= 0:
# skip "commented" files and directories
continue
mod_name = rel_path[0:-3]
if mod_name.endswith("/__init__"):
rel_import_path = mod_name
mod_name = mod_name[: -len("/__init__")]
mod_name = mod_name.replace("/", ".")
if path == "":
global_ctx_name = f"file.{mod_name}"
fq_mod_name = mod_name
else:
fq_mod_name = global_ctx_name = mod_name
i = fq_mod_name.find(".")
if i >= 0:
fq_mod_name = fq_mod_name[i + 1 :]
app_config = None
if global_ctx_name in ctx2source:
# the globs result in apps/APP/__init__.py matching twice, so skip the 2nd time
# also skip apps/APP.py if apps/APP/__init__.py is present
continue
if check_config:
app_name = fq_mod_name
i = app_name.find(".")
if i >= 0:
app_name = app_name[0:i]
if not isinstance(apps_config, dict) or app_name not in apps_config:
_LOGGER.debug(
"load_scripts: skipping %s (app_name=%s) because config not present",
this_path,
app_name,
)
continue
app_config = apps_config[app_name]
try:
with open(this_path, encoding="utf-8") as file_desc:
source = file_desc.read()
mtime = os.path.getmtime(this_path)
except Exception as exc:
_LOGGER.error("load_scripts: skipping %s due to exception %s", this_path, exc)
continue
ctx2source[global_ctx_name] = SourceFile(
global_ctx_name=global_ctx_name,
file_path=this_path,
rel_path=rel_path,
rel_import_path=rel_import_path,
fq_mod_name=fq_mod_name,
check_config=check_config,
app_config=app_config,
source=source,
mtime=mtime,
autoload=autoload,
)
return ctx2source
load_paths = [
# path, glob, check_config, autoload
["", "*.py", False, True],
["apps", "*/__init__.py", True, True],
["apps", "*.py", True, True],
["apps", "*/**/*.py", False, False],
["modules", "*/__init__.py", False, False],
["modules", "*.py", False, False],
["modules", "*/**/*.py", False, False],
["scripts", "**/*.py", False, True],
]
#
# get current global contexts
#
ctx_all = {}
for global_ctx_name, global_ctx in GlobalContextMgr.items():
idx = global_ctx_name.find(".")
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules", "scripts"}:
continue
ctx_all[global_ctx_name] = global_ctx
#
# get list and contents of all source files
#
apps_config = config_data.get("apps", None)
ctx2files = await hass.async_add_executor_job(glob_read_files, load_paths, apps_config)
#
# figure out what to reload based on global_ctx_only and what's changed
#
ctx_delete = set()
if global_ctx_only is not None and global_ctx_only != "*":
if global_ctx_only not in ctx_all and global_ctx_only not in ctx2files:
_LOGGER.error("pyscript.reload: no global context '%s' to reload", global_ctx_only)
return
if global_ctx_only not in ctx2files:
ctx_delete.add(global_ctx_only)
else:
ctx2files[global_ctx_only].force = True
elif global_ctx_only == "*":
ctx_delete = set(ctx_all.keys())
for _, src_info in ctx2files.items():
src_info.force = True
else:
# delete all global_ctxs that aren't present in current files
for global_ctx_name, global_ctx in ctx_all.items():
if global_ctx_name not in ctx2files:
ctx_delete.add(global_ctx_name)
# delete all global_ctxs that have changeed source or mtime
for global_ctx_name, src_info in ctx2files.items():
if global_ctx_name in ctx_all:
ctx = ctx_all[global_ctx_name]
if (
src_info.source != ctx.get_source()
or src_info.app_config != ctx.get_app_config()
or src_info.mtime != ctx.get_mtime()
):
ctx_delete.add(global_ctx_name)
src_info.force = True
else:
src_info.force = src_info.autoload
#
# force reload if any files uses a module that is bring reloaded by
# recursively following each import; first find which modules are
# being reloaded
#
will_reload = set()
for global_ctx_name, src_info in ctx2files.items():
if global_ctx_name.startswith("modules.") and (global_ctx_name in ctx_delete or src_info.force):
parts = global_ctx_name.split(".")
root = f"{parts[0]}.{parts[1]}"
will_reload.add(root)
if len(will_reload) > 0:
def import_recurse(ctx_name, visited, ctx2imports):
if ctx_name in visited or ctx_name in ctx2imports:
return ctx2imports.get(ctx_name, set())
visited.add(ctx_name)
ctx = GlobalContextMgr.get(ctx_name)
if not ctx:
return set()
ctx2imports[ctx_name] = set()
for imp_name in ctx.get_imports():
ctx2imports[ctx_name].add(imp_name)
ctx2imports[ctx_name].update(import_recurse(imp_name, visited, ctx2imports))
return ctx2imports[ctx_name]
ctx2imports = {}
for global_ctx_name, global_ctx in ctx_all.items():
if global_ctx_name not in ctx2imports:
visited = set()
import_recurse(global_ctx_name, visited, ctx2imports)
for mod_name in ctx2imports.get(global_ctx_name, set()):
parts = mod_name.split(".")
root = f"{parts[0]}.{parts[1]}"
if root in will_reload:
ctx_delete.add(global_ctx_name)
if global_ctx_name in ctx2files:
ctx2files[global_ctx_name].force = True
#
# if any file in an app or module has changed, then reload just the top-level
# __init__.py or module/app .py file, and delete everything else
#
done = set()
for global_ctx_name, src_info in ctx2files.items():
if not src_info.force:
continue
if not global_ctx_name.startswith("apps.") and not global_ctx_name.startswith("modules."):
continue
parts = global_ctx_name.split(".")
root = f"{parts[0]}.{parts[1]}"
if root in done:
continue
pkg_path = f"{parts[0]}/{parts[1]}/__init__.py"
mod_path = f"{parts[0]}/{parts[1]}.py"
for ctx_name, this_src_info in ctx2files.items():
if ctx_name == root or ctx_name.startswith(f"{root}."):
if this_src_info.rel_path in {pkg_path, mod_path}:
this_src_info.force = True
else:
this_src_info.force = False
ctx_delete.add(ctx_name)
done.add(root)
#
# delete contexts that are no longer needed or will be reloaded
#
for global_ctx_name in ctx_delete:
if global_ctx_name in ctx_all:
global_ctx = ctx_all[global_ctx_name]
global_ctx.stop()
if global_ctx_name not in ctx2files or not ctx2files[global_ctx_name].autoload:
_LOGGER.info("Unloaded %s", global_ctx.get_file_path())
GlobalContextMgr.delete(global_ctx_name)
await Function.waiter_sync()
#
# now load the requested files, and files that depend on loaded files
#
for global_ctx_name, src_info in sorted(ctx2files.items()):
if not src_info.autoload or not src_info.force:
continue
global_ctx = GlobalContext(
src_info.global_ctx_name,
global_sym_table={"__name__": src_info.fq_mod_name},
manager=GlobalContextMgr,
rel_import_path=src_info.rel_import_path,
app_config=src_info.app_config,
source=src_info.source,
mtime=src_info.mtime,
)
reload = src_info.global_ctx_name in ctx_delete
await GlobalContextMgr.load_file(
global_ctx, src_info.file_path, source=src_info.source, reload=reload
)

View File

@@ -0,0 +1,139 @@
"""Config flow for pyscript."""
import json
from typing import Any, Dict
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.core import callback
from .const import CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, CONF_INSTALLED_PACKAGES, DOMAIN
CONF_BOOL_ALL = {CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL}
PYSCRIPT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): bool,
vol.Optional(CONF_HASS_IS_GLOBAL, default=False): bool,
},
extra=vol.ALLOW_EXTRA,
)
class PyscriptOptionsConfigFlow(config_entries.OptionsFlow):
"""Handle a pyscript options flow."""
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize pyscript options flow."""
self.config_entry = config_entry
self._show_form = False
async def async_step_init(self, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
"""Manage the pyscript options."""
if self.config_entry.source == SOURCE_IMPORT:
self._show_form = True
return await self.async_step_no_ui_configuration_allowed()
if user_input is None:
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(
{
vol.Optional(name, default=self.config_entry.data.get(name, False)): bool
for name in CONF_BOOL_ALL
},
extra=vol.ALLOW_EXTRA,
),
)
if any(
name not in self.config_entry.data or user_input[name] != self.config_entry.data[name]
for name in CONF_BOOL_ALL
):
updated_data = self.config_entry.data.copy()
updated_data.update(user_input)
self.hass.config_entries.async_update_entry(entry=self.config_entry, data=updated_data)
return self.async_create_entry(title="", data={})
self._show_form = True
return await self.async_step_no_update()
async def async_step_no_ui_configuration_allowed(
self, user_input: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Tell user no UI configuration is allowed."""
if self._show_form:
self._show_form = False
return self.async_show_form(step_id="no_ui_configuration_allowed", data_schema=vol.Schema({}))
return self.async_create_entry(title="", data={})
async def async_step_no_update(self, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
"""Tell user no update to process."""
if self._show_form:
self._show_form = False
return self.async_show_form(step_id="no_update", data_schema=vol.Schema({}))
return self.async_create_entry(title="", data={})
class PyscriptConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a pyscript config flow."""
VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
@staticmethod
@callback
def async_get_options_flow(config_entry: ConfigEntry) -> PyscriptOptionsConfigFlow:
"""Get the options flow for this handler."""
return PyscriptOptionsConfigFlow(config_entry)
async def async_step_user(self, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
"""Handle a flow initialized by the user."""
if user_input is not None:
if len(self.hass.config_entries.async_entries(DOMAIN)) > 0:
return self.async_abort(reason="single_instance_allowed")
await self.async_set_unique_id(DOMAIN)
return self.async_create_entry(title=DOMAIN, data=user_input)
return self.async_show_form(step_id="user", data_schema=PYSCRIPT_SCHEMA)
async def async_step_import(self, import_config: Dict[str, Any] = None) -> Dict[str, Any]:
"""Import a config entry from configuration.yaml."""
# Convert OrderedDict to dict
import_config = json.loads(json.dumps(import_config))
# Check if import config entry matches any existing config entries
# so we can update it if necessary
entries = self.hass.config_entries.async_entries(DOMAIN)
if entries:
entry = entries[0]
updated_data = entry.data.copy()
# Update values for all keys, excluding `allow_all_imports` for entries
# set up through the UI.
for key, val in import_config.items():
if entry.source == SOURCE_IMPORT or key not in CONF_BOOL_ALL:
updated_data[key] = val
# Remove values for all keys in entry.data that are not in the imported config,
# excluding `allow_all_imports` for entries set up through the UI.
for key in entry.data:
if (
(entry.source == SOURCE_IMPORT or key not in CONF_BOOL_ALL)
and key != CONF_INSTALLED_PACKAGES
and key not in import_config
):
updated_data.pop(key)
# Update and reload entry if data needs to be updated
if updated_data != entry.data:
self.hass.config_entries.async_update_entry(entry=entry, data=updated_data)
return self.async_abort(reason="updated_entry")
return self.async_abort(reason="already_configured")
return await self.async_step_user(user_input=import_config)

View File

@@ -0,0 +1,63 @@
"""Define pyscript-wide constants."""
#
# 2023.7 supports service response; handle older versions by defaulting enum
# Should eventually deprecate this and just use SupportsResponse import
#
try:
from homeassistant.core import SupportsResponse
SERVICE_RESPONSE_NONE = SupportsResponse.NONE
SERVICE_RESPONSE_OPTIONAL = SupportsResponse.OPTIONAL
SERVICE_RESPONSE_ONLY = SupportsResponse.ONLY
except ImportError:
SERVICE_RESPONSE_NONE = None
SERVICE_RESPONSE_OPTIONAL = None
SERVICE_RESPONSE_ONLY = None
DOMAIN = "pyscript"
CONFIG_ENTRY = "config_entry"
CONFIG_ENTRY_OLD = "config_entry_old"
UNSUB_LISTENERS = "unsub_listeners"
FOLDER = "pyscript"
UNPINNED_VERSION = "_unpinned_version"
ATTR_INSTALLED_VERSION = "installed_version"
ATTR_SOURCES = "sources"
ATTR_VERSION = "version"
CONF_ALLOW_ALL_IMPORTS = "allow_all_imports"
CONF_HASS_IS_GLOBAL = "hass_is_global"
CONF_INSTALLED_PACKAGES = "_installed_packages"
SERVICE_JUPYTER_KERNEL_START = "jupyter_kernel_start"
LOGGER_PATH = "custom_components.pyscript"
REQUIREMENTS_FILE = "requirements.txt"
REQUIREMENTS_PATHS = ("", "apps/*", "modules/*", "scripts/**")
WATCHDOG_TASK = "watch_dog_task"
ALLOWED_IMPORTS = {
"black",
"cmath",
"datetime",
"decimal",
"fractions",
"functools",
"homeassistant.const",
"isort",
"json",
"math",
"number",
"random",
"re",
"statistics",
"string",
"time",
"voluptuous",
}

View File

@@ -0,0 +1,19 @@
"""Entity Classes."""
from homeassistant.const import STATE_UNKNOWN
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import StateType
class PyscriptEntity(RestoreEntity):
"""Generic Pyscript Entity."""
_attr_extra_state_attributes: dict
_attr_state: StateType = STATE_UNKNOWN
def set_state(self, state):
"""Set the state."""
self._attr_state = state
def set_attributes(self, attributes):
"""Set Attributes."""
self._attr_extra_state_attributes = attributes

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
"""Handles event firing and notification."""
import logging
from .const import LOGGER_PATH
_LOGGER = logging.getLogger(LOGGER_PATH + ".event")
class Event:
"""Define event functions."""
#
# Global hass instance
#
hass = None
#
# notify message queues by event type
#
notify = {}
notify_remove = {}
def __init__(self):
"""Warn on Event instantiation."""
_LOGGER.error("Event class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize Event."""
cls.hass = hass
@classmethod
async def event_listener(cls, event):
"""Listen callback for given event which updates any notifications."""
func_args = {
"trigger_type": "event",
"event_type": event.event_type,
"context": event.context,
}
func_args.update(event.data)
await cls.update(event.event_type, func_args)
@classmethod
def notify_add(cls, event_type, queue):
"""Register to notify for events of given type to be sent to queue."""
if event_type not in cls.notify:
cls.notify[event_type] = set()
_LOGGER.debug("event.notify_add(%s) -> adding event listener", event_type)
cls.notify_remove[event_type] = cls.hass.bus.async_listen(event_type, cls.event_listener)
cls.notify[event_type].add(queue)
@classmethod
def notify_del(cls, event_type, queue):
"""Unregister to notify for events of given type for given queue."""
if event_type not in cls.notify or queue not in cls.notify[event_type]:
return
cls.notify[event_type].discard(queue)
if len(cls.notify[event_type]) == 0:
cls.notify_remove[event_type]()
_LOGGER.debug("event.notify_del(%s) -> removing event listener", event_type)
del cls.notify[event_type]
del cls.notify_remove[event_type]
@classmethod
async def update(cls, event_type, func_args):
"""Deliver all notifications for an event of the given type."""
_LOGGER.debug("event.update(%s, %s)", event_type, func_args)
if event_type in cls.notify:
for queue in cls.notify[event_type]:
await queue.put(["event", func_args.copy()])

View File

@@ -0,0 +1,519 @@
"""Function call handling."""
import asyncio
import logging
import traceback
from homeassistant.core import Context
from .const import LOGGER_PATH, SERVICE_RESPONSE_NONE, SERVICE_RESPONSE_ONLY
_LOGGER = logging.getLogger(LOGGER_PATH + ".function")
class Function:
"""Define function handler functions."""
#
# Global hass instance
#
hass = None
#
# Mappings of tasks ids <-> task names
#
unique_task2name = {}
unique_name2task = {}
#
# Mappings of task id to hass contexts
task2context = {}
#
# Set of tasks that are running
#
our_tasks = set()
#
# Done callbacks for each task
#
task2cb = {}
#
# initial list of available functions
#
functions = {}
#
# Functions that take the AstEval context as a first argument,
# which is needed by a handful of special functions that need the
# ast context
#
ast_functions = {}
#
# task id of the task that cancels and waits for other tasks,
#
task_reaper = None
task_reaper_q = None
#
# task id of the task that awaits for coros (used by shutdown triggers)
#
task_waiter = None
task_waiter_q = None
#
# reference counting for service registrations; the new @service trigger
# registers the service call before the old one is removed, so we only
# remove the service registration when the reference count goes to zero
#
service_cnt = {}
#
# save the global_ctx name where a service is registered so we can raise
# an exception if it gets registered by a different global_ctx.
#
service2global_ctx = {}
def __init__(self):
"""Warn on Function instantiation."""
_LOGGER.error("Function class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize Function."""
cls.hass = hass
cls.functions.update(
{
"event.fire": cls.event_fire,
"service.call": cls.service_call,
"service.has_service": cls.service_has_service,
"task.cancel": cls.user_task_cancel,
"task.current_task": cls.user_task_current_task,
"task.remove_done_callback": cls.user_task_remove_done_callback,
"task.sleep": cls.async_sleep,
"task.wait": cls.user_task_wait,
}
)
cls.ast_functions.update(
{
"log.debug": lambda ast_ctx: ast_ctx.get_logger().debug,
"log.error": lambda ast_ctx: ast_ctx.get_logger().error,
"log.info": lambda ast_ctx: ast_ctx.get_logger().info,
"log.warning": lambda ast_ctx: ast_ctx.get_logger().warning,
"print": lambda ast_ctx: ast_ctx.get_logger().debug,
"task.name2id": cls.task_name2id_factory,
"task.unique": cls.task_unique_factory,
}
)
#
# start a task which is a reaper for canceled tasks, since some # functions
# like TrigInfo.stop() can't be async (it's called from a __del__ method)
#
async def task_reaper(reaper_q):
while True:
try:
cmd = await reaper_q.get()
if cmd[0] == "exit":
return
if cmd[0] == "cancel":
try:
cmd[1].cancel()
await cmd[1]
except asyncio.CancelledError:
pass
else:
_LOGGER.error("task_reaper: unknown command %s", cmd[0])
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("task_reaper: got exception %s", traceback.format_exc(-1))
if not cls.task_reaper:
cls.task_reaper_q = asyncio.Queue(0)
cls.task_reaper = cls.create_task(task_reaper(cls.task_reaper_q))
#
# start a task which creates tasks to run coros, and then syncs on their completion;
# this is used by the shutdown trigger
#
async def task_waiter(waiter_q):
aws = []
while True:
try:
cmd = await waiter_q.get()
if cmd[0] == "exit":
return
if cmd[0] == "await":
aws.append(cls.create_task(cmd[1]))
elif cmd[0] == "sync":
if len(aws) > 0:
await asyncio.gather(*aws)
aws = []
await cmd[1].put(0)
else:
_LOGGER.error("task_waiter: unknown command %s", cmd[0])
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("task_waiter: got exception %s", traceback.format_exc(-1))
if not cls.task_waiter:
cls.task_waiter_q = asyncio.Queue(0)
cls.task_waiter = cls.create_task(task_waiter(cls.task_waiter_q))
@classmethod
def reaper_cancel(cls, task):
"""Send a task to be canceled by the reaper."""
cls.task_reaper_q.put_nowait(["cancel", task])
@classmethod
async def reaper_stop(cls):
"""Tell the reaper task to exit."""
if cls.task_reaper:
cls.task_reaper_q.put_nowait(["exit"])
await cls.task_reaper
cls.task_reaper = None
cls.task_reaper_q = None
@classmethod
def waiter_await(cls, coro):
"""Send a coro to be awaited by the waiter task."""
cls.task_waiter_q.put_nowait(["await", coro])
@classmethod
async def waiter_sync(cls):
"""Wait until the waiter queue is empty."""
if cls.task_waiter:
sync_q = asyncio.Queue(0)
cls.task_waiter_q.put_nowait(["sync", sync_q])
await sync_q.get()
@classmethod
async def waiter_stop(cls):
"""Tell the waiter task to exit."""
if cls.task_waiter:
cls.task_waiter_q.put_nowait(["exit"])
await cls.task_waiter
cls.task_waiter = None
cls.task_waiter_q = None
@classmethod
async def async_sleep(cls, duration):
"""Implement task.sleep()."""
await asyncio.sleep(float(duration))
@classmethod
async def event_fire(cls, event_type, **kwargs):
"""Implement event.fire()."""
curr_task = asyncio.current_task()
if "context" in kwargs and isinstance(kwargs["context"], Context):
context = kwargs["context"]
del kwargs["context"]
else:
context = cls.task2context.get(curr_task, None)
cls.hass.bus.async_fire(event_type, kwargs, context=context)
@classmethod
def store_hass_context(cls, hass_context):
"""Store a context against the running task."""
curr_task = asyncio.current_task()
cls.task2context[curr_task] = hass_context
@classmethod
def task_unique_factory(cls, ctx):
"""Define and return task.unique() for this context."""
async def task_unique(name, kill_me=False):
"""Implement task.unique()."""
name = f"{ctx.get_global_ctx_name()}.{name}"
curr_task = asyncio.current_task()
if name in cls.unique_name2task:
task = cls.unique_name2task[name]
if kill_me:
if task != curr_task:
#
# it seems we can't cancel ourselves, so we
# tell the reaper task to cancel us
#
cls.reaper_cancel(curr_task)
# wait to be canceled
await asyncio.sleep(100000)
elif task != curr_task and task in cls.our_tasks:
# only cancel tasks if they are ones we started
cls.reaper_cancel(task)
if curr_task in cls.our_tasks:
if name in cls.unique_name2task:
task = cls.unique_name2task[name]
if task in cls.unique_task2name:
cls.unique_task2name[task].discard(name)
cls.unique_name2task[name] = curr_task
if curr_task not in cls.unique_task2name:
cls.unique_task2name[curr_task] = set()
cls.unique_task2name[curr_task].add(name)
return task_unique
@classmethod
async def user_task_cancel(cls, task=None):
"""Implement task.cancel()."""
do_sleep = False
if not task:
task = asyncio.current_task()
do_sleep = True
if task not in cls.our_tasks:
raise TypeError(f"{task} is not a user-started task")
cls.reaper_cancel(task)
if do_sleep:
# wait to be canceled
await asyncio.sleep(100000)
@classmethod
async def user_task_current_task(cls):
"""Implement task.current_task()."""
return asyncio.current_task()
@classmethod
def task_name2id_factory(cls, ctx):
"""Define and return task.name2id() for this context."""
def user_task_name2id(name=None):
"""Implement task.name2id()."""
prefix = f"{ctx.get_global_ctx_name()}."
if name is None:
ret = {}
for task_name, task_id in cls.unique_name2task.items():
if task_name.startswith(prefix):
ret[task_name[len(prefix) :]] = task_id
return ret
if prefix + name in cls.unique_name2task:
return cls.unique_name2task[prefix + name]
raise NameError(f"task name '{name}' is unknown")
return user_task_name2id
@classmethod
async def user_task_wait(cls, aws, **kwargs):
"""Implement task.wait()."""
return await asyncio.wait(aws, **kwargs)
@classmethod
def user_task_remove_done_callback(cls, task, callback):
"""Implement task.remove_done_callback()."""
cls.task2cb[task]["cb"].pop(callback, None)
@classmethod
def unique_name_used(cls, ctx, name):
"""Return whether the current unique name is in use."""
name = f"{ctx.get_global_ctx_name()}.{name}"
return name in cls.unique_name2task
@classmethod
def service_has_service(cls, domain, name):
"""Implement service.has_service()."""
return cls.hass.services.has_service(domain, name)
@classmethod
async def service_call(cls, domain, name, **kwargs):
"""Implement service.call()."""
curr_task = asyncio.current_task()
hass_args = {}
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default
return await cls.hass_services_async_call(domain, name, kwargs, **hass_args)
@classmethod
async def service_completions(cls, root):
"""Return possible completions of HASS services."""
words = set()
services = cls.hass.services.async_services()
num_period = root.count(".")
if num_period == 1:
domain, svc_root = root.split(".")
if domain in services:
words |= {f"{domain}.{svc}" for svc in services[domain] if svc.lower().startswith(svc_root)}
elif num_period == 0:
words |= {domain for domain in services if domain.lower().startswith(root)}
return words
@classmethod
async def func_completions(cls, root):
"""Return possible completions of functions."""
funcs = {**cls.functions, **cls.ast_functions}
words = {name for name in funcs if name.lower().startswith(root)}
return words
@classmethod
def register(cls, funcs):
"""Register functions to be available for calling."""
cls.functions.update(funcs)
@classmethod
def register_ast(cls, funcs):
"""Register functions that need ast context to be available for calling."""
cls.ast_functions.update(funcs)
@classmethod
def install_ast_funcs(cls, ast_ctx):
"""Install ast functions into the local symbol table."""
sym_table = {name: func(ast_ctx) for name, func in cls.ast_functions.items()}
ast_ctx.set_local_sym_table(sym_table)
@classmethod
def get(cls, name):
"""Lookup a function locally and then as a service."""
func = cls.functions.get(name, None)
if func:
return func
name_parts = name.split(".")
if len(name_parts) != 2:
return None
domain, service = name_parts
if not cls.service_has_service(domain, service):
return None
def service_call_factory(domain, service):
async def service_call(*args, **kwargs):
curr_task = asyncio.current_task()
hass_args = {}
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default
if len(args) != 0:
raise TypeError(f"service {domain}.{service} takes only keyword arguments")
return await cls.hass_services_async_call(domain, service, kwargs, **hass_args)
return service_call
return service_call_factory(domain, service)
@classmethod
async def hass_services_async_call(cls, domain, service, kwargs, **hass_args):
"""Call a hass async service."""
if SERVICE_RESPONSE_ONLY is None:
# backwards compatibility < 2023.7
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
else:
# allow service responses >= 2023.7
if (
"return_response" in hass_args
and hass_args["return_response"]
and "blocking" not in hass_args
):
hass_args["blocking"] = True
elif (
"return_response" not in hass_args
and cls.hass.services.supports_response(domain, service) == SERVICE_RESPONSE_ONLY
):
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
@classmethod
async def run_coro(cls, coro, ast_ctx=None):
"""Run coroutine task and update unique task on start and exit."""
#
# Add a placeholder for the new task so we know it's one we started
#
task: asyncio.Task = None
try:
task = asyncio.current_task()
cls.our_tasks.add(task)
if ast_ctx is not None:
cls.task_done_callback_ctx(task, ast_ctx)
result = await coro
return result
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("run_coro: got exception %s", traceback.format_exc(-1))
finally:
if task in cls.task2cb:
for callback, info in cls.task2cb[task]["cb"].items():
ast_ctx, args, kwargs = info
await ast_ctx.call_func(callback, None, *args, **kwargs)
if ast_ctx.get_exception_obj():
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
break
if task in cls.unique_task2name:
for name in cls.unique_task2name[task]:
del cls.unique_name2task[name]
del cls.unique_task2name[task]
cls.task2context.pop(task, None)
cls.task2cb.pop(task, None)
cls.our_tasks.discard(task)
@classmethod
def create_task(cls, coro, ast_ctx=None):
"""Create a new task that runs a coroutine."""
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))
@classmethod
def service_register(
cls, global_ctx_name, domain, service, callback, supports_response=SERVICE_RESPONSE_NONE
):
"""Register a new service callback."""
key = f"{domain}.{service}"
if key not in cls.service_cnt:
cls.service_cnt[key] = 0
if key not in cls.service2global_ctx:
cls.service2global_ctx[key] = global_ctx_name
if cls.service2global_ctx[key] != global_ctx_name:
raise ValueError(
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
)
cls.service_cnt[key] += 1
if SERVICE_RESPONSE_ONLY is None:
# backwards compatibility < 2023.7
cls.hass.services.async_register(domain, service, callback)
else:
# allow service responses >= 2023.7
cls.hass.services.async_register(domain, service, callback, supports_response=supports_response)
@classmethod
def service_remove(cls, global_ctx_name, domain, service):
"""Remove a service callback."""
key = f"{domain}.{service}"
if cls.service_cnt.get(key, 0) > 1:
cls.service_cnt[key] -= 1
return
cls.service_cnt[key] = 0
cls.hass.services.async_remove(domain, service)
cls.service2global_ctx.pop(key, None)
@classmethod
def task_done_callback_ctx(cls, task, ast_ctx):
"""Set the ast_ctx for a task, which is needed for done callbacks."""
if task not in cls.task2cb or "ctx" not in cls.task2cb[task]:
cls.task2cb[task] = {"ctx": ast_ctx, "cb": {}}
@classmethod
def task_add_done_callback(cls, task, ast_ctx, callback, *args, **kwargs):
"""Add a done callback to the given task."""
if ast_ctx is None:
ast_ctx = cls.task2cb[task]["ctx"]
cls.task2cb[task]["cb"][callback] = [ast_ctx, args, kwargs]

View File

@@ -0,0 +1,352 @@
"""Global context handling."""
import logging
import os
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Union
from homeassistant.config_entries import ConfigEntry
from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH
from .eval import AstEval, EvalFunc
from .function import Function
from .trigger import TrigInfo
_LOGGER = logging.getLogger(LOGGER_PATH + ".global_ctx")
class GlobalContext:
"""Define class for global variables and trigger context."""
def __init__(
self,
name,
global_sym_table: Dict[str, Any] = None,
manager=None,
rel_import_path: str = None,
app_config: Dict[str, Any] = None,
source: str = None,
mtime: float = None,
) -> None:
"""Initialize GlobalContext."""
self.name: str = name
self.global_sym_table: Dict[str, Any] = global_sym_table if global_sym_table else {}
self.triggers: Set[EvalFunc] = set()
self.triggers_delay_start: Set[EvalFunc] = set()
self.logger: logging.Logger = logging.getLogger(LOGGER_PATH + "." + name)
self.manager: GlobalContextMgr = manager
self.auto_start: bool = False
self.module: ModuleType = None
self.rel_import_path: str = rel_import_path
self.source: str = source
self.file_path: str = None
self.mtime: float = mtime
self.app_config: Dict[str, Any] = app_config
self.imports: Set[str] = set()
config_entry: ConfigEntry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {})
if config_entry.data.get(CONF_HASS_IS_GLOBAL, False):
#
# expose hass as a global variable if configured
#
self.global_sym_table["hass"] = Function.hass
if app_config:
self.global_sym_table["pyscript.app_config"] = app_config.copy()
def trigger_register(self, func: EvalFunc) -> bool:
"""Register a trigger function; return True if start now."""
self.triggers.add(func)
if self.auto_start:
return True
self.triggers_delay_start.add(func)
return False
def trigger_unregister(self, func: EvalFunc) -> None:
"""Unregister a trigger function."""
self.triggers.discard(func)
self.triggers_delay_start.discard(func)
def set_auto_start(self, auto_start: bool) -> None:
"""Set the auto-start flag."""
self.auto_start = auto_start
def start(self) -> None:
"""Start any unstarted triggers."""
for func in self.triggers_delay_start:
func.trigger_start()
self.triggers_delay_start = set()
def stop(self) -> None:
"""Stop all triggers and auto_start."""
for func in self.triggers:
func.trigger_stop()
self.triggers = set()
self.triggers_delay_start = set()
self.set_auto_start(False)
def get_name(self) -> str:
"""Return the global context name."""
return self.name
def set_logger_name(self, name) -> None:
"""Set the global context logging name."""
self.logger = logging.getLogger(LOGGER_PATH + "." + name)
def get_global_sym_table(self) -> Dict[str, Any]:
"""Return the global symbol table."""
return self.global_sym_table
def get_source(self) -> str:
"""Return the source code."""
return self.source
def get_app_config(self) -> Dict[str, Any]:
"""Return the app config."""
return self.app_config
def get_mtime(self) -> float:
"""Return the mtime."""
return self.mtime
def get_file_path(self) -> str:
"""Return the file path."""
return self.file_path
def get_imports(self) -> Set[str]:
"""Return the imports."""
return self.imports
def get_trig_info(self, name: str, trig_args: Dict[str, Any]) -> TrigInfo:
"""Return a new trigger info instance with the given args."""
return TrigInfo(name, trig_args, self)
async def module_import(self, module_name: str, import_level: int) -> List[Optional[str]]:
"""Import a pyscript module from the pyscript/modules or apps folder."""
pyscript_dir = Function.hass.config.path(FOLDER)
module_path = module_name.replace(".", "/")
file_paths = []
def find_first_file(file_paths: List[Set[str]]) -> List[Optional[Union[str, ModuleType]]]:
for ctx_name, path, rel_path in file_paths:
abs_path = os.path.join(pyscript_dir, path)
if os.path.isfile(abs_path):
return [ctx_name, abs_path, rel_path]
return None
#
# first build a list of potential import files
#
if import_level > 0:
if self.rel_import_path is None:
raise ImportError("attempted relative import with no known parent package")
path = self.rel_import_path
if path.endswith("/__init__"):
path = os.path.dirname(path)
ctx_name = self.name
for _ in range(import_level - 1):
path = os.path.dirname(path)
idx = ctx_name.rfind(".")
if path.find("/") < 0 or idx < 0:
raise ImportError("attempted relative import above parent package")
ctx_name = ctx_name[0:idx]
ctx_name += f".{module_name}"
module_info = [ctx_name, f"{path}/{module_path}.py", path]
path += f"/{module_path}"
file_paths.append([ctx_name, f"{path}/__init__.py", path])
file_paths.append(module_info)
module_name = ctx_name[ctx_name.find(".") + 1 :]
else:
if self.rel_import_path is not None and self.rel_import_path.startswith("apps/"):
ctx_name = f"apps.{module_name}"
file_paths.append([ctx_name, f"apps/{module_path}/__init__.py", f"apps/{module_path}"])
file_paths.append([ctx_name, f"apps/{module_path}.py", f"apps/{module_path}"])
ctx_name = f"modules.{module_name}"
file_paths.append([ctx_name, f"modules/{module_path}/__init__.py", f"modules/{module_path}"])
file_paths.append([ctx_name, f"modules/{module_path}.py", None])
#
# now see if we have loaded it already
#
for ctx_name, _, _ in file_paths:
mod_ctx = self.manager.get(ctx_name)
if mod_ctx and mod_ctx.module:
self.imports.add(mod_ctx.get_name())
return [mod_ctx.module, None]
#
# not loaded already, so try to find and import it
#
file_info = await Function.hass.async_add_executor_job(find_first_file, file_paths)
if not file_info:
return [None, None]
[ctx_name, file_path, rel_import_path] = file_info
mod = ModuleType(module_name)
global_ctx = GlobalContext(
ctx_name, global_sym_table=mod.__dict__, manager=self.manager, rel_import_path=rel_import_path
)
global_ctx.set_auto_start(True)
_, error_ctx = await self.manager.load_file(global_ctx, file_path)
if error_ctx:
_LOGGER.error(
"module_import: failed to load module %s, ctx = %s, path = %s",
module_name,
ctx_name,
file_path,
)
return [None, error_ctx]
global_ctx.module = mod
self.imports.add(ctx_name)
return [mod, None]
class GlobalContextMgr:
"""Define class for all global contexts."""
#
# map of context names to contexts
#
contexts = {}
#
# sequence number for sessions
#
name_seq = 0
def __init__(self) -> None:
"""Report an error if GlobalContextMgr in instantiated."""
_LOGGER.error("GlobalContextMgr class is not meant to be instantiated")
@classmethod
def init(cls) -> None:
"""Initialize GlobalContextMgr."""
def get_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], str]:
"""Generate a pyscript.get_global_ctx() function with given ast_ctx."""
async def get_global_ctx():
return ast_ctx.get_global_ctx_name()
return get_global_ctx
def list_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], List[str]]:
"""Generate a pyscript.list_global_ctx() function with given ast_ctx."""
async def list_global_ctx():
ctx_names = set(cls.contexts.keys())
curr_ctx_name = ast_ctx.get_global_ctx_name()
ctx_names.discard(curr_ctx_name)
return [curr_ctx_name] + sorted(sorted(ctx_names))
return list_global_ctx
def set_global_ctx_factory(ast_ctx: AstEval) -> Callable[[str], None]:
"""Generate a pyscript.set_global_ctx() function with given ast_ctx."""
async def set_global_ctx(name):
global_ctx = cls.get(name)
if global_ctx is None:
raise NameError(f"global context '{name}' does not exist")
ast_ctx.set_global_ctx(global_ctx)
ast_ctx.set_logger_name(global_ctx.name)
return set_global_ctx
ast_funcs = {
"pyscript.get_global_ctx": get_global_ctx_factory,
"pyscript.list_global_ctx": list_global_ctx_factory,
"pyscript.set_global_ctx": set_global_ctx_factory,
}
Function.register_ast(ast_funcs)
@classmethod
def get(cls, name: str) -> Optional[str]:
"""Return the GlobalContext given a name."""
return cls.contexts.get(name, None)
@classmethod
def set(cls, name: str, global_ctx: GlobalContext) -> None:
"""Save the GlobalContext by name."""
cls.contexts[name] = global_ctx
@classmethod
def items(cls) -> List[Set[Union[str, GlobalContext]]]:
"""Return all the global context items."""
return sorted(cls.contexts.items())
@classmethod
def delete(cls, name: str) -> None:
"""Delete the given GlobalContext."""
if name in cls.contexts:
global_ctx = cls.contexts[name]
global_ctx.stop()
del cls.contexts[name]
@classmethod
def new_name(cls, root: str) -> str:
"""Find a unique new name by appending a sequence number to root."""
while True:
name = f"{root}{cls.name_seq}"
cls.name_seq += 1
if name not in cls.contexts:
return name
@classmethod
async def load_file(
cls, global_ctx: GlobalContext, file_path: str, source: str = None, reload: bool = False
) -> Set[Union[bool, AstEval]]:
"""Load, parse and run the given script file; returns error ast_ctx on error, or None if ok."""
mtime = None
if source is None:
def read_file(path: str) -> Set[Union[str, float]]:
try:
with open(path, encoding="utf-8") as file_desc:
source = file_desc.read()
return source, os.path.getmtime(path)
except Exception as exc:
_LOGGER.error("%s", exc)
return None, 0
source, mtime = await Function.hass.async_add_executor_job(read_file, file_path)
if source is None:
return False, None
ctx_curr = cls.get(global_ctx.get_name())
if ctx_curr:
# stop triggers and destroy old global context
ctx_curr.stop()
cls.delete(global_ctx.get_name())
#
# create new ast eval context and parse source file
#
ast_ctx = AstEval(global_ctx.get_name(), global_ctx)
Function.install_ast_funcs(ast_ctx)
if not ast_ctx.parse(source, filename=file_path):
exc = ast_ctx.get_exception_long()
ast_ctx.get_logger().error(exc)
global_ctx.stop()
return False, ast_ctx
await ast_ctx.eval()
exc = ast_ctx.get_exception_long()
if exc is not None:
ast_ctx.get_logger().error(exc)
global_ctx.stop()
return False, ast_ctx
global_ctx.source = source
global_ctx.file_path = file_path
if mtime is not None:
global_ctx.mtime = mtime
cls.set(global_ctx.get_name(), global_ctx)
_LOGGER.info("%s %s", "Reloaded" if reload else "Loaded", file_path)
return True, None

View File

@@ -0,0 +1,921 @@
"""Pyscript Jupyter kernel."""
#
# Based on simple_kernel.py by Doug Blank <doug.blank@gmail.com>
# https://github.com/dsblank/simple_kernel
# license: public domain
# Thanks Doug!
#
import asyncio
import datetime
import hashlib
import hmac
import json
import logging
import logging.handlers
import re
from struct import pack, unpack
import traceback
import uuid
from .const import LOGGER_PATH
from .function import Function
from .global_ctx import GlobalContextMgr
from .state import State
_LOGGER = logging.getLogger(LOGGER_PATH + ".jupyter_kernel")
# Globals:
DELIM = b"<IDS|MSG>"
def msg_id():
"""Return a new uuid for message id."""
return str(uuid.uuid4())
def str_to_bytes(string):
"""Encode a string in bytes."""
return string.encode("utf-8")
class KernelBufferingHandler(logging.handlers.BufferingHandler):
"""Memory-based handler for logging; send via stdout queue."""
def __init__(self, housekeep_q):
"""Initialize KernelBufferingHandler instance."""
super().__init__(0)
self.housekeep_q = housekeep_q
def flush(self):
"""Flush is a no-op."""
def shouldFlush(self, record):
"""Write the buffer to the housekeeping queue."""
try:
self.housekeep_q.put_nowait(["stdout", self.format(record)])
except asyncio.QueueFull:
_LOGGER.error("housekeep_q unexpectedly full")
################################################################
class ZmqSocket:
"""Defines a minimal implementation of a small subset of ZMQ."""
#
# This allows pyscript to work with Jupyter without the real zmq
# and pyzmq packages, which might not be available or easy to
# install on the wide set of HASS platforms.
#
def __init__(self, reader, writer, sock_type):
"""Initialize a ZMQ socket with the given type and reader/writer streams."""
self.writer = writer
self.reader = reader
self.type = sock_type
async def read_bytes(self, num_bytes):
"""Read bytes from ZMQ socket."""
data = b""
while len(data) < num_bytes:
new_data = await self.reader.read(num_bytes - len(data))
if len(new_data) == 0:
raise EOFError
data += new_data
return data
async def write_bytes(self, raw_msg):
"""Write bytes to ZMQ socket."""
self.writer.write(raw_msg)
await self.writer.drain()
async def handshake(self):
"""Do initial greeting handshake on a new ZMQ connection."""
await self.write_bytes(b"\xff\x00\x00\x00\x00\x00\x00\x00\x01\x7f")
_ = await self.read_bytes(10)
# _LOGGER.debug(f"handshake: got initial greeting {greeting}")
await self.write_bytes(b"\x03")
_ = await self.read_bytes(1)
await self.write_bytes(b"\x00" + "NULL".encode() + b"\x00" * 16 + b"\x00" + b"\x00" * 31)
_ = await self.read_bytes(53)
# _LOGGER.debug(f"handshake: got rest of greeting {greeting}")
params = [["Socket-Type", self.type]]
if self.type == "ROUTER":
params.append(["Identity", ""])
await self.send_cmd("READY", params)
async def recv(self, multipart=False):
"""Receive a message from ZMQ socket."""
parts = []
while 1:
cmd = (await self.read_bytes(1))[0]
if cmd & 0x2:
msg_len = unpack(">Q", await self.read_bytes(8))[0]
else:
msg_len = (await self.read_bytes(1))[0]
msg_body = await self.read_bytes(msg_len)
if cmd & 0x4:
# _LOGGER.debug(f"recv: got cmd {msg_body}")
cmd_len = msg_body[0]
cmd = msg_body[1 : cmd_len + 1]
msg_body = msg_body[cmd_len + 1 :]
params = []
while len(msg_body) > 0:
param_len = msg_body[0]
param = msg_body[1 : param_len + 1]
msg_body = msg_body[param_len + 1 :]
value_len = unpack(">L", msg_body[0:4])[0]
value = msg_body[4 : 4 + value_len]
msg_body = msg_body[4 + value_len :]
params.append([param, value])
# _LOGGER.debug(f"recv: got cmd={cmd}, params={params}")
else:
parts.append(msg_body)
if cmd in (0x0, 0x2):
# _LOGGER.debug(f"recv: got msg {parts}")
if not multipart:
return b"".join(parts)
return parts
async def recv_multipart(self):
"""Receive a multipart message from ZMQ socket."""
return await self.recv(multipart=True)
async def send_cmd(self, cmd, params):
"""Send a command over ZMQ socket."""
raw_msg = bytearray([len(cmd)]) + cmd.encode()
for param in params:
raw_msg += bytearray([len(param[0])]) + param[0].encode()
raw_msg += pack(">L", len(param[1])) + param[1].encode()
len_msg = len(raw_msg)
if len_msg <= 255:
raw_msg = bytearray([0x4, len_msg]) + raw_msg
else:
raw_msg = bytearray([0x6]) + pack(">Q", len_msg) + raw_msg
# _LOGGER.debug(f"send_cmd: sending {raw_msg}")
await self.write_bytes(raw_msg)
async def send(self, msg):
"""Send a message over ZMQ socket."""
len_msg = len(msg)
if len_msg <= 255:
raw_msg = bytearray([0x1, 0x0, 0x0, len_msg]) + msg
else:
raw_msg = bytearray([0x1, 0x0, 0x2]) + pack(">Q", len_msg) + msg
# _LOGGER.debug(f"send: sending {raw_msg}")
await self.write_bytes(raw_msg)
async def send_multipart(self, parts):
"""Send multipart messages over ZMQ socket."""
raw_msg = b""
for i, part in enumerate(parts):
len_part = len(part)
cmd = 0x1 if i < len(parts) - 1 else 0x0
if len_part <= 255:
raw_msg += bytearray([cmd, len_part]) + part
else:
raw_msg += bytearray([cmd + 2]) + pack(">Q", len_part) + part
# _LOGGER.debug(f"send_multipart: sending {raw_msg}")
await self.write_bytes(raw_msg)
def close(self):
"""Close the ZMQ socket."""
self.writer.close()
##########################################
class Kernel:
"""Define a Jupyter Kernel class."""
def __init__(self, config, ast_ctx, global_ctx, global_ctx_name):
"""Initialize a Kernel object, one instance per session."""
self.config = config.copy()
self.global_ctx = global_ctx
self.global_ctx_name = global_ctx_name
self.ast_ctx = ast_ctx
self.secure_key = str_to_bytes(self.config["key"])
self.no_connect_timeout = self.config.get("no_connect_timeout", 30)
self.signature_schemes = {"hmac-sha256": hashlib.sha256}
self.auth = hmac.HMAC(
self.secure_key,
digestmod=self.signature_schemes[self.config["signature_scheme"]],
)
self.execution_count = 1
self.engine_id = str(uuid.uuid4())
self.heartbeat_server = None
self.iopub_server = None
self.control_server = None
self.stdin_server = None
self.shell_server = None
self.heartbeat_port = None
self.iopub_port = None
self.control_port = None
self.stdin_port = None
self.shell_port = None
# this should probably be a configuration parameter
self.avail_port = 50321
# there can be multiple iopub subscribers, with corresponding tasks
self.iopub_socket = set()
self.tasks = {}
self.task_cnt = 0
self.task_cnt_max = 0
self.session_cleanup_callback = None
self.housekeep_q = asyncio.Queue(0)
self.parent_header = None
#
# we create a logging handler so that output from the log functions
# gets delivered back to Jupyter as stdout
#
self.console = KernelBufferingHandler(self.housekeep_q)
self.console.setLevel(logging.DEBUG)
# set a format which is just the message
formatter = logging.Formatter("%(message)s")
self.console.setFormatter(formatter)
# match alphanum or "." at end of line
self.completion_re = re.compile(r".*?([\w.]*)$", re.DOTALL)
# see if line ends in a ":", with optional whitespace and comment
# note: this doesn't detect if we are inside a quoted string...
self.colon_end_re = re.compile(r".*: *(#.*)?$")
def msg_sign(self, msg_lst):
"""Sign a message with a secure signature."""
auth_hmac = self.auth.copy()
for msg in msg_lst:
auth_hmac.update(msg)
return str_to_bytes(auth_hmac.hexdigest())
def deserialize_wire_msg(self, wire_msg):
"""Split the routing prefix and message frames from a message on the wire."""
delim_idx = wire_msg.index(DELIM)
identities = wire_msg[:delim_idx]
m_signature = wire_msg[delim_idx + 1]
msg_frames = wire_msg[delim_idx + 2 :]
def decode(msg):
return json.loads(msg.decode("utf-8"))
msg = {}
msg["header"] = decode(msg_frames[0])
msg["parent_header"] = decode(msg_frames[1])
msg["metadata"] = decode(msg_frames[2])
msg["content"] = decode(msg_frames[3])
check_sig = self.msg_sign(msg_frames)
if check_sig != m_signature:
_LOGGER.error(
"signature mismatch: check_sig=%s, m_signature=%s, wire_msg=%s",
check_sig,
m_signature,
wire_msg,
)
raise ValueError("Signatures do not match")
return identities, msg
def new_header(self, msg_type):
"""Make a new header."""
return {
"date": datetime.datetime.now().isoformat(),
"msg_id": msg_id(),
"username": "kernel",
"session": self.engine_id,
"msg_type": msg_type,
"version": "5.3",
}
async def send(
self,
stream,
msg_type,
content=None,
parent_header=None,
metadata=None,
identities=None,
):
"""Send message to the Jupyter client."""
header = self.new_header(msg_type)
def encode(msg):
return str_to_bytes(json.dumps(msg))
msg_lst = [
encode(header),
encode(parent_header if parent_header else {}),
encode(metadata if metadata else {}),
encode(content if content else {}),
]
signature = self.msg_sign(msg_lst)
parts = [DELIM, signature, msg_lst[0], msg_lst[1], msg_lst[2], msg_lst[3]]
if identities:
parts = identities + parts
if stream:
# _LOGGER.debug("send %s: %s", msg_type, parts)
for this_stream in stream if isinstance(stream, set) else {stream}:
await this_stream.send_multipart(parts)
async def shell_handler(self, shell_socket, wire_msg):
"""Handle shell messages."""
identities, msg = self.deserialize_wire_msg(wire_msg)
# _LOGGER.debug("shell received %s: %s", msg.get('header', {}).get('msg_type', 'UNKNOWN'), msg)
self.parent_header = msg["header"]
content = {
"execution_state": "busy",
}
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
if msg["header"]["msg_type"] == "execute_request":
content = {
"execution_count": self.execution_count,
"code": msg["content"]["code"],
}
await self.send(self.iopub_socket, "execute_input", content, parent_header=msg["header"])
result = None
code = msg["content"]["code"]
#
# replace VSCode initialization code, which depend on iPython % extensions
#
if code.startswith("%config "):
code = "None"
if code.startswith("_rwho_ls = %who_ls"):
code = "print([])"
self.global_ctx.set_auto_start(False)
self.ast_ctx.parse(code)
exc = self.ast_ctx.get_exception_obj()
if exc is None:
result = await self.ast_ctx.eval()
exc = self.ast_ctx.get_exception_obj()
await Function.waiter_sync()
self.global_ctx.set_auto_start(True)
self.global_ctx.start()
if exc:
traceback_mesg = self.ast_ctx.get_exception_long().split("\n")
metadata = {
"dependencies_met": True,
"engine": self.engine_id,
"status": "error",
"started": datetime.datetime.now().isoformat(),
}
content = {
"execution_count": self.execution_count,
"status": "error",
"ename": type(exc).__name__, # Exception name, as a string
"evalue": str(exc), # Exception value, as a string
"traceback": traceback_mesg,
}
_LOGGER.debug("Executing '%s' got exception: %s", code, content)
await self.send(
shell_socket,
"execute_reply",
content,
metadata=metadata,
parent_header=msg["header"],
identities=identities,
)
del content["execution_count"], content["status"]
await self.send(self.iopub_socket, "error", content, parent_header=msg["header"])
content = {
"execution_state": "idle",
}
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
if msg["content"].get("store_history", True):
self.execution_count += 1
return
# if True or isinstance(self.ast_ctx.ast, ast.Expr):
_LOGGER.debug("Executing: '%s' got result %s", code, result)
if result is not None:
content = {
"execution_count": self.execution_count,
"data": {"text/plain": repr(result)},
"metadata": {},
}
await self.send(
self.iopub_socket,
"execute_result",
content,
parent_header=msg["header"],
)
metadata = {
"dependencies_met": True,
"engine": self.engine_id,
"status": "ok",
"started": datetime.datetime.now().isoformat(),
}
content = {
"status": "ok",
"execution_count": self.execution_count,
"user_variables": {},
"payload": [],
"user_expressions": {},
}
await self.send(
shell_socket,
"execute_reply",
content,
metadata=metadata,
parent_header=msg["header"],
identities=identities,
)
if msg["content"].get("store_history", True):
self.execution_count += 1
#
# Make sure stdout gets sent before set report execution_state idle on iopub,
# otherwise VSCode doesn't display stdout. We do a handshake with the
# housekeep task to ensure any queued messages get processed.
#
handshake_q = asyncio.Queue(0)
await self.housekeep_q.put(["handshake", handshake_q, 0])
await handshake_q.get()
elif msg["header"]["msg_type"] == "kernel_info_request":
content = {
"protocol_version": "5.3",
"ipython_version": [1, 1, 0, ""],
"language_version": [0, 0, 1],
"language": "python",
"implementation": "python",
"implementation_version": "3.7",
"language_info": {
"name": "python",
"version": "1.0",
"mimetype": "",
"file_extension": ".py",
"codemirror_mode": "",
"nbconvert_exporter": "",
},
"banner": "",
}
await self.send(
shell_socket,
"kernel_info_reply",
content,
parent_header=msg["header"],
identities=identities,
)
elif msg["header"]["msg_type"] == "complete_request":
root = ""
words = set()
code = msg["content"]["code"]
posn = msg["content"]["cursor_pos"]
match = self.completion_re.match(code[0:posn].lower())
if match:
root = match[1].lower()
words = State.completions(root)
words = words.union(await Function.service_completions(root))
words = words.union(await Function.func_completions(root))
words = words.union(self.ast_ctx.completions(root))
# _LOGGER.debug(f"complete_request code={code}, posn={posn}, root={root}, words={words}")
content = {
"status": "ok",
"matches": sorted(list(words)),
"cursor_start": msg["content"]["cursor_pos"] - len(root),
"cursor_end": msg["content"]["cursor_pos"],
"metadata": {},
}
await self.send(
shell_socket,
"complete_reply",
content,
parent_header=msg["header"],
identities=identities,
)
elif msg["header"]["msg_type"] == "is_complete_request":
code = msg["content"]["code"]
self.ast_ctx.parse(code)
exc = self.ast_ctx.get_exception_obj()
# determine indent of last line
indent = 0
i = code.rfind("\n")
if i >= 0:
while i + 1 < len(code) and code[i + 1] == " ":
i += 1
indent += 1
if exc is None:
if indent == 0:
content = {
# One of 'complete', 'incomplete', 'invalid', 'unknown'
"status": "complete",
# If status is 'incomplete', indent should contain the characters to use
# to indent the next line. This is only a hint: frontends may ignore it
# and use their own autoindentation rules. For other statuses, this
# field does not exist.
# "indent": str,
}
else:
content = {
"status": "incomplete",
"indent": " " * indent,
}
else:
#
# if the syntax error is right at the end, then we label it incomplete,
# otherwise it's invalid
#
if "EOF while" in str(exc) or "expected an indented block" in str(exc):
# if error is at ":" then increase indent
if hasattr(exc, "lineno"):
line = code.split("\n")[exc.lineno - 1]
if self.colon_end_re.match(line):
indent += 4
content = {
"status": "incomplete",
"indent": " " * indent,
}
else:
content = {
"status": "invalid",
}
# _LOGGER.debug(f"is_complete_request code={code}, exc={exc}, content={content}")
await self.send(
shell_socket,
"is_complete_reply",
content,
parent_header=msg["header"],
identities=identities,
)
elif msg["header"]["msg_type"] == "comm_info_request":
content = {"comms": {}}
await self.send(
shell_socket,
"comm_info_reply",
content,
parent_header=msg["header"],
identities=identities,
)
elif msg["header"]["msg_type"] == "history_request":
content = {"history": []}
await self.send(
shell_socket,
"history_reply",
content,
parent_header=msg["header"],
identities=identities,
)
elif msg["header"]["msg_type"] in {"comm_open", "comm_msg", "comm_close"}:
# _LOGGER.debug(f"ignore {msg['header']['msg_type']} message ")
...
else:
_LOGGER.error("unknown msg_type: %s", msg["header"]["msg_type"])
content = {
"execution_state": "idle",
}
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
async def control_listen(self, reader, writer):
"""Task that listens to control messages."""
try:
_LOGGER.debug("control_listen connected")
await self.housekeep_q.put(["register", "control", asyncio.current_task()])
control_socket = ZmqSocket(reader, writer, "ROUTER")
await control_socket.handshake()
while 1:
wire_msg = await control_socket.recv_multipart()
identities, msg = self.deserialize_wire_msg(wire_msg)
# _LOGGER.debug("control received %s: %s", msg.get('header', {}).get('msg_type', 'UNKNOWN'), msg)
if msg["header"]["msg_type"] == "shutdown_request":
content = {
"restart": False,
}
await self.send(
control_socket,
"shutdown_reply",
content,
parent_header=msg["header"],
identities=identities,
)
await self.housekeep_q.put(["shutdown"])
except asyncio.CancelledError:
raise
except (EOFError, ConnectionResetError):
_LOGGER.debug("control_listen got eof")
await self.housekeep_q.put(["unregister", "control", asyncio.current_task()])
control_socket.close()
except Exception as err:
_LOGGER.error("control_listen exception %s", err)
await self.housekeep_q.put(["shutdown"])
async def stdin_listen(self, reader, writer):
"""Task that listens to stdin messages."""
try:
_LOGGER.debug("stdin_listen connected")
await self.housekeep_q.put(["register", "stdin", asyncio.current_task()])
stdin_socket = ZmqSocket(reader, writer, "ROUTER")
await stdin_socket.handshake()
while 1:
_ = await stdin_socket.recv_multipart()
# _LOGGER.debug("stdin_listen received %s", _)
except asyncio.CancelledError:
raise
except (EOFError, ConnectionResetError):
_LOGGER.debug("stdin_listen got eof")
await self.housekeep_q.put(["unregister", "stdin", asyncio.current_task()])
stdin_socket.close()
except Exception:
_LOGGER.error("stdin_listen exception %s", traceback.format_exc(-1))
await self.housekeep_q.put(["shutdown"])
async def shell_listen(self, reader, writer):
"""Task that listens to shell messages."""
try:
_LOGGER.debug("shell_listen connected")
await self.housekeep_q.put(["register", "shell", asyncio.current_task()])
shell_socket = ZmqSocket(reader, writer, "ROUTER")
await shell_socket.handshake()
while 1:
msg = await shell_socket.recv_multipart()
await self.shell_handler(shell_socket, msg)
except asyncio.CancelledError:
shell_socket.close()
raise
except (EOFError, ConnectionResetError):
_LOGGER.debug("shell_listen got eof")
await self.housekeep_q.put(["unregister", "shell", asyncio.current_task()])
shell_socket.close()
except Exception:
_LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
await self.housekeep_q.put(["shutdown"])
async def heartbeat_listen(self, reader, writer):
"""Task that listens and responds to heart beat messages."""
try:
_LOGGER.debug("heartbeat_listen connected")
await self.housekeep_q.put(["register", "heartbeat", asyncio.current_task()])
heartbeat_socket = ZmqSocket(reader, writer, "REP")
await heartbeat_socket.handshake()
while 1:
msg = await heartbeat_socket.recv()
# _LOGGER.debug("heartbeat_listen: got %s", msg)
await heartbeat_socket.send(msg)
except asyncio.CancelledError:
raise
except (EOFError, ConnectionResetError):
_LOGGER.debug("heartbeat_listen got eof")
await self.housekeep_q.put(["unregister", "heartbeat", asyncio.current_task()])
heartbeat_socket.close()
except Exception:
_LOGGER.error("heartbeat_listen exception: %s", traceback.format_exc(-1))
await self.housekeep_q.put(["shutdown"])
async def iopub_listen(self, reader, writer):
"""Task that listens to iopub messages."""
try:
_LOGGER.debug("iopub_listen connected")
await self.housekeep_q.put(["register", "iopub", asyncio.current_task()])
iopub_socket = ZmqSocket(reader, writer, "PUB")
await iopub_socket.handshake()
self.iopub_socket.add(iopub_socket)
while 1:
_ = await iopub_socket.recv_multipart()
# _LOGGER.debug("iopub received %s", _)
except asyncio.CancelledError:
raise
except (EOFError, ConnectionResetError):
await self.housekeep_q.put(["unregister", "iopub", asyncio.current_task()])
iopub_socket.close()
self.iopub_socket.discard(iopub_socket)
_LOGGER.debug("iopub_listen got eof")
except Exception:
_LOGGER.error("iopub_listen exception %s", traceback.format_exc(-1))
await self.housekeep_q.put(["shutdown"])
async def housekeep_run(self):
"""Housekeeping, including closing servers after startup, and doing orderly shutdown."""
while True:
try:
msg = await self.housekeep_q.get()
if msg[0] == "stdout":
content = {"name": "stdout", "text": msg[1] + "\n"}
if self.iopub_socket:
await self.send(
self.iopub_socket,
"stream",
content,
parent_header=self.parent_header,
identities=[b"stream.stdout"],
)
elif msg[0] == "handshake":
await msg[1].put(msg[2])
elif msg[0] == "register":
if msg[1] not in self.tasks:
self.tasks[msg[1]] = set()
self.tasks[msg[1]].add(msg[2])
self.task_cnt += 1
self.task_cnt_max = max(self.task_cnt_max, self.task_cnt)
#
# now a couple of things are connected, call the session_cleanup_callback
#
if self.task_cnt > 1 and self.session_cleanup_callback:
self.session_cleanup_callback()
self.session_cleanup_callback = None
elif msg[0] == "unregister":
if msg[1] in self.tasks:
self.tasks[msg[1]].discard(msg[2])
self.task_cnt -= 1
#
# if there are no connection tasks left, then shutdown the kernel
#
if self.task_cnt == 0 and self.task_cnt_max >= 4:
asyncio.create_task(self.session_shutdown())
await asyncio.sleep(10000)
elif msg[0] == "shutdown":
asyncio.create_task(self.session_shutdown())
return
except asyncio.CancelledError:
raise
except Exception:
_LOGGER.error("housekeep task exception: %s", traceback.format_exc(-1))
async def startup_timeout(self):
"""Shut down the session if nothing connects after 30 seconds."""
await self.housekeep_q.put(["register", "startup_timeout", asyncio.current_task()])
await asyncio.sleep(self.no_connect_timeout)
if self.task_cnt_max <= 1:
#
# nothing started other than us, so shut down the session
#
_LOGGER.error("No connections to session %s; shutting down", self.global_ctx_name)
if self.session_cleanup_callback:
self.session_cleanup_callback()
self.session_cleanup_callback = None
await self.housekeep_q.put(["shutdown"])
await self.housekeep_q.put(["unregister", "startup_timeout", asyncio.current_task()])
async def start_one_server(self, callback):
"""Start a server by finding an available port."""
first_port = self.avail_port
for _ in range(2048):
try:
server = await asyncio.start_server(callback, "0.0.0.0", self.avail_port)
return server, self.avail_port
except OSError:
self.avail_port += 1
_LOGGER.error(
"unable to find an available port from %d to %d",
first_port,
self.avail_port - 1,
)
return None, None
def get_ports(self):
"""Return a dict of the port numbers this kernel session is listening to."""
return {
"iopub_port": self.iopub_port,
"hb_port": self.heartbeat_port,
"control_port": self.control_port,
"stdin_port": self.stdin_port,
"shell_port": self.shell_port,
}
def set_session_cleanup_callback(self, callback):
"""Set a cleanup callback which is called right after the session has started."""
self.session_cleanup_callback = callback
async def session_start(self):
"""Start the kernel session."""
self.ast_ctx.add_logger_handler(self.console)
_LOGGER.info("Starting session %s", self.global_ctx_name)
self.tasks["housekeep"] = {asyncio.create_task(self.housekeep_run())}
self.tasks["startup_timeout"] = {asyncio.create_task(self.startup_timeout())}
self.iopub_server, self.iopub_port = await self.start_one_server(self.iopub_listen)
self.heartbeat_server, self.heartbeat_port = await self.start_one_server(self.heartbeat_listen)
self.control_server, self.control_port = await self.start_one_server(self.control_listen)
self.stdin_server, self.stdin_port = await self.start_one_server(self.stdin_listen)
self.shell_server, self.shell_port = await self.start_one_server(self.shell_listen)
#
# For debugging, can use the real ZMQ library instead on certain sockets; comment out
# the corresponding asyncio.start_server() call above if you enable the ZMQ-based
# functions here. You can then turn of verbosity level 4 (-vvvv) in hass_pyscript_kernel.py
# to see all the byte data in case you need to debug the simple ZMQ implementation here.
# The two most important zmq functions are shown below.
#
# import zmq
# import zmq.asyncio
#
# def zmq_bind(socket, connection, port):
# """Bind a socket."""
# if port <= 0:
# return socket.bind_to_random_port(connection)
# # _LOGGER.debug(f"binding to %s:%s" % (connection, port))
# socket.bind("%s:%s" % (connection, port))
# return port
#
# zmq_ctx = zmq.asyncio.Context()
#
# ##########################################
# # Shell using real ZMQ for debugging:
# async def shell_listen_zmq():
# """Task that listens to shell messages using ZMQ."""
# try:
# _LOGGER.debug("shell_listen_zmq connected")
# connection = self.config["transport"] + "://" + self.config["ip"]
# shell_socket = zmq_ctx.socket(zmq.ROUTER)
# self.shell_port = zmq_bind(shell_socket, connection, -1)
# _LOGGER.debug("shell_listen_zmq connected")
# while 1:
# msg = await shell_socket.recv_multipart()
# await self.shell_handler(shell_socket, msg)
# except asyncio.CancelledError:
# raise
# except Exception:
# _LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
# await self.housekeep_q.put(["shutdown"])
#
# ##########################################
# # IOPub using real ZMQ for debugging:
# # IOPub/Sub:
# async def iopub_listen_zmq():
# """Task that listens to iopub messages using ZMQ."""
# try:
# _LOGGER.debug("iopub_listen_zmq connected")
# connection = self.config["transport"] + "://" + self.config["ip"]
# iopub_socket = zmq_ctx.socket(zmq.PUB)
# self.iopub_port = zmq_bind(self.iopub_socket, connection, -1)
# self.iopub_socket.add(iopub_socket)
# while 1:
# wire_msg = await iopub_socket.recv_multipart()
# _LOGGER.debug("iopub received %s", wire_msg)
# except asyncio.CancelledError:
# raise
# except EOFError:
# await self.housekeep_q.put(["shutdown"])
# _LOGGER.debug("iopub_listen got eof")
# except Exception as err:
# _LOGGER.error("iopub_listen exception %s", err)
# await self.housekeep_q.put(["shutdown"])
#
# self.tasks["shell"] = {asyncio.create_task(shell_listen_zmq())}
# self.tasks["iopub"] = {asyncio.create_task(iopub_listen_zmq())}
#
async def session_shutdown(self):
"""Shutdown the kernel session."""
if not self.iopub_server:
# already shutdown, so quit
return
GlobalContextMgr.delete(self.global_ctx_name)
self.ast_ctx.remove_logger_handler(self.console)
# logging.getLogger("homeassistant.components.pyscript.func.").removeHandler(self.console)
_LOGGER.info("Shutting down session %s", self.global_ctx_name)
for server in [
self.heartbeat_server,
self.control_server,
self.stdin_server,
self.shell_server,
self.iopub_server,
]:
if server:
server.close()
self.heartbeat_server = None
self.iopub_server = None
self.control_server = None
self.stdin_server = None
self.shell_server = None
for task_set in self.tasks.values():
for task in task_set:
try:
task.cancel()
await task
except asyncio.CancelledError:
pass
self.tasks = []
for sock in self.iopub_socket:
try:
sock.close()
except Exception as err:
_LOGGER.error("iopub socket close exception: %s", err)
self.iopub_socket = set()

View File

@@ -0,0 +1,45 @@
"""Describe logbook events."""
import logging
from homeassistant.core import callback
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
@callback
def async_describe_events(hass, async_describe_event): # type: ignore
"""Describe logbook events."""
@callback
def async_describe_logbook_event(event): # type: ignore
"""Describe a logbook event."""
data = event.data
func_args = data.get("func_args", {})
ev_name = data.get("name", "unknown")
ev_entity_id = data.get("entity_id", "pyscript.unknown")
ev_trigger_type = func_args.get("trigger_type", "unknown")
if ev_trigger_type == "event":
ev_source = f"event {func_args.get('event_type', 'unknown event')}"
elif ev_trigger_type == "state":
ev_source = f"state change {func_args.get('var_name', 'unknown entity')} == {func_args.get('value', 'unknown value')}"
elif ev_trigger_type == "time":
ev_trigger_time = func_args.get("trigger_time", "unknown")
if ev_trigger_time is None:
ev_trigger_time = "startup"
ev_source = f"time {ev_trigger_time}"
else:
ev_source = ev_trigger_type
message = f"has been triggered by {ev_source}"
return {
"name": ev_name,
"message": message,
"source": ev_source,
"entity_id": ev_entity_id,
}
async_describe_event(DOMAIN, "pyscript_running", async_describe_logbook_event)

View File

@@ -0,0 +1,17 @@
{
"domain": "pyscript",
"name": "Pyscript Python scripting",
"codeowners": [
"@craigbarratt"
],
"config_flow": true,
"dependencies": [],
"documentation": "https://github.com/custom-components/pyscript",
"homekit": {},
"iot_class": "local_push",
"issue_tracker": "https://github.com/custom-components/pyscript/issues",
"requirements": ["croniter==2.0.2", "watchdog==2.3.1"],
"ssdp": [],
"version": "1.6.1",
"zeroconf": []
}

View File

@@ -0,0 +1,91 @@
"""Handles mqtt messages and notification."""
import json
import logging
from homeassistant.components import mqtt
from .const import LOGGER_PATH
_LOGGER = logging.getLogger(LOGGER_PATH + ".mqtt")
class Mqtt:
"""Define mqtt functions."""
#
# Global hass instance
#
hass = None
#
# notify message queues by mqtt message topic
#
notify = {}
notify_remove = {}
def __init__(self):
"""Warn on Mqtt instantiation."""
_LOGGER.error("Mqtt class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize Mqtt."""
cls.hass = hass
@classmethod
def mqtt_message_handler_maker(cls, subscribed_topic):
"""Closure for mqtt_message_handler."""
async def mqtt_message_handler(mqttmsg):
"""Listen for MQTT messages."""
func_args = {
"trigger_type": "mqtt",
"topic": mqttmsg.topic,
"payload": mqttmsg.payload,
"qos": mqttmsg.qos,
}
try:
func_args["payload_obj"] = json.loads(mqttmsg.payload)
except ValueError:
pass
await cls.update(subscribed_topic, func_args)
return mqtt_message_handler
@classmethod
async def notify_add(cls, topic, queue):
"""Register to notify for mqtt messages of given topic to be sent to queue."""
if topic not in cls.notify:
cls.notify[topic] = set()
_LOGGER.debug("mqtt.notify_add(%s) -> adding mqtt subscription", topic)
cls.notify_remove[topic] = await mqtt.async_subscribe(
cls.hass, topic, cls.mqtt_message_handler_maker(topic), encoding="utf-8", qos=0
)
cls.notify[topic].add(queue)
@classmethod
def notify_del(cls, topic, queue):
"""Unregister to notify for mqtt messages of given topic for given queue."""
if topic not in cls.notify or queue not in cls.notify[topic]:
return
cls.notify[topic].discard(queue)
if len(cls.notify[topic]) == 0:
cls.notify_remove[topic]()
_LOGGER.debug("mqtt.notify_del(%s) -> removing mqtt subscription", topic)
del cls.notify[topic]
del cls.notify_remove[topic]
@classmethod
async def update(cls, topic, func_args):
"""Deliver all notifications for an mqtt message on the given topic."""
_LOGGER.debug("mqtt.update(%s, %s, %s)", topic, vars, func_args)
if topic in cls.notify:
for queue in cls.notify[topic]:
await queue.put(["mqtt", func_args.copy()])

View File

@@ -0,0 +1,323 @@
"""Requirements helpers for pyscript."""
import glob
import logging
import os
import sys
from homeassistant.loader import bind_hass
from homeassistant.requirements import async_process_requirements
from .const import (
ATTR_INSTALLED_VERSION,
ATTR_SOURCES,
ATTR_VERSION,
CONF_ALLOW_ALL_IMPORTS,
CONF_INSTALLED_PACKAGES,
DOMAIN,
LOGGER_PATH,
REQUIREMENTS_FILE,
REQUIREMENTS_PATHS,
UNPINNED_VERSION,
)
if sys.version_info[:2] >= (3, 8):
from importlib.metadata import ( # pylint: disable=no-name-in-module,import-error
PackageNotFoundError,
version as installed_version,
)
else:
from importlib_metadata import ( # pylint: disable=import-error
PackageNotFoundError,
version as installed_version,
)
_LOGGER = logging.getLogger(LOGGER_PATH)
def get_installed_version(pkg_name):
"""Get installed version of package. Returns None if not found."""
try:
return installed_version(pkg_name)
except PackageNotFoundError:
return None
def update_unpinned_versions(package_dict):
"""Check for current installed version of each unpinned package."""
requirements_to_pop = []
for package in package_dict:
if package_dict[package] != UNPINNED_VERSION:
continue
package_dict[package] = get_installed_version(package)
if not package_dict[package]:
_LOGGER.error("%s wasn't able to be installed", package)
requirements_to_pop.append(package)
for package in requirements_to_pop:
package_dict.pop(package)
return package_dict
@bind_hass
def process_all_requirements(pyscript_folder, requirements_paths, requirements_file):
"""
Load all lines from requirements_file located in requirements_paths.
Returns files and a list of packages, if any, that need to be installed.
"""
# Re-import Version to avoid dealing with multiple flake and pylint errors
from packaging.version import Version # pylint: disable=import-outside-toplevel
all_requirements_to_process = {}
for root in requirements_paths:
for requirements_path in glob.glob(os.path.join(pyscript_folder, root, requirements_file)):
with open(requirements_path, "r", encoding="utf-8") as requirements_fp:
all_requirements_to_process[requirements_path] = requirements_fp.readlines()
all_requirements_to_install = {}
for requirements_path, pkg_lines in all_requirements_to_process.items():
for pkg in pkg_lines:
# Remove inline comments which are accepted by pip but not by Home
# Assistant's installation method.
# https://rosettacode.org/wiki/Strip_comments_from_a_string#Python
i = pkg.find("#")
if i >= 0:
pkg = pkg[:i]
pkg = pkg.strip()
if not pkg or len(pkg) == 0:
continue
try:
# Attempt to get version of package. Do nothing if it's found since
# we want to use the version that's already installed to be safe
parts = pkg.split("==")
if len(parts) > 2 or "," in pkg or ">" in pkg or "<" in pkg:
_LOGGER.error(
(
"Ignoring invalid requirement '%s' specified in '%s'; if a specific version"
"is required, the requirement must use the format 'pkg==version'"
),
requirements_path,
pkg,
)
continue
if len(parts) == 1:
new_version = UNPINNED_VERSION
else:
new_version = parts[1]
pkg_name = parts[0]
current_pinned_version = all_requirements_to_install.get(pkg_name, {}).get(ATTR_VERSION)
current_sources = all_requirements_to_install.get(pkg_name, {}).get(ATTR_SOURCES, [])
# If a version hasn't already been recorded, record this one
if not current_pinned_version:
all_requirements_to_install[pkg_name] = {
ATTR_VERSION: new_version,
ATTR_SOURCES: [requirements_path],
ATTR_INSTALLED_VERSION: get_installed_version(pkg_name),
}
# If the new version is unpinned and there is an existing pinned version, use existing
# pinned version
elif new_version == UNPINNED_VERSION and current_pinned_version != UNPINNED_VERSION:
_LOGGER.warning(
(
"Unpinned requirement for package '%s' detected in '%s' will be ignored in "
"favor of the pinned version '%s' detected in '%s'"
),
pkg_name,
requirements_path,
current_pinned_version,
str(current_sources),
)
# If the new version is pinned and the existing version is unpinned, use the new pinned
# version
elif new_version != UNPINNED_VERSION and current_pinned_version == UNPINNED_VERSION:
_LOGGER.warning(
(
"Unpinned requirement for package '%s' detected in '%s will be ignored in "
"favor of the pinned version '%s' detected in '%s'"
),
pkg_name,
str(current_sources),
new_version,
requirements_path,
)
all_requirements_to_install[pkg_name] = {
ATTR_VERSION: new_version,
ATTR_SOURCES: [requirements_path],
ATTR_INSTALLED_VERSION: get_installed_version(pkg_name),
}
# If the already recorded version is the same as the new version, append the current
# path so we can show sources
elif (
new_version == UNPINNED_VERSION and current_pinned_version == UNPINNED_VERSION
) or Version(current_pinned_version) == Version(new_version):
all_requirements_to_install[pkg_name][ATTR_SOURCES].append(requirements_path)
# If the already recorded version is lower than the new version, use the new one
elif Version(current_pinned_version) < Version(new_version):
_LOGGER.warning(
(
"Version '%s' for package '%s' detected in '%s' will be ignored in "
"favor of the higher version '%s' detected in '%s'"
),
current_pinned_version,
pkg_name,
str(current_sources),
new_version,
requirements_path,
)
all_requirements_to_install[pkg_name].update(
{ATTR_VERSION: new_version, ATTR_SOURCES: [requirements_path]}
)
# If the already recorded version is higher than the new version, ignore the new one
elif Version(current_pinned_version) > Version(new_version):
_LOGGER.warning(
(
"Version '%s' for package '%s' detected in '%s' will be ignored in "
"favor of the higher version '%s' detected in '%s'"
),
new_version,
pkg_name,
requirements_path,
current_pinned_version,
str(current_sources),
)
except ValueError:
# Not valid requirements line so it can be skipped
_LOGGER.debug("Ignoring '%s' because it is not a valid package", pkg)
return all_requirements_to_install
@bind_hass
async def install_requirements(hass, config_entry, pyscript_folder):
"""Install missing requirements from requirements.txt."""
pyscript_installed_packages = config_entry.data.get(CONF_INSTALLED_PACKAGES, {}).copy()
# Import packaging inside install_requirements so that we can use Home Assistant to install it
# if it can't been found
try:
from packaging.version import Version # pylint: disable=import-outside-toplevel
except ModuleNotFoundError:
await async_process_requirements(hass, DOMAIN, ["packaging"])
from packaging.version import Version # pylint: disable=import-outside-toplevel
all_requirements = await hass.async_add_executor_job(
process_all_requirements, pyscript_folder, REQUIREMENTS_PATHS, REQUIREMENTS_FILE
)
requirements_to_install = {}
if all_requirements and not config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False):
_LOGGER.error(
(
"Requirements detected but 'allow_all_imports' is set to False, set "
"'allow_all_imports' to True if you want packages to be installed"
)
)
return
for package in all_requirements:
pkg_installed_version = all_requirements[package].get(ATTR_INSTALLED_VERSION)
version_to_install = all_requirements[package][ATTR_VERSION]
sources = all_requirements[package][ATTR_SOURCES]
# If package is already installed, we need to run some checks
if pkg_installed_version:
# If the version to install is unpinned and there is already something installed,
# defer to what is installed
if version_to_install == UNPINNED_VERSION:
_LOGGER.debug(
(
"Skipping unpinned version of package '%s' because version '%s' is "
"already installed"
),
package,
pkg_installed_version,
)
# If installed package is not the same version as the one we last installed,
# that means that the package is externally managed now so we shouldn't touch it
# and should remove it from our internal tracker
if (
package in pyscript_installed_packages
and pyscript_installed_packages[package] != pkg_installed_version
):
pyscript_installed_packages.pop(package)
continue
# If installed package is not the same version as the one we last installed,
# that means that the package is externally managed now so we shouldn't touch it
# and should remove it from our internal tracker
if package in pyscript_installed_packages and Version(
pyscript_installed_packages[package]
) != Version(pkg_installed_version):
_LOGGER.warning(
(
"Version '%s' for package '%s' detected in '%s' will be ignored in favor of"
" the version '%s' which was installed outside of pyscript"
),
version_to_install,
package,
str(sources),
pkg_installed_version,
)
pyscript_installed_packages.pop(package)
# If there is a version mismatch between what we want and what is installed, we
# can overwrite it since we know it was last installed by us
elif package in pyscript_installed_packages and Version(version_to_install) != Version(
pkg_installed_version
):
requirements_to_install[package] = all_requirements[package]
# If there is an installed version that we have not previously installed, we
# should not install it
else:
_LOGGER.debug(
(
"Version '%s' for package '%s' detected in '%s' will be ignored because it"
" is already installed"
),
version_to_install,
package,
str(sources),
)
# Anything not already installed in the environment can be installed
else:
requirements_to_install[package] = all_requirements[package]
if requirements_to_install:
_LOGGER.info(
"Installing the following packages: %s",
str(requirements_to_install),
)
await async_process_requirements(
hass,
DOMAIN,
[
f"{package}=={pkg_info[ATTR_VERSION]}"
if pkg_info[ATTR_VERSION] != UNPINNED_VERSION
else package
for package, pkg_info in requirements_to_install.items()
],
)
else:
_LOGGER.debug("No new packages to install")
# Update package tracker in config entry for next time
pyscript_installed_packages.update(
{package: pkg_info[ATTR_VERSION] for package, pkg_info in requirements_to_install.items()}
)
# If any requirements were unpinned, get their version now so they can be pinned later
if any(version == UNPINNED_VERSION for version in pyscript_installed_packages.values()):
pyscript_installed_packages = await hass.async_add_executor_job(
update_unpinned_versions, pyscript_installed_packages
)
if pyscript_installed_packages != config_entry.data.get(CONF_INSTALLED_PACKAGES, {}):
new_data = config_entry.data.copy()
new_data[CONF_INSTALLED_PACKAGES] = pyscript_installed_packages
hass.config_entries.async_update_entry(entry=config_entry, data=new_data)

View File

@@ -0,0 +1,107 @@
# Describes the format for available pyscript services
reload:
name: Reload pyscript
description: Reloads all available pyscripts and restart triggers
fields:
global_ctx:
name: Global Context
description: Only reload this specific global context (file or app)
example: file.example
required: false
selector:
text:
jupyter_kernel_start:
name: Start Jupyter kernel
description: Starts a jupyter kernel for interactive use; Called by Jupyter front end and should generally not be used by users
fields:
shell_port:
name: Shell Port Number
description: Shell port number
example: 63599
required: false
selector:
number:
min: 10240
max: 65535
iopub_port:
name: IOPub Port Number
description: IOPub port number
example: 63598
required: false
selector:
number:
min: 10240
max: 65535
stdin_port:
name: Stdin Port Number
description: Stdin port number
example: 63597
required: false
selector:
number:
min: 10240
max: 65535
control_port:
name: Control Port Number
description: Control port number
example: 63596
required: false
selector:
number:
min: 10240
max: 65535
hb_port:
name: Heartbeat Port Number
description: Heartbeat port number
example: 63595
required: false
selector:
number:
min: 10240
max: 65535
ip:
name: IP Address
description: IP address to connect to Jupyter front end
example: 127.0.0.1
default: 127.0.0.1
required: false
selector:
text:
key:
name: Security Key
description: Used for signing
example: 012345678-9abcdef023456789abcdef
required: true
selector:
text:
transport:
name: Transport Type
description: Transport type
example: tcp
default: tcp
required: false
selector:
select:
options:
- tcp
- udp
signature_scheme:
name: Signing Algorithm
description: Signing algorithm
example: hmac-sha256
required: false
default: hmac-sha256
selector:
select:
options:
- hmac-sha256
kernel_name:
name: Name of Kernel
description: Kernel name
example: pyscript
required: true
default: pyscript
selector:
text:

View File

@@ -0,0 +1,438 @@
"""Handles state variable access and change notification."""
import asyncio
import logging
from homeassistant.core import Context
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
from homeassistant.helpers.service import async_get_all_descriptions
from .const import LOGGER_PATH
from .entity import PyscriptEntity
from .function import Function
_LOGGER = logging.getLogger(LOGGER_PATH + ".state")
STATE_VIRTUAL_ATTRS = {"entity_id", "last_changed", "last_updated"}
class StateVal(str):
"""Class for representing the value and attributes of a state variable."""
def __new__(cls, state):
"""Create a new instance given a state variable."""
new_var = super().__new__(cls, state.state)
new_var.__dict__ = state.attributes.copy()
new_var.entity_id = state.entity_id
new_var.last_updated = state.last_updated
new_var.last_changed = state.last_changed
return new_var
class State:
"""Class for state functions."""
#
# Global hass instance
#
hass = None
#
# notify message queues by variable
#
notify = {}
#
# Last value of state variable notifications. We maintain this
# so that trigger evaluation can use the last notified value,
# rather than fetching the current value, which is subject to
# race conditions when multiple state variables are set quickly.
#
notify_var_last = {}
#
# pyscript yaml configuration
#
pyscript_config = {}
#
# pyscript vars which have already been registered as persisted
#
persisted_vars = {}
#
# other parameters of all services that have "entity_id" as a parameter
#
service2args = {}
def __init__(self):
"""Warn on State instantiation."""
_LOGGER.error("State class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize State."""
cls.hass = hass
@classmethod
async def get_service_params(cls):
"""Get parameters for all services."""
cls.service2args = {}
all_services = await async_get_all_descriptions(cls.hass)
for domain in all_services:
cls.service2args[domain] = {}
for service, desc in all_services[domain].items():
if "entity_id" not in desc["fields"] and "target" not in desc:
continue
cls.service2args[domain][service] = set(desc["fields"].keys())
cls.service2args[domain][service].discard("entity_id")
@classmethod
async def notify_add(cls, var_names, queue):
"""Register to notify state variables changes to be sent to queue."""
added = False
for var_name in var_names if isinstance(var_names, set) else {var_names}:
parts = var_name.split(".")
if len(parts) != 2 and len(parts) != 3:
continue
state_var_name = f"{parts[0]}.{parts[1]}"
if state_var_name not in cls.notify:
cls.notify[state_var_name] = {}
cls.notify[state_var_name][queue] = var_names
added = True
return added
@classmethod
def notify_del(cls, var_names, queue):
"""Unregister notify of state variables changes for given queue."""
for var_name in var_names if isinstance(var_names, set) else {var_names}:
parts = var_name.split(".")
if len(parts) != 2 and len(parts) != 3:
continue
state_var_name = f"{parts[0]}.{parts[1]}"
if state_var_name not in cls.notify or queue not in cls.notify[state_var_name]:
return
del cls.notify[state_var_name][queue]
@classmethod
async def update(cls, new_vars, func_args):
"""Deliver all notifications for state variable changes."""
notify = {}
for var_name, var_val in new_vars.items():
if var_name in cls.notify:
cls.notify_var_last[var_name] = var_val
notify.update(cls.notify[var_name])
if notify:
_LOGGER.debug("state.update(%s, %s)", new_vars, func_args)
for queue, var_names in notify.items():
await queue.put(["state", [cls.notify_var_get(var_names, new_vars), func_args.copy()]])
@classmethod
def notify_var_get(cls, var_names, new_vars):
"""Add values of var_names to new_vars, or default to None."""
notify_vars = new_vars.copy()
for var_name in var_names if var_names is not None else []:
if var_name in notify_vars:
continue
parts = var_name.split(".")
if var_name in cls.notify_var_last:
notify_vars[var_name] = cls.notify_var_last[var_name]
elif len(parts) == 3 and f"{parts[0]}.{parts[1]}" in cls.notify_var_last:
notify_vars[var_name] = getattr(
cls.notify_var_last[f"{parts[0]}.{parts[1]}"], parts[2], None
)
elif len(parts) == 4 and parts[2] == "old" and f"{parts[0]}.{parts[1]}.old" in notify_vars:
notify_vars[var_name] = getattr(notify_vars[f"{parts[0]}.{parts[1]}.old"], parts[3], None)
elif 1 <= var_name.count(".") <= 3 and not cls.exist(var_name):
notify_vars[var_name] = None
return notify_vars
@classmethod
def set(cls, var_name, value=None, new_attributes=None, **kwargs):
"""Set a state variable and optional attributes in hass."""
if var_name.count(".") != 1:
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
if isinstance(value, StateVal):
if new_attributes is None:
#
# value is a StateVal, so extract the attributes and value
#
new_attributes = value.__dict__.copy()
for discard in STATE_VIRTUAL_ATTRS:
new_attributes.pop(discard, None)
value = str(value)
state_value = None
if value is None or new_attributes is None:
state_value = cls.hass.states.get(var_name)
if value is None and state_value:
value = state_value.state
if new_attributes is None:
if state_value:
new_attributes = state_value.attributes.copy()
else:
new_attributes = {}
curr_task = asyncio.current_task()
if "context" in kwargs and isinstance(kwargs["context"], Context):
context = kwargs["context"]
del kwargs["context"]
else:
context = Function.task2context.get(curr_task, None)
if kwargs:
new_attributes = new_attributes.copy()
new_attributes.update(kwargs)
_LOGGER.debug("setting %s = %s, attr = %s", var_name, value, new_attributes)
cls.hass.states.async_set(var_name, value, new_attributes, context=context)
if var_name in cls.notify_var_last or var_name in cls.notify:
#
# immediately update a variable we are monitoring since it could take a while
# for the state changed event to propagate
#
cls.notify_var_last[var_name] = StateVal(cls.hass.states.get(var_name))
if var_name in cls.persisted_vars:
cls.persisted_vars[var_name].set_state(value)
cls.persisted_vars[var_name].set_attributes(new_attributes)
@classmethod
def setattr(cls, var_attr_name, value):
"""Set a state variable's attribute in hass."""
parts = var_attr_name.split(".")
if len(parts) != 3:
raise NameError(f"invalid name {var_attr_name} (should be 'domain.entity.attr')")
if not cls.exist(f"{parts[0]}.{parts[1]}"):
raise NameError(f"state {parts[0]}.{parts[1]} doesn't exist")
cls.set(f"{parts[0]}.{parts[1]}", **{parts[2]: value})
@classmethod
async def register_persist(cls, var_name):
"""Register pyscript state variable to be persisted with RestoreState."""
if var_name.startswith("pyscript.") and var_name not in cls.persisted_vars:
# this is a hack accessing hass internals; should re-implement using RestoreEntity
restore_data = cls.hass.data[DATA_RESTORE_STATE]
this_entity = PyscriptEntity()
this_entity.entity_id = var_name
cls.persisted_vars[var_name] = this_entity
try:
restore_data.async_restore_entity_added(this_entity)
except TypeError:
restore_data.async_restore_entity_added(var_name)
@classmethod
async def persist(cls, var_name, default_value=None, default_attributes=None):
"""Persist a pyscript domain state variable, and update with optional defaults."""
if var_name.count(".") != 1 or not var_name.startswith("pyscript."):
raise NameError(f"invalid name {var_name} (should be 'pyscript.entity')")
await cls.register_persist(var_name)
exists = cls.exist(var_name)
if not exists and default_value is not None:
cls.set(var_name, default_value, default_attributes)
elif exists and default_attributes is not None:
# Patch the attributes with new values if necessary
current = cls.hass.states.get(var_name)
new_attributes = {k: v for (k, v) in default_attributes.items() if k not in current.attributes}
cls.set(var_name, current.state, **new_attributes)
@classmethod
def exist(cls, var_name):
"""Check if a state variable value or attribute exists in hass."""
parts = var_name.split(".")
if len(parts) != 2 and len(parts) != 3:
return False
value = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
if value is None:
return False
if (
len(parts) == 2
or (parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]])
or parts[2] in value.attributes
or parts[2] in STATE_VIRTUAL_ATTRS
):
return True
return False
@classmethod
def get(cls, var_name):
"""Get a state variable value or attribute from hass."""
parts = var_name.split(".")
if len(parts) != 2 and len(parts) != 3:
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
state = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
if not state:
raise NameError(f"name '{parts[0]}.{parts[1]}' is not defined")
#
# simplest case is just the state value
#
state = StateVal(state)
if len(parts) == 2:
return state
#
# see if this is a service that has an entity_id parameter
#
if parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]]:
params = cls.service2args[parts[0]][parts[2]]
def service_call_factory(domain, service, entity_id, params):
async def service_call(*args, **kwargs):
curr_task = asyncio.current_task()
hass_args = {}
for keyword, typ, default in [
("context", [Context], Function.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default
kwargs["entity_id"] = entity_id
if len(args) == 1 and len(params) == 1:
#
# with just a single parameter and positional argument, create the keyword setting
#
[param_name] = params
kwargs[param_name] = args[0]
elif len(args) != 0:
raise TypeError(f"service {domain}.{service} takes no positional arguments")
# return await Function.hass_services_async_call(domain, service, kwargs, **hass_args)
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
return service_call
return service_call_factory(parts[0], parts[2], f"{parts[0]}.{parts[1]}", params)
#
# finally see if it is an attribute
#
try:
return getattr(state, parts[2])
except AttributeError:
raise AttributeError( # pylint: disable=raise-missing-from
f"state '{parts[0]}.{parts[1]}' has no attribute '{parts[2]}'"
)
@classmethod
def delete(cls, var_name, context=None):
"""Delete a state variable or attribute from hass."""
parts = var_name.split(".")
if not context:
context = Function.task2context.get(asyncio.current_task(), None)
context_arg = {"context": context} if context else {}
if len(parts) == 2:
if var_name in cls.notify_var_last or var_name in cls.notify:
#
# immediately update a variable we are monitoring since it could take a while
# for the state changed event to propagate
#
cls.notify_var_last[var_name] = None
if not cls.hass.states.async_remove(var_name, **context_arg):
raise NameError(f"name '{var_name}' not defined")
return
if len(parts) == 3:
var_name = f"{parts[0]}.{parts[1]}"
value = cls.hass.states.get(var_name)
if value is None:
raise NameError(f"state {var_name} doesn't exist")
new_attr = value.attributes.copy()
if parts[2] not in new_attr:
raise AttributeError(f"state '{var_name}' has no attribute '{parts[2]}'")
del new_attr[parts[2]]
cls.set(f"{var_name}", value.state, new_attributes=new_attr, **context_arg)
return
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
@classmethod
def getattr(cls, var_name):
"""Return a dict of attributes for a state variable."""
if isinstance(var_name, StateVal):
attrs = var_name.__dict__.copy()
for discard in STATE_VIRTUAL_ATTRS:
attrs.pop(discard, None)
return attrs
if var_name.count(".") != 1:
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
value = cls.hass.states.get(var_name)
if not value:
return None
return value.attributes.copy()
@classmethod
def get_attr(cls, var_name):
"""Return a dict of attributes for a state variable - deprecated."""
_LOGGER.warning("state.get_attr() is deprecated: use state.getattr() instead")
return cls.getattr(var_name)
@classmethod
def completions(cls, root):
"""Return possible completions of state variables."""
words = set()
parts = root.split(".")
num_period = len(parts) - 1
if num_period == 2:
#
# complete state attributes
#
last_period = root.rfind(".")
name = root[0:last_period]
value = cls.hass.states.get(name)
if value:
attr_root = root[last_period + 1 :]
attrs = set(value.attributes.keys()).union(STATE_VIRTUAL_ATTRS)
if parts[0] in cls.service2args:
attrs.update(set(cls.service2args[parts[0]].keys()))
for attr_name in attrs:
if attr_name.lower().startswith(attr_root):
words.add(f"{name}.{attr_name}")
elif num_period < 2:
#
# complete among all state names
#
for name in cls.hass.states.async_all():
if name.entity_id.lower().startswith(root):
words.add(name.entity_id)
return words
@classmethod
async def names(cls, domain=None):
"""Implement names, which returns all entity_ids."""
return cls.hass.states.async_entity_ids(domain)
@classmethod
def register_functions(cls):
"""Register state functions and config variable."""
functions = {
"state.get": cls.get,
"state.set": cls.set,
"state.setattr": cls.setattr,
"state.names": cls.names,
"state.getattr": cls.getattr,
"state.get_attr": cls.get_attr, # deprecated form; to be removed
"state.persist": cls.persist,
"state.delete": cls.delete,
"pyscript.config": cls.pyscript_config,
}
Function.register(functions)
@classmethod
def set_pyscript_config(cls, config):
"""Set pyscript yaml config."""
#
# have to update inplace, since dest is already used as value
#
cls.pyscript_config.clear()
for name, value in config.items():
cls.pyscript_config[name] = value

View File

@@ -0,0 +1,38 @@
{
"config": {
"step": {
"user": {
"title": "pyscript",
"description": "Once you have created an entry, refer to the [docs](https://hacs-pyscript.readthedocs.io/en/latest/) to learn how to create scripts and functions.",
"data": {
"allow_all_imports": "Allow All Imports?",
"hass_is_global": "Access hass as a global variable?"
}
}
},
"abort": {
"already_configured": "Already configured.",
"single_instance_allowed": "Already configured. Only a single configuration possible.",
"updated_entry": "This entry has already been setup but the configuration has been updated."
}
},
"options": {
"step": {
"init": {
"title": "Update pyscript configuration",
"data": {
"allow_all_imports": "Allow All Imports?",
"hass_is_global": "Access hass as a global variable?"
}
},
"no_ui_configuration_allowed": {
"title": "No UI configuration allowed",
"description": "This entry was created via `configuration.yaml`, so all configuration parameters must be updated there. The [`pyscript.reload`](developer-tools/service) service will allow you to apply the changes you make to `configuration.yaml` without restarting your Home Assistant instance."
},
"no_update": {
"title": "No update needed",
"description": "There is nothing to update."
}
}
}
}

View File

@@ -0,0 +1,38 @@
{
"config": {
"step": {
"user": {
"title": "pyscript",
"description": "Wenn Sie einen Eintrag angelegt haben, können Sie sich die [Doku (Englisch)](https://hacs-pyscript.readthedocs.io/en/latest/) ansehen, um zu lernen wie Sie Scripts und Funktionen erstellen können.",
"data": {
"allow_all_imports": "Alle Importe erlauben?",
"hass_is_global": "Home Assistant als globale Variable verwenden?"
}
}
},
"abort": {
"already_configured": "Bereits konfiguriert.",
"single_instance_allowed": "Bereits konfiguriert. Es ist nur eine Konfiguration gleichzeitig möglich",
"updated_entry": "Der Eintrag wurde bereits erstellt, aber die Konfiguration wurde aktualisiert."
}
},
"options": {
"step": {
"init": {
"title": "Pyscript configuration aktualisieren",
"data": {
"allow_all_imports": "Alle Importe erlauben??",
"hass_is_global": "Home Assistant als globale Variable verwenden?"
}
},
"no_ui_configuration_allowed": {
"title": "Die Konfiguartion der graphischen Nutzeroberfläche ist deaktiviert",
"description": "Der Eintrag wurde über die Datei `configuration.yaml` erstellt. Alle Konfigurationsparameter müssen desshalb dort eingestellt werden. Der [`pyscript.reload`](developer-tools/service) Service übernimmt alle Änderungen aus `configuration.yaml`, ohne dass Home Assistant neu gestartet werden muss."
},
"no_update": {
"title": "Keine Aktualisierung notwendig",
"description": "Es gibt nichts zu aktualisieren."
}
}
}
}

View File

@@ -0,0 +1,38 @@
{
"config": {
"step": {
"user": {
"title": "pyscript",
"description": "Once you have created an entry, refer to the [docs](https://hacs-pyscript.readthedocs.io/en/latest/) to learn how to create scripts and functions.",
"data": {
"allow_all_imports": "Allow All Imports?",
"hass_is_global": "Access hass as a global variable?"
}
}
},
"abort": {
"already_configured": "Already configured.",
"single_instance_allowed": "Already configured. Only a single configuration possible.",
"updated_entry": "This entry has already been setup but the configuration has been updated."
}
},
"options": {
"step": {
"init": {
"title": "Update pyscript configuration",
"data": {
"allow_all_imports": "Allow All Imports?",
"hass_is_global": "Access hass as a global variable?"
}
},
"no_ui_configuration_allowed": {
"title": "No UI configuration allowed",
"description": "This entry was created via `configuration.yaml`, so all configuration parameters must be updated there. The [`pyscript.reload`](developer-tools/service) service will allow you to apply the changes you make to `configuration.yaml` without restarting your Home Assistant instance."
},
"no_update": {
"title": "No update needed",
"description": "There is nothing to update."
}
}
}
}

View File

@@ -0,0 +1,38 @@
{
"config": {
"step": {
"user": {
"title": "pyscript",
"description": "Akonáhle ste vytvorili položku, pozrite si [docs](https://hacs-pyscript.readthedocs.io/en/latest/) naučiť sa, ako vytvárať skripty a funkcie.",
"data": {
"allow_all_imports": "Povoliť všetky importy?",
"hass_is_global": "Prístup k globálnej premennej?"
}
}
},
"abort": {
"already_configured": "Už konfigurované.",
"single_instance_allowed": "Už nakonfigurované. Iba jedna možná konfigurácia.",
"updated_entry": "Táto položka už bola nastavená, ale konfigurácia bola aktualizovaná."
}
},
"options": {
"step": {
"init": {
"title": "Aktualizovať pyscript konfiguráciu",
"data": {
"allow_all_imports": "povoliť všetky importy?",
"hass_is_global": "Prístup k globálnej premennej?"
}
},
"no_ui_configuration_allowed": {
"title": "Nie je povolená konfigurácia používateľského rozhrania",
"description": "Tento záznam bol vytvorený cez `configuration.yaml`, Takže všetky konfiguračné parametre sa musia aktualizovať. [`pyscript.reload`](developer-tools/service) Služba vám umožní uplatniť zmeny, ktoré vykonáte `configuration.yaml` bez reštartovania inštancie Home Assistant."
},
"no_update": {
"title": "Nie je potrebná aktualizácia",
"description": "Nie je nič na aktualizáciu."
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,95 @@
"""Handles webhooks and notification."""
import logging
from aiohttp import hdrs
from homeassistant.components import webhook
from .const import LOGGER_PATH
_LOGGER = logging.getLogger(LOGGER_PATH + ".webhook")
class Webhook:
"""Define webhook functions."""
#
# Global hass instance
#
hass = None
#
# notify message queues by webhook type
#
notify = {}
notify_remove = {}
def __init__(self):
"""Warn on Webhook instantiation."""
_LOGGER.error("Webhook class is not meant to be instantiated")
@classmethod
def init(cls, hass):
"""Initialize Webhook."""
cls.hass = hass
@classmethod
async def webhook_handler(cls, hass, webhook_id, request):
"""Listen callback for given webhook which updates any notifications."""
func_args = {
"trigger_type": "webhook",
"webhook_id": webhook_id,
}
if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""):
func_args["payload"] = await request.json()
else:
# Could potentially return multiples of a key - only take the first
payload_multidict = await request.post()
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}
await cls.update(webhook_id, func_args)
@classmethod
def notify_add(cls, webhook_id, local_only, methods, queue):
"""Register to notify for webhooks of given type to be sent to queue."""
if webhook_id not in cls.notify:
cls.notify[webhook_id] = set()
_LOGGER.debug("webhook.notify_add(%s) -> adding webhook listener", webhook_id)
webhook.async_register(
cls.hass,
"pyscript", # DOMAIN
"pyscript", # NAME
webhook_id,
cls.webhook_handler,
local_only=local_only,
allowed_methods=methods,
)
cls.notify_remove[webhook_id] = lambda: webhook.async_unregister(cls.hass, webhook_id)
cls.notify[webhook_id].add(queue)
@classmethod
def notify_del(cls, webhook_id, queue):
"""Unregister to notify for webhooks of given type for given queue."""
if webhook_id not in cls.notify or queue not in cls.notify[webhook_id]:
return
cls.notify[webhook_id].discard(queue)
if len(cls.notify[webhook_id]) == 0:
cls.notify_remove[webhook_id]()
_LOGGER.debug("webhook.notify_del(%s) -> removing webhook listener", webhook_id)
del cls.notify[webhook_id]
del cls.notify_remove[webhook_id]
@classmethod
async def update(cls, webhook_id, func_args):
"""Deliver all notifications for an webhook of the given type."""
_LOGGER.debug("webhook.update(%s, %s)", webhook_id, func_args)
if webhook_id in cls.notify:
for queue in cls.notify[webhook_id]:
await queue.put(["webhook", func_args.copy()])