Home Assistant Git Exporter
This commit is contained in:
682
config/custom_components/pyscript/__init__.py
Normal file
682
config/custom_components/pyscript/__init__.py
Normal file
@@ -0,0 +1,682 @@
|
||||
"""Component to allow running Python scripts."""
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Set, Union
|
||||
|
||||
import voluptuous as vol
|
||||
from watchdog.events import DirModifiedEvent, FileSystemEvent, FileSystemEventHandler
|
||||
import watchdog.observers
|
||||
|
||||
from homeassistant.config import async_hass_config_yaml
|
||||
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_STARTED,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_STATE_CHANGED,
|
||||
SERVICE_RELOAD,
|
||||
)
|
||||
from homeassistant.core import Config, Event as HAEvent, HomeAssistant, ServiceCall
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from .const import (
|
||||
CONF_ALLOW_ALL_IMPORTS,
|
||||
CONF_HASS_IS_GLOBAL,
|
||||
CONFIG_ENTRY,
|
||||
CONFIG_ENTRY_OLD,
|
||||
DOMAIN,
|
||||
FOLDER,
|
||||
LOGGER_PATH,
|
||||
REQUIREMENTS_FILE,
|
||||
SERVICE_JUPYTER_KERNEL_START,
|
||||
UNSUB_LISTENERS,
|
||||
WATCHDOG_TASK,
|
||||
)
|
||||
from .eval import AstEval
|
||||
from .event import Event
|
||||
from .function import Function
|
||||
from .global_ctx import GlobalContext, GlobalContextMgr
|
||||
from .jupyter_kernel import Kernel
|
||||
from .mqtt import Mqtt
|
||||
from .requirements import install_requirements
|
||||
from .state import State, StateVal
|
||||
from .trigger import TrigTime
|
||||
from .webhook import Webhook
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH)
|
||||
|
||||
PYSCRIPT_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): cv.boolean,
|
||||
vol.Optional(CONF_HASS_IS_GLOBAL, default=False): cv.boolean,
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({DOMAIN: PYSCRIPT_SCHEMA}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: Config) -> bool:
|
||||
"""Component setup, run import config flow for each entry in config."""
|
||||
await restore_state(hass)
|
||||
if DOMAIN in config:
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_IMPORT}, data=config[DOMAIN]
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def restore_state(hass: HomeAssistant) -> None:
|
||||
"""Restores the persisted pyscript state."""
|
||||
# this is a hack accessing hass internals; should re-implement using RestoreEntity
|
||||
restore_data = hass.data[DATA_RESTORE_STATE]
|
||||
for entity_id, value in restore_data.last_states.items():
|
||||
if entity_id.startswith("pyscript."):
|
||||
last_state = value.state
|
||||
hass.states.async_set(entity_id, last_state.state, last_state.attributes)
|
||||
|
||||
|
||||
async def update_yaml_config(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||
"""Update the yaml config."""
|
||||
try:
|
||||
conf = await async_hass_config_yaml(hass)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error(err)
|
||||
return
|
||||
|
||||
config = PYSCRIPT_SCHEMA(conf.get(DOMAIN, {}))
|
||||
|
||||
#
|
||||
# If data in config doesn't match config entry, trigger a config import
|
||||
# so that the config entry can get updated
|
||||
#
|
||||
if config != config_entry.data:
|
||||
await hass.config_entries.flow.async_init(DOMAIN, context={"source": SOURCE_IMPORT}, data=config)
|
||||
|
||||
#
|
||||
# if hass_is_global or allow_all_imports have changed, we need to reload all scripts
|
||||
# since they affect all scripts
|
||||
#
|
||||
config_save = {
|
||||
param: config_entry.data.get(param, False) for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS]
|
||||
}
|
||||
if DOMAIN not in hass.data:
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
if CONFIG_ENTRY_OLD in hass.data[DOMAIN]:
|
||||
old_entry = hass.data[DOMAIN][CONFIG_ENTRY_OLD]
|
||||
hass.data[DOMAIN][CONFIG_ENTRY_OLD] = config_save
|
||||
for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS]:
|
||||
if old_entry.get(param, False) != config_entry.data.get(param, False):
|
||||
return True
|
||||
hass.data[DOMAIN][CONFIG_ENTRY_OLD] = config_save
|
||||
return False
|
||||
|
||||
|
||||
def start_global_contexts(global_ctx_only: str = None) -> None:
|
||||
"""Start all the file and apps global contexts."""
|
||||
start_list = []
|
||||
for global_ctx_name, global_ctx in GlobalContextMgr.items():
|
||||
idx = global_ctx_name.find(".")
|
||||
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "scripts"}:
|
||||
continue
|
||||
if global_ctx_only is not None and global_ctx_only != "*":
|
||||
if global_ctx_name != global_ctx_only and not global_ctx_name.startswith(global_ctx_only + "."):
|
||||
continue
|
||||
global_ctx.set_auto_start(True)
|
||||
start_list.append(global_ctx)
|
||||
for global_ctx in start_list:
|
||||
global_ctx.start()
|
||||
|
||||
|
||||
async def watchdog_start(
|
||||
hass: HomeAssistant, pyscript_folder: str, reload_scripts_handler: Callable[[None], None]
|
||||
) -> None:
|
||||
"""Start watchdog thread to look for changed files in pyscript_folder."""
|
||||
if WATCHDOG_TASK in hass.data[DOMAIN]:
|
||||
return
|
||||
|
||||
class WatchDogHandler(FileSystemEventHandler):
|
||||
"""Class for handling watchdog events."""
|
||||
|
||||
def __init__(
|
||||
self, watchdog_q: asyncio.Queue, observer: watchdog.observers.Observer, path: str
|
||||
) -> None:
|
||||
self.watchdog_q = watchdog_q
|
||||
self._observer = observer
|
||||
self._observer.schedule(self, path, recursive=True)
|
||||
if not hass.is_running:
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_STARTED, self.startup)
|
||||
else:
|
||||
self.startup(None)
|
||||
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown)
|
||||
_LOGGER.debug("watchdog init path=%s", path)
|
||||
|
||||
def startup(self, event: Event | None) -> None:
|
||||
"""Start the observer."""
|
||||
_LOGGER.debug("watchdog startup")
|
||||
self._observer.start()
|
||||
|
||||
def shutdown(self, event: Event | None) -> None:
|
||||
"""Stop the observer."""
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
_LOGGER.debug("watchdog shutdown")
|
||||
|
||||
def process(self, event: FileSystemEvent) -> None:
|
||||
"""Send watchdog events to main loop task."""
|
||||
_LOGGER.debug("watchdog process(%s)", event)
|
||||
hass.loop.call_soon_threadsafe(self.watchdog_q.put_nowait, event)
|
||||
|
||||
def on_modified(self, event: FileSystemEvent) -> None:
|
||||
"""File modified."""
|
||||
self.process(event)
|
||||
|
||||
def on_moved(self, event: FileSystemEvent) -> None:
|
||||
"""File moved."""
|
||||
self.process(event)
|
||||
|
||||
def on_created(self, event: FileSystemEvent) -> None:
|
||||
"""File created."""
|
||||
self.process(event)
|
||||
|
||||
def on_deleted(self, event: FileSystemEvent) -> None:
|
||||
"""File deleted."""
|
||||
self.process(event)
|
||||
|
||||
async def task_watchdog(watchdog_q: asyncio.Queue) -> None:
|
||||
def check_event(event, do_reload: bool) -> bool:
|
||||
"""Check if event should trigger a reload."""
|
||||
if event.is_directory:
|
||||
# don't reload if it's just a directory modified
|
||||
if isinstance(event, DirModifiedEvent):
|
||||
return do_reload
|
||||
return True
|
||||
# only reload if it's a script, yaml, or requirements.txt file
|
||||
for valid_suffix in [".py", ".yaml", "/" + REQUIREMENTS_FILE]:
|
||||
if event.src_path.endswith(valid_suffix):
|
||||
return True
|
||||
return do_reload
|
||||
|
||||
while True:
|
||||
try:
|
||||
#
|
||||
# since some file/dir changes create multiple events, we consume all
|
||||
# events in a small window; first wait indefinitely for next event
|
||||
#
|
||||
do_reload = check_event(await watchdog_q.get(), False)
|
||||
#
|
||||
# now consume all additional events with 50ms timeout or 500ms elapsed
|
||||
#
|
||||
t_start = time.monotonic()
|
||||
while time.monotonic() - t_start < 0.5:
|
||||
try:
|
||||
do_reload = check_event(
|
||||
await asyncio.wait_for(watchdog_q.get(), timeout=0.05), do_reload
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if do_reload:
|
||||
await reload_scripts_handler(None)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
_LOGGER.error("task_watchdog: got exception %s", traceback.format_exc(-1))
|
||||
|
||||
watchdog_q = asyncio.Queue(0)
|
||||
observer = watchdog.observers.Observer()
|
||||
if observer is not None:
|
||||
# don't run watchdog when we are testing (Observer() patches to None)
|
||||
hass.data[DOMAIN][WATCHDOG_TASK] = Function.create_task(task_watchdog(watchdog_q))
|
||||
|
||||
await hass.async_add_executor_job(WatchDogHandler, watchdog_q, observer, pyscript_folder)
|
||||
_LOGGER.debug("watchdog started job and task folder=%s", pyscript_folder)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||
"""Initialize the pyscript config entry."""
|
||||
global_ctx_only = None
|
||||
doing_reload = False
|
||||
if Function.hass:
|
||||
#
|
||||
# reload yaml if this isn't the first time (ie, on reload)
|
||||
#
|
||||
doing_reload = True
|
||||
if await update_yaml_config(hass, config_entry):
|
||||
global_ctx_only = "*"
|
||||
|
||||
Function.init(hass)
|
||||
Event.init(hass)
|
||||
Mqtt.init(hass)
|
||||
TrigTime.init(hass)
|
||||
State.init(hass)
|
||||
Webhook.init(hass)
|
||||
State.register_functions()
|
||||
GlobalContextMgr.init()
|
||||
|
||||
pyscript_folder = hass.config.path(FOLDER)
|
||||
if not await hass.async_add_executor_job(os.path.isdir, pyscript_folder):
|
||||
_LOGGER.debug("Folder %s not found in configuration folder, creating it", FOLDER)
|
||||
await hass.async_add_executor_job(os.makedirs, pyscript_folder)
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
hass.data[DOMAIN][CONFIG_ENTRY] = config_entry
|
||||
hass.data[DOMAIN][UNSUB_LISTENERS] = []
|
||||
|
||||
State.set_pyscript_config(config_entry.data)
|
||||
|
||||
await install_requirements(hass, config_entry, pyscript_folder)
|
||||
await load_scripts(hass, config_entry.data, global_ctx_only=global_ctx_only)
|
||||
|
||||
async def reload_scripts_handler(call: ServiceCall) -> None:
|
||||
"""Handle reload service calls."""
|
||||
_LOGGER.debug("reload: yaml, reloading scripts, and restarting")
|
||||
|
||||
global_ctx_only = call.data.get("global_ctx", None) if call else None
|
||||
|
||||
if await update_yaml_config(hass, config_entry):
|
||||
global_ctx_only = "*"
|
||||
State.set_pyscript_config(config_entry.data)
|
||||
|
||||
await State.get_service_params()
|
||||
|
||||
await install_requirements(hass, config_entry, pyscript_folder)
|
||||
await load_scripts(hass, config_entry.data, global_ctx_only=global_ctx_only)
|
||||
|
||||
start_global_contexts(global_ctx_only=global_ctx_only)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_scripts_handler)
|
||||
|
||||
async def jupyter_kernel_start(call: ServiceCall) -> None:
|
||||
"""Handle Jupyter kernel start call."""
|
||||
_LOGGER.debug("service call to jupyter_kernel_start: %s", call.data)
|
||||
|
||||
global_ctx_name = GlobalContextMgr.new_name("jupyter_")
|
||||
global_ctx = GlobalContext(
|
||||
global_ctx_name, global_sym_table={"__name__": global_ctx_name}, manager=GlobalContextMgr
|
||||
)
|
||||
global_ctx.set_auto_start(True)
|
||||
GlobalContextMgr.set(global_ctx_name, global_ctx)
|
||||
|
||||
ast_ctx = AstEval(global_ctx_name, global_ctx)
|
||||
Function.install_ast_funcs(ast_ctx)
|
||||
kernel = Kernel(call.data, ast_ctx, global_ctx, global_ctx_name)
|
||||
await kernel.session_start()
|
||||
hass.states.async_set(call.data["state_var"], json.dumps(kernel.get_ports()))
|
||||
|
||||
def state_var_remove():
|
||||
hass.states.async_remove(call.data["state_var"])
|
||||
|
||||
kernel.set_session_cleanup_callback(state_var_remove)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_JUPYTER_KERNEL_START, jupyter_kernel_start)
|
||||
|
||||
async def state_changed(event: HAEvent) -> None:
|
||||
var_name = event.data["entity_id"]
|
||||
if event.data.get("new_state", None):
|
||||
new_val = StateVal(event.data["new_state"])
|
||||
else:
|
||||
# state variable has been deleted
|
||||
new_val = None
|
||||
|
||||
if event.data.get("old_state", None):
|
||||
old_val = StateVal(event.data["old_state"])
|
||||
else:
|
||||
# no previous state
|
||||
old_val = None
|
||||
|
||||
new_vars = {var_name: new_val, f"{var_name}.old": old_val}
|
||||
func_args = {
|
||||
"trigger_type": "state",
|
||||
"var_name": var_name,
|
||||
"value": new_val,
|
||||
"old_value": old_val,
|
||||
"context": event.context,
|
||||
}
|
||||
await State.update(new_vars, func_args)
|
||||
|
||||
async def hass_started(event: HAEvent) -> None:
|
||||
_LOGGER.debug("adding state changed listener and starting global contexts")
|
||||
await State.get_service_params()
|
||||
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed))
|
||||
start_global_contexts()
|
||||
|
||||
async def hass_stop(event: HAEvent) -> None:
|
||||
if WATCHDOG_TASK in hass.data[DOMAIN]:
|
||||
Function.reaper_cancel(hass.data[DOMAIN][WATCHDOG_TASK])
|
||||
del hass.data[DOMAIN][WATCHDOG_TASK]
|
||||
|
||||
_LOGGER.debug("stopping global contexts")
|
||||
await unload_scripts(unload_all=True)
|
||||
# sync with waiter, and then tell waiter and reaper tasks to exit
|
||||
await Function.waiter_sync()
|
||||
await Function.waiter_stop()
|
||||
await Function.reaper_stop()
|
||||
|
||||
# Store callbacks to event listeners so we can unsubscribe on unload
|
||||
hass.data[DOMAIN][UNSUB_LISTENERS].append(
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_STARTED, hass_started)
|
||||
)
|
||||
hass.data[DOMAIN][UNSUB_LISTENERS].append(hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, hass_stop))
|
||||
|
||||
await watchdog_start(hass, pyscript_folder, reload_scripts_handler)
|
||||
|
||||
if doing_reload:
|
||||
start_global_contexts(global_ctx_only="*")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
_LOGGER.info("Unloading all scripts")
|
||||
await unload_scripts(unload_all=True)
|
||||
|
||||
for unsub_listener in hass.data[DOMAIN][UNSUB_LISTENERS]:
|
||||
unsub_listener()
|
||||
hass.data[DOMAIN][UNSUB_LISTENERS] = []
|
||||
|
||||
# sync with waiter, and then tell waiter and reaper tasks to exit
|
||||
await Function.waiter_sync()
|
||||
await Function.waiter_stop()
|
||||
await Function.reaper_stop()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def unload_scripts(global_ctx_only: str = None, unload_all: bool = False) -> None:
|
||||
"""Unload all scripts from GlobalContextMgr with given name prefixes."""
|
||||
ctx_delete = {}
|
||||
for global_ctx_name, global_ctx in GlobalContextMgr.items():
|
||||
if not unload_all:
|
||||
idx = global_ctx_name.find(".")
|
||||
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules", "scripts"}:
|
||||
continue
|
||||
if global_ctx_only is not None:
|
||||
if global_ctx_name != global_ctx_only and not global_ctx_name.startswith(global_ctx_only + "."):
|
||||
continue
|
||||
global_ctx.stop()
|
||||
ctx_delete[global_ctx_name] = global_ctx
|
||||
for global_ctx_name, global_ctx in ctx_delete.items():
|
||||
GlobalContextMgr.delete(global_ctx_name)
|
||||
await Function.waiter_sync()
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def load_scripts(hass: HomeAssistant, config_data: Dict[str, Any], global_ctx_only: str = None):
|
||||
"""Load all python scripts in FOLDER."""
|
||||
|
||||
class SourceFile:
|
||||
"""Class for information about a source file."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_ctx_name=None,
|
||||
file_path=None,
|
||||
rel_path=None,
|
||||
rel_import_path=None,
|
||||
fq_mod_name=None,
|
||||
check_config=None,
|
||||
app_config=None,
|
||||
source=None,
|
||||
mtime=None,
|
||||
autoload=None,
|
||||
):
|
||||
self.global_ctx_name = global_ctx_name
|
||||
self.file_path = file_path
|
||||
self.rel_path = rel_path
|
||||
self.rel_import_path = rel_import_path
|
||||
self.fq_mod_name = fq_mod_name
|
||||
self.check_config = check_config
|
||||
self.app_config = app_config
|
||||
self.source = source
|
||||
self.mtime = mtime
|
||||
self.autoload = autoload
|
||||
self.force = False
|
||||
|
||||
pyscript_dir = hass.config.path(FOLDER)
|
||||
|
||||
def glob_read_files(
|
||||
load_paths: List[Set[Union[str, bool]]], apps_config: Dict[str, Any]
|
||||
) -> Dict[str, SourceFile]:
|
||||
"""Expand globs and read all the source files."""
|
||||
ctx2source = {}
|
||||
for path, match, check_config, autoload in load_paths:
|
||||
for this_path in sorted(glob.glob(os.path.join(pyscript_dir, path, match), recursive=True)):
|
||||
rel_import_path = None
|
||||
rel_path = this_path
|
||||
if rel_path.startswith(pyscript_dir):
|
||||
rel_path = rel_path[len(pyscript_dir) :]
|
||||
if rel_path.startswith("/"):
|
||||
rel_path = rel_path[1:]
|
||||
if rel_path[0] == "#" or rel_path.find("/#") >= 0:
|
||||
# skip "commented" files and directories
|
||||
continue
|
||||
mod_name = rel_path[0:-3]
|
||||
if mod_name.endswith("/__init__"):
|
||||
rel_import_path = mod_name
|
||||
mod_name = mod_name[: -len("/__init__")]
|
||||
mod_name = mod_name.replace("/", ".")
|
||||
if path == "":
|
||||
global_ctx_name = f"file.{mod_name}"
|
||||
fq_mod_name = mod_name
|
||||
else:
|
||||
fq_mod_name = global_ctx_name = mod_name
|
||||
i = fq_mod_name.find(".")
|
||||
if i >= 0:
|
||||
fq_mod_name = fq_mod_name[i + 1 :]
|
||||
app_config = None
|
||||
|
||||
if global_ctx_name in ctx2source:
|
||||
# the globs result in apps/APP/__init__.py matching twice, so skip the 2nd time
|
||||
# also skip apps/APP.py if apps/APP/__init__.py is present
|
||||
continue
|
||||
|
||||
if check_config:
|
||||
app_name = fq_mod_name
|
||||
i = app_name.find(".")
|
||||
if i >= 0:
|
||||
app_name = app_name[0:i]
|
||||
if not isinstance(apps_config, dict) or app_name not in apps_config:
|
||||
_LOGGER.debug(
|
||||
"load_scripts: skipping %s (app_name=%s) because config not present",
|
||||
this_path,
|
||||
app_name,
|
||||
)
|
||||
continue
|
||||
app_config = apps_config[app_name]
|
||||
|
||||
try:
|
||||
with open(this_path, encoding="utf-8") as file_desc:
|
||||
source = file_desc.read()
|
||||
mtime = os.path.getmtime(this_path)
|
||||
except Exception as exc:
|
||||
_LOGGER.error("load_scripts: skipping %s due to exception %s", this_path, exc)
|
||||
continue
|
||||
|
||||
ctx2source[global_ctx_name] = SourceFile(
|
||||
global_ctx_name=global_ctx_name,
|
||||
file_path=this_path,
|
||||
rel_path=rel_path,
|
||||
rel_import_path=rel_import_path,
|
||||
fq_mod_name=fq_mod_name,
|
||||
check_config=check_config,
|
||||
app_config=app_config,
|
||||
source=source,
|
||||
mtime=mtime,
|
||||
autoload=autoload,
|
||||
)
|
||||
|
||||
return ctx2source
|
||||
|
||||
load_paths = [
|
||||
# path, glob, check_config, autoload
|
||||
["", "*.py", False, True],
|
||||
["apps", "*/__init__.py", True, True],
|
||||
["apps", "*.py", True, True],
|
||||
["apps", "*/**/*.py", False, False],
|
||||
["modules", "*/__init__.py", False, False],
|
||||
["modules", "*.py", False, False],
|
||||
["modules", "*/**/*.py", False, False],
|
||||
["scripts", "**/*.py", False, True],
|
||||
]
|
||||
|
||||
#
|
||||
# get current global contexts
|
||||
#
|
||||
ctx_all = {}
|
||||
for global_ctx_name, global_ctx in GlobalContextMgr.items():
|
||||
idx = global_ctx_name.find(".")
|
||||
if idx < 0 or global_ctx_name[0:idx] not in {"file", "apps", "modules", "scripts"}:
|
||||
continue
|
||||
ctx_all[global_ctx_name] = global_ctx
|
||||
|
||||
#
|
||||
# get list and contents of all source files
|
||||
#
|
||||
apps_config = config_data.get("apps", None)
|
||||
ctx2files = await hass.async_add_executor_job(glob_read_files, load_paths, apps_config)
|
||||
|
||||
#
|
||||
# figure out what to reload based on global_ctx_only and what's changed
|
||||
#
|
||||
ctx_delete = set()
|
||||
if global_ctx_only is not None and global_ctx_only != "*":
|
||||
if global_ctx_only not in ctx_all and global_ctx_only not in ctx2files:
|
||||
_LOGGER.error("pyscript.reload: no global context '%s' to reload", global_ctx_only)
|
||||
return
|
||||
if global_ctx_only not in ctx2files:
|
||||
ctx_delete.add(global_ctx_only)
|
||||
else:
|
||||
ctx2files[global_ctx_only].force = True
|
||||
elif global_ctx_only == "*":
|
||||
ctx_delete = set(ctx_all.keys())
|
||||
for _, src_info in ctx2files.items():
|
||||
src_info.force = True
|
||||
else:
|
||||
# delete all global_ctxs that aren't present in current files
|
||||
for global_ctx_name, global_ctx in ctx_all.items():
|
||||
if global_ctx_name not in ctx2files:
|
||||
ctx_delete.add(global_ctx_name)
|
||||
# delete all global_ctxs that have changeed source or mtime
|
||||
for global_ctx_name, src_info in ctx2files.items():
|
||||
if global_ctx_name in ctx_all:
|
||||
ctx = ctx_all[global_ctx_name]
|
||||
if (
|
||||
src_info.source != ctx.get_source()
|
||||
or src_info.app_config != ctx.get_app_config()
|
||||
or src_info.mtime != ctx.get_mtime()
|
||||
):
|
||||
ctx_delete.add(global_ctx_name)
|
||||
src_info.force = True
|
||||
else:
|
||||
src_info.force = src_info.autoload
|
||||
|
||||
#
|
||||
# force reload if any files uses a module that is bring reloaded by
|
||||
# recursively following each import; first find which modules are
|
||||
# being reloaded
|
||||
#
|
||||
will_reload = set()
|
||||
for global_ctx_name, src_info in ctx2files.items():
|
||||
if global_ctx_name.startswith("modules.") and (global_ctx_name in ctx_delete or src_info.force):
|
||||
parts = global_ctx_name.split(".")
|
||||
root = f"{parts[0]}.{parts[1]}"
|
||||
will_reload.add(root)
|
||||
|
||||
if len(will_reload) > 0:
|
||||
|
||||
def import_recurse(ctx_name, visited, ctx2imports):
|
||||
if ctx_name in visited or ctx_name in ctx2imports:
|
||||
return ctx2imports.get(ctx_name, set())
|
||||
visited.add(ctx_name)
|
||||
ctx = GlobalContextMgr.get(ctx_name)
|
||||
if not ctx:
|
||||
return set()
|
||||
ctx2imports[ctx_name] = set()
|
||||
for imp_name in ctx.get_imports():
|
||||
ctx2imports[ctx_name].add(imp_name)
|
||||
ctx2imports[ctx_name].update(import_recurse(imp_name, visited, ctx2imports))
|
||||
return ctx2imports[ctx_name]
|
||||
|
||||
ctx2imports = {}
|
||||
for global_ctx_name, global_ctx in ctx_all.items():
|
||||
if global_ctx_name not in ctx2imports:
|
||||
visited = set()
|
||||
import_recurse(global_ctx_name, visited, ctx2imports)
|
||||
for mod_name in ctx2imports.get(global_ctx_name, set()):
|
||||
parts = mod_name.split(".")
|
||||
root = f"{parts[0]}.{parts[1]}"
|
||||
if root in will_reload:
|
||||
ctx_delete.add(global_ctx_name)
|
||||
if global_ctx_name in ctx2files:
|
||||
ctx2files[global_ctx_name].force = True
|
||||
|
||||
#
|
||||
# if any file in an app or module has changed, then reload just the top-level
|
||||
# __init__.py or module/app .py file, and delete everything else
|
||||
#
|
||||
done = set()
|
||||
for global_ctx_name, src_info in ctx2files.items():
|
||||
if not src_info.force:
|
||||
continue
|
||||
if not global_ctx_name.startswith("apps.") and not global_ctx_name.startswith("modules."):
|
||||
continue
|
||||
parts = global_ctx_name.split(".")
|
||||
root = f"{parts[0]}.{parts[1]}"
|
||||
if root in done:
|
||||
continue
|
||||
pkg_path = f"{parts[0]}/{parts[1]}/__init__.py"
|
||||
mod_path = f"{parts[0]}/{parts[1]}.py"
|
||||
for ctx_name, this_src_info in ctx2files.items():
|
||||
if ctx_name == root or ctx_name.startswith(f"{root}."):
|
||||
if this_src_info.rel_path in {pkg_path, mod_path}:
|
||||
this_src_info.force = True
|
||||
else:
|
||||
this_src_info.force = False
|
||||
ctx_delete.add(ctx_name)
|
||||
done.add(root)
|
||||
|
||||
#
|
||||
# delete contexts that are no longer needed or will be reloaded
|
||||
#
|
||||
for global_ctx_name in ctx_delete:
|
||||
if global_ctx_name in ctx_all:
|
||||
global_ctx = ctx_all[global_ctx_name]
|
||||
global_ctx.stop()
|
||||
if global_ctx_name not in ctx2files or not ctx2files[global_ctx_name].autoload:
|
||||
_LOGGER.info("Unloaded %s", global_ctx.get_file_path())
|
||||
GlobalContextMgr.delete(global_ctx_name)
|
||||
await Function.waiter_sync()
|
||||
|
||||
#
|
||||
# now load the requested files, and files that depend on loaded files
|
||||
#
|
||||
for global_ctx_name, src_info in sorted(ctx2files.items()):
|
||||
if not src_info.autoload or not src_info.force:
|
||||
continue
|
||||
global_ctx = GlobalContext(
|
||||
src_info.global_ctx_name,
|
||||
global_sym_table={"__name__": src_info.fq_mod_name},
|
||||
manager=GlobalContextMgr,
|
||||
rel_import_path=src_info.rel_import_path,
|
||||
app_config=src_info.app_config,
|
||||
source=src_info.source,
|
||||
mtime=src_info.mtime,
|
||||
)
|
||||
reload = src_info.global_ctx_name in ctx_delete
|
||||
await GlobalContextMgr.load_file(
|
||||
global_ctx, src_info.file_path, source=src_info.source, reload=reload
|
||||
)
|
||||
139
config/custom_components/pyscript/config_flow.py
Normal file
139
config/custom_components/pyscript/config_flow.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Config flow for pyscript."""
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, CONF_INSTALLED_PACKAGES, DOMAIN
|
||||
|
||||
CONF_BOOL_ALL = {CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL}
|
||||
|
||||
PYSCRIPT_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): bool,
|
||||
vol.Optional(CONF_HASS_IS_GLOBAL, default=False): bool,
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
class PyscriptOptionsConfigFlow(config_entries.OptionsFlow):
|
||||
"""Handle a pyscript options flow."""
|
||||
|
||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||
"""Initialize pyscript options flow."""
|
||||
self.config_entry = config_entry
|
||||
self._show_form = False
|
||||
|
||||
async def async_step_init(self, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Manage the pyscript options."""
|
||||
if self.config_entry.source == SOURCE_IMPORT:
|
||||
self._show_form = True
|
||||
return await self.async_step_no_ui_configuration_allowed()
|
||||
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Optional(name, default=self.config_entry.data.get(name, False)): bool
|
||||
for name in CONF_BOOL_ALL
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
)
|
||||
|
||||
if any(
|
||||
name not in self.config_entry.data or user_input[name] != self.config_entry.data[name]
|
||||
for name in CONF_BOOL_ALL
|
||||
):
|
||||
updated_data = self.config_entry.data.copy()
|
||||
updated_data.update(user_input)
|
||||
self.hass.config_entries.async_update_entry(entry=self.config_entry, data=updated_data)
|
||||
return self.async_create_entry(title="", data={})
|
||||
|
||||
self._show_form = True
|
||||
return await self.async_step_no_update()
|
||||
|
||||
async def async_step_no_ui_configuration_allowed(
|
||||
self, user_input: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Tell user no UI configuration is allowed."""
|
||||
if self._show_form:
|
||||
self._show_form = False
|
||||
return self.async_show_form(step_id="no_ui_configuration_allowed", data_schema=vol.Schema({}))
|
||||
|
||||
return self.async_create_entry(title="", data={})
|
||||
|
||||
async def async_step_no_update(self, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Tell user no update to process."""
|
||||
if self._show_form:
|
||||
self._show_form = False
|
||||
return self.async_show_form(step_id="no_update", data_schema=vol.Schema({}))
|
||||
|
||||
return self.async_create_entry(title="", data={})
|
||||
|
||||
|
||||
class PyscriptConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a pyscript config flow."""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry: ConfigEntry) -> PyscriptOptionsConfigFlow:
|
||||
"""Get the options flow for this handler."""
|
||||
return PyscriptOptionsConfigFlow(config_entry)
|
||||
|
||||
async def async_step_user(self, user_input: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Handle a flow initialized by the user."""
|
||||
if user_input is not None:
|
||||
if len(self.hass.config_entries.async_entries(DOMAIN)) > 0:
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
await self.async_set_unique_id(DOMAIN)
|
||||
return self.async_create_entry(title=DOMAIN, data=user_input)
|
||||
|
||||
return self.async_show_form(step_id="user", data_schema=PYSCRIPT_SCHEMA)
|
||||
|
||||
async def async_step_import(self, import_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Import a config entry from configuration.yaml."""
|
||||
# Convert OrderedDict to dict
|
||||
import_config = json.loads(json.dumps(import_config))
|
||||
|
||||
# Check if import config entry matches any existing config entries
|
||||
# so we can update it if necessary
|
||||
entries = self.hass.config_entries.async_entries(DOMAIN)
|
||||
if entries:
|
||||
entry = entries[0]
|
||||
updated_data = entry.data.copy()
|
||||
|
||||
# Update values for all keys, excluding `allow_all_imports` for entries
|
||||
# set up through the UI.
|
||||
for key, val in import_config.items():
|
||||
if entry.source == SOURCE_IMPORT or key not in CONF_BOOL_ALL:
|
||||
updated_data[key] = val
|
||||
|
||||
# Remove values for all keys in entry.data that are not in the imported config,
|
||||
# excluding `allow_all_imports` for entries set up through the UI.
|
||||
for key in entry.data:
|
||||
if (
|
||||
(entry.source == SOURCE_IMPORT or key not in CONF_BOOL_ALL)
|
||||
and key != CONF_INSTALLED_PACKAGES
|
||||
and key not in import_config
|
||||
):
|
||||
updated_data.pop(key)
|
||||
|
||||
# Update and reload entry if data needs to be updated
|
||||
if updated_data != entry.data:
|
||||
self.hass.config_entries.async_update_entry(entry=entry, data=updated_data)
|
||||
return self.async_abort(reason="updated_entry")
|
||||
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
return await self.async_step_user(user_input=import_config)
|
||||
63
config/custom_components/pyscript/const.py
Normal file
63
config/custom_components/pyscript/const.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Define pyscript-wide constants."""
|
||||
|
||||
#
|
||||
# 2023.7 supports service response; handle older versions by defaulting enum
|
||||
# Should eventually deprecate this and just use SupportsResponse import
|
||||
#
|
||||
try:
|
||||
from homeassistant.core import SupportsResponse
|
||||
|
||||
SERVICE_RESPONSE_NONE = SupportsResponse.NONE
|
||||
SERVICE_RESPONSE_OPTIONAL = SupportsResponse.OPTIONAL
|
||||
SERVICE_RESPONSE_ONLY = SupportsResponse.ONLY
|
||||
except ImportError:
|
||||
SERVICE_RESPONSE_NONE = None
|
||||
SERVICE_RESPONSE_OPTIONAL = None
|
||||
SERVICE_RESPONSE_ONLY = None
|
||||
|
||||
DOMAIN = "pyscript"
|
||||
|
||||
CONFIG_ENTRY = "config_entry"
|
||||
CONFIG_ENTRY_OLD = "config_entry_old"
|
||||
UNSUB_LISTENERS = "unsub_listeners"
|
||||
|
||||
FOLDER = "pyscript"
|
||||
|
||||
UNPINNED_VERSION = "_unpinned_version"
|
||||
|
||||
ATTR_INSTALLED_VERSION = "installed_version"
|
||||
ATTR_SOURCES = "sources"
|
||||
ATTR_VERSION = "version"
|
||||
|
||||
CONF_ALLOW_ALL_IMPORTS = "allow_all_imports"
|
||||
CONF_HASS_IS_GLOBAL = "hass_is_global"
|
||||
CONF_INSTALLED_PACKAGES = "_installed_packages"
|
||||
|
||||
SERVICE_JUPYTER_KERNEL_START = "jupyter_kernel_start"
|
||||
|
||||
LOGGER_PATH = "custom_components.pyscript"
|
||||
|
||||
REQUIREMENTS_FILE = "requirements.txt"
|
||||
REQUIREMENTS_PATHS = ("", "apps/*", "modules/*", "scripts/**")
|
||||
|
||||
WATCHDOG_TASK = "watch_dog_task"
|
||||
|
||||
ALLOWED_IMPORTS = {
|
||||
"black",
|
||||
"cmath",
|
||||
"datetime",
|
||||
"decimal",
|
||||
"fractions",
|
||||
"functools",
|
||||
"homeassistant.const",
|
||||
"isort",
|
||||
"json",
|
||||
"math",
|
||||
"number",
|
||||
"random",
|
||||
"re",
|
||||
"statistics",
|
||||
"string",
|
||||
"time",
|
||||
"voluptuous",
|
||||
}
|
||||
19
config/custom_components/pyscript/entity.py
Normal file
19
config/custom_components/pyscript/entity.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Entity Classes."""
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import StateType
|
||||
|
||||
|
||||
class PyscriptEntity(RestoreEntity):
|
||||
"""Generic Pyscript Entity."""
|
||||
|
||||
_attr_extra_state_attributes: dict
|
||||
_attr_state: StateType = STATE_UNKNOWN
|
||||
|
||||
def set_state(self, state):
|
||||
"""Set the state."""
|
||||
self._attr_state = state
|
||||
|
||||
def set_attributes(self, attributes):
|
||||
"""Set Attributes."""
|
||||
self._attr_extra_state_attributes = attributes
|
||||
2313
config/custom_components/pyscript/eval.py
Normal file
2313
config/custom_components/pyscript/eval.py
Normal file
File diff suppressed because it is too large
Load Diff
76
config/custom_components/pyscript/event.py
Normal file
76
config/custom_components/pyscript/event.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Handles event firing and notification."""
|
||||
|
||||
import logging
|
||||
|
||||
from .const import LOGGER_PATH
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".event")
|
||||
|
||||
|
||||
class Event:
|
||||
"""Define event functions."""
|
||||
|
||||
#
|
||||
# Global hass instance
|
||||
#
|
||||
hass = None
|
||||
|
||||
#
|
||||
# notify message queues by event type
|
||||
#
|
||||
notify = {}
|
||||
notify_remove = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Warn on Event instantiation."""
|
||||
_LOGGER.error("Event class is not meant to be instantiated")
|
||||
|
||||
@classmethod
|
||||
def init(cls, hass):
|
||||
"""Initialize Event."""
|
||||
|
||||
cls.hass = hass
|
||||
|
||||
@classmethod
|
||||
async def event_listener(cls, event):
|
||||
"""Listen callback for given event which updates any notifications."""
|
||||
|
||||
func_args = {
|
||||
"trigger_type": "event",
|
||||
"event_type": event.event_type,
|
||||
"context": event.context,
|
||||
}
|
||||
func_args.update(event.data)
|
||||
await cls.update(event.event_type, func_args)
|
||||
|
||||
@classmethod
|
||||
def notify_add(cls, event_type, queue):
|
||||
"""Register to notify for events of given type to be sent to queue."""
|
||||
|
||||
if event_type not in cls.notify:
|
||||
cls.notify[event_type] = set()
|
||||
_LOGGER.debug("event.notify_add(%s) -> adding event listener", event_type)
|
||||
cls.notify_remove[event_type] = cls.hass.bus.async_listen(event_type, cls.event_listener)
|
||||
cls.notify[event_type].add(queue)
|
||||
|
||||
@classmethod
|
||||
def notify_del(cls, event_type, queue):
|
||||
"""Unregister to notify for events of given type for given queue."""
|
||||
|
||||
if event_type not in cls.notify or queue not in cls.notify[event_type]:
|
||||
return
|
||||
cls.notify[event_type].discard(queue)
|
||||
if len(cls.notify[event_type]) == 0:
|
||||
cls.notify_remove[event_type]()
|
||||
_LOGGER.debug("event.notify_del(%s) -> removing event listener", event_type)
|
||||
del cls.notify[event_type]
|
||||
del cls.notify_remove[event_type]
|
||||
|
||||
@classmethod
|
||||
async def update(cls, event_type, func_args):
|
||||
"""Deliver all notifications for an event of the given type."""
|
||||
|
||||
_LOGGER.debug("event.update(%s, %s)", event_type, func_args)
|
||||
if event_type in cls.notify:
|
||||
for queue in cls.notify[event_type]:
|
||||
await queue.put(["event", func_args.copy()])
|
||||
519
config/custom_components/pyscript/function.py
Normal file
519
config/custom_components/pyscript/function.py
Normal file
@@ -0,0 +1,519 @@
|
||||
"""Function call handling."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from homeassistant.core import Context
|
||||
|
||||
from .const import LOGGER_PATH, SERVICE_RESPONSE_NONE, SERVICE_RESPONSE_ONLY
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".function")
|
||||
|
||||
|
||||
class Function:
|
||||
"""Define function handler functions."""
|
||||
|
||||
#
|
||||
# Global hass instance
|
||||
#
|
||||
hass = None
|
||||
|
||||
#
|
||||
# Mappings of tasks ids <-> task names
|
||||
#
|
||||
unique_task2name = {}
|
||||
unique_name2task = {}
|
||||
|
||||
#
|
||||
# Mappings of task id to hass contexts
|
||||
task2context = {}
|
||||
|
||||
#
|
||||
# Set of tasks that are running
|
||||
#
|
||||
our_tasks = set()
|
||||
|
||||
#
|
||||
# Done callbacks for each task
|
||||
#
|
||||
task2cb = {}
|
||||
|
||||
#
|
||||
# initial list of available functions
|
||||
#
|
||||
functions = {}
|
||||
|
||||
#
|
||||
# Functions that take the AstEval context as a first argument,
|
||||
# which is needed by a handful of special functions that need the
|
||||
# ast context
|
||||
#
|
||||
ast_functions = {}
|
||||
|
||||
#
|
||||
# task id of the task that cancels and waits for other tasks,
|
||||
#
|
||||
task_reaper = None
|
||||
task_reaper_q = None
|
||||
|
||||
#
|
||||
# task id of the task that awaits for coros (used by shutdown triggers)
|
||||
#
|
||||
task_waiter = None
|
||||
task_waiter_q = None
|
||||
|
||||
#
|
||||
# reference counting for service registrations; the new @service trigger
|
||||
# registers the service call before the old one is removed, so we only
|
||||
# remove the service registration when the reference count goes to zero
|
||||
#
|
||||
service_cnt = {}
|
||||
|
||||
#
|
||||
# save the global_ctx name where a service is registered so we can raise
|
||||
# an exception if it gets registered by a different global_ctx.
|
||||
#
|
||||
service2global_ctx = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Warn on Function instantiation."""
|
||||
_LOGGER.error("Function class is not meant to be instantiated")
|
||||
|
||||
@classmethod
|
||||
def init(cls, hass):
|
||||
"""Initialize Function."""
|
||||
cls.hass = hass
|
||||
cls.functions.update(
|
||||
{
|
||||
"event.fire": cls.event_fire,
|
||||
"service.call": cls.service_call,
|
||||
"service.has_service": cls.service_has_service,
|
||||
"task.cancel": cls.user_task_cancel,
|
||||
"task.current_task": cls.user_task_current_task,
|
||||
"task.remove_done_callback": cls.user_task_remove_done_callback,
|
||||
"task.sleep": cls.async_sleep,
|
||||
"task.wait": cls.user_task_wait,
|
||||
}
|
||||
)
|
||||
cls.ast_functions.update(
|
||||
{
|
||||
"log.debug": lambda ast_ctx: ast_ctx.get_logger().debug,
|
||||
"log.error": lambda ast_ctx: ast_ctx.get_logger().error,
|
||||
"log.info": lambda ast_ctx: ast_ctx.get_logger().info,
|
||||
"log.warning": lambda ast_ctx: ast_ctx.get_logger().warning,
|
||||
"print": lambda ast_ctx: ast_ctx.get_logger().debug,
|
||||
"task.name2id": cls.task_name2id_factory,
|
||||
"task.unique": cls.task_unique_factory,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
# start a task which is a reaper for canceled tasks, since some # functions
|
||||
# like TrigInfo.stop() can't be async (it's called from a __del__ method)
|
||||
#
|
||||
async def task_reaper(reaper_q):
|
||||
while True:
|
||||
try:
|
||||
cmd = await reaper_q.get()
|
||||
if cmd[0] == "exit":
|
||||
return
|
||||
if cmd[0] == "cancel":
|
||||
try:
|
||||
cmd[1].cancel()
|
||||
await cmd[1]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
_LOGGER.error("task_reaper: unknown command %s", cmd[0])
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
_LOGGER.error("task_reaper: got exception %s", traceback.format_exc(-1))
|
||||
|
||||
if not cls.task_reaper:
|
||||
cls.task_reaper_q = asyncio.Queue(0)
|
||||
cls.task_reaper = cls.create_task(task_reaper(cls.task_reaper_q))
|
||||
|
||||
#
|
||||
# start a task which creates tasks to run coros, and then syncs on their completion;
|
||||
# this is used by the shutdown trigger
|
||||
#
|
||||
async def task_waiter(waiter_q):
|
||||
aws = []
|
||||
while True:
|
||||
try:
|
||||
cmd = await waiter_q.get()
|
||||
if cmd[0] == "exit":
|
||||
return
|
||||
if cmd[0] == "await":
|
||||
aws.append(cls.create_task(cmd[1]))
|
||||
elif cmd[0] == "sync":
|
||||
if len(aws) > 0:
|
||||
await asyncio.gather(*aws)
|
||||
aws = []
|
||||
await cmd[1].put(0)
|
||||
else:
|
||||
_LOGGER.error("task_waiter: unknown command %s", cmd[0])
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
_LOGGER.error("task_waiter: got exception %s", traceback.format_exc(-1))
|
||||
|
||||
if not cls.task_waiter:
|
||||
cls.task_waiter_q = asyncio.Queue(0)
|
||||
cls.task_waiter = cls.create_task(task_waiter(cls.task_waiter_q))
|
||||
|
||||
@classmethod
|
||||
def reaper_cancel(cls, task):
|
||||
"""Send a task to be canceled by the reaper."""
|
||||
cls.task_reaper_q.put_nowait(["cancel", task])
|
||||
|
||||
@classmethod
|
||||
async def reaper_stop(cls):
|
||||
"""Tell the reaper task to exit."""
|
||||
if cls.task_reaper:
|
||||
cls.task_reaper_q.put_nowait(["exit"])
|
||||
await cls.task_reaper
|
||||
cls.task_reaper = None
|
||||
cls.task_reaper_q = None
|
||||
|
||||
@classmethod
|
||||
def waiter_await(cls, coro):
|
||||
"""Send a coro to be awaited by the waiter task."""
|
||||
cls.task_waiter_q.put_nowait(["await", coro])
|
||||
|
||||
@classmethod
|
||||
async def waiter_sync(cls):
|
||||
"""Wait until the waiter queue is empty."""
|
||||
if cls.task_waiter:
|
||||
sync_q = asyncio.Queue(0)
|
||||
cls.task_waiter_q.put_nowait(["sync", sync_q])
|
||||
await sync_q.get()
|
||||
|
||||
@classmethod
|
||||
async def waiter_stop(cls):
|
||||
"""Tell the waiter task to exit."""
|
||||
if cls.task_waiter:
|
||||
cls.task_waiter_q.put_nowait(["exit"])
|
||||
await cls.task_waiter
|
||||
cls.task_waiter = None
|
||||
cls.task_waiter_q = None
|
||||
|
||||
@classmethod
|
||||
async def async_sleep(cls, duration):
|
||||
"""Implement task.sleep()."""
|
||||
await asyncio.sleep(float(duration))
|
||||
|
||||
@classmethod
|
||||
async def event_fire(cls, event_type, **kwargs):
|
||||
"""Implement event.fire()."""
|
||||
curr_task = asyncio.current_task()
|
||||
if "context" in kwargs and isinstance(kwargs["context"], Context):
|
||||
context = kwargs["context"]
|
||||
del kwargs["context"]
|
||||
else:
|
||||
context = cls.task2context.get(curr_task, None)
|
||||
|
||||
cls.hass.bus.async_fire(event_type, kwargs, context=context)
|
||||
|
||||
@classmethod
|
||||
def store_hass_context(cls, hass_context):
|
||||
"""Store a context against the running task."""
|
||||
curr_task = asyncio.current_task()
|
||||
cls.task2context[curr_task] = hass_context
|
||||
|
||||
@classmethod
|
||||
def task_unique_factory(cls, ctx):
|
||||
"""Define and return task.unique() for this context."""
|
||||
|
||||
async def task_unique(name, kill_me=False):
|
||||
"""Implement task.unique()."""
|
||||
name = f"{ctx.get_global_ctx_name()}.{name}"
|
||||
curr_task = asyncio.current_task()
|
||||
if name in cls.unique_name2task:
|
||||
task = cls.unique_name2task[name]
|
||||
if kill_me:
|
||||
if task != curr_task:
|
||||
#
|
||||
# it seems we can't cancel ourselves, so we
|
||||
# tell the reaper task to cancel us
|
||||
#
|
||||
cls.reaper_cancel(curr_task)
|
||||
# wait to be canceled
|
||||
await asyncio.sleep(100000)
|
||||
elif task != curr_task and task in cls.our_tasks:
|
||||
# only cancel tasks if they are ones we started
|
||||
cls.reaper_cancel(task)
|
||||
if curr_task in cls.our_tasks:
|
||||
if name in cls.unique_name2task:
|
||||
task = cls.unique_name2task[name]
|
||||
if task in cls.unique_task2name:
|
||||
cls.unique_task2name[task].discard(name)
|
||||
cls.unique_name2task[name] = curr_task
|
||||
if curr_task not in cls.unique_task2name:
|
||||
cls.unique_task2name[curr_task] = set()
|
||||
cls.unique_task2name[curr_task].add(name)
|
||||
|
||||
return task_unique
|
||||
|
||||
@classmethod
|
||||
async def user_task_cancel(cls, task=None):
|
||||
"""Implement task.cancel()."""
|
||||
do_sleep = False
|
||||
if not task:
|
||||
task = asyncio.current_task()
|
||||
do_sleep = True
|
||||
if task not in cls.our_tasks:
|
||||
raise TypeError(f"{task} is not a user-started task")
|
||||
cls.reaper_cancel(task)
|
||||
if do_sleep:
|
||||
# wait to be canceled
|
||||
await asyncio.sleep(100000)
|
||||
|
||||
@classmethod
|
||||
async def user_task_current_task(cls):
|
||||
"""Implement task.current_task()."""
|
||||
return asyncio.current_task()
|
||||
|
||||
@classmethod
|
||||
def task_name2id_factory(cls, ctx):
|
||||
"""Define and return task.name2id() for this context."""
|
||||
|
||||
def user_task_name2id(name=None):
|
||||
"""Implement task.name2id()."""
|
||||
prefix = f"{ctx.get_global_ctx_name()}."
|
||||
if name is None:
|
||||
ret = {}
|
||||
for task_name, task_id in cls.unique_name2task.items():
|
||||
if task_name.startswith(prefix):
|
||||
ret[task_name[len(prefix) :]] = task_id
|
||||
return ret
|
||||
if prefix + name in cls.unique_name2task:
|
||||
return cls.unique_name2task[prefix + name]
|
||||
raise NameError(f"task name '{name}' is unknown")
|
||||
|
||||
return user_task_name2id
|
||||
|
||||
@classmethod
|
||||
async def user_task_wait(cls, aws, **kwargs):
|
||||
"""Implement task.wait()."""
|
||||
return await asyncio.wait(aws, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def user_task_remove_done_callback(cls, task, callback):
|
||||
"""Implement task.remove_done_callback()."""
|
||||
cls.task2cb[task]["cb"].pop(callback, None)
|
||||
|
||||
@classmethod
|
||||
def unique_name_used(cls, ctx, name):
|
||||
"""Return whether the current unique name is in use."""
|
||||
name = f"{ctx.get_global_ctx_name()}.{name}"
|
||||
return name in cls.unique_name2task
|
||||
|
||||
@classmethod
|
||||
def service_has_service(cls, domain, name):
|
||||
"""Implement service.has_service()."""
|
||||
return cls.hass.services.has_service(domain, name)
|
||||
|
||||
@classmethod
|
||||
async def service_call(cls, domain, name, **kwargs):
|
||||
"""Implement service.call()."""
|
||||
curr_task = asyncio.current_task()
|
||||
hass_args = {}
|
||||
for keyword, typ, default in [
|
||||
("context", [Context], cls.task2context.get(curr_task, None)),
|
||||
("blocking", [bool], None),
|
||||
("return_response", [bool], None),
|
||||
]:
|
||||
if keyword in kwargs and type(kwargs[keyword]) in typ:
|
||||
hass_args[keyword] = kwargs.pop(keyword)
|
||||
elif default:
|
||||
hass_args[keyword] = default
|
||||
|
||||
return await cls.hass_services_async_call(domain, name, kwargs, **hass_args)
|
||||
|
||||
@classmethod
|
||||
async def service_completions(cls, root):
|
||||
"""Return possible completions of HASS services."""
|
||||
words = set()
|
||||
services = cls.hass.services.async_services()
|
||||
num_period = root.count(".")
|
||||
if num_period == 1:
|
||||
domain, svc_root = root.split(".")
|
||||
if domain in services:
|
||||
words |= {f"{domain}.{svc}" for svc in services[domain] if svc.lower().startswith(svc_root)}
|
||||
elif num_period == 0:
|
||||
words |= {domain for domain in services if domain.lower().startswith(root)}
|
||||
|
||||
return words
|
||||
|
||||
@classmethod
|
||||
async def func_completions(cls, root):
|
||||
"""Return possible completions of functions."""
|
||||
funcs = {**cls.functions, **cls.ast_functions}
|
||||
words = {name for name in funcs if name.lower().startswith(root)}
|
||||
|
||||
return words
|
||||
|
||||
@classmethod
|
||||
def register(cls, funcs):
|
||||
"""Register functions to be available for calling."""
|
||||
cls.functions.update(funcs)
|
||||
|
||||
@classmethod
|
||||
def register_ast(cls, funcs):
|
||||
"""Register functions that need ast context to be available for calling."""
|
||||
cls.ast_functions.update(funcs)
|
||||
|
||||
@classmethod
|
||||
def install_ast_funcs(cls, ast_ctx):
|
||||
"""Install ast functions into the local symbol table."""
|
||||
sym_table = {name: func(ast_ctx) for name, func in cls.ast_functions.items()}
|
||||
ast_ctx.set_local_sym_table(sym_table)
|
||||
|
||||
@classmethod
|
||||
def get(cls, name):
|
||||
"""Lookup a function locally and then as a service."""
|
||||
func = cls.functions.get(name, None)
|
||||
if func:
|
||||
return func
|
||||
|
||||
name_parts = name.split(".")
|
||||
if len(name_parts) != 2:
|
||||
return None
|
||||
|
||||
domain, service = name_parts
|
||||
if not cls.service_has_service(domain, service):
|
||||
return None
|
||||
|
||||
def service_call_factory(domain, service):
|
||||
async def service_call(*args, **kwargs):
|
||||
curr_task = asyncio.current_task()
|
||||
hass_args = {}
|
||||
for keyword, typ, default in [
|
||||
("context", [Context], cls.task2context.get(curr_task, None)),
|
||||
("blocking", [bool], None),
|
||||
("return_response", [bool], None),
|
||||
]:
|
||||
if keyword in kwargs and type(kwargs[keyword]) in typ:
|
||||
hass_args[keyword] = kwargs.pop(keyword)
|
||||
elif default:
|
||||
hass_args[keyword] = default
|
||||
|
||||
if len(args) != 0:
|
||||
raise TypeError(f"service {domain}.{service} takes only keyword arguments")
|
||||
|
||||
return await cls.hass_services_async_call(domain, service, kwargs, **hass_args)
|
||||
|
||||
return service_call
|
||||
|
||||
return service_call_factory(domain, service)
|
||||
|
||||
@classmethod
|
||||
async def hass_services_async_call(cls, domain, service, kwargs, **hass_args):
|
||||
"""Call a hass async service."""
|
||||
if SERVICE_RESPONSE_ONLY is None:
|
||||
# backwards compatibility < 2023.7
|
||||
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
|
||||
else:
|
||||
# allow service responses >= 2023.7
|
||||
if (
|
||||
"return_response" in hass_args
|
||||
and hass_args["return_response"]
|
||||
and "blocking" not in hass_args
|
||||
):
|
||||
hass_args["blocking"] = True
|
||||
elif (
|
||||
"return_response" not in hass_args
|
||||
and cls.hass.services.supports_response(domain, service) == SERVICE_RESPONSE_ONLY
|
||||
):
|
||||
hass_args["return_response"] = True
|
||||
if "blocking" not in hass_args:
|
||||
hass_args["blocking"] = True
|
||||
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
|
||||
|
||||
@classmethod
|
||||
async def run_coro(cls, coro, ast_ctx=None):
|
||||
"""Run coroutine task and update unique task on start and exit."""
|
||||
#
|
||||
# Add a placeholder for the new task so we know it's one we started
|
||||
#
|
||||
task: asyncio.Task = None
|
||||
try:
|
||||
task = asyncio.current_task()
|
||||
cls.our_tasks.add(task)
|
||||
if ast_ctx is not None:
|
||||
cls.task_done_callback_ctx(task, ast_ctx)
|
||||
result = await coro
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
_LOGGER.error("run_coro: got exception %s", traceback.format_exc(-1))
|
||||
finally:
|
||||
if task in cls.task2cb:
|
||||
for callback, info in cls.task2cb[task]["cb"].items():
|
||||
ast_ctx, args, kwargs = info
|
||||
await ast_ctx.call_func(callback, None, *args, **kwargs)
|
||||
if ast_ctx.get_exception_obj():
|
||||
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
|
||||
break
|
||||
if task in cls.unique_task2name:
|
||||
for name in cls.unique_task2name[task]:
|
||||
del cls.unique_name2task[name]
|
||||
del cls.unique_task2name[task]
|
||||
cls.task2context.pop(task, None)
|
||||
cls.task2cb.pop(task, None)
|
||||
cls.our_tasks.discard(task)
|
||||
|
||||
@classmethod
|
||||
def create_task(cls, coro, ast_ctx=None):
|
||||
"""Create a new task that runs a coroutine."""
|
||||
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))
|
||||
|
||||
@classmethod
|
||||
def service_register(
|
||||
cls, global_ctx_name, domain, service, callback, supports_response=SERVICE_RESPONSE_NONE
|
||||
):
|
||||
"""Register a new service callback."""
|
||||
key = f"{domain}.{service}"
|
||||
if key not in cls.service_cnt:
|
||||
cls.service_cnt[key] = 0
|
||||
if key not in cls.service2global_ctx:
|
||||
cls.service2global_ctx[key] = global_ctx_name
|
||||
if cls.service2global_ctx[key] != global_ctx_name:
|
||||
raise ValueError(
|
||||
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
|
||||
)
|
||||
cls.service_cnt[key] += 1
|
||||
if SERVICE_RESPONSE_ONLY is None:
|
||||
# backwards compatibility < 2023.7
|
||||
cls.hass.services.async_register(domain, service, callback)
|
||||
else:
|
||||
# allow service responses >= 2023.7
|
||||
cls.hass.services.async_register(domain, service, callback, supports_response=supports_response)
|
||||
|
||||
@classmethod
|
||||
def service_remove(cls, global_ctx_name, domain, service):
|
||||
"""Remove a service callback."""
|
||||
key = f"{domain}.{service}"
|
||||
if cls.service_cnt.get(key, 0) > 1:
|
||||
cls.service_cnt[key] -= 1
|
||||
return
|
||||
cls.service_cnt[key] = 0
|
||||
cls.hass.services.async_remove(domain, service)
|
||||
cls.service2global_ctx.pop(key, None)
|
||||
|
||||
@classmethod
|
||||
def task_done_callback_ctx(cls, task, ast_ctx):
|
||||
"""Set the ast_ctx for a task, which is needed for done callbacks."""
|
||||
if task not in cls.task2cb or "ctx" not in cls.task2cb[task]:
|
||||
cls.task2cb[task] = {"ctx": ast_ctx, "cb": {}}
|
||||
|
||||
@classmethod
|
||||
def task_add_done_callback(cls, task, ast_ctx, callback, *args, **kwargs):
|
||||
"""Add a done callback to the given task."""
|
||||
if ast_ctx is None:
|
||||
ast_ctx = cls.task2cb[task]["ctx"]
|
||||
cls.task2cb[task]["cb"][callback] = [ast_ctx, args, kwargs]
|
||||
352
config/custom_components/pyscript/global_ctx.py
Normal file
352
config/custom_components/pyscript/global_ctx.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Global context handling."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
||||
from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH
|
||||
from .eval import AstEval, EvalFunc
|
||||
from .function import Function
|
||||
from .trigger import TrigInfo
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".global_ctx")
|
||||
|
||||
|
||||
class GlobalContext:
|
||||
"""Define class for global variables and trigger context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
global_sym_table: Dict[str, Any] = None,
|
||||
manager=None,
|
||||
rel_import_path: str = None,
|
||||
app_config: Dict[str, Any] = None,
|
||||
source: str = None,
|
||||
mtime: float = None,
|
||||
) -> None:
|
||||
"""Initialize GlobalContext."""
|
||||
self.name: str = name
|
||||
self.global_sym_table: Dict[str, Any] = global_sym_table if global_sym_table else {}
|
||||
self.triggers: Set[EvalFunc] = set()
|
||||
self.triggers_delay_start: Set[EvalFunc] = set()
|
||||
self.logger: logging.Logger = logging.getLogger(LOGGER_PATH + "." + name)
|
||||
self.manager: GlobalContextMgr = manager
|
||||
self.auto_start: bool = False
|
||||
self.module: ModuleType = None
|
||||
self.rel_import_path: str = rel_import_path
|
||||
self.source: str = source
|
||||
self.file_path: str = None
|
||||
self.mtime: float = mtime
|
||||
self.app_config: Dict[str, Any] = app_config
|
||||
self.imports: Set[str] = set()
|
||||
config_entry: ConfigEntry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {})
|
||||
if config_entry.data.get(CONF_HASS_IS_GLOBAL, False):
|
||||
#
|
||||
# expose hass as a global variable if configured
|
||||
#
|
||||
self.global_sym_table["hass"] = Function.hass
|
||||
if app_config:
|
||||
self.global_sym_table["pyscript.app_config"] = app_config.copy()
|
||||
|
||||
def trigger_register(self, func: EvalFunc) -> bool:
|
||||
"""Register a trigger function; return True if start now."""
|
||||
self.triggers.add(func)
|
||||
if self.auto_start:
|
||||
return True
|
||||
self.triggers_delay_start.add(func)
|
||||
return False
|
||||
|
||||
def trigger_unregister(self, func: EvalFunc) -> None:
|
||||
"""Unregister a trigger function."""
|
||||
self.triggers.discard(func)
|
||||
self.triggers_delay_start.discard(func)
|
||||
|
||||
def set_auto_start(self, auto_start: bool) -> None:
|
||||
"""Set the auto-start flag."""
|
||||
self.auto_start = auto_start
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start any unstarted triggers."""
|
||||
for func in self.triggers_delay_start:
|
||||
func.trigger_start()
|
||||
self.triggers_delay_start = set()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all triggers and auto_start."""
|
||||
for func in self.triggers:
|
||||
func.trigger_stop()
|
||||
self.triggers = set()
|
||||
self.triggers_delay_start = set()
|
||||
self.set_auto_start(False)
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the global context name."""
|
||||
return self.name
|
||||
|
||||
def set_logger_name(self, name) -> None:
|
||||
"""Set the global context logging name."""
|
||||
self.logger = logging.getLogger(LOGGER_PATH + "." + name)
|
||||
|
||||
def get_global_sym_table(self) -> Dict[str, Any]:
|
||||
"""Return the global symbol table."""
|
||||
return self.global_sym_table
|
||||
|
||||
def get_source(self) -> str:
|
||||
"""Return the source code."""
|
||||
return self.source
|
||||
|
||||
def get_app_config(self) -> Dict[str, Any]:
|
||||
"""Return the app config."""
|
||||
return self.app_config
|
||||
|
||||
def get_mtime(self) -> float:
|
||||
"""Return the mtime."""
|
||||
return self.mtime
|
||||
|
||||
def get_file_path(self) -> str:
|
||||
"""Return the file path."""
|
||||
return self.file_path
|
||||
|
||||
def get_imports(self) -> Set[str]:
|
||||
"""Return the imports."""
|
||||
return self.imports
|
||||
|
||||
def get_trig_info(self, name: str, trig_args: Dict[str, Any]) -> TrigInfo:
|
||||
"""Return a new trigger info instance with the given args."""
|
||||
return TrigInfo(name, trig_args, self)
|
||||
|
||||
async def module_import(self, module_name: str, import_level: int) -> List[Optional[str]]:
|
||||
"""Import a pyscript module from the pyscript/modules or apps folder."""
|
||||
|
||||
pyscript_dir = Function.hass.config.path(FOLDER)
|
||||
module_path = module_name.replace(".", "/")
|
||||
file_paths = []
|
||||
|
||||
def find_first_file(file_paths: List[Set[str]]) -> List[Optional[Union[str, ModuleType]]]:
|
||||
for ctx_name, path, rel_path in file_paths:
|
||||
abs_path = os.path.join(pyscript_dir, path)
|
||||
if os.path.isfile(abs_path):
|
||||
return [ctx_name, abs_path, rel_path]
|
||||
return None
|
||||
|
||||
#
|
||||
# first build a list of potential import files
|
||||
#
|
||||
if import_level > 0:
|
||||
if self.rel_import_path is None:
|
||||
raise ImportError("attempted relative import with no known parent package")
|
||||
path = self.rel_import_path
|
||||
if path.endswith("/__init__"):
|
||||
path = os.path.dirname(path)
|
||||
ctx_name = self.name
|
||||
for _ in range(import_level - 1):
|
||||
path = os.path.dirname(path)
|
||||
idx = ctx_name.rfind(".")
|
||||
if path.find("/") < 0 or idx < 0:
|
||||
raise ImportError("attempted relative import above parent package")
|
||||
ctx_name = ctx_name[0:idx]
|
||||
ctx_name += f".{module_name}"
|
||||
module_info = [ctx_name, f"{path}/{module_path}.py", path]
|
||||
path += f"/{module_path}"
|
||||
file_paths.append([ctx_name, f"{path}/__init__.py", path])
|
||||
file_paths.append(module_info)
|
||||
module_name = ctx_name[ctx_name.find(".") + 1 :]
|
||||
|
||||
else:
|
||||
if self.rel_import_path is not None and self.rel_import_path.startswith("apps/"):
|
||||
ctx_name = f"apps.{module_name}"
|
||||
file_paths.append([ctx_name, f"apps/{module_path}/__init__.py", f"apps/{module_path}"])
|
||||
file_paths.append([ctx_name, f"apps/{module_path}.py", f"apps/{module_path}"])
|
||||
|
||||
ctx_name = f"modules.{module_name}"
|
||||
file_paths.append([ctx_name, f"modules/{module_path}/__init__.py", f"modules/{module_path}"])
|
||||
file_paths.append([ctx_name, f"modules/{module_path}.py", None])
|
||||
|
||||
#
|
||||
# now see if we have loaded it already
|
||||
#
|
||||
for ctx_name, _, _ in file_paths:
|
||||
mod_ctx = self.manager.get(ctx_name)
|
||||
if mod_ctx and mod_ctx.module:
|
||||
self.imports.add(mod_ctx.get_name())
|
||||
return [mod_ctx.module, None]
|
||||
|
||||
#
|
||||
# not loaded already, so try to find and import it
|
||||
#
|
||||
file_info = await Function.hass.async_add_executor_job(find_first_file, file_paths)
|
||||
if not file_info:
|
||||
return [None, None]
|
||||
|
||||
[ctx_name, file_path, rel_import_path] = file_info
|
||||
|
||||
mod = ModuleType(module_name)
|
||||
global_ctx = GlobalContext(
|
||||
ctx_name, global_sym_table=mod.__dict__, manager=self.manager, rel_import_path=rel_import_path
|
||||
)
|
||||
global_ctx.set_auto_start(True)
|
||||
_, error_ctx = await self.manager.load_file(global_ctx, file_path)
|
||||
if error_ctx:
|
||||
_LOGGER.error(
|
||||
"module_import: failed to load module %s, ctx = %s, path = %s",
|
||||
module_name,
|
||||
ctx_name,
|
||||
file_path,
|
||||
)
|
||||
return [None, error_ctx]
|
||||
global_ctx.module = mod
|
||||
self.imports.add(ctx_name)
|
||||
return [mod, None]
|
||||
|
||||
|
||||
class GlobalContextMgr:
|
||||
"""Define class for all global contexts."""
|
||||
|
||||
#
|
||||
# map of context names to contexts
|
||||
#
|
||||
contexts = {}
|
||||
|
||||
#
|
||||
# sequence number for sessions
|
||||
#
|
||||
name_seq = 0
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Report an error if GlobalContextMgr in instantiated."""
|
||||
_LOGGER.error("GlobalContextMgr class is not meant to be instantiated")
|
||||
|
||||
@classmethod
|
||||
def init(cls) -> None:
|
||||
"""Initialize GlobalContextMgr."""
|
||||
|
||||
def get_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], str]:
|
||||
"""Generate a pyscript.get_global_ctx() function with given ast_ctx."""
|
||||
|
||||
async def get_global_ctx():
|
||||
return ast_ctx.get_global_ctx_name()
|
||||
|
||||
return get_global_ctx
|
||||
|
||||
def list_global_ctx_factory(ast_ctx: AstEval) -> Callable[[], List[str]]:
|
||||
"""Generate a pyscript.list_global_ctx() function with given ast_ctx."""
|
||||
|
||||
async def list_global_ctx():
|
||||
ctx_names = set(cls.contexts.keys())
|
||||
curr_ctx_name = ast_ctx.get_global_ctx_name()
|
||||
ctx_names.discard(curr_ctx_name)
|
||||
return [curr_ctx_name] + sorted(sorted(ctx_names))
|
||||
|
||||
return list_global_ctx
|
||||
|
||||
def set_global_ctx_factory(ast_ctx: AstEval) -> Callable[[str], None]:
|
||||
"""Generate a pyscript.set_global_ctx() function with given ast_ctx."""
|
||||
|
||||
async def set_global_ctx(name):
|
||||
global_ctx = cls.get(name)
|
||||
if global_ctx is None:
|
||||
raise NameError(f"global context '{name}' does not exist")
|
||||
ast_ctx.set_global_ctx(global_ctx)
|
||||
ast_ctx.set_logger_name(global_ctx.name)
|
||||
|
||||
return set_global_ctx
|
||||
|
||||
ast_funcs = {
|
||||
"pyscript.get_global_ctx": get_global_ctx_factory,
|
||||
"pyscript.list_global_ctx": list_global_ctx_factory,
|
||||
"pyscript.set_global_ctx": set_global_ctx_factory,
|
||||
}
|
||||
|
||||
Function.register_ast(ast_funcs)
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> Optional[str]:
|
||||
"""Return the GlobalContext given a name."""
|
||||
return cls.contexts.get(name, None)
|
||||
|
||||
@classmethod
|
||||
def set(cls, name: str, global_ctx: GlobalContext) -> None:
|
||||
"""Save the GlobalContext by name."""
|
||||
cls.contexts[name] = global_ctx
|
||||
|
||||
@classmethod
|
||||
def items(cls) -> List[Set[Union[str, GlobalContext]]]:
|
||||
"""Return all the global context items."""
|
||||
return sorted(cls.contexts.items())
|
||||
|
||||
@classmethod
|
||||
def delete(cls, name: str) -> None:
|
||||
"""Delete the given GlobalContext."""
|
||||
if name in cls.contexts:
|
||||
global_ctx = cls.contexts[name]
|
||||
global_ctx.stop()
|
||||
del cls.contexts[name]
|
||||
|
||||
@classmethod
|
||||
def new_name(cls, root: str) -> str:
|
||||
"""Find a unique new name by appending a sequence number to root."""
|
||||
while True:
|
||||
name = f"{root}{cls.name_seq}"
|
||||
cls.name_seq += 1
|
||||
if name not in cls.contexts:
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
async def load_file(
|
||||
cls, global_ctx: GlobalContext, file_path: str, source: str = None, reload: bool = False
|
||||
) -> Set[Union[bool, AstEval]]:
|
||||
"""Load, parse and run the given script file; returns error ast_ctx on error, or None if ok."""
|
||||
|
||||
mtime = None
|
||||
if source is None:
|
||||
|
||||
def read_file(path: str) -> Set[Union[str, float]]:
|
||||
try:
|
||||
with open(path, encoding="utf-8") as file_desc:
|
||||
source = file_desc.read()
|
||||
return source, os.path.getmtime(path)
|
||||
except Exception as exc:
|
||||
_LOGGER.error("%s", exc)
|
||||
return None, 0
|
||||
|
||||
source, mtime = await Function.hass.async_add_executor_job(read_file, file_path)
|
||||
|
||||
if source is None:
|
||||
return False, None
|
||||
|
||||
ctx_curr = cls.get(global_ctx.get_name())
|
||||
if ctx_curr:
|
||||
# stop triggers and destroy old global context
|
||||
ctx_curr.stop()
|
||||
cls.delete(global_ctx.get_name())
|
||||
|
||||
#
|
||||
# create new ast eval context and parse source file
|
||||
#
|
||||
ast_ctx = AstEval(global_ctx.get_name(), global_ctx)
|
||||
Function.install_ast_funcs(ast_ctx)
|
||||
|
||||
if not ast_ctx.parse(source, filename=file_path):
|
||||
exc = ast_ctx.get_exception_long()
|
||||
ast_ctx.get_logger().error(exc)
|
||||
global_ctx.stop()
|
||||
return False, ast_ctx
|
||||
await ast_ctx.eval()
|
||||
exc = ast_ctx.get_exception_long()
|
||||
if exc is not None:
|
||||
ast_ctx.get_logger().error(exc)
|
||||
global_ctx.stop()
|
||||
return False, ast_ctx
|
||||
global_ctx.source = source
|
||||
global_ctx.file_path = file_path
|
||||
if mtime is not None:
|
||||
global_ctx.mtime = mtime
|
||||
cls.set(global_ctx.get_name(), global_ctx)
|
||||
|
||||
_LOGGER.info("%s %s", "Reloaded" if reload else "Loaded", file_path)
|
||||
|
||||
return True, None
|
||||
921
config/custom_components/pyscript/jupyter_kernel.py
Normal file
921
config/custom_components/pyscript/jupyter_kernel.py
Normal file
@@ -0,0 +1,921 @@
|
||||
"""Pyscript Jupyter kernel."""
|
||||
|
||||
#
|
||||
# Based on simple_kernel.py by Doug Blank <doug.blank@gmail.com>
|
||||
# https://github.com/dsblank/simple_kernel
|
||||
# license: public domain
|
||||
# Thanks Doug!
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
import re
|
||||
from struct import pack, unpack
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from .const import LOGGER_PATH
|
||||
from .function import Function
|
||||
from .global_ctx import GlobalContextMgr
|
||||
from .state import State
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".jupyter_kernel")
|
||||
|
||||
# Globals:
|
||||
|
||||
DELIM = b"<IDS|MSG>"
|
||||
|
||||
|
||||
def msg_id():
|
||||
"""Return a new uuid for message id."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def str_to_bytes(string):
|
||||
"""Encode a string in bytes."""
|
||||
return string.encode("utf-8")
|
||||
|
||||
|
||||
class KernelBufferingHandler(logging.handlers.BufferingHandler):
|
||||
"""Memory-based handler for logging; send via stdout queue."""
|
||||
|
||||
def __init__(self, housekeep_q):
|
||||
"""Initialize KernelBufferingHandler instance."""
|
||||
super().__init__(0)
|
||||
self.housekeep_q = housekeep_q
|
||||
|
||||
def flush(self):
|
||||
"""Flush is a no-op."""
|
||||
|
||||
def shouldFlush(self, record):
|
||||
"""Write the buffer to the housekeeping queue."""
|
||||
try:
|
||||
self.housekeep_q.put_nowait(["stdout", self.format(record)])
|
||||
except asyncio.QueueFull:
|
||||
_LOGGER.error("housekeep_q unexpectedly full")
|
||||
|
||||
|
||||
################################################################
|
||||
class ZmqSocket:
|
||||
"""Defines a minimal implementation of a small subset of ZMQ."""
|
||||
|
||||
#
|
||||
# This allows pyscript to work with Jupyter without the real zmq
|
||||
# and pyzmq packages, which might not be available or easy to
|
||||
# install on the wide set of HASS platforms.
|
||||
#
|
||||
def __init__(self, reader, writer, sock_type):
|
||||
"""Initialize a ZMQ socket with the given type and reader/writer streams."""
|
||||
self.writer = writer
|
||||
self.reader = reader
|
||||
self.type = sock_type
|
||||
|
||||
async def read_bytes(self, num_bytes):
|
||||
"""Read bytes from ZMQ socket."""
|
||||
data = b""
|
||||
while len(data) < num_bytes:
|
||||
new_data = await self.reader.read(num_bytes - len(data))
|
||||
if len(new_data) == 0:
|
||||
raise EOFError
|
||||
data += new_data
|
||||
return data
|
||||
|
||||
async def write_bytes(self, raw_msg):
|
||||
"""Write bytes to ZMQ socket."""
|
||||
self.writer.write(raw_msg)
|
||||
await self.writer.drain()
|
||||
|
||||
async def handshake(self):
|
||||
"""Do initial greeting handshake on a new ZMQ connection."""
|
||||
await self.write_bytes(b"\xff\x00\x00\x00\x00\x00\x00\x00\x01\x7f")
|
||||
_ = await self.read_bytes(10)
|
||||
# _LOGGER.debug(f"handshake: got initial greeting {greeting}")
|
||||
await self.write_bytes(b"\x03")
|
||||
_ = await self.read_bytes(1)
|
||||
await self.write_bytes(b"\x00" + "NULL".encode() + b"\x00" * 16 + b"\x00" + b"\x00" * 31)
|
||||
_ = await self.read_bytes(53)
|
||||
# _LOGGER.debug(f"handshake: got rest of greeting {greeting}")
|
||||
params = [["Socket-Type", self.type]]
|
||||
if self.type == "ROUTER":
|
||||
params.append(["Identity", ""])
|
||||
await self.send_cmd("READY", params)
|
||||
|
||||
async def recv(self, multipart=False):
|
||||
"""Receive a message from ZMQ socket."""
|
||||
parts = []
|
||||
while 1:
|
||||
cmd = (await self.read_bytes(1))[0]
|
||||
if cmd & 0x2:
|
||||
msg_len = unpack(">Q", await self.read_bytes(8))[0]
|
||||
else:
|
||||
msg_len = (await self.read_bytes(1))[0]
|
||||
msg_body = await self.read_bytes(msg_len)
|
||||
if cmd & 0x4:
|
||||
# _LOGGER.debug(f"recv: got cmd {msg_body}")
|
||||
cmd_len = msg_body[0]
|
||||
cmd = msg_body[1 : cmd_len + 1]
|
||||
msg_body = msg_body[cmd_len + 1 :]
|
||||
params = []
|
||||
while len(msg_body) > 0:
|
||||
param_len = msg_body[0]
|
||||
param = msg_body[1 : param_len + 1]
|
||||
msg_body = msg_body[param_len + 1 :]
|
||||
value_len = unpack(">L", msg_body[0:4])[0]
|
||||
value = msg_body[4 : 4 + value_len]
|
||||
msg_body = msg_body[4 + value_len :]
|
||||
params.append([param, value])
|
||||
# _LOGGER.debug(f"recv: got cmd={cmd}, params={params}")
|
||||
else:
|
||||
parts.append(msg_body)
|
||||
if cmd in (0x0, 0x2):
|
||||
# _LOGGER.debug(f"recv: got msg {parts}")
|
||||
if not multipart:
|
||||
return b"".join(parts)
|
||||
|
||||
return parts
|
||||
|
||||
async def recv_multipart(self):
|
||||
"""Receive a multipart message from ZMQ socket."""
|
||||
return await self.recv(multipart=True)
|
||||
|
||||
async def send_cmd(self, cmd, params):
|
||||
"""Send a command over ZMQ socket."""
|
||||
raw_msg = bytearray([len(cmd)]) + cmd.encode()
|
||||
for param in params:
|
||||
raw_msg += bytearray([len(param[0])]) + param[0].encode()
|
||||
raw_msg += pack(">L", len(param[1])) + param[1].encode()
|
||||
len_msg = len(raw_msg)
|
||||
if len_msg <= 255:
|
||||
raw_msg = bytearray([0x4, len_msg]) + raw_msg
|
||||
else:
|
||||
raw_msg = bytearray([0x6]) + pack(">Q", len_msg) + raw_msg
|
||||
# _LOGGER.debug(f"send_cmd: sending {raw_msg}")
|
||||
await self.write_bytes(raw_msg)
|
||||
|
||||
async def send(self, msg):
|
||||
"""Send a message over ZMQ socket."""
|
||||
len_msg = len(msg)
|
||||
if len_msg <= 255:
|
||||
raw_msg = bytearray([0x1, 0x0, 0x0, len_msg]) + msg
|
||||
else:
|
||||
raw_msg = bytearray([0x1, 0x0, 0x2]) + pack(">Q", len_msg) + msg
|
||||
# _LOGGER.debug(f"send: sending {raw_msg}")
|
||||
await self.write_bytes(raw_msg)
|
||||
|
||||
async def send_multipart(self, parts):
|
||||
"""Send multipart messages over ZMQ socket."""
|
||||
raw_msg = b""
|
||||
for i, part in enumerate(parts):
|
||||
len_part = len(part)
|
||||
cmd = 0x1 if i < len(parts) - 1 else 0x0
|
||||
if len_part <= 255:
|
||||
raw_msg += bytearray([cmd, len_part]) + part
|
||||
else:
|
||||
raw_msg += bytearray([cmd + 2]) + pack(">Q", len_part) + part
|
||||
# _LOGGER.debug(f"send_multipart: sending {raw_msg}")
|
||||
await self.write_bytes(raw_msg)
|
||||
|
||||
def close(self):
|
||||
"""Close the ZMQ socket."""
|
||||
self.writer.close()
|
||||
|
||||
|
||||
##########################################
|
||||
class Kernel:
|
||||
"""Define a Jupyter Kernel class."""
|
||||
|
||||
def __init__(self, config, ast_ctx, global_ctx, global_ctx_name):
|
||||
"""Initialize a Kernel object, one instance per session."""
|
||||
self.config = config.copy()
|
||||
self.global_ctx = global_ctx
|
||||
self.global_ctx_name = global_ctx_name
|
||||
self.ast_ctx = ast_ctx
|
||||
|
||||
self.secure_key = str_to_bytes(self.config["key"])
|
||||
self.no_connect_timeout = self.config.get("no_connect_timeout", 30)
|
||||
self.signature_schemes = {"hmac-sha256": hashlib.sha256}
|
||||
self.auth = hmac.HMAC(
|
||||
self.secure_key,
|
||||
digestmod=self.signature_schemes[self.config["signature_scheme"]],
|
||||
)
|
||||
self.execution_count = 1
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
self.heartbeat_server = None
|
||||
self.iopub_server = None
|
||||
self.control_server = None
|
||||
self.stdin_server = None
|
||||
self.shell_server = None
|
||||
|
||||
self.heartbeat_port = None
|
||||
self.iopub_port = None
|
||||
self.control_port = None
|
||||
self.stdin_port = None
|
||||
self.shell_port = None
|
||||
# this should probably be a configuration parameter
|
||||
self.avail_port = 50321
|
||||
|
||||
# there can be multiple iopub subscribers, with corresponding tasks
|
||||
self.iopub_socket = set()
|
||||
|
||||
self.tasks = {}
|
||||
self.task_cnt = 0
|
||||
self.task_cnt_max = 0
|
||||
|
||||
self.session_cleanup_callback = None
|
||||
|
||||
self.housekeep_q = asyncio.Queue(0)
|
||||
|
||||
self.parent_header = None
|
||||
|
||||
#
|
||||
# we create a logging handler so that output from the log functions
|
||||
# gets delivered back to Jupyter as stdout
|
||||
#
|
||||
self.console = KernelBufferingHandler(self.housekeep_q)
|
||||
self.console.setLevel(logging.DEBUG)
|
||||
# set a format which is just the message
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
self.console.setFormatter(formatter)
|
||||
|
||||
# match alphanum or "." at end of line
|
||||
self.completion_re = re.compile(r".*?([\w.]*)$", re.DOTALL)
|
||||
|
||||
# see if line ends in a ":", with optional whitespace and comment
|
||||
# note: this doesn't detect if we are inside a quoted string...
|
||||
self.colon_end_re = re.compile(r".*: *(#.*)?$")
|
||||
|
||||
def msg_sign(self, msg_lst):
|
||||
"""Sign a message with a secure signature."""
|
||||
auth_hmac = self.auth.copy()
|
||||
for msg in msg_lst:
|
||||
auth_hmac.update(msg)
|
||||
return str_to_bytes(auth_hmac.hexdigest())
|
||||
|
||||
def deserialize_wire_msg(self, wire_msg):
|
||||
"""Split the routing prefix and message frames from a message on the wire."""
|
||||
delim_idx = wire_msg.index(DELIM)
|
||||
identities = wire_msg[:delim_idx]
|
||||
m_signature = wire_msg[delim_idx + 1]
|
||||
msg_frames = wire_msg[delim_idx + 2 :]
|
||||
|
||||
def decode(msg):
|
||||
return json.loads(msg.decode("utf-8"))
|
||||
|
||||
msg = {}
|
||||
msg["header"] = decode(msg_frames[0])
|
||||
msg["parent_header"] = decode(msg_frames[1])
|
||||
msg["metadata"] = decode(msg_frames[2])
|
||||
msg["content"] = decode(msg_frames[3])
|
||||
check_sig = self.msg_sign(msg_frames)
|
||||
if check_sig != m_signature:
|
||||
_LOGGER.error(
|
||||
"signature mismatch: check_sig=%s, m_signature=%s, wire_msg=%s",
|
||||
check_sig,
|
||||
m_signature,
|
||||
wire_msg,
|
||||
)
|
||||
raise ValueError("Signatures do not match")
|
||||
|
||||
return identities, msg
|
||||
|
||||
def new_header(self, msg_type):
|
||||
"""Make a new header."""
|
||||
return {
|
||||
"date": datetime.datetime.now().isoformat(),
|
||||
"msg_id": msg_id(),
|
||||
"username": "kernel",
|
||||
"session": self.engine_id,
|
||||
"msg_type": msg_type,
|
||||
"version": "5.3",
|
||||
}
|
||||
|
||||
async def send(
|
||||
self,
|
||||
stream,
|
||||
msg_type,
|
||||
content=None,
|
||||
parent_header=None,
|
||||
metadata=None,
|
||||
identities=None,
|
||||
):
|
||||
"""Send message to the Jupyter client."""
|
||||
header = self.new_header(msg_type)
|
||||
|
||||
def encode(msg):
|
||||
return str_to_bytes(json.dumps(msg))
|
||||
|
||||
msg_lst = [
|
||||
encode(header),
|
||||
encode(parent_header if parent_header else {}),
|
||||
encode(metadata if metadata else {}),
|
||||
encode(content if content else {}),
|
||||
]
|
||||
signature = self.msg_sign(msg_lst)
|
||||
parts = [DELIM, signature, msg_lst[0], msg_lst[1], msg_lst[2], msg_lst[3]]
|
||||
if identities:
|
||||
parts = identities + parts
|
||||
if stream:
|
||||
# _LOGGER.debug("send %s: %s", msg_type, parts)
|
||||
for this_stream in stream if isinstance(stream, set) else {stream}:
|
||||
await this_stream.send_multipart(parts)
|
||||
|
||||
async def shell_handler(self, shell_socket, wire_msg):
|
||||
"""Handle shell messages."""
|
||||
|
||||
identities, msg = self.deserialize_wire_msg(wire_msg)
|
||||
# _LOGGER.debug("shell received %s: %s", msg.get('header', {}).get('msg_type', 'UNKNOWN'), msg)
|
||||
self.parent_header = msg["header"]
|
||||
|
||||
content = {
|
||||
"execution_state": "busy",
|
||||
}
|
||||
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
|
||||
|
||||
if msg["header"]["msg_type"] == "execute_request":
|
||||
|
||||
content = {
|
||||
"execution_count": self.execution_count,
|
||||
"code": msg["content"]["code"],
|
||||
}
|
||||
await self.send(self.iopub_socket, "execute_input", content, parent_header=msg["header"])
|
||||
result = None
|
||||
|
||||
code = msg["content"]["code"]
|
||||
#
|
||||
# replace VSCode initialization code, which depend on iPython % extensions
|
||||
#
|
||||
if code.startswith("%config "):
|
||||
code = "None"
|
||||
if code.startswith("_rwho_ls = %who_ls"):
|
||||
code = "print([])"
|
||||
|
||||
self.global_ctx.set_auto_start(False)
|
||||
self.ast_ctx.parse(code)
|
||||
exc = self.ast_ctx.get_exception_obj()
|
||||
if exc is None:
|
||||
result = await self.ast_ctx.eval()
|
||||
exc = self.ast_ctx.get_exception_obj()
|
||||
await Function.waiter_sync()
|
||||
self.global_ctx.set_auto_start(True)
|
||||
self.global_ctx.start()
|
||||
if exc:
|
||||
traceback_mesg = self.ast_ctx.get_exception_long().split("\n")
|
||||
|
||||
metadata = {
|
||||
"dependencies_met": True,
|
||||
"engine": self.engine_id,
|
||||
"status": "error",
|
||||
"started": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
content = {
|
||||
"execution_count": self.execution_count,
|
||||
"status": "error",
|
||||
"ename": type(exc).__name__, # Exception name, as a string
|
||||
"evalue": str(exc), # Exception value, as a string
|
||||
"traceback": traceback_mesg,
|
||||
}
|
||||
_LOGGER.debug("Executing '%s' got exception: %s", code, content)
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"execute_reply",
|
||||
content,
|
||||
metadata=metadata,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
del content["execution_count"], content["status"]
|
||||
await self.send(self.iopub_socket, "error", content, parent_header=msg["header"])
|
||||
|
||||
content = {
|
||||
"execution_state": "idle",
|
||||
}
|
||||
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
|
||||
if msg["content"].get("store_history", True):
|
||||
self.execution_count += 1
|
||||
return
|
||||
|
||||
# if True or isinstance(self.ast_ctx.ast, ast.Expr):
|
||||
_LOGGER.debug("Executing: '%s' got result %s", code, result)
|
||||
if result is not None:
|
||||
content = {
|
||||
"execution_count": self.execution_count,
|
||||
"data": {"text/plain": repr(result)},
|
||||
"metadata": {},
|
||||
}
|
||||
await self.send(
|
||||
self.iopub_socket,
|
||||
"execute_result",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"dependencies_met": True,
|
||||
"engine": self.engine_id,
|
||||
"status": "ok",
|
||||
"started": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
content = {
|
||||
"status": "ok",
|
||||
"execution_count": self.execution_count,
|
||||
"user_variables": {},
|
||||
"payload": [],
|
||||
"user_expressions": {},
|
||||
}
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"execute_reply",
|
||||
content,
|
||||
metadata=metadata,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
if msg["content"].get("store_history", True):
|
||||
self.execution_count += 1
|
||||
|
||||
#
|
||||
# Make sure stdout gets sent before set report execution_state idle on iopub,
|
||||
# otherwise VSCode doesn't display stdout. We do a handshake with the
|
||||
# housekeep task to ensure any queued messages get processed.
|
||||
#
|
||||
handshake_q = asyncio.Queue(0)
|
||||
await self.housekeep_q.put(["handshake", handshake_q, 0])
|
||||
await handshake_q.get()
|
||||
|
||||
elif msg["header"]["msg_type"] == "kernel_info_request":
|
||||
content = {
|
||||
"protocol_version": "5.3",
|
||||
"ipython_version": [1, 1, 0, ""],
|
||||
"language_version": [0, 0, 1],
|
||||
"language": "python",
|
||||
"implementation": "python",
|
||||
"implementation_version": "3.7",
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "1.0",
|
||||
"mimetype": "",
|
||||
"file_extension": ".py",
|
||||
"codemirror_mode": "",
|
||||
"nbconvert_exporter": "",
|
||||
},
|
||||
"banner": "",
|
||||
}
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"kernel_info_reply",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
|
||||
elif msg["header"]["msg_type"] == "complete_request":
|
||||
root = ""
|
||||
words = set()
|
||||
code = msg["content"]["code"]
|
||||
posn = msg["content"]["cursor_pos"]
|
||||
match = self.completion_re.match(code[0:posn].lower())
|
||||
if match:
|
||||
root = match[1].lower()
|
||||
words = State.completions(root)
|
||||
words = words.union(await Function.service_completions(root))
|
||||
words = words.union(await Function.func_completions(root))
|
||||
words = words.union(self.ast_ctx.completions(root))
|
||||
# _LOGGER.debug(f"complete_request code={code}, posn={posn}, root={root}, words={words}")
|
||||
content = {
|
||||
"status": "ok",
|
||||
"matches": sorted(list(words)),
|
||||
"cursor_start": msg["content"]["cursor_pos"] - len(root),
|
||||
"cursor_end": msg["content"]["cursor_pos"],
|
||||
"metadata": {},
|
||||
}
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"complete_reply",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
|
||||
elif msg["header"]["msg_type"] == "is_complete_request":
|
||||
code = msg["content"]["code"]
|
||||
self.ast_ctx.parse(code)
|
||||
exc = self.ast_ctx.get_exception_obj()
|
||||
|
||||
# determine indent of last line
|
||||
indent = 0
|
||||
i = code.rfind("\n")
|
||||
if i >= 0:
|
||||
while i + 1 < len(code) and code[i + 1] == " ":
|
||||
i += 1
|
||||
indent += 1
|
||||
if exc is None:
|
||||
if indent == 0:
|
||||
content = {
|
||||
# One of 'complete', 'incomplete', 'invalid', 'unknown'
|
||||
"status": "complete",
|
||||
# If status is 'incomplete', indent should contain the characters to use
|
||||
# to indent the next line. This is only a hint: frontends may ignore it
|
||||
# and use their own autoindentation rules. For other statuses, this
|
||||
# field does not exist.
|
||||
# "indent": str,
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
"status": "incomplete",
|
||||
"indent": " " * indent,
|
||||
}
|
||||
else:
|
||||
#
|
||||
# if the syntax error is right at the end, then we label it incomplete,
|
||||
# otherwise it's invalid
|
||||
#
|
||||
if "EOF while" in str(exc) or "expected an indented block" in str(exc):
|
||||
# if error is at ":" then increase indent
|
||||
if hasattr(exc, "lineno"):
|
||||
line = code.split("\n")[exc.lineno - 1]
|
||||
if self.colon_end_re.match(line):
|
||||
indent += 4
|
||||
content = {
|
||||
"status": "incomplete",
|
||||
"indent": " " * indent,
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
"status": "invalid",
|
||||
}
|
||||
# _LOGGER.debug(f"is_complete_request code={code}, exc={exc}, content={content}")
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"is_complete_reply",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
|
||||
elif msg["header"]["msg_type"] == "comm_info_request":
|
||||
content = {"comms": {}}
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"comm_info_reply",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
|
||||
elif msg["header"]["msg_type"] == "history_request":
|
||||
content = {"history": []}
|
||||
await self.send(
|
||||
shell_socket,
|
||||
"history_reply",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
|
||||
elif msg["header"]["msg_type"] in {"comm_open", "comm_msg", "comm_close"}:
|
||||
# _LOGGER.debug(f"ignore {msg['header']['msg_type']} message ")
|
||||
...
|
||||
else:
|
||||
_LOGGER.error("unknown msg_type: %s", msg["header"]["msg_type"])
|
||||
|
||||
content = {
|
||||
"execution_state": "idle",
|
||||
}
|
||||
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
|
||||
|
||||
async def control_listen(self, reader, writer):
|
||||
"""Task that listens to control messages."""
|
||||
try:
|
||||
_LOGGER.debug("control_listen connected")
|
||||
await self.housekeep_q.put(["register", "control", asyncio.current_task()])
|
||||
control_socket = ZmqSocket(reader, writer, "ROUTER")
|
||||
await control_socket.handshake()
|
||||
while 1:
|
||||
wire_msg = await control_socket.recv_multipart()
|
||||
identities, msg = self.deserialize_wire_msg(wire_msg)
|
||||
# _LOGGER.debug("control received %s: %s", msg.get('header', {}).get('msg_type', 'UNKNOWN'), msg)
|
||||
if msg["header"]["msg_type"] == "shutdown_request":
|
||||
content = {
|
||||
"restart": False,
|
||||
}
|
||||
await self.send(
|
||||
control_socket,
|
||||
"shutdown_reply",
|
||||
content,
|
||||
parent_header=msg["header"],
|
||||
identities=identities,
|
||||
)
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (EOFError, ConnectionResetError):
|
||||
_LOGGER.debug("control_listen got eof")
|
||||
await self.housekeep_q.put(["unregister", "control", asyncio.current_task()])
|
||||
control_socket.close()
|
||||
except Exception as err:
|
||||
_LOGGER.error("control_listen exception %s", err)
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
|
||||
async def stdin_listen(self, reader, writer):
|
||||
"""Task that listens to stdin messages."""
|
||||
try:
|
||||
_LOGGER.debug("stdin_listen connected")
|
||||
await self.housekeep_q.put(["register", "stdin", asyncio.current_task()])
|
||||
stdin_socket = ZmqSocket(reader, writer, "ROUTER")
|
||||
await stdin_socket.handshake()
|
||||
while 1:
|
||||
_ = await stdin_socket.recv_multipart()
|
||||
# _LOGGER.debug("stdin_listen received %s", _)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (EOFError, ConnectionResetError):
|
||||
_LOGGER.debug("stdin_listen got eof")
|
||||
await self.housekeep_q.put(["unregister", "stdin", asyncio.current_task()])
|
||||
stdin_socket.close()
|
||||
except Exception:
|
||||
_LOGGER.error("stdin_listen exception %s", traceback.format_exc(-1))
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
|
||||
async def shell_listen(self, reader, writer):
|
||||
"""Task that listens to shell messages."""
|
||||
try:
|
||||
_LOGGER.debug("shell_listen connected")
|
||||
await self.housekeep_q.put(["register", "shell", asyncio.current_task()])
|
||||
shell_socket = ZmqSocket(reader, writer, "ROUTER")
|
||||
await shell_socket.handshake()
|
||||
while 1:
|
||||
msg = await shell_socket.recv_multipart()
|
||||
await self.shell_handler(shell_socket, msg)
|
||||
except asyncio.CancelledError:
|
||||
shell_socket.close()
|
||||
raise
|
||||
except (EOFError, ConnectionResetError):
|
||||
_LOGGER.debug("shell_listen got eof")
|
||||
await self.housekeep_q.put(["unregister", "shell", asyncio.current_task()])
|
||||
shell_socket.close()
|
||||
except Exception:
|
||||
_LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
|
||||
async def heartbeat_listen(self, reader, writer):
|
||||
"""Task that listens and responds to heart beat messages."""
|
||||
try:
|
||||
_LOGGER.debug("heartbeat_listen connected")
|
||||
await self.housekeep_q.put(["register", "heartbeat", asyncio.current_task()])
|
||||
heartbeat_socket = ZmqSocket(reader, writer, "REP")
|
||||
await heartbeat_socket.handshake()
|
||||
while 1:
|
||||
msg = await heartbeat_socket.recv()
|
||||
# _LOGGER.debug("heartbeat_listen: got %s", msg)
|
||||
await heartbeat_socket.send(msg)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (EOFError, ConnectionResetError):
|
||||
_LOGGER.debug("heartbeat_listen got eof")
|
||||
await self.housekeep_q.put(["unregister", "heartbeat", asyncio.current_task()])
|
||||
heartbeat_socket.close()
|
||||
except Exception:
|
||||
_LOGGER.error("heartbeat_listen exception: %s", traceback.format_exc(-1))
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
|
||||
async def iopub_listen(self, reader, writer):
|
||||
"""Task that listens to iopub messages."""
|
||||
try:
|
||||
_LOGGER.debug("iopub_listen connected")
|
||||
await self.housekeep_q.put(["register", "iopub", asyncio.current_task()])
|
||||
iopub_socket = ZmqSocket(reader, writer, "PUB")
|
||||
await iopub_socket.handshake()
|
||||
self.iopub_socket.add(iopub_socket)
|
||||
while 1:
|
||||
_ = await iopub_socket.recv_multipart()
|
||||
# _LOGGER.debug("iopub received %s", _)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (EOFError, ConnectionResetError):
|
||||
await self.housekeep_q.put(["unregister", "iopub", asyncio.current_task()])
|
||||
iopub_socket.close()
|
||||
self.iopub_socket.discard(iopub_socket)
|
||||
_LOGGER.debug("iopub_listen got eof")
|
||||
except Exception:
|
||||
_LOGGER.error("iopub_listen exception %s", traceback.format_exc(-1))
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
|
||||
async def housekeep_run(self):
|
||||
"""Housekeeping, including closing servers after startup, and doing orderly shutdown."""
|
||||
while True:
|
||||
try:
|
||||
msg = await self.housekeep_q.get()
|
||||
if msg[0] == "stdout":
|
||||
content = {"name": "stdout", "text": msg[1] + "\n"}
|
||||
if self.iopub_socket:
|
||||
await self.send(
|
||||
self.iopub_socket,
|
||||
"stream",
|
||||
content,
|
||||
parent_header=self.parent_header,
|
||||
identities=[b"stream.stdout"],
|
||||
)
|
||||
elif msg[0] == "handshake":
|
||||
await msg[1].put(msg[2])
|
||||
elif msg[0] == "register":
|
||||
if msg[1] not in self.tasks:
|
||||
self.tasks[msg[1]] = set()
|
||||
self.tasks[msg[1]].add(msg[2])
|
||||
self.task_cnt += 1
|
||||
self.task_cnt_max = max(self.task_cnt_max, self.task_cnt)
|
||||
#
|
||||
# now a couple of things are connected, call the session_cleanup_callback
|
||||
#
|
||||
if self.task_cnt > 1 and self.session_cleanup_callback:
|
||||
self.session_cleanup_callback()
|
||||
self.session_cleanup_callback = None
|
||||
elif msg[0] == "unregister":
|
||||
if msg[1] in self.tasks:
|
||||
self.tasks[msg[1]].discard(msg[2])
|
||||
self.task_cnt -= 1
|
||||
#
|
||||
# if there are no connection tasks left, then shutdown the kernel
|
||||
#
|
||||
if self.task_cnt == 0 and self.task_cnt_max >= 4:
|
||||
asyncio.create_task(self.session_shutdown())
|
||||
await asyncio.sleep(10000)
|
||||
elif msg[0] == "shutdown":
|
||||
asyncio.create_task(self.session_shutdown())
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
_LOGGER.error("housekeep task exception: %s", traceback.format_exc(-1))
|
||||
|
||||
async def startup_timeout(self):
|
||||
"""Shut down the session if nothing connects after 30 seconds."""
|
||||
await self.housekeep_q.put(["register", "startup_timeout", asyncio.current_task()])
|
||||
await asyncio.sleep(self.no_connect_timeout)
|
||||
if self.task_cnt_max <= 1:
|
||||
#
|
||||
# nothing started other than us, so shut down the session
|
||||
#
|
||||
_LOGGER.error("No connections to session %s; shutting down", self.global_ctx_name)
|
||||
if self.session_cleanup_callback:
|
||||
self.session_cleanup_callback()
|
||||
self.session_cleanup_callback = None
|
||||
await self.housekeep_q.put(["shutdown"])
|
||||
await self.housekeep_q.put(["unregister", "startup_timeout", asyncio.current_task()])
|
||||
|
||||
async def start_one_server(self, callback):
|
||||
"""Start a server by finding an available port."""
|
||||
first_port = self.avail_port
|
||||
for _ in range(2048):
|
||||
try:
|
||||
server = await asyncio.start_server(callback, "0.0.0.0", self.avail_port)
|
||||
return server, self.avail_port
|
||||
except OSError:
|
||||
self.avail_port += 1
|
||||
_LOGGER.error(
|
||||
"unable to find an available port from %d to %d",
|
||||
first_port,
|
||||
self.avail_port - 1,
|
||||
)
|
||||
return None, None
|
||||
|
||||
def get_ports(self):
|
||||
"""Return a dict of the port numbers this kernel session is listening to."""
|
||||
return {
|
||||
"iopub_port": self.iopub_port,
|
||||
"hb_port": self.heartbeat_port,
|
||||
"control_port": self.control_port,
|
||||
"stdin_port": self.stdin_port,
|
||||
"shell_port": self.shell_port,
|
||||
}
|
||||
|
||||
def set_session_cleanup_callback(self, callback):
|
||||
"""Set a cleanup callback which is called right after the session has started."""
|
||||
self.session_cleanup_callback = callback
|
||||
|
||||
async def session_start(self):
|
||||
"""Start the kernel session."""
|
||||
self.ast_ctx.add_logger_handler(self.console)
|
||||
_LOGGER.info("Starting session %s", self.global_ctx_name)
|
||||
|
||||
self.tasks["housekeep"] = {asyncio.create_task(self.housekeep_run())}
|
||||
self.tasks["startup_timeout"] = {asyncio.create_task(self.startup_timeout())}
|
||||
|
||||
self.iopub_server, self.iopub_port = await self.start_one_server(self.iopub_listen)
|
||||
self.heartbeat_server, self.heartbeat_port = await self.start_one_server(self.heartbeat_listen)
|
||||
self.control_server, self.control_port = await self.start_one_server(self.control_listen)
|
||||
self.stdin_server, self.stdin_port = await self.start_one_server(self.stdin_listen)
|
||||
self.shell_server, self.shell_port = await self.start_one_server(self.shell_listen)
|
||||
|
||||
#
|
||||
# For debugging, can use the real ZMQ library instead on certain sockets; comment out
|
||||
# the corresponding asyncio.start_server() call above if you enable the ZMQ-based
|
||||
# functions here. You can then turn of verbosity level 4 (-vvvv) in hass_pyscript_kernel.py
|
||||
# to see all the byte data in case you need to debug the simple ZMQ implementation here.
|
||||
# The two most important zmq functions are shown below.
|
||||
#
|
||||
# import zmq
|
||||
# import zmq.asyncio
|
||||
#
|
||||
# def zmq_bind(socket, connection, port):
|
||||
# """Bind a socket."""
|
||||
# if port <= 0:
|
||||
# return socket.bind_to_random_port(connection)
|
||||
# # _LOGGER.debug(f"binding to %s:%s" % (connection, port))
|
||||
# socket.bind("%s:%s" % (connection, port))
|
||||
# return port
|
||||
#
|
||||
# zmq_ctx = zmq.asyncio.Context()
|
||||
#
|
||||
# ##########################################
|
||||
# # Shell using real ZMQ for debugging:
|
||||
# async def shell_listen_zmq():
|
||||
# """Task that listens to shell messages using ZMQ."""
|
||||
# try:
|
||||
# _LOGGER.debug("shell_listen_zmq connected")
|
||||
# connection = self.config["transport"] + "://" + self.config["ip"]
|
||||
# shell_socket = zmq_ctx.socket(zmq.ROUTER)
|
||||
# self.shell_port = zmq_bind(shell_socket, connection, -1)
|
||||
# _LOGGER.debug("shell_listen_zmq connected")
|
||||
# while 1:
|
||||
# msg = await shell_socket.recv_multipart()
|
||||
# await self.shell_handler(shell_socket, msg)
|
||||
# except asyncio.CancelledError:
|
||||
# raise
|
||||
# except Exception:
|
||||
# _LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
|
||||
# await self.housekeep_q.put(["shutdown"])
|
||||
#
|
||||
# ##########################################
|
||||
# # IOPub using real ZMQ for debugging:
|
||||
# # IOPub/Sub:
|
||||
# async def iopub_listen_zmq():
|
||||
# """Task that listens to iopub messages using ZMQ."""
|
||||
# try:
|
||||
# _LOGGER.debug("iopub_listen_zmq connected")
|
||||
# connection = self.config["transport"] + "://" + self.config["ip"]
|
||||
# iopub_socket = zmq_ctx.socket(zmq.PUB)
|
||||
# self.iopub_port = zmq_bind(self.iopub_socket, connection, -1)
|
||||
# self.iopub_socket.add(iopub_socket)
|
||||
# while 1:
|
||||
# wire_msg = await iopub_socket.recv_multipart()
|
||||
# _LOGGER.debug("iopub received %s", wire_msg)
|
||||
# except asyncio.CancelledError:
|
||||
# raise
|
||||
# except EOFError:
|
||||
# await self.housekeep_q.put(["shutdown"])
|
||||
# _LOGGER.debug("iopub_listen got eof")
|
||||
# except Exception as err:
|
||||
# _LOGGER.error("iopub_listen exception %s", err)
|
||||
# await self.housekeep_q.put(["shutdown"])
|
||||
#
|
||||
# self.tasks["shell"] = {asyncio.create_task(shell_listen_zmq())}
|
||||
# self.tasks["iopub"] = {asyncio.create_task(iopub_listen_zmq())}
|
||||
#
|
||||
|
||||
async def session_shutdown(self):
|
||||
"""Shutdown the kernel session."""
|
||||
if not self.iopub_server:
|
||||
# already shutdown, so quit
|
||||
return
|
||||
GlobalContextMgr.delete(self.global_ctx_name)
|
||||
self.ast_ctx.remove_logger_handler(self.console)
|
||||
# logging.getLogger("homeassistant.components.pyscript.func.").removeHandler(self.console)
|
||||
_LOGGER.info("Shutting down session %s", self.global_ctx_name)
|
||||
|
||||
for server in [
|
||||
self.heartbeat_server,
|
||||
self.control_server,
|
||||
self.stdin_server,
|
||||
self.shell_server,
|
||||
self.iopub_server,
|
||||
]:
|
||||
if server:
|
||||
server.close()
|
||||
self.heartbeat_server = None
|
||||
self.iopub_server = None
|
||||
self.control_server = None
|
||||
self.stdin_server = None
|
||||
self.shell_server = None
|
||||
|
||||
for task_set in self.tasks.values():
|
||||
for task in task_set:
|
||||
try:
|
||||
task.cancel()
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.tasks = []
|
||||
|
||||
for sock in self.iopub_socket:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception as err:
|
||||
_LOGGER.error("iopub socket close exception: %s", err)
|
||||
|
||||
self.iopub_socket = set()
|
||||
45
config/custom_components/pyscript/logbook.py
Normal file
45
config/custom_components/pyscript/logbook.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Describe logbook events."""
|
||||
import logging
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_describe_events(hass, async_describe_event): # type: ignore
|
||||
"""Describe logbook events."""
|
||||
|
||||
@callback
|
||||
def async_describe_logbook_event(event): # type: ignore
|
||||
"""Describe a logbook event."""
|
||||
data = event.data
|
||||
func_args = data.get("func_args", {})
|
||||
ev_name = data.get("name", "unknown")
|
||||
ev_entity_id = data.get("entity_id", "pyscript.unknown")
|
||||
|
||||
ev_trigger_type = func_args.get("trigger_type", "unknown")
|
||||
if ev_trigger_type == "event":
|
||||
ev_source = f"event {func_args.get('event_type', 'unknown event')}"
|
||||
elif ev_trigger_type == "state":
|
||||
ev_source = f"state change {func_args.get('var_name', 'unknown entity')} == {func_args.get('value', 'unknown value')}"
|
||||
elif ev_trigger_type == "time":
|
||||
ev_trigger_time = func_args.get("trigger_time", "unknown")
|
||||
if ev_trigger_time is None:
|
||||
ev_trigger_time = "startup"
|
||||
ev_source = f"time {ev_trigger_time}"
|
||||
else:
|
||||
ev_source = ev_trigger_type
|
||||
|
||||
message = f"has been triggered by {ev_source}"
|
||||
|
||||
return {
|
||||
"name": ev_name,
|
||||
"message": message,
|
||||
"source": ev_source,
|
||||
"entity_id": ev_entity_id,
|
||||
}
|
||||
|
||||
async_describe_event(DOMAIN, "pyscript_running", async_describe_logbook_event)
|
||||
17
config/custom_components/pyscript/manifest.json
Normal file
17
config/custom_components/pyscript/manifest.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"domain": "pyscript",
|
||||
"name": "Pyscript Python scripting",
|
||||
"codeowners": [
|
||||
"@craigbarratt"
|
||||
],
|
||||
"config_flow": true,
|
||||
"dependencies": [],
|
||||
"documentation": "https://github.com/custom-components/pyscript",
|
||||
"homekit": {},
|
||||
"iot_class": "local_push",
|
||||
"issue_tracker": "https://github.com/custom-components/pyscript/issues",
|
||||
"requirements": ["croniter==2.0.2", "watchdog==2.3.1"],
|
||||
"ssdp": [],
|
||||
"version": "1.6.1",
|
||||
"zeroconf": []
|
||||
}
|
||||
91
config/custom_components/pyscript/mqtt.py
Normal file
91
config/custom_components/pyscript/mqtt.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Handles mqtt messages and notification."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from homeassistant.components import mqtt
|
||||
|
||||
from .const import LOGGER_PATH
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".mqtt")
|
||||
|
||||
|
||||
class Mqtt:
|
||||
"""Define mqtt functions."""
|
||||
|
||||
#
|
||||
# Global hass instance
|
||||
#
|
||||
hass = None
|
||||
|
||||
#
|
||||
# notify message queues by mqtt message topic
|
||||
#
|
||||
notify = {}
|
||||
notify_remove = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Warn on Mqtt instantiation."""
|
||||
_LOGGER.error("Mqtt class is not meant to be instantiated")
|
||||
|
||||
@classmethod
|
||||
def init(cls, hass):
|
||||
"""Initialize Mqtt."""
|
||||
|
||||
cls.hass = hass
|
||||
|
||||
@classmethod
|
||||
def mqtt_message_handler_maker(cls, subscribed_topic):
|
||||
"""Closure for mqtt_message_handler."""
|
||||
|
||||
async def mqtt_message_handler(mqttmsg):
|
||||
"""Listen for MQTT messages."""
|
||||
func_args = {
|
||||
"trigger_type": "mqtt",
|
||||
"topic": mqttmsg.topic,
|
||||
"payload": mqttmsg.payload,
|
||||
"qos": mqttmsg.qos,
|
||||
}
|
||||
|
||||
try:
|
||||
func_args["payload_obj"] = json.loads(mqttmsg.payload)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
await cls.update(subscribed_topic, func_args)
|
||||
|
||||
return mqtt_message_handler
|
||||
|
||||
@classmethod
|
||||
async def notify_add(cls, topic, queue):
|
||||
"""Register to notify for mqtt messages of given topic to be sent to queue."""
|
||||
|
||||
if topic not in cls.notify:
|
||||
cls.notify[topic] = set()
|
||||
_LOGGER.debug("mqtt.notify_add(%s) -> adding mqtt subscription", topic)
|
||||
cls.notify_remove[topic] = await mqtt.async_subscribe(
|
||||
cls.hass, topic, cls.mqtt_message_handler_maker(topic), encoding="utf-8", qos=0
|
||||
)
|
||||
cls.notify[topic].add(queue)
|
||||
|
||||
@classmethod
|
||||
def notify_del(cls, topic, queue):
|
||||
"""Unregister to notify for mqtt messages of given topic for given queue."""
|
||||
|
||||
if topic not in cls.notify or queue not in cls.notify[topic]:
|
||||
return
|
||||
cls.notify[topic].discard(queue)
|
||||
if len(cls.notify[topic]) == 0:
|
||||
cls.notify_remove[topic]()
|
||||
_LOGGER.debug("mqtt.notify_del(%s) -> removing mqtt subscription", topic)
|
||||
del cls.notify[topic]
|
||||
del cls.notify_remove[topic]
|
||||
|
||||
@classmethod
|
||||
async def update(cls, topic, func_args):
|
||||
"""Deliver all notifications for an mqtt message on the given topic."""
|
||||
|
||||
_LOGGER.debug("mqtt.update(%s, %s, %s)", topic, vars, func_args)
|
||||
if topic in cls.notify:
|
||||
for queue in cls.notify[topic]:
|
||||
await queue.put(["mqtt", func_args.copy()])
|
||||
323
config/custom_components/pyscript/requirements.py
Normal file
323
config/custom_components/pyscript/requirements.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Requirements helpers for pyscript."""
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.requirements import async_process_requirements
|
||||
|
||||
from .const import (
|
||||
ATTR_INSTALLED_VERSION,
|
||||
ATTR_SOURCES,
|
||||
ATTR_VERSION,
|
||||
CONF_ALLOW_ALL_IMPORTS,
|
||||
CONF_INSTALLED_PACKAGES,
|
||||
DOMAIN,
|
||||
LOGGER_PATH,
|
||||
REQUIREMENTS_FILE,
|
||||
REQUIREMENTS_PATHS,
|
||||
UNPINNED_VERSION,
|
||||
)
|
||||
|
||||
if sys.version_info[:2] >= (3, 8):
|
||||
from importlib.metadata import ( # pylint: disable=no-name-in-module,import-error
|
||||
PackageNotFoundError,
|
||||
version as installed_version,
|
||||
)
|
||||
else:
|
||||
from importlib_metadata import ( # pylint: disable=import-error
|
||||
PackageNotFoundError,
|
||||
version as installed_version,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH)
|
||||
|
||||
|
||||
def get_installed_version(pkg_name):
|
||||
"""Get installed version of package. Returns None if not found."""
|
||||
try:
|
||||
return installed_version(pkg_name)
|
||||
except PackageNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def update_unpinned_versions(package_dict):
|
||||
"""Check for current installed version of each unpinned package."""
|
||||
requirements_to_pop = []
|
||||
for package in package_dict:
|
||||
if package_dict[package] != UNPINNED_VERSION:
|
||||
continue
|
||||
|
||||
package_dict[package] = get_installed_version(package)
|
||||
if not package_dict[package]:
|
||||
_LOGGER.error("%s wasn't able to be installed", package)
|
||||
requirements_to_pop.append(package)
|
||||
|
||||
for package in requirements_to_pop:
|
||||
package_dict.pop(package)
|
||||
|
||||
return package_dict
|
||||
|
||||
|
||||
@bind_hass
|
||||
def process_all_requirements(pyscript_folder, requirements_paths, requirements_file):
|
||||
"""
|
||||
Load all lines from requirements_file located in requirements_paths.
|
||||
|
||||
Returns files and a list of packages, if any, that need to be installed.
|
||||
"""
|
||||
|
||||
# Re-import Version to avoid dealing with multiple flake and pylint errors
|
||||
from packaging.version import Version # pylint: disable=import-outside-toplevel
|
||||
|
||||
all_requirements_to_process = {}
|
||||
for root in requirements_paths:
|
||||
for requirements_path in glob.glob(os.path.join(pyscript_folder, root, requirements_file)):
|
||||
with open(requirements_path, "r", encoding="utf-8") as requirements_fp:
|
||||
all_requirements_to_process[requirements_path] = requirements_fp.readlines()
|
||||
|
||||
all_requirements_to_install = {}
|
||||
for requirements_path, pkg_lines in all_requirements_to_process.items():
|
||||
for pkg in pkg_lines:
|
||||
# Remove inline comments which are accepted by pip but not by Home
|
||||
# Assistant's installation method.
|
||||
# https://rosettacode.org/wiki/Strip_comments_from_a_string#Python
|
||||
i = pkg.find("#")
|
||||
if i >= 0:
|
||||
pkg = pkg[:i]
|
||||
pkg = pkg.strip()
|
||||
|
||||
if not pkg or len(pkg) == 0:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Attempt to get version of package. Do nothing if it's found since
|
||||
# we want to use the version that's already installed to be safe
|
||||
parts = pkg.split("==")
|
||||
if len(parts) > 2 or "," in pkg or ">" in pkg or "<" in pkg:
|
||||
_LOGGER.error(
|
||||
(
|
||||
"Ignoring invalid requirement '%s' specified in '%s'; if a specific version"
|
||||
"is required, the requirement must use the format 'pkg==version'"
|
||||
),
|
||||
requirements_path,
|
||||
pkg,
|
||||
)
|
||||
continue
|
||||
if len(parts) == 1:
|
||||
new_version = UNPINNED_VERSION
|
||||
else:
|
||||
new_version = parts[1]
|
||||
pkg_name = parts[0]
|
||||
|
||||
current_pinned_version = all_requirements_to_install.get(pkg_name, {}).get(ATTR_VERSION)
|
||||
current_sources = all_requirements_to_install.get(pkg_name, {}).get(ATTR_SOURCES, [])
|
||||
# If a version hasn't already been recorded, record this one
|
||||
if not current_pinned_version:
|
||||
all_requirements_to_install[pkg_name] = {
|
||||
ATTR_VERSION: new_version,
|
||||
ATTR_SOURCES: [requirements_path],
|
||||
ATTR_INSTALLED_VERSION: get_installed_version(pkg_name),
|
||||
}
|
||||
|
||||
# If the new version is unpinned and there is an existing pinned version, use existing
|
||||
# pinned version
|
||||
elif new_version == UNPINNED_VERSION and current_pinned_version != UNPINNED_VERSION:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Unpinned requirement for package '%s' detected in '%s' will be ignored in "
|
||||
"favor of the pinned version '%s' detected in '%s'"
|
||||
),
|
||||
pkg_name,
|
||||
requirements_path,
|
||||
current_pinned_version,
|
||||
str(current_sources),
|
||||
)
|
||||
# If the new version is pinned and the existing version is unpinned, use the new pinned
|
||||
# version
|
||||
elif new_version != UNPINNED_VERSION and current_pinned_version == UNPINNED_VERSION:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Unpinned requirement for package '%s' detected in '%s will be ignored in "
|
||||
"favor of the pinned version '%s' detected in '%s'"
|
||||
),
|
||||
pkg_name,
|
||||
str(current_sources),
|
||||
new_version,
|
||||
requirements_path,
|
||||
)
|
||||
all_requirements_to_install[pkg_name] = {
|
||||
ATTR_VERSION: new_version,
|
||||
ATTR_SOURCES: [requirements_path],
|
||||
ATTR_INSTALLED_VERSION: get_installed_version(pkg_name),
|
||||
}
|
||||
# If the already recorded version is the same as the new version, append the current
|
||||
# path so we can show sources
|
||||
elif (
|
||||
new_version == UNPINNED_VERSION and current_pinned_version == UNPINNED_VERSION
|
||||
) or Version(current_pinned_version) == Version(new_version):
|
||||
all_requirements_to_install[pkg_name][ATTR_SOURCES].append(requirements_path)
|
||||
# If the already recorded version is lower than the new version, use the new one
|
||||
elif Version(current_pinned_version) < Version(new_version):
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Version '%s' for package '%s' detected in '%s' will be ignored in "
|
||||
"favor of the higher version '%s' detected in '%s'"
|
||||
),
|
||||
current_pinned_version,
|
||||
pkg_name,
|
||||
str(current_sources),
|
||||
new_version,
|
||||
requirements_path,
|
||||
)
|
||||
all_requirements_to_install[pkg_name].update(
|
||||
{ATTR_VERSION: new_version, ATTR_SOURCES: [requirements_path]}
|
||||
)
|
||||
# If the already recorded version is higher than the new version, ignore the new one
|
||||
elif Version(current_pinned_version) > Version(new_version):
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Version '%s' for package '%s' detected in '%s' will be ignored in "
|
||||
"favor of the higher version '%s' detected in '%s'"
|
||||
),
|
||||
new_version,
|
||||
pkg_name,
|
||||
requirements_path,
|
||||
current_pinned_version,
|
||||
str(current_sources),
|
||||
)
|
||||
except ValueError:
|
||||
# Not valid requirements line so it can be skipped
|
||||
_LOGGER.debug("Ignoring '%s' because it is not a valid package", pkg)
|
||||
|
||||
return all_requirements_to_install
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def install_requirements(hass, config_entry, pyscript_folder):
|
||||
"""Install missing requirements from requirements.txt."""
|
||||
|
||||
pyscript_installed_packages = config_entry.data.get(CONF_INSTALLED_PACKAGES, {}).copy()
|
||||
|
||||
# Import packaging inside install_requirements so that we can use Home Assistant to install it
|
||||
# if it can't been found
|
||||
try:
|
||||
from packaging.version import Version # pylint: disable=import-outside-toplevel
|
||||
except ModuleNotFoundError:
|
||||
await async_process_requirements(hass, DOMAIN, ["packaging"])
|
||||
from packaging.version import Version # pylint: disable=import-outside-toplevel
|
||||
|
||||
all_requirements = await hass.async_add_executor_job(
|
||||
process_all_requirements, pyscript_folder, REQUIREMENTS_PATHS, REQUIREMENTS_FILE
|
||||
)
|
||||
|
||||
requirements_to_install = {}
|
||||
|
||||
if all_requirements and not config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False):
|
||||
_LOGGER.error(
|
||||
(
|
||||
"Requirements detected but 'allow_all_imports' is set to False, set "
|
||||
"'allow_all_imports' to True if you want packages to be installed"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
for package in all_requirements:
|
||||
pkg_installed_version = all_requirements[package].get(ATTR_INSTALLED_VERSION)
|
||||
version_to_install = all_requirements[package][ATTR_VERSION]
|
||||
sources = all_requirements[package][ATTR_SOURCES]
|
||||
# If package is already installed, we need to run some checks
|
||||
if pkg_installed_version:
|
||||
# If the version to install is unpinned and there is already something installed,
|
||||
# defer to what is installed
|
||||
if version_to_install == UNPINNED_VERSION:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"Skipping unpinned version of package '%s' because version '%s' is "
|
||||
"already installed"
|
||||
),
|
||||
package,
|
||||
pkg_installed_version,
|
||||
)
|
||||
# If installed package is not the same version as the one we last installed,
|
||||
# that means that the package is externally managed now so we shouldn't touch it
|
||||
# and should remove it from our internal tracker
|
||||
if (
|
||||
package in pyscript_installed_packages
|
||||
and pyscript_installed_packages[package] != pkg_installed_version
|
||||
):
|
||||
pyscript_installed_packages.pop(package)
|
||||
continue
|
||||
|
||||
# If installed package is not the same version as the one we last installed,
|
||||
# that means that the package is externally managed now so we shouldn't touch it
|
||||
# and should remove it from our internal tracker
|
||||
if package in pyscript_installed_packages and Version(
|
||||
pyscript_installed_packages[package]
|
||||
) != Version(pkg_installed_version):
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Version '%s' for package '%s' detected in '%s' will be ignored in favor of"
|
||||
" the version '%s' which was installed outside of pyscript"
|
||||
),
|
||||
version_to_install,
|
||||
package,
|
||||
str(sources),
|
||||
pkg_installed_version,
|
||||
)
|
||||
pyscript_installed_packages.pop(package)
|
||||
# If there is a version mismatch between what we want and what is installed, we
|
||||
# can overwrite it since we know it was last installed by us
|
||||
elif package in pyscript_installed_packages and Version(version_to_install) != Version(
|
||||
pkg_installed_version
|
||||
):
|
||||
requirements_to_install[package] = all_requirements[package]
|
||||
# If there is an installed version that we have not previously installed, we
|
||||
# should not install it
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"Version '%s' for package '%s' detected in '%s' will be ignored because it"
|
||||
" is already installed"
|
||||
),
|
||||
version_to_install,
|
||||
package,
|
||||
str(sources),
|
||||
)
|
||||
# Anything not already installed in the environment can be installed
|
||||
else:
|
||||
requirements_to_install[package] = all_requirements[package]
|
||||
|
||||
if requirements_to_install:
|
||||
_LOGGER.info(
|
||||
"Installing the following packages: %s",
|
||||
str(requirements_to_install),
|
||||
)
|
||||
await async_process_requirements(
|
||||
hass,
|
||||
DOMAIN,
|
||||
[
|
||||
f"{package}=={pkg_info[ATTR_VERSION]}"
|
||||
if pkg_info[ATTR_VERSION] != UNPINNED_VERSION
|
||||
else package
|
||||
for package, pkg_info in requirements_to_install.items()
|
||||
],
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug("No new packages to install")
|
||||
|
||||
# Update package tracker in config entry for next time
|
||||
pyscript_installed_packages.update(
|
||||
{package: pkg_info[ATTR_VERSION] for package, pkg_info in requirements_to_install.items()}
|
||||
)
|
||||
|
||||
# If any requirements were unpinned, get their version now so they can be pinned later
|
||||
if any(version == UNPINNED_VERSION for version in pyscript_installed_packages.values()):
|
||||
pyscript_installed_packages = await hass.async_add_executor_job(
|
||||
update_unpinned_versions, pyscript_installed_packages
|
||||
)
|
||||
if pyscript_installed_packages != config_entry.data.get(CONF_INSTALLED_PACKAGES, {}):
|
||||
new_data = config_entry.data.copy()
|
||||
new_data[CONF_INSTALLED_PACKAGES] = pyscript_installed_packages
|
||||
hass.config_entries.async_update_entry(entry=config_entry, data=new_data)
|
||||
107
config/custom_components/pyscript/services.yaml
Normal file
107
config/custom_components/pyscript/services.yaml
Normal file
@@ -0,0 +1,107 @@
|
||||
# Describes the format for available pyscript services
|
||||
|
||||
reload:
|
||||
name: Reload pyscript
|
||||
description: Reloads all available pyscripts and restart triggers
|
||||
fields:
|
||||
global_ctx:
|
||||
name: Global Context
|
||||
description: Only reload this specific global context (file or app)
|
||||
example: file.example
|
||||
required: false
|
||||
selector:
|
||||
text:
|
||||
|
||||
jupyter_kernel_start:
|
||||
name: Start Jupyter kernel
|
||||
description: Starts a jupyter kernel for interactive use; Called by Jupyter front end and should generally not be used by users
|
||||
fields:
|
||||
shell_port:
|
||||
name: Shell Port Number
|
||||
description: Shell port number
|
||||
example: 63599
|
||||
required: false
|
||||
selector:
|
||||
number:
|
||||
min: 10240
|
||||
max: 65535
|
||||
iopub_port:
|
||||
name: IOPub Port Number
|
||||
description: IOPub port number
|
||||
example: 63598
|
||||
required: false
|
||||
selector:
|
||||
number:
|
||||
min: 10240
|
||||
max: 65535
|
||||
stdin_port:
|
||||
name: Stdin Port Number
|
||||
description: Stdin port number
|
||||
example: 63597
|
||||
required: false
|
||||
selector:
|
||||
number:
|
||||
min: 10240
|
||||
max: 65535
|
||||
control_port:
|
||||
name: Control Port Number
|
||||
description: Control port number
|
||||
example: 63596
|
||||
required: false
|
||||
selector:
|
||||
number:
|
||||
min: 10240
|
||||
max: 65535
|
||||
hb_port:
|
||||
name: Heartbeat Port Number
|
||||
description: Heartbeat port number
|
||||
example: 63595
|
||||
required: false
|
||||
selector:
|
||||
number:
|
||||
min: 10240
|
||||
max: 65535
|
||||
ip:
|
||||
name: IP Address
|
||||
description: IP address to connect to Jupyter front end
|
||||
example: 127.0.0.1
|
||||
default: 127.0.0.1
|
||||
required: false
|
||||
selector:
|
||||
text:
|
||||
key:
|
||||
name: Security Key
|
||||
description: Used for signing
|
||||
example: 012345678-9abcdef023456789abcdef
|
||||
required: true
|
||||
selector:
|
||||
text:
|
||||
transport:
|
||||
name: Transport Type
|
||||
description: Transport type
|
||||
example: tcp
|
||||
default: tcp
|
||||
required: false
|
||||
selector:
|
||||
select:
|
||||
options:
|
||||
- tcp
|
||||
- udp
|
||||
signature_scheme:
|
||||
name: Signing Algorithm
|
||||
description: Signing algorithm
|
||||
example: hmac-sha256
|
||||
required: false
|
||||
default: hmac-sha256
|
||||
selector:
|
||||
select:
|
||||
options:
|
||||
- hmac-sha256
|
||||
kernel_name:
|
||||
name: Name of Kernel
|
||||
description: Kernel name
|
||||
example: pyscript
|
||||
required: true
|
||||
default: pyscript
|
||||
selector:
|
||||
text:
|
||||
438
config/custom_components/pyscript/state.py
Normal file
438
config/custom_components/pyscript/state.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""Handles state variable access and change notification."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
|
||||
from .const import LOGGER_PATH
|
||||
from .entity import PyscriptEntity
|
||||
from .function import Function
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".state")
|
||||
|
||||
STATE_VIRTUAL_ATTRS = {"entity_id", "last_changed", "last_updated"}
|
||||
|
||||
|
||||
class StateVal(str):
|
||||
"""Class for representing the value and attributes of a state variable."""
|
||||
|
||||
def __new__(cls, state):
|
||||
"""Create a new instance given a state variable."""
|
||||
new_var = super().__new__(cls, state.state)
|
||||
new_var.__dict__ = state.attributes.copy()
|
||||
new_var.entity_id = state.entity_id
|
||||
new_var.last_updated = state.last_updated
|
||||
new_var.last_changed = state.last_changed
|
||||
return new_var
|
||||
|
||||
|
||||
class State:
|
||||
"""Class for state functions."""
|
||||
|
||||
#
|
||||
# Global hass instance
|
||||
#
|
||||
hass = None
|
||||
|
||||
#
|
||||
# notify message queues by variable
|
||||
#
|
||||
notify = {}
|
||||
|
||||
#
|
||||
# Last value of state variable notifications. We maintain this
|
||||
# so that trigger evaluation can use the last notified value,
|
||||
# rather than fetching the current value, which is subject to
|
||||
# race conditions when multiple state variables are set quickly.
|
||||
#
|
||||
notify_var_last = {}
|
||||
|
||||
#
|
||||
# pyscript yaml configuration
|
||||
#
|
||||
pyscript_config = {}
|
||||
|
||||
#
|
||||
# pyscript vars which have already been registered as persisted
|
||||
#
|
||||
persisted_vars = {}
|
||||
|
||||
#
|
||||
# other parameters of all services that have "entity_id" as a parameter
|
||||
#
|
||||
service2args = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Warn on State instantiation."""
|
||||
_LOGGER.error("State class is not meant to be instantiated")
|
||||
|
||||
@classmethod
|
||||
def init(cls, hass):
|
||||
"""Initialize State."""
|
||||
cls.hass = hass
|
||||
|
||||
@classmethod
|
||||
async def get_service_params(cls):
|
||||
"""Get parameters for all services."""
|
||||
cls.service2args = {}
|
||||
all_services = await async_get_all_descriptions(cls.hass)
|
||||
for domain in all_services:
|
||||
cls.service2args[domain] = {}
|
||||
for service, desc in all_services[domain].items():
|
||||
if "entity_id" not in desc["fields"] and "target" not in desc:
|
||||
continue
|
||||
cls.service2args[domain][service] = set(desc["fields"].keys())
|
||||
cls.service2args[domain][service].discard("entity_id")
|
||||
|
||||
@classmethod
|
||||
async def notify_add(cls, var_names, queue):
|
||||
"""Register to notify state variables changes to be sent to queue."""
|
||||
|
||||
added = False
|
||||
for var_name in var_names if isinstance(var_names, set) else {var_names}:
|
||||
parts = var_name.split(".")
|
||||
if len(parts) != 2 and len(parts) != 3:
|
||||
continue
|
||||
state_var_name = f"{parts[0]}.{parts[1]}"
|
||||
if state_var_name not in cls.notify:
|
||||
cls.notify[state_var_name] = {}
|
||||
cls.notify[state_var_name][queue] = var_names
|
||||
added = True
|
||||
return added
|
||||
|
||||
@classmethod
|
||||
def notify_del(cls, var_names, queue):
|
||||
"""Unregister notify of state variables changes for given queue."""
|
||||
|
||||
for var_name in var_names if isinstance(var_names, set) else {var_names}:
|
||||
parts = var_name.split(".")
|
||||
if len(parts) != 2 and len(parts) != 3:
|
||||
continue
|
||||
state_var_name = f"{parts[0]}.{parts[1]}"
|
||||
if state_var_name not in cls.notify or queue not in cls.notify[state_var_name]:
|
||||
return
|
||||
del cls.notify[state_var_name][queue]
|
||||
|
||||
@classmethod
|
||||
async def update(cls, new_vars, func_args):
|
||||
"""Deliver all notifications for state variable changes."""
|
||||
|
||||
notify = {}
|
||||
for var_name, var_val in new_vars.items():
|
||||
if var_name in cls.notify:
|
||||
cls.notify_var_last[var_name] = var_val
|
||||
notify.update(cls.notify[var_name])
|
||||
|
||||
if notify:
|
||||
_LOGGER.debug("state.update(%s, %s)", new_vars, func_args)
|
||||
for queue, var_names in notify.items():
|
||||
await queue.put(["state", [cls.notify_var_get(var_names, new_vars), func_args.copy()]])
|
||||
|
||||
@classmethod
|
||||
def notify_var_get(cls, var_names, new_vars):
|
||||
"""Add values of var_names to new_vars, or default to None."""
|
||||
notify_vars = new_vars.copy()
|
||||
for var_name in var_names if var_names is not None else []:
|
||||
if var_name in notify_vars:
|
||||
continue
|
||||
parts = var_name.split(".")
|
||||
if var_name in cls.notify_var_last:
|
||||
notify_vars[var_name] = cls.notify_var_last[var_name]
|
||||
elif len(parts) == 3 and f"{parts[0]}.{parts[1]}" in cls.notify_var_last:
|
||||
notify_vars[var_name] = getattr(
|
||||
cls.notify_var_last[f"{parts[0]}.{parts[1]}"], parts[2], None
|
||||
)
|
||||
elif len(parts) == 4 and parts[2] == "old" and f"{parts[0]}.{parts[1]}.old" in notify_vars:
|
||||
notify_vars[var_name] = getattr(notify_vars[f"{parts[0]}.{parts[1]}.old"], parts[3], None)
|
||||
elif 1 <= var_name.count(".") <= 3 and not cls.exist(var_name):
|
||||
notify_vars[var_name] = None
|
||||
return notify_vars
|
||||
|
||||
@classmethod
|
||||
def set(cls, var_name, value=None, new_attributes=None, **kwargs):
|
||||
"""Set a state variable and optional attributes in hass."""
|
||||
if var_name.count(".") != 1:
|
||||
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
|
||||
|
||||
if isinstance(value, StateVal):
|
||||
if new_attributes is None:
|
||||
#
|
||||
# value is a StateVal, so extract the attributes and value
|
||||
#
|
||||
new_attributes = value.__dict__.copy()
|
||||
for discard in STATE_VIRTUAL_ATTRS:
|
||||
new_attributes.pop(discard, None)
|
||||
value = str(value)
|
||||
|
||||
state_value = None
|
||||
if value is None or new_attributes is None:
|
||||
state_value = cls.hass.states.get(var_name)
|
||||
|
||||
if value is None and state_value:
|
||||
value = state_value.state
|
||||
|
||||
if new_attributes is None:
|
||||
if state_value:
|
||||
new_attributes = state_value.attributes.copy()
|
||||
else:
|
||||
new_attributes = {}
|
||||
|
||||
curr_task = asyncio.current_task()
|
||||
if "context" in kwargs and isinstance(kwargs["context"], Context):
|
||||
context = kwargs["context"]
|
||||
del kwargs["context"]
|
||||
else:
|
||||
context = Function.task2context.get(curr_task, None)
|
||||
|
||||
if kwargs:
|
||||
new_attributes = new_attributes.copy()
|
||||
new_attributes.update(kwargs)
|
||||
|
||||
_LOGGER.debug("setting %s = %s, attr = %s", var_name, value, new_attributes)
|
||||
cls.hass.states.async_set(var_name, value, new_attributes, context=context)
|
||||
if var_name in cls.notify_var_last or var_name in cls.notify:
|
||||
#
|
||||
# immediately update a variable we are monitoring since it could take a while
|
||||
# for the state changed event to propagate
|
||||
#
|
||||
cls.notify_var_last[var_name] = StateVal(cls.hass.states.get(var_name))
|
||||
|
||||
if var_name in cls.persisted_vars:
|
||||
cls.persisted_vars[var_name].set_state(value)
|
||||
cls.persisted_vars[var_name].set_attributes(new_attributes)
|
||||
|
||||
@classmethod
|
||||
def setattr(cls, var_attr_name, value):
|
||||
"""Set a state variable's attribute in hass."""
|
||||
parts = var_attr_name.split(".")
|
||||
if len(parts) != 3:
|
||||
raise NameError(f"invalid name {var_attr_name} (should be 'domain.entity.attr')")
|
||||
if not cls.exist(f"{parts[0]}.{parts[1]}"):
|
||||
raise NameError(f"state {parts[0]}.{parts[1]} doesn't exist")
|
||||
cls.set(f"{parts[0]}.{parts[1]}", **{parts[2]: value})
|
||||
|
||||
@classmethod
|
||||
async def register_persist(cls, var_name):
|
||||
"""Register pyscript state variable to be persisted with RestoreState."""
|
||||
if var_name.startswith("pyscript.") and var_name not in cls.persisted_vars:
|
||||
# this is a hack accessing hass internals; should re-implement using RestoreEntity
|
||||
restore_data = cls.hass.data[DATA_RESTORE_STATE]
|
||||
this_entity = PyscriptEntity()
|
||||
this_entity.entity_id = var_name
|
||||
cls.persisted_vars[var_name] = this_entity
|
||||
try:
|
||||
restore_data.async_restore_entity_added(this_entity)
|
||||
except TypeError:
|
||||
restore_data.async_restore_entity_added(var_name)
|
||||
|
||||
@classmethod
|
||||
async def persist(cls, var_name, default_value=None, default_attributes=None):
|
||||
"""Persist a pyscript domain state variable, and update with optional defaults."""
|
||||
if var_name.count(".") != 1 or not var_name.startswith("pyscript."):
|
||||
raise NameError(f"invalid name {var_name} (should be 'pyscript.entity')")
|
||||
|
||||
await cls.register_persist(var_name)
|
||||
exists = cls.exist(var_name)
|
||||
|
||||
if not exists and default_value is not None:
|
||||
cls.set(var_name, default_value, default_attributes)
|
||||
elif exists and default_attributes is not None:
|
||||
# Patch the attributes with new values if necessary
|
||||
current = cls.hass.states.get(var_name)
|
||||
new_attributes = {k: v for (k, v) in default_attributes.items() if k not in current.attributes}
|
||||
cls.set(var_name, current.state, **new_attributes)
|
||||
|
||||
@classmethod
|
||||
def exist(cls, var_name):
|
||||
"""Check if a state variable value or attribute exists in hass."""
|
||||
parts = var_name.split(".")
|
||||
if len(parts) != 2 and len(parts) != 3:
|
||||
return False
|
||||
value = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
|
||||
if value is None:
|
||||
return False
|
||||
if (
|
||||
len(parts) == 2
|
||||
or (parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]])
|
||||
or parts[2] in value.attributes
|
||||
or parts[2] in STATE_VIRTUAL_ATTRS
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get(cls, var_name):
|
||||
"""Get a state variable value or attribute from hass."""
|
||||
parts = var_name.split(".")
|
||||
if len(parts) != 2 and len(parts) != 3:
|
||||
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
|
||||
state = cls.hass.states.get(f"{parts[0]}.{parts[1]}")
|
||||
if not state:
|
||||
raise NameError(f"name '{parts[0]}.{parts[1]}' is not defined")
|
||||
#
|
||||
# simplest case is just the state value
|
||||
#
|
||||
state = StateVal(state)
|
||||
if len(parts) == 2:
|
||||
return state
|
||||
#
|
||||
# see if this is a service that has an entity_id parameter
|
||||
#
|
||||
if parts[0] in cls.service2args and parts[2] in cls.service2args[parts[0]]:
|
||||
params = cls.service2args[parts[0]][parts[2]]
|
||||
|
||||
def service_call_factory(domain, service, entity_id, params):
|
||||
async def service_call(*args, **kwargs):
|
||||
curr_task = asyncio.current_task()
|
||||
hass_args = {}
|
||||
for keyword, typ, default in [
|
||||
("context", [Context], Function.task2context.get(curr_task, None)),
|
||||
("blocking", [bool], None),
|
||||
("return_response", [bool], None),
|
||||
("limit", [float, int], None),
|
||||
]:
|
||||
if keyword in kwargs and type(kwargs[keyword]) in typ:
|
||||
hass_args[keyword] = kwargs.pop(keyword)
|
||||
elif default:
|
||||
hass_args[keyword] = default
|
||||
|
||||
kwargs["entity_id"] = entity_id
|
||||
if len(args) == 1 and len(params) == 1:
|
||||
#
|
||||
# with just a single parameter and positional argument, create the keyword setting
|
||||
#
|
||||
[param_name] = params
|
||||
kwargs[param_name] = args[0]
|
||||
elif len(args) != 0:
|
||||
raise TypeError(f"service {domain}.{service} takes no positional arguments")
|
||||
|
||||
# return await Function.hass_services_async_call(domain, service, kwargs, **hass_args)
|
||||
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
|
||||
|
||||
return service_call
|
||||
|
||||
return service_call_factory(parts[0], parts[2], f"{parts[0]}.{parts[1]}", params)
|
||||
#
|
||||
# finally see if it is an attribute
|
||||
#
|
||||
try:
|
||||
return getattr(state, parts[2])
|
||||
except AttributeError:
|
||||
raise AttributeError( # pylint: disable=raise-missing-from
|
||||
f"state '{parts[0]}.{parts[1]}' has no attribute '{parts[2]}'"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, var_name, context=None):
|
||||
"""Delete a state variable or attribute from hass."""
|
||||
parts = var_name.split(".")
|
||||
if not context:
|
||||
context = Function.task2context.get(asyncio.current_task(), None)
|
||||
context_arg = {"context": context} if context else {}
|
||||
if len(parts) == 2:
|
||||
if var_name in cls.notify_var_last or var_name in cls.notify:
|
||||
#
|
||||
# immediately update a variable we are monitoring since it could take a while
|
||||
# for the state changed event to propagate
|
||||
#
|
||||
cls.notify_var_last[var_name] = None
|
||||
if not cls.hass.states.async_remove(var_name, **context_arg):
|
||||
raise NameError(f"name '{var_name}' not defined")
|
||||
return
|
||||
if len(parts) == 3:
|
||||
var_name = f"{parts[0]}.{parts[1]}"
|
||||
value = cls.hass.states.get(var_name)
|
||||
if value is None:
|
||||
raise NameError(f"state {var_name} doesn't exist")
|
||||
new_attr = value.attributes.copy()
|
||||
if parts[2] not in new_attr:
|
||||
raise AttributeError(f"state '{var_name}' has no attribute '{parts[2]}'")
|
||||
del new_attr[parts[2]]
|
||||
cls.set(f"{var_name}", value.state, new_attributes=new_attr, **context_arg)
|
||||
return
|
||||
raise NameError(f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')")
|
||||
|
||||
@classmethod
|
||||
def getattr(cls, var_name):
|
||||
"""Return a dict of attributes for a state variable."""
|
||||
if isinstance(var_name, StateVal):
|
||||
attrs = var_name.__dict__.copy()
|
||||
for discard in STATE_VIRTUAL_ATTRS:
|
||||
attrs.pop(discard, None)
|
||||
return attrs
|
||||
if var_name.count(".") != 1:
|
||||
raise NameError(f"invalid name {var_name} (should be 'domain.entity')")
|
||||
value = cls.hass.states.get(var_name)
|
||||
if not value:
|
||||
return None
|
||||
return value.attributes.copy()
|
||||
|
||||
@classmethod
|
||||
def get_attr(cls, var_name):
|
||||
"""Return a dict of attributes for a state variable - deprecated."""
|
||||
_LOGGER.warning("state.get_attr() is deprecated: use state.getattr() instead")
|
||||
return cls.getattr(var_name)
|
||||
|
||||
@classmethod
|
||||
def completions(cls, root):
|
||||
"""Return possible completions of state variables."""
|
||||
words = set()
|
||||
parts = root.split(".")
|
||||
num_period = len(parts) - 1
|
||||
if num_period == 2:
|
||||
#
|
||||
# complete state attributes
|
||||
#
|
||||
last_period = root.rfind(".")
|
||||
name = root[0:last_period]
|
||||
value = cls.hass.states.get(name)
|
||||
if value:
|
||||
attr_root = root[last_period + 1 :]
|
||||
attrs = set(value.attributes.keys()).union(STATE_VIRTUAL_ATTRS)
|
||||
if parts[0] in cls.service2args:
|
||||
attrs.update(set(cls.service2args[parts[0]].keys()))
|
||||
for attr_name in attrs:
|
||||
if attr_name.lower().startswith(attr_root):
|
||||
words.add(f"{name}.{attr_name}")
|
||||
elif num_period < 2:
|
||||
#
|
||||
# complete among all state names
|
||||
#
|
||||
for name in cls.hass.states.async_all():
|
||||
if name.entity_id.lower().startswith(root):
|
||||
words.add(name.entity_id)
|
||||
return words
|
||||
|
||||
@classmethod
|
||||
async def names(cls, domain=None):
|
||||
"""Implement names, which returns all entity_ids."""
|
||||
return cls.hass.states.async_entity_ids(domain)
|
||||
|
||||
@classmethod
|
||||
def register_functions(cls):
|
||||
"""Register state functions and config variable."""
|
||||
functions = {
|
||||
"state.get": cls.get,
|
||||
"state.set": cls.set,
|
||||
"state.setattr": cls.setattr,
|
||||
"state.names": cls.names,
|
||||
"state.getattr": cls.getattr,
|
||||
"state.get_attr": cls.get_attr, # deprecated form; to be removed
|
||||
"state.persist": cls.persist,
|
||||
"state.delete": cls.delete,
|
||||
"pyscript.config": cls.pyscript_config,
|
||||
}
|
||||
Function.register(functions)
|
||||
|
||||
@classmethod
|
||||
def set_pyscript_config(cls, config):
|
||||
"""Set pyscript yaml config."""
|
||||
#
|
||||
# have to update inplace, since dest is already used as value
|
||||
#
|
||||
cls.pyscript_config.clear()
|
||||
for name, value in config.items():
|
||||
cls.pyscript_config[name] = value
|
||||
38
config/custom_components/pyscript/strings.json
Normal file
38
config/custom_components/pyscript/strings.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "pyscript",
|
||||
"description": "Once you have created an entry, refer to the [docs](https://hacs-pyscript.readthedocs.io/en/latest/) to learn how to create scripts and functions.",
|
||||
"data": {
|
||||
"allow_all_imports": "Allow All Imports?",
|
||||
"hass_is_global": "Access hass as a global variable?"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "Already configured.",
|
||||
"single_instance_allowed": "Already configured. Only a single configuration possible.",
|
||||
"updated_entry": "This entry has already been setup but the configuration has been updated."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "Update pyscript configuration",
|
||||
"data": {
|
||||
"allow_all_imports": "Allow All Imports?",
|
||||
"hass_is_global": "Access hass as a global variable?"
|
||||
}
|
||||
},
|
||||
"no_ui_configuration_allowed": {
|
||||
"title": "No UI configuration allowed",
|
||||
"description": "This entry was created via `configuration.yaml`, so all configuration parameters must be updated there. The [`pyscript.reload`](developer-tools/service) service will allow you to apply the changes you make to `configuration.yaml` without restarting your Home Assistant instance."
|
||||
},
|
||||
"no_update": {
|
||||
"title": "No update needed",
|
||||
"description": "There is nothing to update."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
38
config/custom_components/pyscript/translations/de.json
Normal file
38
config/custom_components/pyscript/translations/de.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "pyscript",
|
||||
"description": "Wenn Sie einen Eintrag angelegt haben, können Sie sich die [Doku (Englisch)](https://hacs-pyscript.readthedocs.io/en/latest/) ansehen, um zu lernen wie Sie Scripts und Funktionen erstellen können.",
|
||||
"data": {
|
||||
"allow_all_imports": "Alle Importe erlauben?",
|
||||
"hass_is_global": "Home Assistant als globale Variable verwenden?"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "Bereits konfiguriert.",
|
||||
"single_instance_allowed": "Bereits konfiguriert. Es ist nur eine Konfiguration gleichzeitig möglich",
|
||||
"updated_entry": "Der Eintrag wurde bereits erstellt, aber die Konfiguration wurde aktualisiert."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "Pyscript configuration aktualisieren",
|
||||
"data": {
|
||||
"allow_all_imports": "Alle Importe erlauben??",
|
||||
"hass_is_global": "Home Assistant als globale Variable verwenden?"
|
||||
}
|
||||
},
|
||||
"no_ui_configuration_allowed": {
|
||||
"title": "Die Konfiguartion der graphischen Nutzeroberfläche ist deaktiviert",
|
||||
"description": "Der Eintrag wurde über die Datei `configuration.yaml` erstellt. Alle Konfigurationsparameter müssen desshalb dort eingestellt werden. Der [`pyscript.reload`](developer-tools/service) Service übernimmt alle Änderungen aus `configuration.yaml`, ohne dass Home Assistant neu gestartet werden muss."
|
||||
},
|
||||
"no_update": {
|
||||
"title": "Keine Aktualisierung notwendig",
|
||||
"description": "Es gibt nichts zu aktualisieren."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
38
config/custom_components/pyscript/translations/en.json
Normal file
38
config/custom_components/pyscript/translations/en.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "pyscript",
|
||||
"description": "Once you have created an entry, refer to the [docs](https://hacs-pyscript.readthedocs.io/en/latest/) to learn how to create scripts and functions.",
|
||||
"data": {
|
||||
"allow_all_imports": "Allow All Imports?",
|
||||
"hass_is_global": "Access hass as a global variable?"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "Already configured.",
|
||||
"single_instance_allowed": "Already configured. Only a single configuration possible.",
|
||||
"updated_entry": "This entry has already been setup but the configuration has been updated."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "Update pyscript configuration",
|
||||
"data": {
|
||||
"allow_all_imports": "Allow All Imports?",
|
||||
"hass_is_global": "Access hass as a global variable?"
|
||||
}
|
||||
},
|
||||
"no_ui_configuration_allowed": {
|
||||
"title": "No UI configuration allowed",
|
||||
"description": "This entry was created via `configuration.yaml`, so all configuration parameters must be updated there. The [`pyscript.reload`](developer-tools/service) service will allow you to apply the changes you make to `configuration.yaml` without restarting your Home Assistant instance."
|
||||
},
|
||||
"no_update": {
|
||||
"title": "No update needed",
|
||||
"description": "There is nothing to update."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
38
config/custom_components/pyscript/translations/sk.json
Normal file
38
config/custom_components/pyscript/translations/sk.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "pyscript",
|
||||
"description": "Akonáhle ste vytvorili položku, pozrite si [docs](https://hacs-pyscript.readthedocs.io/en/latest/) naučiť sa, ako vytvárať skripty a funkcie.",
|
||||
"data": {
|
||||
"allow_all_imports": "Povoliť všetky importy?",
|
||||
"hass_is_global": "Prístup k globálnej premennej?"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "Už konfigurované.",
|
||||
"single_instance_allowed": "Už nakonfigurované. Iba jedna možná konfigurácia.",
|
||||
"updated_entry": "Táto položka už bola nastavená, ale konfigurácia bola aktualizovaná."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "Aktualizovať pyscript konfiguráciu",
|
||||
"data": {
|
||||
"allow_all_imports": "povoliť všetky importy?",
|
||||
"hass_is_global": "Prístup k globálnej premennej?"
|
||||
}
|
||||
},
|
||||
"no_ui_configuration_allowed": {
|
||||
"title": "Nie je povolená konfigurácia používateľského rozhrania",
|
||||
"description": "Tento záznam bol vytvorený cez `configuration.yaml`, Takže všetky konfiguračné parametre sa musia aktualizovať. [`pyscript.reload`](developer-tools/service) Služba vám umožní uplatniť zmeny, ktoré vykonáte `configuration.yaml` bez reštartovania inštancie Home Assistant."
|
||||
},
|
||||
"no_update": {
|
||||
"title": "Nie je potrebná aktualizácia",
|
||||
"description": "Nie je nič na aktualizáciu."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1396
config/custom_components/pyscript/trigger.py
Normal file
1396
config/custom_components/pyscript/trigger.py
Normal file
File diff suppressed because it is too large
Load Diff
95
config/custom_components/pyscript/webhook.py
Normal file
95
config/custom_components/pyscript/webhook.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Handles webhooks and notification."""
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import hdrs
|
||||
|
||||
from homeassistant.components import webhook
|
||||
|
||||
from .const import LOGGER_PATH
|
||||
|
||||
_LOGGER = logging.getLogger(LOGGER_PATH + ".webhook")
|
||||
|
||||
|
||||
class Webhook:
|
||||
"""Define webhook functions."""
|
||||
|
||||
#
|
||||
# Global hass instance
|
||||
#
|
||||
hass = None
|
||||
|
||||
#
|
||||
# notify message queues by webhook type
|
||||
#
|
||||
notify = {}
|
||||
notify_remove = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Warn on Webhook instantiation."""
|
||||
_LOGGER.error("Webhook class is not meant to be instantiated")
|
||||
|
||||
@classmethod
|
||||
def init(cls, hass):
|
||||
"""Initialize Webhook."""
|
||||
|
||||
cls.hass = hass
|
||||
|
||||
@classmethod
|
||||
async def webhook_handler(cls, hass, webhook_id, request):
|
||||
"""Listen callback for given webhook which updates any notifications."""
|
||||
|
||||
func_args = {
|
||||
"trigger_type": "webhook",
|
||||
"webhook_id": webhook_id,
|
||||
}
|
||||
|
||||
if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""):
|
||||
func_args["payload"] = await request.json()
|
||||
else:
|
||||
# Could potentially return multiples of a key - only take the first
|
||||
payload_multidict = await request.post()
|
||||
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}
|
||||
|
||||
await cls.update(webhook_id, func_args)
|
||||
|
||||
@classmethod
|
||||
def notify_add(cls, webhook_id, local_only, methods, queue):
|
||||
"""Register to notify for webhooks of given type to be sent to queue."""
|
||||
if webhook_id not in cls.notify:
|
||||
cls.notify[webhook_id] = set()
|
||||
_LOGGER.debug("webhook.notify_add(%s) -> adding webhook listener", webhook_id)
|
||||
webhook.async_register(
|
||||
cls.hass,
|
||||
"pyscript", # DOMAIN
|
||||
"pyscript", # NAME
|
||||
webhook_id,
|
||||
cls.webhook_handler,
|
||||
local_only=local_only,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
cls.notify_remove[webhook_id] = lambda: webhook.async_unregister(cls.hass, webhook_id)
|
||||
|
||||
cls.notify[webhook_id].add(queue)
|
||||
|
||||
@classmethod
|
||||
def notify_del(cls, webhook_id, queue):
|
||||
"""Unregister to notify for webhooks of given type for given queue."""
|
||||
|
||||
if webhook_id not in cls.notify or queue not in cls.notify[webhook_id]:
|
||||
return
|
||||
cls.notify[webhook_id].discard(queue)
|
||||
if len(cls.notify[webhook_id]) == 0:
|
||||
cls.notify_remove[webhook_id]()
|
||||
_LOGGER.debug("webhook.notify_del(%s) -> removing webhook listener", webhook_id)
|
||||
del cls.notify[webhook_id]
|
||||
del cls.notify_remove[webhook_id]
|
||||
|
||||
@classmethod
|
||||
async def update(cls, webhook_id, func_args):
|
||||
"""Deliver all notifications for an webhook of the given type."""
|
||||
|
||||
_LOGGER.debug("webhook.update(%s, %s)", webhook_id, func_args)
|
||||
if webhook_id in cls.notify:
|
||||
for queue in cls.notify[webhook_id]:
|
||||
await queue.put(["webhook", func_args.copy()])
|
||||
Reference in New Issue
Block a user