"""Python interpreter for pyscript.""" import ast import asyncio import builtins from collections import OrderedDict import functools import importlib import inspect import io import keyword import logging import sys import time import weakref import yaml from homeassistant.const import SERVICE_RELOAD from homeassistant.helpers.service import async_set_service_schema from .const import ( ALLOWED_IMPORTS, CONF_ALLOW_ALL_IMPORTS, CONFIG_ENTRY, DOMAIN, LOGGER_PATH, SERVICE_JUPYTER_KERNEL_START, SERVICE_RESPONSE_NONE, ) from .function import Function from .state import State _LOGGER = logging.getLogger(LOGGER_PATH + ".eval") # # Built-ins to exclude to improve security or avoid i/o # BUILTIN_EXCLUDE = { "breakpoint", "compile", "input", "memoryview", "open", "print", } TRIG_DECORATORS = { "time_trigger", "state_trigger", "event_trigger", "mqtt_trigger", "webhook_trigger", "state_active", "time_active", "task_unique", } TRIG_SERV_DECORATORS = TRIG_DECORATORS.union({"service"}) COMP_DECORATORS = { "pyscript_compile", "pyscript_executor", } TRIGGER_KWARGS = { "context", "event_type", "old_value", "payload", "payload_obj", "qos", "topic", "trigger_type", "trigger_time", "var_name", "value", "webhook_id", } WEBHOOK_METHODS = { "GET", "HEAD", "POST", "PUT", } def ast_eval_exec_factory(ast_ctx, mode): """Generate a function that executes eval() or exec() with given ast_ctx.""" async def eval_func(arg_str, eval_globals=None, eval_locals=None): eval_ast = AstEval(ast_ctx.name, ast_ctx.global_ctx) eval_ast.parse(arg_str, f"{mode}()", mode=mode) if eval_ast.exception_obj: raise eval_ast.exception_obj eval_ast.local_sym_table = ast_ctx.local_sym_table if eval_globals is not None: eval_ast.global_sym_table = eval_globals if eval_locals is not None: eval_ast.sym_table_stack = [eval_globals] eval_ast.sym_table = eval_locals else: eval_ast.sym_table_stack = [] eval_ast.sym_table = eval_globals else: eval_ast.sym_table_stack = ast_ctx.sym_table_stack.copy() if ast_ctx.sym_table == ast_ctx.global_sym_table: eval_ast.sym_table = ast_ctx.sym_table else: eval_ast.sym_table = ast_ctx.sym_table.copy() eval_ast.sym_table.update(ast_ctx.user_locals) to_delete = set() for var, value in eval_ast.sym_table.items(): if isinstance(value, EvalLocalVar): if value.is_defined(): eval_ast.sym_table[var] = value.get() else: to_delete.add(var) for var in to_delete: del eval_ast.sym_table[var] eval_ast.curr_func = None try: eval_result = await eval_ast.aeval(eval_ast.ast) except Exception as err: ast_ctx.exception_obj = err ast_ctx.exception = f"Exception in {ast_ctx.filename} line {ast_ctx.lineno} column {ast_ctx.col_offset}: {eval_ast.exception}" ast_ctx.exception_long = ( ast_ctx.format_exc(err, ast_ctx.lineno, ast_ctx.col_offset, short=True) + "\n" + eval_ast.exception_long ) raise # # save variables only in the locals scope # if eval_globals is None and eval_ast.sym_table != ast_ctx.sym_table: for var, value in eval_ast.sym_table.items(): if var in ast_ctx.global_sym_table and value == ast_ctx.global_sym_table[var]: continue if var not in ast_ctx.sym_table and ( ast_ctx.curr_func is None or var not in ast_ctx.curr_func.local_names ): ast_ctx.user_locals[var] = value return eval_result return eval_func def ast_eval_factory(ast_ctx): """Generate a function that executes eval() with given ast_ctx.""" return ast_eval_exec_factory(ast_ctx, "eval") def ast_exec_factory(ast_ctx): """Generate a function that executes exec() with given ast_ctx.""" return ast_eval_exec_factory(ast_ctx, "exec") def ast_globals_factory(ast_ctx): """Generate a globals() function with given ast_ctx.""" async def globals_func(): return ast_ctx.global_sym_table return globals_func def ast_locals_factory(ast_ctx): """Generate a locals() function with given ast_ctx.""" async def locals_func(): if ast_ctx.sym_table == ast_ctx.global_sym_table: return ast_ctx.sym_table local_sym_table = ast_ctx.sym_table.copy() local_sym_table.update(ast_ctx.user_locals) to_delete = set() for var, value in local_sym_table.items(): if isinstance(value, EvalLocalVar): if value.is_defined(): local_sym_table[var] = value.get() else: to_delete.add(var) for var in to_delete: del local_sym_table[var] return local_sym_table return locals_func # # Built-in functions that are also passed the ast context # BUILTIN_AST_FUNCS_FACTORY = { "eval": ast_eval_factory, "exec": ast_exec_factory, "globals": ast_globals_factory, "locals": ast_locals_factory, } # # Objects returned by return, break and continue statements that change execution flow, # or objects returned that capture particular information # class EvalStopFlow: """Denotes a statement or action that stops execution flow, eg: return, break etc.""" _name = None def name(self): """Return short name.""" return self._name class EvalReturn(EvalStopFlow): """Return statement.""" _name = "return" def __init__(self, value): """Initialize return statement value.""" self.value = value class EvalBreak(EvalStopFlow): """Break statement.""" _name = "break" class EvalContinue(EvalStopFlow): """Continue statement.""" _name = "continue" class EvalLocalVar: """Wrapper for local variable symtable entry.""" def __init__(self, name, **kwargs): """Initialize value of local symbol.""" self.name = name self.defined = False if "value" in kwargs: self.value = kwargs["value"] self.defined = True def get(self): """Get value of local symbol.""" if not self.defined: raise NameError(f"name '{self.name}' is not defined") return self.value def get_name(self): """Get name of local symbol.""" return self.name def set(self, value): """Set value of local symbol.""" self.value = value self.defined = True def is_defined(self): """Return whether value is defined.""" return self.defined def set_undefined(self): """Set local symbol to undefined.""" self.defined = False def __getattr__(self, attr): """Get attribute of local variable.""" if not self.defined: raise NameError(f"name '{self.name}' is not defined") return getattr(self.value, attr) def __repr__(self): """Generate string with address and value.""" return f"EvalLocalVar @{hex(id(self))} = {self.value if self.defined else 'undefined'}" class EvalName: """Identifier that hasn't yet been resolved.""" def __init__(self, name): """Initialize identifier to name.""" self.name = name def __getattr__(self, attr): """Get attribute for EvalName.""" raise NameError(f"name '{self.name}.{attr}' is not defined") class EvalAttrSet: """Class for object and attribute on lhs of assignment.""" def __init__(self, obj, attr): """Initialize identifier to name.""" self.obj = obj self.attr = attr def setattr(self, value): """Set the attribute value.""" setattr(self.obj, self.attr, value) def getattr(self): """Get the attribute value.""" return getattr(self.obj, self.attr) class EvalFunc: """Class for a callable pyscript function.""" def __init__(self, func_def, code_list, code_str, global_ctx): """Initialize a function calling context.""" self.func_def = func_def self.name = func_def.name self.global_ctx = global_ctx self.global_ctx_name = global_ctx.get_name() self.logger = logging.getLogger(LOGGER_PATH + "." + self.global_ctx_name) self.defaults = [] self.kw_defaults = [] self.decorators = [] self.global_names = set() self.nonlocal_names = set() self.local_names = None self.local_sym_table = {} self.doc_string = ast.get_docstring(func_def) self.num_posonly_arg = len(self.func_def.args.posonlyargs) self.num_posn_arg = self.num_posonly_arg + len(self.func_def.args.args) - len(self.defaults) self.code_list = code_list self.code_str = code_str self.exception = None self.exception_obj = None self.exception_long = None self.trigger = [] self.trigger_service = set() self.has_closure = False def get_name(self): """Return the function name.""" return self.name def set_name(self, name): """Set the function name.""" self.name = name async def eval_defaults(self, ast_ctx): """Evaluate the default function arguments.""" self.defaults = [] for val in self.func_def.args.defaults: self.defaults.append(await ast_ctx.aeval(val)) self.num_posn_arg = self.num_posonly_arg + len(self.func_def.args.args) - len(self.defaults) self.kw_defaults = [] for val in self.func_def.args.kw_defaults: self.kw_defaults.append({"ok": bool(val), "val": None if not val else await ast_ctx.aeval(val)}) async def trigger_init(self, trig_ctx, func_name): """Initialize decorator triggers for this function.""" trig_args = {} trig_decs = {} trig_ctx_name = trig_ctx.get_name() self.logger = logging.getLogger(LOGGER_PATH + "." + trig_ctx_name) self.global_ctx.set_logger_name(trig_ctx_name) self.global_ctx_name = trig_ctx_name got_reqd_dec = False exc_mesg = f"function '{func_name}' defined in {trig_ctx_name}" trig_decorators_reqd = { "event_trigger", "mqtt_trigger", "state_trigger", "time_trigger", "webhook_trigger", } arg_check = { "event_trigger": {"arg_cnt": {1, 2, 3}, "rep_ok": True}, "mqtt_trigger": {"arg_cnt": {1, 2, 3}, "rep_ok": True}, "state_active": {"arg_cnt": {1}}, "state_trigger": {"arg_cnt": {"*"}, "type": {list, set}, "rep_ok": True}, "service": {"arg_cnt": {0, "*"}}, "task_unique": {"arg_cnt": {1, 2}}, "time_active": {"arg_cnt": {"*"}}, "time_trigger": {"arg_cnt": {0, "*"}, "rep_ok": True}, "webhook_trigger": {"arg_cnt": {1, 2}, "rep_ok": True}, } kwarg_check = { "event_trigger": {"kwargs": {dict}}, "mqtt_trigger": {"kwargs": {dict}}, "time_trigger": {"kwargs": {dict}}, "task_unique": {"kill_me": {bool, int}}, "time_active": {"hold_off": {int, float}}, "service": {"supports_response": {str}}, "state_trigger": { "kwargs": {dict}, "state_hold": {int, float}, "state_check_now": {bool, int}, "state_hold_false": {int, float}, "watch": {set, list}, }, "webhook_trigger": { "kwargs": {dict}, "local_only": {bool}, "methods": {list, set}, }, } for dec in self.decorators: dec_name, dec_args, dec_kwargs = dec[0], dec[1], dec[2] if dec_name not in TRIG_SERV_DECORATORS: raise SyntaxError(f"{exc_mesg}: unknown decorator @{dec_name}") if dec_name in trig_decorators_reqd: got_reqd_dec = True arg_info = arg_check.get(dec_name, {}) # # check that we have the right number of arguments, and that they are # strings # arg_cnt = arg_info["arg_cnt"] if dec_args is None and 0 not in arg_cnt: raise TypeError(f"{exc_mesg}: decorator @{dec_name} needs at least one argument") if dec_args: if "*" not in arg_cnt and len(dec_args) not in arg_cnt: raise TypeError( f"{exc_mesg}: decorator @{dec_name} got {len(dec_args)}" f" argument{'s' if len(dec_args) > 1 else ''}, expected" f" {' or '.join([str(cnt) for cnt in sorted(arg_cnt)])}" ) for arg_num, arg in enumerate(dec_args): if isinstance(arg, str): continue mesg = "string" if "type" in arg_info: if type(arg) in arg_info["type"]: for val in arg: if not isinstance(val, str): break else: continue mesg += ", or " + ", or ".join( sorted([ok_type.__name__ for ok_type in arg_info["type"]]) ) raise TypeError( f"{exc_mesg}: decorator @{dec_name} argument {arg_num + 1} should be a {mesg}" ) if arg_cnt == {1}: dec_args = dec_args[0] if dec_name not in kwarg_check and dec_kwargs is not None: raise TypeError(f"{exc_mesg}: decorator @{dec_name} doesn't take keyword arguments") if dec_kwargs is None: dec_kwargs = {} if dec_name in kwarg_check: allowed = kwarg_check[dec_name] for arg, value in dec_kwargs.items(): if arg not in allowed: raise TypeError( f"{exc_mesg}: decorator @{dec_name} invalid keyword argument '{arg}'" ) if value is None or type(value) in allowed[arg]: continue ok_types = " or ".join(sorted([t.__name__ for t in allowed[arg]])) raise TypeError( f"{exc_mesg}: decorator @{dec_name} keyword '{arg}' should be type {ok_types}" ) if dec_name == "service": desc = self.doc_string if desc is None or desc == "": desc = f"pyscript function {func_name}()" desc = desc.lstrip(" \n\r") if desc.startswith("yaml"): try: desc = desc[4:].lstrip(" \n\r") file_desc = io.StringIO(desc) service_desc = yaml.load(file_desc, Loader=yaml.BaseLoader) or OrderedDict() file_desc.close() except Exception as exc: self.logger.error( "Unable to decode yaml doc_string for %s(): %s", func_name, str(exc), ) raise exc else: fields = OrderedDict() for arg in self.get_positional_args(): fields[arg] = OrderedDict(description=f"argument {arg}") service_desc = {"description": desc, "fields": fields} def pyscript_service_factory(func_name, func): async def pyscript_service_handler(call): """Handle python script service calls.""" # self.logger.debug("service call to %s", func_name) # # use a new AstEval context so it can run fully independently # of other instances (except for global_ctx which is common) # ast_ctx = AstEval(f"{trig_ctx_name}.{func_name}", self.global_ctx) Function.install_ast_funcs(ast_ctx) func_args = { "trigger_type": "service", "context": call.context, } func_args.update(call.data) async def do_service_call(func, ast_ctx, data): retval = await func.call(ast_ctx, **data) if ast_ctx.get_exception_obj(): ast_ctx.get_logger().error(ast_ctx.get_exception_long()) return retval task = Function.create_task(do_service_call(func, ast_ctx, func_args)) await task return task.result() return pyscript_service_handler for srv_name in dec_args if dec_args else [f"{DOMAIN}.{func_name}"]: if type(srv_name) is not str or srv_name.count(".") != 1: raise ValueError(f"{exc_mesg}: @service argument must be a string with one period") domain, name = srv_name.split(".", 1) if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START): raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service") Function.service_register( trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SERVICE_RESPONSE_NONE), ) async_set_service_schema(Function.hass, domain, name, service_desc) self.trigger_service.add(srv_name) continue if dec_name == "webhook_trigger" and "methods" in dec_kwargs: if len(bad := set(dec_kwargs["methods"]).difference(WEBHOOK_METHODS)) > 0: raise TypeError(f"{exc_mesg}: {bad} aren't valid {dec_name} methods") if dec_name not in trig_decs: trig_decs[dec_name] = [] if len(trig_decs[dec_name]) > 0 and "rep_ok" not in arg_info: raise SyntaxError(f"{exc_mesg}: decorator @{dec_name} can only be used once") trig_decs[dec_name].append({"args": dec_args, "kwargs": dec_kwargs}) if not got_reqd_dec and len(trig_decs) > 0: self.logger.error( "%s defined in %s: needs at least one trigger decorator (ie: %s)", func_name, trig_ctx_name, ", ".join(sorted(trig_decorators_reqd)), ) return if len(trig_decs) == 0: if len(self.trigger_service) > 0: trig_ctx.trigger_register(self) return # # start one or more triggers until they are all consumed # each trigger task can handle at most one of each type of # trigger; all get the same state_active, time_active and # task_unique decorators # while True: trig_args = { "action": self, "global_sym_table": self.global_ctx.global_sym_table, } got_trig = False for trig in trig_decorators_reqd: if trig not in trig_decs or len(trig_decs[trig]) == 0: continue trig_args[trig] = trig_decs[trig].pop(0) got_trig = True if not got_trig: break for dec_name in ["state_active", "time_active", "task_unique"]: if dec_name in trig_decs: trig_args[dec_name] = trig_decs[dec_name][0] self.trigger.append(trig_ctx.get_trig_info(f"{trig_ctx_name}.{func_name}", trig_args)) if trig_ctx.trigger_register(self): self.trigger_start() def trigger_start(self): """Start any triggers for this function.""" for trigger in self.trigger: trigger.start() def trigger_stop(self): """Stop any triggers for this function.""" for trigger in self.trigger: trigger.stop() self.trigger = [] for srv_name in self.trigger_service: domain, name = srv_name.split(".", 1) Function.service_remove(self.global_ctx_name, domain, name) self.trigger_service = set() async def eval_decorators(self, ast_ctx): """Evaluate the function decorators arguments.""" code_str, code_list = ast_ctx.code_str, ast_ctx.code_list ast_ctx.code_str, ast_ctx.code_list = self.code_str, self.code_list dec_other = [] dec_trig = [] for dec in self.func_def.decorator_list: if ( isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id in TRIG_SERV_DECORATORS ): args = await ast_ctx.eval_elt_list(dec.args) kwargs = {keyw.arg: await ast_ctx.aeval(keyw.value) for keyw in dec.keywords} dec_trig.append([dec.func.id, args, kwargs if len(kwargs) > 0 else None]) elif isinstance(dec, ast.Name) and dec.id in TRIG_SERV_DECORATORS: dec_trig.append([dec.id, None, None]) else: dec_other.append(await ast_ctx.aeval(dec)) ast_ctx.code_str, ast_ctx.code_list = code_str, code_list return dec_trig, reversed(dec_other) async def resolve_nonlocals(self, ast_ctx): """Tag local variables and resolve nonlocals.""" # # determine the list of local variables, nonlocal and global # arguments are local variables too # args = self.get_positional_args() if self.func_def.args.vararg: args.append(self.func_def.args.vararg.arg) if self.func_def.args.kwarg: args.append(self.func_def.args.kwarg.arg) for kwonlyarg in self.func_def.args.kwonlyargs: args.append(kwonlyarg.arg) nonlocal_names = set() global_names = set() var_names = set(args) self.local_names = set(args) for stmt in self.func_def.body: self.has_closure = self.has_closure or await self.check_for_closure(stmt) var_names = var_names.union( await ast_ctx.get_names( stmt, nonlocal_names=nonlocal_names, global_names=global_names, local_names=self.local_names, ) ) for var_name in var_names: got_dot = var_name.find(".") if got_dot >= 0: var_name = var_name[0:got_dot] if var_name in global_names: continue if var_name in self.local_names and var_name not in nonlocal_names: if self.has_closure: self.local_sym_table[var_name] = EvalLocalVar(var_name) continue if var_name in nonlocal_names: sym_table_idx = 1 else: sym_table_idx = 0 for sym_table in reversed(ast_ctx.sym_table_stack[sym_table_idx:] + [ast_ctx.sym_table]): if var_name in sym_table and isinstance(sym_table[var_name], EvalLocalVar): self.local_sym_table[var_name] = sym_table[var_name] break else: if var_name in nonlocal_names: val = await ast_ctx.ast_name(ast.Name(id=var_name, ctx=ast.Load())) if isinstance(val, EvalName) and got_dot < 0: raise SyntaxError(f"no binding for nonlocal '{var_name}' found") def get_decorators(self): """Return the function decorators.""" return self.decorators def get_doc_string(self): """Return the function doc_string.""" return self.doc_string def get_positional_args(self): """Return the function positional arguments.""" args = [] for arg in self.func_def.args.posonlyargs + self.func_def.args.args: args.append(arg.arg) return args async def try_aeval(self, ast_ctx, arg): """Call self.aeval and capture exceptions.""" try: return await ast_ctx.aeval(arg) except asyncio.CancelledError: raise except Exception as err: if ast_ctx.exception_long is None: ast_ctx.exception_long = ast_ctx.format_exc(err, arg.lineno, arg.col_offset) async def call(self, ast_ctx, *args, **kwargs): """Call the function with the given context and arguments.""" sym_table = {} if args is None: args = [] kwargs = kwargs.copy() if kwargs else {} bad_kwargs = [] for i, func_def_arg in enumerate(self.func_def.args.posonlyargs + self.func_def.args.args): var_name = func_def_arg.arg val = None if i < len(args): val = args[i] if var_name in kwargs: raise TypeError(f"{self.name}() got multiple values for argument '{var_name}'") elif var_name in kwargs: if i < self.num_posonly_arg: bad_kwargs.append(var_name) val = kwargs[var_name] del kwargs[var_name] elif self.num_posn_arg <= i < len(self.defaults) + self.num_posn_arg: val = self.defaults[i - self.num_posn_arg] else: raise TypeError( f"{self.name}() missing {self.num_posn_arg - i} required positional arguments" ) sym_table[var_name] = val if len(bad_kwargs) > 0: raise TypeError( f"{self.name}() got some positional-only arguments passed as keyword arguments: '{', '.join(bad_kwargs)}'" ) for i, kwonlyarg in enumerate(self.func_def.args.kwonlyargs): var_name = kwonlyarg.arg if var_name in kwargs: val = kwargs[var_name] del kwargs[var_name] elif i < len(self.kw_defaults) and self.kw_defaults[i]["ok"]: val = self.kw_defaults[i]["val"] else: raise TypeError(f"{self.name}() missing required keyword-only arguments") sym_table[var_name] = val if self.func_def.args.kwarg: sym_table[self.func_def.args.kwarg.arg] = kwargs elif not set(kwargs.keys()).issubset(TRIGGER_KWARGS): # don't raise an exception for extra trigger keyword parameters; # it's difficult to apply this exception to just trigger functions # since they could have non-trigger decorators too unexpected = ", ".join(sorted(set(kwargs.keys()) - TRIGGER_KWARGS)) raise TypeError(f"{self.name}() called with unexpected keyword arguments: {unexpected}") num_posn = self.num_posonly_arg + len(self.func_def.args.args) if self.func_def.args.vararg: if len(args) > num_posn: sym_table[self.func_def.args.vararg.arg] = tuple(args[num_posn:]) else: sym_table[self.func_def.args.vararg.arg] = () elif len(args) > num_posn: raise TypeError(f"{self.name}() called with too many positional arguments") for name, value in self.local_sym_table.items(): if name in sym_table: sym_table[name] = EvalLocalVar(name, value=sym_table[name]) elif value.is_defined(): sym_table[name] = value else: sym_table[name] = EvalLocalVar(name) if ast_ctx.global_ctx != self.global_ctx: # # switch to the global symbol table in the global context # where the function was defined # prev_sym_table = [ ast_ctx.global_sym_table, ast_ctx.sym_table, ast_ctx.sym_table_stack, ast_ctx.global_ctx, ] ast_ctx.global_sym_table = self.global_ctx.get_global_sym_table() ast_ctx.sym_table_stack = [ast_ctx.global_sym_table] ast_ctx.global_ctx = self.global_ctx else: ast_ctx.sym_table_stack.append(ast_ctx.sym_table) prev_sym_table = None ast_ctx.sym_table = sym_table code_str, code_list = ast_ctx.code_str, ast_ctx.code_list ast_ctx.code_str, ast_ctx.code_list = self.code_str, self.code_list self.exception = None self.exception_obj = None self.exception_long = None prev_func = ast_ctx.curr_func save_user_locals = ast_ctx.user_locals ast_ctx.user_locals = {} ast_ctx.curr_func = self del args, kwargs for arg1 in self.func_def.body: val = await self.try_aeval(ast_ctx, arg1) if isinstance(val, EvalReturn): val = val.value break # return None at end if there isn't a return val = None if ast_ctx.get_exception_obj(): break ast_ctx.curr_func = prev_func ast_ctx.user_locals = save_user_locals ast_ctx.code_str, ast_ctx.code_list = code_str, code_list if prev_sym_table is not None: ( ast_ctx.global_sym_table, ast_ctx.sym_table, ast_ctx.sym_table_stack, ast_ctx.global_ctx, ) = prev_sym_table else: ast_ctx.sym_table = ast_ctx.sym_table_stack.pop() return val async def check_for_closure(self, arg): """Recursively check ast tree arg and return True if there is an inner function or class.""" if isinstance(arg, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)): return True for child in ast.iter_child_nodes(arg): if await self.check_for_closure(child): return True return False class EvalFuncVar: """Class for a callable pyscript function.""" def __init__(self, func): """Initialize instance with given EvalFunc function.""" self.func = func self.ast_ctx = None def get_func(self): """Return the EvalFunc function.""" return self.func def remove_func(self): """Remove and return the EvalFunc function.""" func = self.func self.func = None return func async def call(self, ast_ctx, *args, **kwargs): """Call the EvalFunc function.""" return await self.func.call(ast_ctx, *args, **kwargs) def get_name(self): """Return the function name.""" return self.func.get_name() def set_name(self, name): """Set the function name.""" self.func.set_name(name) def set_ast_ctx(self, ast_ctx): """Set the ast context.""" self.ast_ctx = ast_ctx def get_ast_ctx(self): """Return the ast context.""" return self.ast_ctx def __del__(self): """On deletion, stop any triggers for this function.""" if self.func: self.func.trigger_stop() async def __call__(self, *args, **kwargs): """Call the EvalFunc function using our saved ast ctx.""" return await self.func.call(self.ast_ctx, *args, **kwargs) class EvalFuncVarClassInst(EvalFuncVar): """Class for a callable pyscript class instance function.""" def __init__(self, func, ast_ctx, class_inst_weak): """Initialize instance with given EvalFunc function.""" super().__init__(func) self.ast_ctx = ast_ctx self.class_inst_weak = class_inst_weak async def call(self, ast_ctx, *args, **kwargs): """Call the EvalFunc function.""" return await self.func.call(ast_ctx, self.class_inst_weak(), *args, **kwargs) async def __call__(self, *args, **kwargs): """Call the function using our saved ast ctx and class instance.""" return await self.func.call(self.ast_ctx, self.class_inst_weak(), *args, **kwargs) class AstEval: """Python interpreter AST object evaluator.""" def __init__(self, name, global_ctx, logger_name=None): """Initialize an interpreter execution context.""" self.name = name self.str = None self.ast = None self.global_ctx = global_ctx self.global_sym_table = global_ctx.get_global_sym_table() if global_ctx else {} self.sym_table_stack = [] self.sym_table = self.global_sym_table self.local_sym_table = {} self.user_locals = {} self.curr_func = None self.filename = name self.code_str = None self.code_list = None self.exception = None self.exception_obj = None self.exception_long = None self.exception_curr = None self.lineno = 1 self.col_offset = 0 self.logger_handlers = set() self.logger = None self.set_logger_name(logger_name if logger_name is not None else self.name) self.config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {}) self.dec_eval_depth = 0 async def ast_not_implemented(self, arg, *args): """Raise NotImplementedError exception for unimplemented AST types.""" name = "ast_" + arg.__class__.__name__.lower() raise NotImplementedError(f"{self.name}: not implemented ast " + name) async def aeval(self, arg, undefined_check=True): """Vector to specific function based on ast class type.""" name = "ast_" + arg.__class__.__name__.lower() try: if hasattr(arg, "lineno"): self.lineno = arg.lineno self.col_offset = arg.col_offset val = await getattr(self, name, self.ast_not_implemented)(arg) if undefined_check and isinstance(val, EvalName): raise NameError(f"name '{val.name}' is not defined") return val except Exception as err: if not self.exception_obj: func_name = self.curr_func.get_name() + "(), " if self.curr_func else "" self.exception_obj = err self.exception = f"Exception in {func_name}{self.filename} line {self.lineno} column {self.col_offset}: {err}" self.exception_long = self.format_exc(err, self.lineno, self.col_offset) raise # Statements return NONE, EvalBreak, EvalContinue, EvalReturn async def ast_module(self, arg): """Execute ast_module - a list of statements.""" val = None for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalReturn): raise SyntaxError(f"{val.name()} statement outside function") if isinstance(val, EvalStopFlow): raise SyntaxError(f"{val.name()} statement outside loop") return val async def ast_import(self, arg): """Execute import.""" for imp in arg.names: mod, error_ctx = await self.global_ctx.module_import(imp.name, 0) if error_ctx: self.exception_obj = error_ctx.exception_obj self.exception = error_ctx.exception self.exception_long = error_ctx.exception_long raise self.exception_obj if not mod: if ( not self.config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False) and imp.name not in ALLOWED_IMPORTS ): raise ModuleNotFoundError(f"import of {imp.name} not allowed") if imp.name not in sys.modules: mod = await Function.hass.async_add_executor_job(importlib.import_module, imp.name) else: mod = sys.modules[imp.name] self.sym_table[imp.name if imp.asname is None else imp.asname] = mod async def ast_importfrom(self, arg): """Execute from X import Y.""" if arg.module is None: # handle: "from . import xyz" for imp in arg.names: mod, error_ctx = await self.global_ctx.module_import(imp.name, arg.level) if error_ctx: self.exception_obj = error_ctx.exception_obj self.exception = error_ctx.exception self.exception_long = error_ctx.exception_long raise self.exception_obj if not mod: raise ModuleNotFoundError(f"module '{imp.name}' not found") self.sym_table[imp.name if imp.asname is None else imp.asname] = mod return mod, error_ctx = await self.global_ctx.module_import(arg.module, arg.level) if error_ctx: self.exception_obj = error_ctx.exception_obj self.exception = error_ctx.exception self.exception_long = error_ctx.exception_long raise self.exception_obj if not mod: if ( not self.config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False) and arg.module not in ALLOWED_IMPORTS ): raise ModuleNotFoundError(f"import from {arg.module} not allowed") if arg.module not in sys.modules: mod = await Function.hass.async_add_executor_job(importlib.import_module, arg.module) else: mod = sys.modules[arg.module] for imp in arg.names: if imp.name == "*": for name, value in mod.__dict__.items(): if name[0] != "_": self.sym_table[name] = value else: self.sym_table[imp.name if imp.asname is None else imp.asname] = getattr(mod, imp.name) async def ast_if(self, arg): """Execute if statement.""" val = None if await self.aeval(arg.test): for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): return val else: for arg1 in arg.orelse: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): return val return val async def ast_for(self, arg): """Execute for statement.""" for loop_var in await self.aeval(arg.iter): await self.recurse_assign(arg.target, loop_var) for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): break if isinstance(val, EvalBreak): break if isinstance(val, EvalReturn): return val else: for arg1 in arg.orelse: val = await self.aeval(arg1) if isinstance(val, EvalReturn): return val return None async def ast_asyncfor(self, arg): """Execute async for statement.""" return await self.ast_for(arg) async def ast_while(self, arg): """Execute while statement.""" while await self.aeval(arg.test): for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): break if isinstance(val, EvalBreak): break if isinstance(val, EvalReturn): return val else: for arg1 in arg.orelse: val = await self.aeval(arg1) if isinstance(val, EvalReturn): return val return None async def ast_classdef(self, arg): """Evaluate class definition.""" bases = [(await self.aeval(base)) for base in arg.bases] if self.curr_func and arg.name in self.curr_func.global_names: sym_table_assign = self.global_sym_table else: sym_table_assign = self.sym_table sym_table_assign[arg.name] = EvalLocalVar(arg.name) sym_table = {} self.sym_table_stack.append(self.sym_table) self.sym_table = sym_table for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalReturn): raise SyntaxError(f"{val.name()} statement outside function") if isinstance(val, EvalStopFlow): raise SyntaxError(f"{val.name()} statement outside loop") self.sym_table = self.sym_table_stack.pop() sym_table["__init__evalfunc_wrap__"] = None if "__init__" in sym_table: sym_table["__init__evalfunc_wrap__"] = sym_table["__init__"] del sym_table["__init__"] sym_table_assign[arg.name].set(type(arg.name, tuple(bases), sym_table)) async def ast_functiondef(self, arg): """Evaluate function definition.""" other_dec = [] dec_name = None pyscript_compile = None for dec in arg.decorator_list: if isinstance(dec, ast.Name) and dec.id in COMP_DECORATORS: dec_name = dec.id elif ( isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id in COMP_DECORATORS ): dec_name = dec.func.id else: other_dec.append(dec) continue if pyscript_compile: raise SyntaxError( f"can only specify single decorator of {', '.join(sorted(COMP_DECORATORS))}" ) pyscript_compile = dec if pyscript_compile: if isinstance(pyscript_compile, ast.Call): if len(pyscript_compile.args) > 0: raise TypeError(f"@{dec_name}() takes 0 positional arguments") if len(pyscript_compile.keywords) > 0: raise TypeError(f"@{dec_name}() takes no keyword arguments") arg.decorator_list = other_dec local_var = None if arg.name in self.sym_table and isinstance(self.sym_table[arg.name], EvalLocalVar): local_var = self.sym_table[arg.name] code = compile(ast.Module(body=[arg], type_ignores=[]), filename=self.filename, mode="exec") exec(code, self.global_sym_table, self.sym_table) # pylint: disable=exec-used func = self.sym_table[arg.name] if dec_name == "pyscript_executor": if not asyncio.iscoroutinefunction(func): def executor_wrap_factory(func): async def executor_wrap(*args, **kwargs): return await Function.hass.async_add_executor_job( functools.partial(func, **kwargs), *args ) return executor_wrap self.sym_table[arg.name] = executor_wrap_factory(func) else: raise TypeError("@pyscript_executor() needs a regular, not async, function") if local_var: self.sym_table[arg.name] = local_var self.sym_table[arg.name].set(func) return func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx) await func.eval_defaults(self) await func.resolve_nonlocals(self) name = func.get_name() dec_trig, dec_other = await func.eval_decorators(self) self.dec_eval_depth += 1 for dec_func in dec_other: func = await self.call_func(dec_func, None, func) if isinstance(func, EvalFuncVar): # set the function name back to its original instead of the decorator function we just called func.set_name(name) func = func.remove_func() dec_trig += func.decorators elif isinstance(func, EvalFunc): func.set_name(name) self.dec_eval_depth -= 1 if isinstance(func, EvalFunc): func.decorators = dec_trig if self.dec_eval_depth == 0: func.trigger_stop() await func.trigger_init(self.global_ctx, name) func_var = EvalFuncVar(func) func_var.set_ast_ctx(self) else: func_var = EvalFuncVar(func) func_var.set_ast_ctx(self) else: func_var = func if self.curr_func and name in self.curr_func.global_names: sym_table = self.global_sym_table else: sym_table = self.sym_table if name in sym_table and isinstance(sym_table[name], EvalLocalVar): sym_table[name].set(func_var) else: sym_table[name] = func_var async def ast_lambda(self, arg): """Evaluate lambda definition by compiling a regular function.""" name = "__lambda_defn_temp__" await self.aeval( ast.FunctionDef( args=arg.args, body=[ast.Return(value=arg.body, lineno=arg.body.lineno, col_offset=arg.body.col_offset)], name=name, decorator_list=[ast.Name(id="pyscript_compile", ctx=ast.Load())], lineno=arg.col_offset, col_offset=arg.col_offset, ) ) func = self.sym_table[name] del self.sym_table[name] return func async def ast_asyncfunctiondef(self, arg): """Evaluate async function definition.""" return await self.ast_functiondef(arg) async def ast_try(self, arg): """Execute try...except statement.""" try: for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): return val if self.exception_obj is not None: raise self.exception_obj except Exception as err: curr_exc = self.exception_curr self.exception_curr = err for handler in arg.handlers: match = False if handler.type: exc_list = await self.aeval(handler.type) if not isinstance(exc_list, tuple): exc_list = [exc_list] for exc in exc_list: if isinstance(err, exc): match = True break else: match = True if match: save_obj = self.exception_obj save_exc_long = self.exception_long save_exc = self.exception self.exception_obj = None self.exception = None self.exception_long = None if handler.name is not None: if handler.name in self.sym_table and isinstance( self.sym_table[handler.name], EvalLocalVar ): self.sym_table[handler.name].set(err) else: self.sym_table[handler.name] = err for arg1 in handler.body: try: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): if handler.name is not None: del self.sym_table[handler.name] self.exception_curr = curr_exc return val except Exception: if self.exception_obj is not None: if handler.name is not None: del self.sym_table[handler.name] self.exception_curr = curr_exc if self.exception_obj == save_obj: self.exception_long = save_exc_long self.exception = save_exc else: self.exception_long = ( save_exc_long + "\n\nDuring handling of the above exception, another exception occurred:\n\n" + self.exception_long ) raise self.exception_obj # pylint: disable=raise-missing-from if handler.name is not None: del self.sym_table[handler.name] break else: self.exception_curr = curr_exc raise err else: for arg1 in arg.orelse: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): return val finally: for arg1 in arg.finalbody: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): return val # pylint: disable=lost-exception,return-in-finally return None async def ast_raise(self, arg): """Execute raise statement.""" if not arg.exc: if not self.exception_curr: raise RuntimeError("No active exception to reraise") exc = self.exception_curr else: exc = await self.aeval(arg.exc) if self.exception_curr: exc.__cause__ = self.exception_curr if arg.cause: cause = await self.aeval(arg.cause) raise exc from cause raise exc async def ast_with(self, arg, async_attr=""): """Execute with statement.""" hit_except = False ctx_list = [] val = None enter_attr = f"__{async_attr}enter__" exit_attr = f"__{async_attr}exit__" try: for item in arg.items: manager = await self.aeval(item.context_expr) ctx_list.append( { "manager": manager, "enter": getattr(type(manager), enter_attr), "exit": getattr(type(manager), exit_attr), "target": item.optional_vars, } ) for ctx in ctx_list: value = await self.call_func(ctx["enter"], enter_attr, ctx["manager"]) if ctx["target"]: await self.recurse_assign(ctx["target"], value) for arg1 in arg.body: val = await self.aeval(arg1) if isinstance(val, EvalStopFlow): break except Exception: hit_except = True exit_ok = True for ctx in reversed(ctx_list): ret = await self.call_func(ctx["exit"], exit_attr, ctx["manager"], *sys.exc_info()) exit_ok = exit_ok and ret if not exit_ok: raise finally: if not hit_except: for ctx in reversed(ctx_list): await self.call_func(ctx["exit"], exit_attr, ctx["manager"], None, None, None) return val async def ast_asyncwith(self, arg): """Execute async with statement.""" return await self.ast_with(arg, async_attr="a") async def ast_pass(self, arg): """Execute pass statement.""" async def ast_expression(self, arg): """Execute expression statement.""" return await self.aeval(arg.body) async def ast_expr(self, arg): """Execute expression statement.""" return await self.aeval(arg.value) async def ast_break(self, arg): """Execute break statement - return special class.""" return EvalBreak() async def ast_continue(self, arg): """Execute continue statement - return special class.""" return EvalContinue() async def ast_return(self, arg): """Execute return statement - return special class.""" return EvalReturn(await self.aeval(arg.value) if arg.value else None) async def ast_global(self, arg): """Execute global statement.""" if not self.curr_func: raise SyntaxError("global statement outside function") for var_name in arg.names: self.curr_func.global_names.add(var_name) async def ast_nonlocal(self, arg): """Execute nonlocal statement.""" if not self.curr_func: raise SyntaxError("nonlocal statement outside function") for var_name in arg.names: self.curr_func.nonlocal_names.add(var_name) async def recurse_assign(self, lhs, val): """Recursive assignment.""" if isinstance(lhs, ast.Tuple): try: vals = [*(iter(val))] except Exception: raise TypeError("cannot unpack non-iterable object") # pylint: disable=raise-missing-from got_star = 0 for lhs_elt in lhs.elts: if isinstance(lhs_elt, ast.Starred): got_star = 1 break if len(lhs.elts) > len(vals) + got_star: if got_star: err_msg = f"at least {len(lhs.elts) - got_star}" else: err_msg = f"{len(lhs.elts)}" raise ValueError(f"too few values to unpack (expected {err_msg})") if len(lhs.elts) < len(vals) and got_star == 0: raise ValueError(f"too many values to unpack (expected {len(lhs.elts)})") val_idx = 0 for lhs_elt in lhs.elts: if isinstance(lhs_elt, ast.Starred): star_len = len(vals) - len(lhs.elts) + 1 star_name = lhs_elt.value.id await self.recurse_assign( ast.Name(id=star_name, ctx=ast.Store()), vals[val_idx : val_idx + star_len], ) val_idx += star_len else: await self.recurse_assign(lhs_elt, vals[val_idx]) val_idx += 1 elif isinstance(lhs, ast.Subscript): var = await self.aeval(lhs.value) if isinstance(lhs.slice, ast.Index): ind = await self.aeval(lhs.slice.value) var[ind] = val elif isinstance(lhs.slice, ast.Slice): lower = await self.aeval(lhs.slice.lower) if lhs.slice.lower else None upper = await self.aeval(lhs.slice.upper) if lhs.slice.upper else None step = await self.aeval(lhs.slice.step) if lhs.slice.step else None var[slice(lower, upper, step)] = val else: var[await self.aeval(lhs.slice)] = val else: var_name = await self.aeval(lhs) if isinstance(var_name, EvalAttrSet): var_name.setattr(val) return if not isinstance(var_name, str): raise NotImplementedError(f"unknown lhs type {lhs} (got {var_name}) in assign") dot_count = var_name.count(".") if dot_count == 1: State.set(var_name, val) return if dot_count == 2: State.setattr(var_name, val) return if dot_count > 0: raise NameError( f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')" ) if self.curr_func and var_name in self.curr_func.global_names: self.global_sym_table[var_name] = val return if var_name in self.sym_table and isinstance(self.sym_table[var_name], EvalLocalVar): self.sym_table[var_name].set(val) else: self.sym_table[var_name] = val async def ast_assign(self, arg): """Execute assignment statement.""" rhs = await self.aeval(arg.value) for target in arg.targets: await self.recurse_assign(target, rhs) async def ast_augassign(self, arg): """Execute augmented assignment statement (lhs = value).""" arg.target.ctx = ast.Load() new_val = await self.aeval(ast.BinOp(left=arg.target, op=arg.op, right=arg.value)) arg.target.ctx = ast.Store() await self.recurse_assign(arg.target, new_val) async def ast_annassign(self, arg): """Execute type hint assignment statement (just ignore the type hint).""" if arg.value is not None: rhs = await self.aeval(arg.value) await self.recurse_assign(arg.target, rhs) async def ast_namedexpr(self, arg): """Execute named expression.""" val = await self.aeval(arg.value) await self.recurse_assign(arg.target, val) return val async def ast_delete(self, arg): """Execute del statement.""" for arg1 in arg.targets: if isinstance(arg1, ast.Subscript): var = await self.aeval(arg1.value) if isinstance(arg1.slice, ast.Index): ind = await self.aeval(arg1.slice.value) for elt in ind if isinstance(ind, list) else [ind]: del var[elt] elif isinstance(arg1.slice, ast.Slice): lower, upper, step = None, None, None if arg1.slice.lower: lower = await self.aeval(arg1.slice.lower) if arg1.slice.upper: upper = await self.aeval(arg1.slice.upper) if arg1.slice.step: step = await self.aeval(arg1.slice.step) del var[slice(lower, upper, step)] else: del var[await self.aeval(arg1.slice)] elif isinstance(arg1, ast.Name): if self.curr_func and arg1.id in self.curr_func.global_names: if arg1.id in self.global_sym_table: del self.global_sym_table[arg1.id] elif arg1.id in self.sym_table: if isinstance(self.sym_table[arg1.id], EvalLocalVar): if self.sym_table[arg1.id].is_defined(): self.sym_table[arg1.id].set_undefined() else: raise NameError(f"name '{arg1.id}' is not defined") else: del self.sym_table[arg1.id] else: raise NameError(f"name '{arg1.id}' is not defined") elif isinstance(arg1, ast.Attribute): var_name = await self.ast_attribute_collapse(arg1, check_undef=False) if not isinstance(var_name, str): raise NameError("state name should be 'domain.entity' or 'domain.entity.attr'") State.delete(var_name) else: raise NotImplementedError(f"unknown target type {arg1} in del") async def ast_assert(self, arg): """Execute assert statement.""" if not await self.aeval(arg.test): if arg.msg: raise AssertionError(await self.aeval(arg.msg)) raise AssertionError async def ast_attribute_collapse(self, arg, check_undef=True): """Combine dotted attributes to allow variable names to have dots.""" # collapse dotted names, eg: # Attribute(value=Attribute(value=Name(id='i', ctx=Load()), attr='j', ctx=Load()), attr='k', ctx=Store()) name = arg.attr val = arg.value while isinstance(val, ast.Attribute): name = val.attr + "." + name val = val.value if isinstance(val, ast.Name): name = val.id + "." + name # ensure the first portion of name is undefined if check_undef and not isinstance( await self.ast_name(ast.Name(id=val.id, ctx=ast.Load())), EvalName ): return None return name return None async def ast_attribute(self, arg): """Apply attributes.""" full_name = await self.ast_attribute_collapse(arg) if full_name is not None: if isinstance(arg.ctx, ast.Store): return full_name val = await self.ast_name(ast.Name(id=full_name, ctx=arg.ctx)) if not isinstance(val, EvalName): return val val = await self.aeval(arg.value) if isinstance(arg.ctx, ast.Store): return EvalAttrSet(val, arg.attr) return getattr(val, arg.attr) async def ast_name(self, arg): """Look up value of identifier on load, or returns name on set.""" if isinstance(arg.ctx, ast.Load): # # check other scopes if required by global declarations # if self.curr_func and arg.id in self.curr_func.global_names: if arg.id in self.global_sym_table: return self.global_sym_table[arg.id] raise NameError(f"global name '{arg.id}' is not defined") # # now check in our current symbol table, and then some other places # if arg.id in self.sym_table: if isinstance(self.sym_table[arg.id], EvalLocalVar): return self.sym_table[arg.id].get() return self.sym_table[arg.id] if arg.id in self.local_sym_table: return self.local_sym_table[arg.id] if arg.id in self.global_sym_table: if self.curr_func and arg.id in self.curr_func.local_names: raise UnboundLocalError(f"local variable '{arg.id}' referenced before assignment") return self.global_sym_table[arg.id] if arg.id in BUILTIN_AST_FUNCS_FACTORY: return BUILTIN_AST_FUNCS_FACTORY[arg.id](self) if hasattr(builtins, arg.id) and arg.id not in BUILTIN_EXCLUDE and arg.id[0] != "_": return getattr(builtins, arg.id) if Function.get(arg.id): return Function.get(arg.id) num_dots = arg.id.count(".") # # any single-dot name could be a state variable # a two-dot name for state.attr needs to exist # if num_dots == 1 or (num_dots == 2 and State.exist(arg.id)): return State.get(arg.id) # # Couldn't find it, so return just the name wrapped in EvalName to # distinguish from a string variable value. This is to support # names with ".", which are joined by ast_attribute # return EvalName(arg.id) return arg.id async def ast_binop(self, arg): """Evaluate binary operators by calling function based on class.""" name = "ast_binop_" + arg.op.__class__.__name__.lower() return await getattr(self, name, self.ast_not_implemented)(arg.left, arg.right) async def ast_binop_add(self, arg0, arg1): """Evaluate binary operator: +.""" return (await self.aeval(arg0)) + (await self.aeval(arg1)) async def ast_binop_sub(self, arg0, arg1): """Evaluate binary operator: -.""" return (await self.aeval(arg0)) - (await self.aeval(arg1)) async def ast_binop_mult(self, arg0, arg1): """Evaluate binary operator: *.""" return (await self.aeval(arg0)) * (await self.aeval(arg1)) async def ast_binop_div(self, arg0, arg1): """Evaluate binary operator: /.""" return (await self.aeval(arg0)) / (await self.aeval(arg1)) async def ast_binop_mod(self, arg0, arg1): """Evaluate binary operator: %.""" return (await self.aeval(arg0)) % (await self.aeval(arg1)) async def ast_binop_pow(self, arg0, arg1): """Evaluate binary operator: **.""" return (await self.aeval(arg0)) ** (await self.aeval(arg1)) async def ast_binop_lshift(self, arg0, arg1): """Evaluate binary operator: <<.""" return (await self.aeval(arg0)) << (await self.aeval(arg1)) async def ast_binop_rshift(self, arg0, arg1): """Evaluate binary operator: >>.""" return (await self.aeval(arg0)) >> (await self.aeval(arg1)) async def ast_binop_bitor(self, arg0, arg1): """Evaluate binary operator: |.""" return (await self.aeval(arg0)) | (await self.aeval(arg1)) async def ast_binop_bitxor(self, arg0, arg1): """Evaluate binary operator: ^.""" return (await self.aeval(arg0)) ^ (await self.aeval(arg1)) async def ast_binop_bitand(self, arg0, arg1): """Evaluate binary operator: &.""" return (await self.aeval(arg0)) & (await self.aeval(arg1)) async def ast_binop_floordiv(self, arg0, arg1): """Evaluate binary operator: //.""" return (await self.aeval(arg0)) // (await self.aeval(arg1)) async def ast_unaryop(self, arg): """Evaluate unary operators by calling function based on class.""" name = "ast_unaryop_" + arg.op.__class__.__name__.lower() return await getattr(self, name, self.ast_not_implemented)(arg.operand) async def ast_unaryop_not(self, arg0): """Evaluate unary operator: not.""" return not (await self.aeval(arg0)) async def ast_unaryop_invert(self, arg0): """Evaluate unary operator: ~.""" return ~(await self.aeval(arg0)) async def ast_unaryop_uadd(self, arg0): """Evaluate unary operator: +.""" return await self.aeval(arg0) async def ast_unaryop_usub(self, arg0): """Evaluate unary operator: -.""" return -(await self.aeval(arg0)) async def ast_compare(self, arg): """Evaluate comparison operators by calling function based on class.""" left = arg.left for cmp_op, right in zip(arg.ops, arg.comparators): name = "ast_cmpop_" + cmp_op.__class__.__name__.lower() val = await getattr(self, name, self.ast_not_implemented)(left, right) if not val: return False left = right return True async def ast_cmpop_eq(self, arg0, arg1): """Evaluate comparison operator: ==.""" return (await self.aeval(arg0)) == (await self.aeval(arg1)) async def ast_cmpop_noteq(self, arg0, arg1): """Evaluate comparison operator: !=.""" return (await self.aeval(arg0)) != (await self.aeval(arg1)) async def ast_cmpop_lt(self, arg0, arg1): """Evaluate comparison operator: <.""" return (await self.aeval(arg0)) < (await self.aeval(arg1)) async def ast_cmpop_lte(self, arg0, arg1): """Evaluate comparison operator: <=.""" return (await self.aeval(arg0)) <= (await self.aeval(arg1)) async def ast_cmpop_gt(self, arg0, arg1): """Evaluate comparison operator: >.""" return (await self.aeval(arg0)) > (await self.aeval(arg1)) async def ast_cmpop_gte(self, arg0, arg1): """Evaluate comparison operator: >=.""" return (await self.aeval(arg0)) >= (await self.aeval(arg1)) async def ast_cmpop_is(self, arg0, arg1): """Evaluate comparison operator: is.""" return (await self.aeval(arg0)) is (await self.aeval(arg1)) async def ast_cmpop_isnot(self, arg0, arg1): """Evaluate comparison operator: is not.""" return (await self.aeval(arg0)) is not (await self.aeval(arg1)) async def ast_cmpop_in(self, arg0, arg1): """Evaluate comparison operator: in.""" return (await self.aeval(arg0)) in (await self.aeval(arg1)) async def ast_cmpop_notin(self, arg0, arg1): """Evaluate comparison operator: not in.""" return (await self.aeval(arg0)) not in (await self.aeval(arg1)) async def ast_boolop(self, arg): """Evaluate boolean operators and and or.""" if isinstance(arg.op, ast.And): val = True for arg1 in arg.values: val = await self.aeval(arg1) if not val: return val return val val = False for arg1 in arg.values: val = await self.aeval(arg1) if val: return val return val async def eval_elt_list(self, elts): """Evaluate and star list elements.""" val = [] for arg in elts: if isinstance(arg, ast.Starred): val += await self.aeval(arg.value) else: val.append(await self.aeval(arg)) return val async def ast_list(self, arg): """Evaluate list.""" if isinstance(arg.ctx, ast.Load): return await self.eval_elt_list(arg.elts) async def loopvar_scope_save(self, generators): """Return current scope variables that match looping target vars.""" # # looping variables are in their own implicit nested scope, so save/restore # variables in the current scope with the same names # lvars = set() for gen in generators: await self.get_names( ast.Assign(targets=[gen.target], value=ast.Constant(value=None)), local_names=lvars ) return lvars, {var: self.sym_table[var] for var in lvars if var in self.sym_table} async def loopvar_scope_restore(self, var_names, save_vars): """Restore current scope variables that match looping target vars.""" for var_name in var_names: if var_name in save_vars: self.sym_table[var_name] = save_vars[var_name] else: try: del self.sym_table[var_name] except KeyError: # If the iterator was empty, the loop variables were never # assigned to, so deleting them will fail. pass async def listcomp_loop(self, generators, elt): """Recursive list comprehension.""" out = [] gen = generators[0] for loop_var in await self.aeval(gen.iter): await self.recurse_assign(gen.target, loop_var) for cond in gen.ifs: if not await self.aeval(cond): break else: if len(generators) == 1: out.append(await self.aeval(elt)) else: out += await self.listcomp_loop(generators[1:], elt) return out async def ast_listcomp(self, arg): """Evaluate list comprehension.""" target_vars, save_values = await self.loopvar_scope_save(arg.generators) result = await self.listcomp_loop(arg.generators, arg.elt) await self.loopvar_scope_restore(target_vars, save_values) return result async def ast_tuple(self, arg): """Evaluate Tuple.""" return tuple(await self.eval_elt_list(arg.elts)) async def ast_dict(self, arg): """Evaluate dict.""" val = {} for key_ast, val_ast in zip(arg.keys, arg.values): this_val = await self.aeval(val_ast) if key_ast is None: val.update(this_val) else: val[await self.aeval(key_ast)] = this_val return val async def dictcomp_loop(self, generators, key, value): """Recursive dict comprehension.""" out = {} gen = generators[0] for loop_var in await self.aeval(gen.iter): await self.recurse_assign(gen.target, loop_var) for cond in gen.ifs: if not await self.aeval(cond): break else: if len(generators) == 1: # # key is evaluated before value starting in 3.8 # key_val = await self.aeval(key) out[key_val] = await self.aeval(value) else: out.update(await self.dictcomp_loop(generators[1:], key, value)) return out async def ast_dictcomp(self, arg): """Evaluate dict comprehension.""" target_vars, save_values = await self.loopvar_scope_save(arg.generators) result = await self.dictcomp_loop(arg.generators, arg.key, arg.value) await self.loopvar_scope_restore(target_vars, save_values) return result async def ast_set(self, arg): """Evaluate set.""" ret = set() for elt in await self.eval_elt_list(arg.elts): ret.add(elt) return ret async def setcomp_loop(self, generators, elt): """Recursive list comprehension.""" out = set() gen = generators[0] for loop_var in await self.aeval(gen.iter): await self.recurse_assign(gen.target, loop_var) for cond in gen.ifs: if not await self.aeval(cond): break else: if len(generators) == 1: out.add(await self.aeval(elt)) else: out.update(await self.setcomp_loop(generators[1:], elt)) return out async def ast_setcomp(self, arg): """Evaluate set comprehension.""" target_vars, save_values = await self.loopvar_scope_save(arg.generators) result = await self.setcomp_loop(arg.generators, arg.elt) await self.loopvar_scope_restore(target_vars, save_values) return result async def ast_subscript(self, arg): """Evaluate subscript.""" var = await self.aeval(arg.value) if isinstance(arg.ctx, ast.Load): if isinstance(arg.slice, ast.Index): return var[await self.aeval(arg.slice)] if isinstance(arg.slice, ast.Slice): lower = (await self.aeval(arg.slice.lower)) if arg.slice.lower else None upper = (await self.aeval(arg.slice.upper)) if arg.slice.upper else None step = (await self.aeval(arg.slice.step)) if arg.slice.step else None return var[slice(lower, upper, step)] return var[await self.aeval(arg.slice)] return None async def ast_index(self, arg): """Evaluate index.""" return await self.aeval(arg.value) async def ast_slice(self, arg): """Evaluate slice.""" return await self.aeval(arg.value) async def ast_call(self, arg): """Evaluate function call.""" func = await self.aeval(arg.func) kwargs = {} for kw_arg in arg.keywords: if kw_arg.arg is None: kwargs.update(await self.aeval(kw_arg.value)) else: kwargs[kw_arg.arg] = await self.aeval(kw_arg.value) args = await self.eval_elt_list(arg.args) # # try to deduce function name, although this only works in simple cases # func_name = None if isinstance(arg.func, ast.Name): func_name = arg.func.id elif isinstance(arg.func, ast.Attribute): func_name = arg.func.attr if isinstance(func, EvalLocalVar): func_name = func.get_name() func = func.get() return await self.call_func(func, func_name, *args, **kwargs) async def call_func(self, func, func_name, *args, **kwargs): """Call a function with the given arguments.""" if func_name is None: try: if isinstance(func, (EvalFunc, EvalFuncVar)): func_name = func.get_name() else: func_name = func.__name__ except Exception: func_name = "" arg_str = ", ".join(['"' + elt + '"' if isinstance(elt, str) else str(elt) for elt in args]) _LOGGER.debug("%s: calling %s(%s, %s)", self.name, func_name, arg_str, kwargs) if isinstance(func, (EvalFunc, EvalFuncVar)): return await func.call(self, *args, **kwargs) if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"): inst = func() # # we use weak references when we bind the method calls to the instance inst; # otherwise these self references cause the object to not be deleted until # it is later garbage collected # inst_weak = weakref.ref(inst) for name in dir(inst): value = getattr(inst, name) if type(value) is not EvalFuncVar: continue setattr(inst, name, EvalFuncVarClassInst(value.get_func(), value.get_ast_ctx(), inst_weak)) if getattr(func, "__init__evalfunc_wrap__") is not None: # # since our __init__ function is async, call the renamed one # await inst.__init__evalfunc_wrap__.call(self, *args, **kwargs) return inst if asyncio.iscoroutinefunction(func): return await func(*args, **kwargs) if callable(func): if func == time.sleep: # pylint: disable=comparison-with-callable _LOGGER.warning( "%s line %s calls blocking time.sleep(); replaced with asyncio.sleep()", self.filename, self.lineno, ) return await asyncio.sleep(*args, **kwargs) return func(*args, **kwargs) raise TypeError(f"'{func_name}' is not callable (got {func})") async def ast_ifexp(self, arg): """Evaluate if expression.""" return await self.aeval(arg.body) if (await self.aeval(arg.test)) else await self.aeval(arg.orelse) async def ast_num(self, arg): """Evaluate number.""" return arg.n async def ast_str(self, arg): """Evaluate string.""" return arg.s async def ast_nameconstant(self, arg): """Evaluate name constant.""" return arg.value async def ast_constant(self, arg): """Evaluate constant.""" return arg.value async def ast_joinedstr(self, arg): """Evaluate joined string.""" val = "" for arg1 in arg.values: this_val = await self.aeval(arg1) val = val + str(this_val) return val async def ast_formattedvalue(self, arg): """Evaluate formatted value.""" val = await self.aeval(arg.value) if arg.format_spec is not None: fmt = await self.aeval(arg.format_spec) return f"{val:{fmt}}" return f"{val}" async def ast_await(self, arg): """Evaluate await expr.""" coro = await self.aeval(arg.value) if coro: return await coro return None async def get_target_names(self, lhs): """Recursively find all the target names mentioned in the AST tree.""" names = set() if isinstance(lhs, ast.Tuple): for lhs_elt in lhs.elts: if isinstance(lhs_elt, ast.Starred): names.add(lhs_elt.value.id) else: names = names.union(await self.get_target_names(lhs_elt)) elif isinstance(lhs, ast.Attribute): var_name = await self.ast_attribute_collapse(lhs, check_undef=False) if isinstance(var_name, str): names.add(var_name) elif isinstance(lhs, ast.Name): names.add(lhs.id) return names async def get_names_set(self, arg, names, nonlocal_names, global_names, local_names): """Recursively find all the names mentioned in the AST tree.""" cls_name = arg.__class__.__name__ if cls_name == "Attribute": full_name = await self.ast_attribute_collapse(arg) if full_name is not None: names.add(full_name) return if cls_name == "Name": names.add(arg.id) return if cls_name == "Nonlocal" and nonlocal_names is not None: for var_name in arg.names: nonlocal_names.add(var_name) names.add(var_name) return if cls_name == "Global" and global_names is not None: for var_name in arg.names: global_names.add(var_name) names.add(var_name) return if local_names is not None: # # find all the local variables by looking for assignments; # also, don't recurse into function definitions # if cls_name == "Assign": for target in arg.targets: for name in await self.get_target_names(target): local_names.add(name) names.add(name) elif cls_name in {"AugAssign", "For", "AsyncFor", "NamedExpr"}: for name in await self.get_target_names(arg.target): local_names.add(name) names.add(name) elif cls_name in {"With", "AsyncWith"}: for item in arg.items: if item.optional_vars: for name in await self.get_target_names(item.optional_vars): local_names.add(name) names.add(name) elif cls_name in {"ListComp", "DictComp", "SetComp"}: target_vars, _ = await self.loopvar_scope_save(arg.generators) for name in target_vars: local_names.add(name) elif cls_name == "Try": for handler in arg.handlers: if handler.name is not None: local_names.add(handler.name) names.add(handler.name) elif cls_name == "Call": await self.get_names_set(arg.func, names, nonlocal_names, global_names, local_names) for this_arg in arg.args: await self.get_names_set(this_arg, names, nonlocal_names, global_names, local_names) for this_arg in arg.keywords or []: await self.get_names_set(this_arg, names, nonlocal_names, global_names, local_names) return elif cls_name in {"FunctionDef", "ClassDef", "AsyncFunctionDef"}: local_names.add(arg.name) names.add(arg.name) for dec in arg.decorator_list: await self.get_names_set(dec, names, nonlocal_names, global_names, local_names) # # find unbound names from the body of the function or class # inner_global, inner_names, inner_local = set(), set(), set() for child in arg.body: await self.get_names_set(child, inner_names, None, inner_global, inner_local) for name in inner_names: if name not in inner_local and name not in inner_global: names.add(name) return elif cls_name == "Delete": for arg1 in arg.targets: if isinstance(arg1, ast.Name): local_names.add(arg1.id) for child in ast.iter_child_nodes(arg): await self.get_names_set(child, names, nonlocal_names, global_names, local_names) async def get_names(self, this_ast=None, nonlocal_names=None, global_names=None, local_names=None): """Return set of all the names mentioned in our AST tree.""" names = set() this_ast = this_ast or self.ast if this_ast: await self.get_names_set(this_ast, names, nonlocal_names, global_names, local_names) return names def parse(self, code_str, filename=None, mode="exec"): """Parse the code_str source code into an AST tree.""" self.exception = None self.exception_obj = None self.exception_long = None self.ast = None if filename is not None: self.filename = filename try: if isinstance(code_str, list): self.code_list = code_str self.code_str = "\n".join(code_str) elif isinstance(code_str, str): self.code_str = code_str self.code_list = code_str.split("\n") else: self.code_str = code_str self.code_list = [] self.ast = ast.parse(self.code_str, filename=self.filename, mode=mode) return True except SyntaxError as err: self.exception_obj = err self.lineno = err.lineno self.col_offset = err.offset - 1 self.exception = f"syntax error {err}" if err.filename == self.filename: self.exception_long = self.format_exc(err, self.lineno, self.col_offset) else: self.exception_long = self.format_exc(err, 1, self.col_offset, code_list=[err.text]) return False except asyncio.CancelledError: raise except Exception as err: self.exception_obj = err self.lineno = 1 self.col_offset = 0 self.exception = f"parsing error {err}" self.exception_long = self.format_exc(err) return False def format_exc(self, exc, lineno=None, col_offset=None, short=False, code_list=None): """Format an multi-line exception message using lineno if available.""" if code_list is None: code_list = self.code_list if lineno is not None and lineno <= len(code_list): if short: mesg = f"In <{self.filename}> line {lineno}:\n" mesg += " " + code_list[lineno - 1] else: mesg = f"Exception in <{self.filename}> line {lineno}:\n" mesg += " " + code_list[lineno - 1] + "\n" if col_offset is not None: mesg += " " + " " * col_offset + "^\n" mesg += f"{type(exc).__name__}: {exc}" else: mesg = f"Exception in <{self.filename}>:\n" mesg += f"{type(exc).__name__}: {exc}" # # to get a more detailed traceback on exception (eg, when chasing an internal # error), add an "import traceback" above, and uncomment this next line # # return mesg + "\n" + traceback.format_exc(-1) return mesg def get_exception(self): """Return the last exception str.""" return self.exception def get_exception_obj(self): """Return the last exception object.""" return self.exception_obj def get_exception_long(self): """Return the last exception in a longer str form.""" return self.exception_long def set_local_sym_table(self, sym_table): """Set the local symbol table.""" self.local_sym_table = sym_table def set_global_ctx(self, global_ctx): """Set the global context.""" self.global_ctx = global_ctx if self.sym_table == self.global_sym_table: self.global_sym_table = global_ctx.get_global_sym_table() self.sym_table = self.global_sym_table else: self.global_sym_table = global_ctx.get_global_sym_table() if len(self.sym_table_stack) > 0: self.sym_table_stack[0] = self.global_sym_table def get_global_ctx(self): """Return the global context.""" return self.global_ctx def get_global_ctx_name(self): """Return the global context name.""" return self.global_ctx.get_name() def set_logger_name(self, name): """Set the context's logger name.""" if self.logger: for handler in self.logger_handlers: self.logger.removeHandler(handler) self.logger_name = name self.logger = logging.getLogger(LOGGER_PATH + "." + name) for handler in self.logger_handlers: self.logger.addHandler(handler) def get_logger_name(self): """Get the context's logger name.""" return self.logger_name def get_logger(self): """Get the context's logger.""" return self.logger def add_logger_handler(self, handler): """Add logger handler to this context.""" self.logger.addHandler(handler) self.logger_handlers.add(handler) def remove_logger_handler(self, handler): """Remove logger handler to this context.""" self.logger.removeHandler(handler) self.logger_handlers.discard(handler) def completions(self, root): """Return potential variable, function or attribute matches.""" words = set() num_period = root.count(".") if num_period >= 1: last_period = root.rfind(".") name = root[0:last_period] attr_root = root[last_period + 1 :] if name in self.global_sym_table: var = self.global_sym_table[name] try: for attr in var.__dict__: if attr.lower().startswith(attr_root) and (attr_root != "" or attr[0:1] != "_"): words.add(f"{name}.{attr}") except Exception: pass for keyw in set(keyword.kwlist) - {"yield"}: if keyw.lower().startswith(root): words.add(keyw) sym_table = BUILTIN_AST_FUNCS_FACTORY.copy() for name, value in builtins.__dict__.items(): if name[0] != "_" and name not in BUILTIN_EXCLUDE: sym_table[name] = value sym_table.update(self.global_sym_table.items()) for name, value in sym_table.items(): if name.lower().startswith(root): words.add(name) return words async def eval(self, new_state_vars=None, merge_local=False): """Execute parsed code, with the optional state variables added to the scope.""" self.exception = None self.exception_obj = None self.exception_long = None if new_state_vars: if not merge_local: self.local_sym_table = {} self.local_sym_table.update(new_state_vars) if self.ast: try: val = await self.aeval(self.ast) if isinstance(val, EvalStopFlow): return None return val except asyncio.CancelledError: raise except Exception as err: if self.exception_long is None: self.exception_long = self.format_exc(err, self.lineno, self.col_offset) return None def dump(self, this_ast=None): """Dump the AST tree for debugging.""" return ast.dump(this_ast if this_ast else self.ast)