2314 lines
90 KiB
Python
2314 lines
90 KiB
Python
"""Python interpreter for pyscript."""
|
|
|
|
import ast
|
|
import asyncio
|
|
import builtins
|
|
from collections import OrderedDict
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import io
|
|
import keyword
|
|
import logging
|
|
import sys
|
|
import time
|
|
import weakref
|
|
|
|
import yaml
|
|
|
|
from homeassistant.const import SERVICE_RELOAD
|
|
from homeassistant.helpers.service import async_set_service_schema
|
|
|
|
from .const import (
|
|
ALLOWED_IMPORTS,
|
|
CONF_ALLOW_ALL_IMPORTS,
|
|
CONFIG_ENTRY,
|
|
DOMAIN,
|
|
LOGGER_PATH,
|
|
SERVICE_JUPYTER_KERNEL_START,
|
|
SERVICE_RESPONSE_NONE,
|
|
)
|
|
from .function import Function
|
|
from .state import State
|
|
|
|
_LOGGER = logging.getLogger(LOGGER_PATH + ".eval")
|
|
|
|
#
|
|
# Built-ins to exclude to improve security or avoid i/o
|
|
#
|
|
BUILTIN_EXCLUDE = {
|
|
"breakpoint",
|
|
"compile",
|
|
"input",
|
|
"memoryview",
|
|
"open",
|
|
"print",
|
|
}
|
|
|
|
TRIG_DECORATORS = {
|
|
"time_trigger",
|
|
"state_trigger",
|
|
"event_trigger",
|
|
"mqtt_trigger",
|
|
"webhook_trigger",
|
|
"state_active",
|
|
"time_active",
|
|
"task_unique",
|
|
}
|
|
|
|
TRIG_SERV_DECORATORS = TRIG_DECORATORS.union({"service"})
|
|
|
|
COMP_DECORATORS = {
|
|
"pyscript_compile",
|
|
"pyscript_executor",
|
|
}
|
|
|
|
TRIGGER_KWARGS = {
|
|
"context",
|
|
"event_type",
|
|
"old_value",
|
|
"payload",
|
|
"payload_obj",
|
|
"qos",
|
|
"topic",
|
|
"trigger_type",
|
|
"trigger_time",
|
|
"var_name",
|
|
"value",
|
|
"webhook_id",
|
|
}
|
|
|
|
WEBHOOK_METHODS = {
|
|
"GET",
|
|
"HEAD",
|
|
"POST",
|
|
"PUT",
|
|
}
|
|
|
|
|
|
def ast_eval_exec_factory(ast_ctx, mode):
|
|
"""Generate a function that executes eval() or exec() with given ast_ctx."""
|
|
|
|
async def eval_func(arg_str, eval_globals=None, eval_locals=None):
|
|
eval_ast = AstEval(ast_ctx.name, ast_ctx.global_ctx)
|
|
eval_ast.parse(arg_str, f"{mode}()", mode=mode)
|
|
if eval_ast.exception_obj:
|
|
raise eval_ast.exception_obj
|
|
eval_ast.local_sym_table = ast_ctx.local_sym_table
|
|
if eval_globals is not None:
|
|
eval_ast.global_sym_table = eval_globals
|
|
if eval_locals is not None:
|
|
eval_ast.sym_table_stack = [eval_globals]
|
|
eval_ast.sym_table = eval_locals
|
|
else:
|
|
eval_ast.sym_table_stack = []
|
|
eval_ast.sym_table = eval_globals
|
|
else:
|
|
eval_ast.sym_table_stack = ast_ctx.sym_table_stack.copy()
|
|
if ast_ctx.sym_table == ast_ctx.global_sym_table:
|
|
eval_ast.sym_table = ast_ctx.sym_table
|
|
else:
|
|
eval_ast.sym_table = ast_ctx.sym_table.copy()
|
|
eval_ast.sym_table.update(ast_ctx.user_locals)
|
|
to_delete = set()
|
|
for var, value in eval_ast.sym_table.items():
|
|
if isinstance(value, EvalLocalVar):
|
|
if value.is_defined():
|
|
eval_ast.sym_table[var] = value.get()
|
|
else:
|
|
to_delete.add(var)
|
|
for var in to_delete:
|
|
del eval_ast.sym_table[var]
|
|
|
|
eval_ast.curr_func = None
|
|
try:
|
|
eval_result = await eval_ast.aeval(eval_ast.ast)
|
|
except Exception as err:
|
|
ast_ctx.exception_obj = err
|
|
ast_ctx.exception = f"Exception in {ast_ctx.filename} line {ast_ctx.lineno} column {ast_ctx.col_offset}: {eval_ast.exception}"
|
|
ast_ctx.exception_long = (
|
|
ast_ctx.format_exc(err, ast_ctx.lineno, ast_ctx.col_offset, short=True)
|
|
+ "\n"
|
|
+ eval_ast.exception_long
|
|
)
|
|
raise
|
|
#
|
|
# save variables only in the locals scope
|
|
#
|
|
if eval_globals is None and eval_ast.sym_table != ast_ctx.sym_table:
|
|
for var, value in eval_ast.sym_table.items():
|
|
if var in ast_ctx.global_sym_table and value == ast_ctx.global_sym_table[var]:
|
|
continue
|
|
if var not in ast_ctx.sym_table and (
|
|
ast_ctx.curr_func is None or var not in ast_ctx.curr_func.local_names
|
|
):
|
|
ast_ctx.user_locals[var] = value
|
|
return eval_result
|
|
|
|
return eval_func
|
|
|
|
|
|
def ast_eval_factory(ast_ctx):
|
|
"""Generate a function that executes eval() with given ast_ctx."""
|
|
return ast_eval_exec_factory(ast_ctx, "eval")
|
|
|
|
|
|
def ast_exec_factory(ast_ctx):
|
|
"""Generate a function that executes exec() with given ast_ctx."""
|
|
return ast_eval_exec_factory(ast_ctx, "exec")
|
|
|
|
|
|
def ast_globals_factory(ast_ctx):
|
|
"""Generate a globals() function with given ast_ctx."""
|
|
|
|
async def globals_func():
|
|
return ast_ctx.global_sym_table
|
|
|
|
return globals_func
|
|
|
|
|
|
def ast_locals_factory(ast_ctx):
|
|
"""Generate a locals() function with given ast_ctx."""
|
|
|
|
async def locals_func():
|
|
if ast_ctx.sym_table == ast_ctx.global_sym_table:
|
|
return ast_ctx.sym_table
|
|
local_sym_table = ast_ctx.sym_table.copy()
|
|
local_sym_table.update(ast_ctx.user_locals)
|
|
to_delete = set()
|
|
for var, value in local_sym_table.items():
|
|
if isinstance(value, EvalLocalVar):
|
|
if value.is_defined():
|
|
local_sym_table[var] = value.get()
|
|
else:
|
|
to_delete.add(var)
|
|
for var in to_delete:
|
|
del local_sym_table[var]
|
|
return local_sym_table
|
|
|
|
return locals_func
|
|
|
|
|
|
#
|
|
# Built-in functions that are also passed the ast context
|
|
#
|
|
BUILTIN_AST_FUNCS_FACTORY = {
|
|
"eval": ast_eval_factory,
|
|
"exec": ast_exec_factory,
|
|
"globals": ast_globals_factory,
|
|
"locals": ast_locals_factory,
|
|
}
|
|
|
|
|
|
#
|
|
# Objects returned by return, break and continue statements that change execution flow,
|
|
# or objects returned that capture particular information
|
|
#
|
|
class EvalStopFlow:
|
|
"""Denotes a statement or action that stops execution flow, eg: return, break etc."""
|
|
|
|
_name = None
|
|
|
|
def name(self):
|
|
"""Return short name."""
|
|
return self._name
|
|
|
|
|
|
class EvalReturn(EvalStopFlow):
|
|
"""Return statement."""
|
|
|
|
_name = "return"
|
|
|
|
def __init__(self, value):
|
|
"""Initialize return statement value."""
|
|
self.value = value
|
|
|
|
|
|
class EvalBreak(EvalStopFlow):
|
|
"""Break statement."""
|
|
|
|
_name = "break"
|
|
|
|
|
|
class EvalContinue(EvalStopFlow):
|
|
"""Continue statement."""
|
|
|
|
_name = "continue"
|
|
|
|
|
|
class EvalLocalVar:
|
|
"""Wrapper for local variable symtable entry."""
|
|
|
|
def __init__(self, name, **kwargs):
|
|
"""Initialize value of local symbol."""
|
|
self.name = name
|
|
self.defined = False
|
|
if "value" in kwargs:
|
|
self.value = kwargs["value"]
|
|
self.defined = True
|
|
|
|
def get(self):
|
|
"""Get value of local symbol."""
|
|
if not self.defined:
|
|
raise NameError(f"name '{self.name}' is not defined")
|
|
return self.value
|
|
|
|
def get_name(self):
|
|
"""Get name of local symbol."""
|
|
return self.name
|
|
|
|
def set(self, value):
|
|
"""Set value of local symbol."""
|
|
self.value = value
|
|
self.defined = True
|
|
|
|
def is_defined(self):
|
|
"""Return whether value is defined."""
|
|
return self.defined
|
|
|
|
def set_undefined(self):
|
|
"""Set local symbol to undefined."""
|
|
self.defined = False
|
|
|
|
def __getattr__(self, attr):
|
|
"""Get attribute of local variable."""
|
|
if not self.defined:
|
|
raise NameError(f"name '{self.name}' is not defined")
|
|
return getattr(self.value, attr)
|
|
|
|
def __repr__(self):
|
|
"""Generate string with address and value."""
|
|
return f"EvalLocalVar @{hex(id(self))} = {self.value if self.defined else 'undefined'}"
|
|
|
|
|
|
class EvalName:
|
|
"""Identifier that hasn't yet been resolved."""
|
|
|
|
def __init__(self, name):
|
|
"""Initialize identifier to name."""
|
|
self.name = name
|
|
|
|
def __getattr__(self, attr):
|
|
"""Get attribute for EvalName."""
|
|
raise NameError(f"name '{self.name}.{attr}' is not defined")
|
|
|
|
|
|
class EvalAttrSet:
|
|
"""Class for object and attribute on lhs of assignment."""
|
|
|
|
def __init__(self, obj, attr):
|
|
"""Initialize identifier to name."""
|
|
self.obj = obj
|
|
self.attr = attr
|
|
|
|
def setattr(self, value):
|
|
"""Set the attribute value."""
|
|
setattr(self.obj, self.attr, value)
|
|
|
|
def getattr(self):
|
|
"""Get the attribute value."""
|
|
return getattr(self.obj, self.attr)
|
|
|
|
|
|
class EvalFunc:
|
|
"""Class for a callable pyscript function."""
|
|
|
|
def __init__(self, func_def, code_list, code_str, global_ctx):
|
|
"""Initialize a function calling context."""
|
|
self.func_def = func_def
|
|
self.name = func_def.name
|
|
self.global_ctx = global_ctx
|
|
self.global_ctx_name = global_ctx.get_name()
|
|
self.logger = logging.getLogger(LOGGER_PATH + "." + self.global_ctx_name)
|
|
self.defaults = []
|
|
self.kw_defaults = []
|
|
self.decorators = []
|
|
self.global_names = set()
|
|
self.nonlocal_names = set()
|
|
self.local_names = None
|
|
self.local_sym_table = {}
|
|
self.doc_string = ast.get_docstring(func_def)
|
|
self.num_posonly_arg = len(self.func_def.args.posonlyargs)
|
|
self.num_posn_arg = self.num_posonly_arg + len(self.func_def.args.args) - len(self.defaults)
|
|
self.code_list = code_list
|
|
self.code_str = code_str
|
|
self.exception = None
|
|
self.exception_obj = None
|
|
self.exception_long = None
|
|
self.trigger = []
|
|
self.trigger_service = set()
|
|
self.has_closure = False
|
|
|
|
def get_name(self):
|
|
"""Return the function name."""
|
|
return self.name
|
|
|
|
def set_name(self, name):
|
|
"""Set the function name."""
|
|
self.name = name
|
|
|
|
async def eval_defaults(self, ast_ctx):
|
|
"""Evaluate the default function arguments."""
|
|
self.defaults = []
|
|
for val in self.func_def.args.defaults:
|
|
self.defaults.append(await ast_ctx.aeval(val))
|
|
self.num_posn_arg = self.num_posonly_arg + len(self.func_def.args.args) - len(self.defaults)
|
|
self.kw_defaults = []
|
|
for val in self.func_def.args.kw_defaults:
|
|
self.kw_defaults.append({"ok": bool(val), "val": None if not val else await ast_ctx.aeval(val)})
|
|
|
|
async def trigger_init(self, trig_ctx, func_name):
|
|
"""Initialize decorator triggers for this function."""
|
|
trig_args = {}
|
|
trig_decs = {}
|
|
trig_ctx_name = trig_ctx.get_name()
|
|
self.logger = logging.getLogger(LOGGER_PATH + "." + trig_ctx_name)
|
|
self.global_ctx.set_logger_name(trig_ctx_name)
|
|
self.global_ctx_name = trig_ctx_name
|
|
got_reqd_dec = False
|
|
exc_mesg = f"function '{func_name}' defined in {trig_ctx_name}"
|
|
trig_decorators_reqd = {
|
|
"event_trigger",
|
|
"mqtt_trigger",
|
|
"state_trigger",
|
|
"time_trigger",
|
|
"webhook_trigger",
|
|
}
|
|
arg_check = {
|
|
"event_trigger": {"arg_cnt": {1, 2, 3}, "rep_ok": True},
|
|
"mqtt_trigger": {"arg_cnt": {1, 2, 3}, "rep_ok": True},
|
|
"state_active": {"arg_cnt": {1}},
|
|
"state_trigger": {"arg_cnt": {"*"}, "type": {list, set}, "rep_ok": True},
|
|
"service": {"arg_cnt": {0, "*"}},
|
|
"task_unique": {"arg_cnt": {1, 2}},
|
|
"time_active": {"arg_cnt": {"*"}},
|
|
"time_trigger": {"arg_cnt": {0, "*"}, "rep_ok": True},
|
|
"webhook_trigger": {"arg_cnt": {1, 2}, "rep_ok": True},
|
|
}
|
|
kwarg_check = {
|
|
"event_trigger": {"kwargs": {dict}},
|
|
"mqtt_trigger": {"kwargs": {dict}},
|
|
"time_trigger": {"kwargs": {dict}},
|
|
"task_unique": {"kill_me": {bool, int}},
|
|
"time_active": {"hold_off": {int, float}},
|
|
"service": {"supports_response": {str}},
|
|
"state_trigger": {
|
|
"kwargs": {dict},
|
|
"state_hold": {int, float},
|
|
"state_check_now": {bool, int},
|
|
"state_hold_false": {int, float},
|
|
"watch": {set, list},
|
|
},
|
|
"webhook_trigger": {
|
|
"kwargs": {dict},
|
|
"local_only": {bool},
|
|
"methods": {list, set},
|
|
},
|
|
}
|
|
|
|
for dec in self.decorators:
|
|
dec_name, dec_args, dec_kwargs = dec[0], dec[1], dec[2]
|
|
if dec_name not in TRIG_SERV_DECORATORS:
|
|
raise SyntaxError(f"{exc_mesg}: unknown decorator @{dec_name}")
|
|
if dec_name in trig_decorators_reqd:
|
|
got_reqd_dec = True
|
|
arg_info = arg_check.get(dec_name, {})
|
|
#
|
|
# check that we have the right number of arguments, and that they are
|
|
# strings
|
|
#
|
|
arg_cnt = arg_info["arg_cnt"]
|
|
if dec_args is None and 0 not in arg_cnt:
|
|
raise TypeError(f"{exc_mesg}: decorator @{dec_name} needs at least one argument")
|
|
if dec_args:
|
|
if "*" not in arg_cnt and len(dec_args) not in arg_cnt:
|
|
raise TypeError(
|
|
f"{exc_mesg}: decorator @{dec_name} got {len(dec_args)}"
|
|
f" argument{'s' if len(dec_args) > 1 else ''}, expected"
|
|
f" {' or '.join([str(cnt) for cnt in sorted(arg_cnt)])}"
|
|
)
|
|
for arg_num, arg in enumerate(dec_args):
|
|
if isinstance(arg, str):
|
|
continue
|
|
mesg = "string"
|
|
if "type" in arg_info:
|
|
if type(arg) in arg_info["type"]:
|
|
for val in arg:
|
|
if not isinstance(val, str):
|
|
break
|
|
else:
|
|
continue
|
|
mesg += ", or " + ", or ".join(
|
|
sorted([ok_type.__name__ for ok_type in arg_info["type"]])
|
|
)
|
|
raise TypeError(
|
|
f"{exc_mesg}: decorator @{dec_name} argument {arg_num + 1} should be a {mesg}"
|
|
)
|
|
if arg_cnt == {1}:
|
|
dec_args = dec_args[0]
|
|
|
|
if dec_name not in kwarg_check and dec_kwargs is not None:
|
|
raise TypeError(f"{exc_mesg}: decorator @{dec_name} doesn't take keyword arguments")
|
|
if dec_kwargs is None:
|
|
dec_kwargs = {}
|
|
if dec_name in kwarg_check:
|
|
allowed = kwarg_check[dec_name]
|
|
for arg, value in dec_kwargs.items():
|
|
if arg not in allowed:
|
|
raise TypeError(
|
|
f"{exc_mesg}: decorator @{dec_name} invalid keyword argument '{arg}'"
|
|
)
|
|
if value is None or type(value) in allowed[arg]:
|
|
continue
|
|
ok_types = " or ".join(sorted([t.__name__ for t in allowed[arg]]))
|
|
raise TypeError(
|
|
f"{exc_mesg}: decorator @{dec_name} keyword '{arg}' should be type {ok_types}"
|
|
)
|
|
if dec_name == "service":
|
|
desc = self.doc_string
|
|
if desc is None or desc == "":
|
|
desc = f"pyscript function {func_name}()"
|
|
desc = desc.lstrip(" \n\r")
|
|
if desc.startswith("yaml"):
|
|
try:
|
|
desc = desc[4:].lstrip(" \n\r")
|
|
file_desc = io.StringIO(desc)
|
|
service_desc = yaml.load(file_desc, Loader=yaml.BaseLoader) or OrderedDict()
|
|
file_desc.close()
|
|
except Exception as exc:
|
|
self.logger.error(
|
|
"Unable to decode yaml doc_string for %s(): %s",
|
|
func_name,
|
|
str(exc),
|
|
)
|
|
raise exc
|
|
else:
|
|
fields = OrderedDict()
|
|
for arg in self.get_positional_args():
|
|
fields[arg] = OrderedDict(description=f"argument {arg}")
|
|
service_desc = {"description": desc, "fields": fields}
|
|
|
|
def pyscript_service_factory(func_name, func):
|
|
async def pyscript_service_handler(call):
|
|
"""Handle python script service calls."""
|
|
# self.logger.debug("service call to %s", func_name)
|
|
#
|
|
# use a new AstEval context so it can run fully independently
|
|
# of other instances (except for global_ctx which is common)
|
|
#
|
|
ast_ctx = AstEval(f"{trig_ctx_name}.{func_name}", self.global_ctx)
|
|
Function.install_ast_funcs(ast_ctx)
|
|
func_args = {
|
|
"trigger_type": "service",
|
|
"context": call.context,
|
|
}
|
|
func_args.update(call.data)
|
|
|
|
async def do_service_call(func, ast_ctx, data):
|
|
retval = await func.call(ast_ctx, **data)
|
|
if ast_ctx.get_exception_obj():
|
|
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
|
|
return retval
|
|
|
|
task = Function.create_task(do_service_call(func, ast_ctx, func_args))
|
|
await task
|
|
return task.result()
|
|
|
|
return pyscript_service_handler
|
|
|
|
for srv_name in dec_args if dec_args else [f"{DOMAIN}.{func_name}"]:
|
|
if type(srv_name) is not str or srv_name.count(".") != 1:
|
|
raise ValueError(f"{exc_mesg}: @service argument must be a string with one period")
|
|
domain, name = srv_name.split(".", 1)
|
|
if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START):
|
|
raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service")
|
|
Function.service_register(
|
|
trig_ctx_name,
|
|
domain,
|
|
name,
|
|
pyscript_service_factory(func_name, self),
|
|
dec_kwargs.get("supports_response", SERVICE_RESPONSE_NONE),
|
|
)
|
|
async_set_service_schema(Function.hass, domain, name, service_desc)
|
|
self.trigger_service.add(srv_name)
|
|
continue
|
|
|
|
if dec_name == "webhook_trigger" and "methods" in dec_kwargs:
|
|
if len(bad := set(dec_kwargs["methods"]).difference(WEBHOOK_METHODS)) > 0:
|
|
raise TypeError(f"{exc_mesg}: {bad} aren't valid {dec_name} methods")
|
|
|
|
if dec_name not in trig_decs:
|
|
trig_decs[dec_name] = []
|
|
if len(trig_decs[dec_name]) > 0 and "rep_ok" not in arg_info:
|
|
raise SyntaxError(f"{exc_mesg}: decorator @{dec_name} can only be used once")
|
|
trig_decs[dec_name].append({"args": dec_args, "kwargs": dec_kwargs})
|
|
|
|
if not got_reqd_dec and len(trig_decs) > 0:
|
|
self.logger.error(
|
|
"%s defined in %s: needs at least one trigger decorator (ie: %s)",
|
|
func_name,
|
|
trig_ctx_name,
|
|
", ".join(sorted(trig_decorators_reqd)),
|
|
)
|
|
return
|
|
|
|
if len(trig_decs) == 0:
|
|
if len(self.trigger_service) > 0:
|
|
trig_ctx.trigger_register(self)
|
|
return
|
|
|
|
#
|
|
# start one or more triggers until they are all consumed
|
|
# each trigger task can handle at most one of each type of
|
|
# trigger; all get the same state_active, time_active and
|
|
# task_unique decorators
|
|
#
|
|
while True:
|
|
trig_args = {
|
|
"action": self,
|
|
"global_sym_table": self.global_ctx.global_sym_table,
|
|
}
|
|
got_trig = False
|
|
for trig in trig_decorators_reqd:
|
|
if trig not in trig_decs or len(trig_decs[trig]) == 0:
|
|
continue
|
|
trig_args[trig] = trig_decs[trig].pop(0)
|
|
got_trig = True
|
|
if not got_trig:
|
|
break
|
|
for dec_name in ["state_active", "time_active", "task_unique"]:
|
|
if dec_name in trig_decs:
|
|
trig_args[dec_name] = trig_decs[dec_name][0]
|
|
|
|
self.trigger.append(trig_ctx.get_trig_info(f"{trig_ctx_name}.{func_name}", trig_args))
|
|
|
|
if trig_ctx.trigger_register(self):
|
|
self.trigger_start()
|
|
|
|
def trigger_start(self):
|
|
"""Start any triggers for this function."""
|
|
for trigger in self.trigger:
|
|
trigger.start()
|
|
|
|
def trigger_stop(self):
|
|
"""Stop any triggers for this function."""
|
|
for trigger in self.trigger:
|
|
trigger.stop()
|
|
self.trigger = []
|
|
for srv_name in self.trigger_service:
|
|
domain, name = srv_name.split(".", 1)
|
|
Function.service_remove(self.global_ctx_name, domain, name)
|
|
self.trigger_service = set()
|
|
|
|
async def eval_decorators(self, ast_ctx):
|
|
"""Evaluate the function decorators arguments."""
|
|
code_str, code_list = ast_ctx.code_str, ast_ctx.code_list
|
|
ast_ctx.code_str, ast_ctx.code_list = self.code_str, self.code_list
|
|
|
|
dec_other = []
|
|
dec_trig = []
|
|
for dec in self.func_def.decorator_list:
|
|
if (
|
|
isinstance(dec, ast.Call)
|
|
and isinstance(dec.func, ast.Name)
|
|
and dec.func.id in TRIG_SERV_DECORATORS
|
|
):
|
|
args = await ast_ctx.eval_elt_list(dec.args)
|
|
kwargs = {keyw.arg: await ast_ctx.aeval(keyw.value) for keyw in dec.keywords}
|
|
dec_trig.append([dec.func.id, args, kwargs if len(kwargs) > 0 else None])
|
|
elif isinstance(dec, ast.Name) and dec.id in TRIG_SERV_DECORATORS:
|
|
dec_trig.append([dec.id, None, None])
|
|
else:
|
|
dec_other.append(await ast_ctx.aeval(dec))
|
|
|
|
ast_ctx.code_str, ast_ctx.code_list = code_str, code_list
|
|
return dec_trig, reversed(dec_other)
|
|
|
|
async def resolve_nonlocals(self, ast_ctx):
|
|
"""Tag local variables and resolve nonlocals."""
|
|
|
|
#
|
|
# determine the list of local variables, nonlocal and global
|
|
# arguments are local variables too
|
|
#
|
|
args = self.get_positional_args()
|
|
if self.func_def.args.vararg:
|
|
args.append(self.func_def.args.vararg.arg)
|
|
if self.func_def.args.kwarg:
|
|
args.append(self.func_def.args.kwarg.arg)
|
|
for kwonlyarg in self.func_def.args.kwonlyargs:
|
|
args.append(kwonlyarg.arg)
|
|
nonlocal_names = set()
|
|
global_names = set()
|
|
var_names = set(args)
|
|
self.local_names = set(args)
|
|
for stmt in self.func_def.body:
|
|
self.has_closure = self.has_closure or await self.check_for_closure(stmt)
|
|
var_names = var_names.union(
|
|
await ast_ctx.get_names(
|
|
stmt,
|
|
nonlocal_names=nonlocal_names,
|
|
global_names=global_names,
|
|
local_names=self.local_names,
|
|
)
|
|
)
|
|
for var_name in var_names:
|
|
got_dot = var_name.find(".")
|
|
if got_dot >= 0:
|
|
var_name = var_name[0:got_dot]
|
|
|
|
if var_name in global_names:
|
|
continue
|
|
|
|
if var_name in self.local_names and var_name not in nonlocal_names:
|
|
if self.has_closure:
|
|
self.local_sym_table[var_name] = EvalLocalVar(var_name)
|
|
continue
|
|
|
|
if var_name in nonlocal_names:
|
|
sym_table_idx = 1
|
|
else:
|
|
sym_table_idx = 0
|
|
for sym_table in reversed(ast_ctx.sym_table_stack[sym_table_idx:] + [ast_ctx.sym_table]):
|
|
if var_name in sym_table and isinstance(sym_table[var_name], EvalLocalVar):
|
|
self.local_sym_table[var_name] = sym_table[var_name]
|
|
break
|
|
else:
|
|
if var_name in nonlocal_names:
|
|
val = await ast_ctx.ast_name(ast.Name(id=var_name, ctx=ast.Load()))
|
|
if isinstance(val, EvalName) and got_dot < 0:
|
|
raise SyntaxError(f"no binding for nonlocal '{var_name}' found")
|
|
|
|
def get_decorators(self):
|
|
"""Return the function decorators."""
|
|
return self.decorators
|
|
|
|
def get_doc_string(self):
|
|
"""Return the function doc_string."""
|
|
return self.doc_string
|
|
|
|
def get_positional_args(self):
|
|
"""Return the function positional arguments."""
|
|
args = []
|
|
for arg in self.func_def.args.posonlyargs + self.func_def.args.args:
|
|
args.append(arg.arg)
|
|
return args
|
|
|
|
async def try_aeval(self, ast_ctx, arg):
|
|
"""Call self.aeval and capture exceptions."""
|
|
try:
|
|
return await ast_ctx.aeval(arg)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as err:
|
|
if ast_ctx.exception_long is None:
|
|
ast_ctx.exception_long = ast_ctx.format_exc(err, arg.lineno, arg.col_offset)
|
|
|
|
async def call(self, ast_ctx, *args, **kwargs):
|
|
"""Call the function with the given context and arguments."""
|
|
sym_table = {}
|
|
if args is None:
|
|
args = []
|
|
kwargs = kwargs.copy() if kwargs else {}
|
|
bad_kwargs = []
|
|
for i, func_def_arg in enumerate(self.func_def.args.posonlyargs + self.func_def.args.args):
|
|
var_name = func_def_arg.arg
|
|
val = None
|
|
if i < len(args):
|
|
val = args[i]
|
|
if var_name in kwargs:
|
|
raise TypeError(f"{self.name}() got multiple values for argument '{var_name}'")
|
|
elif var_name in kwargs:
|
|
if i < self.num_posonly_arg:
|
|
bad_kwargs.append(var_name)
|
|
val = kwargs[var_name]
|
|
del kwargs[var_name]
|
|
elif self.num_posn_arg <= i < len(self.defaults) + self.num_posn_arg:
|
|
val = self.defaults[i - self.num_posn_arg]
|
|
else:
|
|
raise TypeError(
|
|
f"{self.name}() missing {self.num_posn_arg - i} required positional arguments"
|
|
)
|
|
sym_table[var_name] = val
|
|
if len(bad_kwargs) > 0:
|
|
raise TypeError(
|
|
f"{self.name}() got some positional-only arguments passed as keyword arguments: '{', '.join(bad_kwargs)}'"
|
|
)
|
|
|
|
for i, kwonlyarg in enumerate(self.func_def.args.kwonlyargs):
|
|
var_name = kwonlyarg.arg
|
|
if var_name in kwargs:
|
|
val = kwargs[var_name]
|
|
del kwargs[var_name]
|
|
elif i < len(self.kw_defaults) and self.kw_defaults[i]["ok"]:
|
|
val = self.kw_defaults[i]["val"]
|
|
else:
|
|
raise TypeError(f"{self.name}() missing required keyword-only arguments")
|
|
sym_table[var_name] = val
|
|
if self.func_def.args.kwarg:
|
|
sym_table[self.func_def.args.kwarg.arg] = kwargs
|
|
elif not set(kwargs.keys()).issubset(TRIGGER_KWARGS):
|
|
# don't raise an exception for extra trigger keyword parameters;
|
|
# it's difficult to apply this exception to just trigger functions
|
|
# since they could have non-trigger decorators too
|
|
unexpected = ", ".join(sorted(set(kwargs.keys()) - TRIGGER_KWARGS))
|
|
raise TypeError(f"{self.name}() called with unexpected keyword arguments: {unexpected}")
|
|
num_posn = self.num_posonly_arg + len(self.func_def.args.args)
|
|
if self.func_def.args.vararg:
|
|
if len(args) > num_posn:
|
|
sym_table[self.func_def.args.vararg.arg] = tuple(args[num_posn:])
|
|
else:
|
|
sym_table[self.func_def.args.vararg.arg] = ()
|
|
elif len(args) > num_posn:
|
|
raise TypeError(f"{self.name}() called with too many positional arguments")
|
|
for name, value in self.local_sym_table.items():
|
|
if name in sym_table:
|
|
sym_table[name] = EvalLocalVar(name, value=sym_table[name])
|
|
elif value.is_defined():
|
|
sym_table[name] = value
|
|
else:
|
|
sym_table[name] = EvalLocalVar(name)
|
|
if ast_ctx.global_ctx != self.global_ctx:
|
|
#
|
|
# switch to the global symbol table in the global context
|
|
# where the function was defined
|
|
#
|
|
prev_sym_table = [
|
|
ast_ctx.global_sym_table,
|
|
ast_ctx.sym_table,
|
|
ast_ctx.sym_table_stack,
|
|
ast_ctx.global_ctx,
|
|
]
|
|
ast_ctx.global_sym_table = self.global_ctx.get_global_sym_table()
|
|
ast_ctx.sym_table_stack = [ast_ctx.global_sym_table]
|
|
ast_ctx.global_ctx = self.global_ctx
|
|
else:
|
|
ast_ctx.sym_table_stack.append(ast_ctx.sym_table)
|
|
prev_sym_table = None
|
|
ast_ctx.sym_table = sym_table
|
|
code_str, code_list = ast_ctx.code_str, ast_ctx.code_list
|
|
ast_ctx.code_str, ast_ctx.code_list = self.code_str, self.code_list
|
|
self.exception = None
|
|
self.exception_obj = None
|
|
self.exception_long = None
|
|
prev_func = ast_ctx.curr_func
|
|
save_user_locals = ast_ctx.user_locals
|
|
ast_ctx.user_locals = {}
|
|
ast_ctx.curr_func = self
|
|
del args, kwargs
|
|
for arg1 in self.func_def.body:
|
|
val = await self.try_aeval(ast_ctx, arg1)
|
|
if isinstance(val, EvalReturn):
|
|
val = val.value
|
|
break
|
|
# return None at end if there isn't a return
|
|
val = None
|
|
if ast_ctx.get_exception_obj():
|
|
break
|
|
ast_ctx.curr_func = prev_func
|
|
ast_ctx.user_locals = save_user_locals
|
|
ast_ctx.code_str, ast_ctx.code_list = code_str, code_list
|
|
if prev_sym_table is not None:
|
|
(
|
|
ast_ctx.global_sym_table,
|
|
ast_ctx.sym_table,
|
|
ast_ctx.sym_table_stack,
|
|
ast_ctx.global_ctx,
|
|
) = prev_sym_table
|
|
else:
|
|
ast_ctx.sym_table = ast_ctx.sym_table_stack.pop()
|
|
return val
|
|
|
|
async def check_for_closure(self, arg):
|
|
"""Recursively check ast tree arg and return True if there is an inner function or class."""
|
|
if isinstance(arg, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)):
|
|
return True
|
|
for child in ast.iter_child_nodes(arg):
|
|
if await self.check_for_closure(child):
|
|
return True
|
|
return False
|
|
|
|
|
|
class EvalFuncVar:
|
|
"""Class for a callable pyscript function."""
|
|
|
|
def __init__(self, func):
|
|
"""Initialize instance with given EvalFunc function."""
|
|
self.func = func
|
|
self.ast_ctx = None
|
|
|
|
def get_func(self):
|
|
"""Return the EvalFunc function."""
|
|
return self.func
|
|
|
|
def remove_func(self):
|
|
"""Remove and return the EvalFunc function."""
|
|
func = self.func
|
|
self.func = None
|
|
return func
|
|
|
|
async def call(self, ast_ctx, *args, **kwargs):
|
|
"""Call the EvalFunc function."""
|
|
return await self.func.call(ast_ctx, *args, **kwargs)
|
|
|
|
def get_name(self):
|
|
"""Return the function name."""
|
|
return self.func.get_name()
|
|
|
|
def set_name(self, name):
|
|
"""Set the function name."""
|
|
self.func.set_name(name)
|
|
|
|
def set_ast_ctx(self, ast_ctx):
|
|
"""Set the ast context."""
|
|
self.ast_ctx = ast_ctx
|
|
|
|
def get_ast_ctx(self):
|
|
"""Return the ast context."""
|
|
return self.ast_ctx
|
|
|
|
def __del__(self):
|
|
"""On deletion, stop any triggers for this function."""
|
|
if self.func:
|
|
self.func.trigger_stop()
|
|
|
|
async def __call__(self, *args, **kwargs):
|
|
"""Call the EvalFunc function using our saved ast ctx."""
|
|
return await self.func.call(self.ast_ctx, *args, **kwargs)
|
|
|
|
|
|
class EvalFuncVarClassInst(EvalFuncVar):
|
|
"""Class for a callable pyscript class instance function."""
|
|
|
|
def __init__(self, func, ast_ctx, class_inst_weak):
|
|
"""Initialize instance with given EvalFunc function."""
|
|
super().__init__(func)
|
|
self.ast_ctx = ast_ctx
|
|
self.class_inst_weak = class_inst_weak
|
|
|
|
async def call(self, ast_ctx, *args, **kwargs):
|
|
"""Call the EvalFunc function."""
|
|
return await self.func.call(ast_ctx, self.class_inst_weak(), *args, **kwargs)
|
|
|
|
async def __call__(self, *args, **kwargs):
|
|
"""Call the function using our saved ast ctx and class instance."""
|
|
return await self.func.call(self.ast_ctx, self.class_inst_weak(), *args, **kwargs)
|
|
|
|
|
|
class AstEval:
|
|
"""Python interpreter AST object evaluator."""
|
|
|
|
def __init__(self, name, global_ctx, logger_name=None):
|
|
"""Initialize an interpreter execution context."""
|
|
self.name = name
|
|
self.str = None
|
|
self.ast = None
|
|
self.global_ctx = global_ctx
|
|
self.global_sym_table = global_ctx.get_global_sym_table() if global_ctx else {}
|
|
self.sym_table_stack = []
|
|
self.sym_table = self.global_sym_table
|
|
self.local_sym_table = {}
|
|
self.user_locals = {}
|
|
self.curr_func = None
|
|
self.filename = name
|
|
self.code_str = None
|
|
self.code_list = None
|
|
self.exception = None
|
|
self.exception_obj = None
|
|
self.exception_long = None
|
|
self.exception_curr = None
|
|
self.lineno = 1
|
|
self.col_offset = 0
|
|
self.logger_handlers = set()
|
|
self.logger = None
|
|
self.set_logger_name(logger_name if logger_name is not None else self.name)
|
|
self.config_entry = Function.hass.data.get(DOMAIN, {}).get(CONFIG_ENTRY, {})
|
|
self.dec_eval_depth = 0
|
|
|
|
async def ast_not_implemented(self, arg, *args):
|
|
"""Raise NotImplementedError exception for unimplemented AST types."""
|
|
name = "ast_" + arg.__class__.__name__.lower()
|
|
raise NotImplementedError(f"{self.name}: not implemented ast " + name)
|
|
|
|
async def aeval(self, arg, undefined_check=True):
|
|
"""Vector to specific function based on ast class type."""
|
|
name = "ast_" + arg.__class__.__name__.lower()
|
|
try:
|
|
if hasattr(arg, "lineno"):
|
|
self.lineno = arg.lineno
|
|
self.col_offset = arg.col_offset
|
|
val = await getattr(self, name, self.ast_not_implemented)(arg)
|
|
if undefined_check and isinstance(val, EvalName):
|
|
raise NameError(f"name '{val.name}' is not defined")
|
|
return val
|
|
except Exception as err:
|
|
if not self.exception_obj:
|
|
func_name = self.curr_func.get_name() + "(), " if self.curr_func else ""
|
|
self.exception_obj = err
|
|
self.exception = f"Exception in {func_name}{self.filename} line {self.lineno} column {self.col_offset}: {err}"
|
|
self.exception_long = self.format_exc(err, self.lineno, self.col_offset)
|
|
raise
|
|
|
|
# Statements return NONE, EvalBreak, EvalContinue, EvalReturn
|
|
async def ast_module(self, arg):
|
|
"""Execute ast_module - a list of statements."""
|
|
val = None
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalReturn):
|
|
raise SyntaxError(f"{val.name()} statement outside function")
|
|
if isinstance(val, EvalStopFlow):
|
|
raise SyntaxError(f"{val.name()} statement outside loop")
|
|
return val
|
|
|
|
async def ast_import(self, arg):
|
|
"""Execute import."""
|
|
for imp in arg.names:
|
|
mod, error_ctx = await self.global_ctx.module_import(imp.name, 0)
|
|
if error_ctx:
|
|
self.exception_obj = error_ctx.exception_obj
|
|
self.exception = error_ctx.exception
|
|
self.exception_long = error_ctx.exception_long
|
|
raise self.exception_obj
|
|
if not mod:
|
|
if (
|
|
not self.config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False)
|
|
and imp.name not in ALLOWED_IMPORTS
|
|
):
|
|
raise ModuleNotFoundError(f"import of {imp.name} not allowed")
|
|
if imp.name not in sys.modules:
|
|
mod = await Function.hass.async_add_executor_job(importlib.import_module, imp.name)
|
|
else:
|
|
mod = sys.modules[imp.name]
|
|
self.sym_table[imp.name if imp.asname is None else imp.asname] = mod
|
|
|
|
async def ast_importfrom(self, arg):
|
|
"""Execute from X import Y."""
|
|
if arg.module is None:
|
|
# handle: "from . import xyz"
|
|
for imp in arg.names:
|
|
mod, error_ctx = await self.global_ctx.module_import(imp.name, arg.level)
|
|
if error_ctx:
|
|
self.exception_obj = error_ctx.exception_obj
|
|
self.exception = error_ctx.exception
|
|
self.exception_long = error_ctx.exception_long
|
|
raise self.exception_obj
|
|
if not mod:
|
|
raise ModuleNotFoundError(f"module '{imp.name}' not found")
|
|
self.sym_table[imp.name if imp.asname is None else imp.asname] = mod
|
|
return
|
|
mod, error_ctx = await self.global_ctx.module_import(arg.module, arg.level)
|
|
if error_ctx:
|
|
self.exception_obj = error_ctx.exception_obj
|
|
self.exception = error_ctx.exception
|
|
self.exception_long = error_ctx.exception_long
|
|
raise self.exception_obj
|
|
if not mod:
|
|
if (
|
|
not self.config_entry.data.get(CONF_ALLOW_ALL_IMPORTS, False)
|
|
and arg.module not in ALLOWED_IMPORTS
|
|
):
|
|
raise ModuleNotFoundError(f"import from {arg.module} not allowed")
|
|
if arg.module not in sys.modules:
|
|
mod = await Function.hass.async_add_executor_job(importlib.import_module, arg.module)
|
|
else:
|
|
mod = sys.modules[arg.module]
|
|
for imp in arg.names:
|
|
if imp.name == "*":
|
|
for name, value in mod.__dict__.items():
|
|
if name[0] != "_":
|
|
self.sym_table[name] = value
|
|
else:
|
|
self.sym_table[imp.name if imp.asname is None else imp.asname] = getattr(mod, imp.name)
|
|
|
|
async def ast_if(self, arg):
|
|
"""Execute if statement."""
|
|
val = None
|
|
if await self.aeval(arg.test):
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
return val
|
|
else:
|
|
for arg1 in arg.orelse:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
return val
|
|
return val
|
|
|
|
async def ast_for(self, arg):
|
|
"""Execute for statement."""
|
|
for loop_var in await self.aeval(arg.iter):
|
|
await self.recurse_assign(arg.target, loop_var)
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
break
|
|
if isinstance(val, EvalBreak):
|
|
break
|
|
if isinstance(val, EvalReturn):
|
|
return val
|
|
else:
|
|
for arg1 in arg.orelse:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalReturn):
|
|
return val
|
|
return None
|
|
|
|
async def ast_asyncfor(self, arg):
|
|
"""Execute async for statement."""
|
|
return await self.ast_for(arg)
|
|
|
|
async def ast_while(self, arg):
|
|
"""Execute while statement."""
|
|
while await self.aeval(arg.test):
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
break
|
|
if isinstance(val, EvalBreak):
|
|
break
|
|
if isinstance(val, EvalReturn):
|
|
return val
|
|
else:
|
|
for arg1 in arg.orelse:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalReturn):
|
|
return val
|
|
return None
|
|
|
|
async def ast_classdef(self, arg):
|
|
"""Evaluate class definition."""
|
|
bases = [(await self.aeval(base)) for base in arg.bases]
|
|
if self.curr_func and arg.name in self.curr_func.global_names:
|
|
sym_table_assign = self.global_sym_table
|
|
else:
|
|
sym_table_assign = self.sym_table
|
|
sym_table_assign[arg.name] = EvalLocalVar(arg.name)
|
|
sym_table = {}
|
|
self.sym_table_stack.append(self.sym_table)
|
|
self.sym_table = sym_table
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalReturn):
|
|
raise SyntaxError(f"{val.name()} statement outside function")
|
|
if isinstance(val, EvalStopFlow):
|
|
raise SyntaxError(f"{val.name()} statement outside loop")
|
|
self.sym_table = self.sym_table_stack.pop()
|
|
|
|
sym_table["__init__evalfunc_wrap__"] = None
|
|
if "__init__" in sym_table:
|
|
sym_table["__init__evalfunc_wrap__"] = sym_table["__init__"]
|
|
del sym_table["__init__"]
|
|
sym_table_assign[arg.name].set(type(arg.name, tuple(bases), sym_table))
|
|
|
|
async def ast_functiondef(self, arg):
|
|
"""Evaluate function definition."""
|
|
other_dec = []
|
|
dec_name = None
|
|
pyscript_compile = None
|
|
for dec in arg.decorator_list:
|
|
if isinstance(dec, ast.Name) and dec.id in COMP_DECORATORS:
|
|
dec_name = dec.id
|
|
elif (
|
|
isinstance(dec, ast.Call)
|
|
and isinstance(dec.func, ast.Name)
|
|
and dec.func.id in COMP_DECORATORS
|
|
):
|
|
dec_name = dec.func.id
|
|
else:
|
|
other_dec.append(dec)
|
|
continue
|
|
if pyscript_compile:
|
|
raise SyntaxError(
|
|
f"can only specify single decorator of {', '.join(sorted(COMP_DECORATORS))}"
|
|
)
|
|
pyscript_compile = dec
|
|
|
|
if pyscript_compile:
|
|
if isinstance(pyscript_compile, ast.Call):
|
|
if len(pyscript_compile.args) > 0:
|
|
raise TypeError(f"@{dec_name}() takes 0 positional arguments")
|
|
if len(pyscript_compile.keywords) > 0:
|
|
raise TypeError(f"@{dec_name}() takes no keyword arguments")
|
|
arg.decorator_list = other_dec
|
|
local_var = None
|
|
if arg.name in self.sym_table and isinstance(self.sym_table[arg.name], EvalLocalVar):
|
|
local_var = self.sym_table[arg.name]
|
|
code = compile(ast.Module(body=[arg], type_ignores=[]), filename=self.filename, mode="exec")
|
|
exec(code, self.global_sym_table, self.sym_table) # pylint: disable=exec-used
|
|
|
|
func = self.sym_table[arg.name]
|
|
if dec_name == "pyscript_executor":
|
|
if not asyncio.iscoroutinefunction(func):
|
|
|
|
def executor_wrap_factory(func):
|
|
async def executor_wrap(*args, **kwargs):
|
|
return await Function.hass.async_add_executor_job(
|
|
functools.partial(func, **kwargs), *args
|
|
)
|
|
|
|
return executor_wrap
|
|
|
|
self.sym_table[arg.name] = executor_wrap_factory(func)
|
|
else:
|
|
raise TypeError("@pyscript_executor() needs a regular, not async, function")
|
|
if local_var:
|
|
self.sym_table[arg.name] = local_var
|
|
self.sym_table[arg.name].set(func)
|
|
return
|
|
|
|
func = EvalFunc(arg, self.code_list, self.code_str, self.global_ctx)
|
|
await func.eval_defaults(self)
|
|
await func.resolve_nonlocals(self)
|
|
name = func.get_name()
|
|
dec_trig, dec_other = await func.eval_decorators(self)
|
|
self.dec_eval_depth += 1
|
|
for dec_func in dec_other:
|
|
func = await self.call_func(dec_func, None, func)
|
|
if isinstance(func, EvalFuncVar):
|
|
# set the function name back to its original instead of the decorator function we just called
|
|
func.set_name(name)
|
|
func = func.remove_func()
|
|
dec_trig += func.decorators
|
|
elif isinstance(func, EvalFunc):
|
|
func.set_name(name)
|
|
self.dec_eval_depth -= 1
|
|
if isinstance(func, EvalFunc):
|
|
func.decorators = dec_trig
|
|
if self.dec_eval_depth == 0:
|
|
func.trigger_stop()
|
|
await func.trigger_init(self.global_ctx, name)
|
|
func_var = EvalFuncVar(func)
|
|
func_var.set_ast_ctx(self)
|
|
else:
|
|
func_var = EvalFuncVar(func)
|
|
func_var.set_ast_ctx(self)
|
|
else:
|
|
func_var = func
|
|
|
|
if self.curr_func and name in self.curr_func.global_names:
|
|
sym_table = self.global_sym_table
|
|
else:
|
|
sym_table = self.sym_table
|
|
if name in sym_table and isinstance(sym_table[name], EvalLocalVar):
|
|
sym_table[name].set(func_var)
|
|
else:
|
|
sym_table[name] = func_var
|
|
|
|
async def ast_lambda(self, arg):
|
|
"""Evaluate lambda definition by compiling a regular function."""
|
|
name = "__lambda_defn_temp__"
|
|
await self.aeval(
|
|
ast.FunctionDef(
|
|
args=arg.args,
|
|
body=[ast.Return(value=arg.body, lineno=arg.body.lineno, col_offset=arg.body.col_offset)],
|
|
name=name,
|
|
decorator_list=[ast.Name(id="pyscript_compile", ctx=ast.Load())],
|
|
lineno=arg.col_offset,
|
|
col_offset=arg.col_offset,
|
|
)
|
|
)
|
|
func = self.sym_table[name]
|
|
del self.sym_table[name]
|
|
return func
|
|
|
|
async def ast_asyncfunctiondef(self, arg):
|
|
"""Evaluate async function definition."""
|
|
return await self.ast_functiondef(arg)
|
|
|
|
async def ast_try(self, arg):
|
|
"""Execute try...except statement."""
|
|
try:
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
return val
|
|
if self.exception_obj is not None:
|
|
raise self.exception_obj
|
|
except Exception as err:
|
|
curr_exc = self.exception_curr
|
|
self.exception_curr = err
|
|
for handler in arg.handlers:
|
|
match = False
|
|
if handler.type:
|
|
exc_list = await self.aeval(handler.type)
|
|
if not isinstance(exc_list, tuple):
|
|
exc_list = [exc_list]
|
|
for exc in exc_list:
|
|
if isinstance(err, exc):
|
|
match = True
|
|
break
|
|
else:
|
|
match = True
|
|
if match:
|
|
save_obj = self.exception_obj
|
|
save_exc_long = self.exception_long
|
|
save_exc = self.exception
|
|
self.exception_obj = None
|
|
self.exception = None
|
|
self.exception_long = None
|
|
if handler.name is not None:
|
|
if handler.name in self.sym_table and isinstance(
|
|
self.sym_table[handler.name], EvalLocalVar
|
|
):
|
|
self.sym_table[handler.name].set(err)
|
|
else:
|
|
self.sym_table[handler.name] = err
|
|
for arg1 in handler.body:
|
|
try:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
if handler.name is not None:
|
|
del self.sym_table[handler.name]
|
|
self.exception_curr = curr_exc
|
|
return val
|
|
except Exception:
|
|
if self.exception_obj is not None:
|
|
if handler.name is not None:
|
|
del self.sym_table[handler.name]
|
|
self.exception_curr = curr_exc
|
|
if self.exception_obj == save_obj:
|
|
self.exception_long = save_exc_long
|
|
self.exception = save_exc
|
|
else:
|
|
self.exception_long = (
|
|
save_exc_long
|
|
+ "\n\nDuring handling of the above exception, another exception occurred:\n\n"
|
|
+ self.exception_long
|
|
)
|
|
raise self.exception_obj # pylint: disable=raise-missing-from
|
|
if handler.name is not None:
|
|
del self.sym_table[handler.name]
|
|
break
|
|
else:
|
|
self.exception_curr = curr_exc
|
|
raise err
|
|
else:
|
|
for arg1 in arg.orelse:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
return val
|
|
finally:
|
|
for arg1 in arg.finalbody:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
return val # pylint: disable=lost-exception,return-in-finally
|
|
return None
|
|
|
|
async def ast_raise(self, arg):
|
|
"""Execute raise statement."""
|
|
if not arg.exc:
|
|
if not self.exception_curr:
|
|
raise RuntimeError("No active exception to reraise")
|
|
exc = self.exception_curr
|
|
else:
|
|
exc = await self.aeval(arg.exc)
|
|
if self.exception_curr:
|
|
exc.__cause__ = self.exception_curr
|
|
if arg.cause:
|
|
cause = await self.aeval(arg.cause)
|
|
raise exc from cause
|
|
raise exc
|
|
|
|
async def ast_with(self, arg, async_attr=""):
|
|
"""Execute with statement."""
|
|
hit_except = False
|
|
ctx_list = []
|
|
val = None
|
|
enter_attr = f"__{async_attr}enter__"
|
|
exit_attr = f"__{async_attr}exit__"
|
|
try:
|
|
for item in arg.items:
|
|
manager = await self.aeval(item.context_expr)
|
|
ctx_list.append(
|
|
{
|
|
"manager": manager,
|
|
"enter": getattr(type(manager), enter_attr),
|
|
"exit": getattr(type(manager), exit_attr),
|
|
"target": item.optional_vars,
|
|
}
|
|
)
|
|
for ctx in ctx_list:
|
|
value = await self.call_func(ctx["enter"], enter_attr, ctx["manager"])
|
|
if ctx["target"]:
|
|
await self.recurse_assign(ctx["target"], value)
|
|
for arg1 in arg.body:
|
|
val = await self.aeval(arg1)
|
|
if isinstance(val, EvalStopFlow):
|
|
break
|
|
except Exception:
|
|
hit_except = True
|
|
exit_ok = True
|
|
for ctx in reversed(ctx_list):
|
|
ret = await self.call_func(ctx["exit"], exit_attr, ctx["manager"], *sys.exc_info())
|
|
exit_ok = exit_ok and ret
|
|
if not exit_ok:
|
|
raise
|
|
finally:
|
|
if not hit_except:
|
|
for ctx in reversed(ctx_list):
|
|
await self.call_func(ctx["exit"], exit_attr, ctx["manager"], None, None, None)
|
|
return val
|
|
|
|
async def ast_asyncwith(self, arg):
|
|
"""Execute async with statement."""
|
|
return await self.ast_with(arg, async_attr="a")
|
|
|
|
async def ast_pass(self, arg):
|
|
"""Execute pass statement."""
|
|
|
|
async def ast_expression(self, arg):
|
|
"""Execute expression statement."""
|
|
return await self.aeval(arg.body)
|
|
|
|
async def ast_expr(self, arg):
|
|
"""Execute expression statement."""
|
|
return await self.aeval(arg.value)
|
|
|
|
async def ast_break(self, arg):
|
|
"""Execute break statement - return special class."""
|
|
return EvalBreak()
|
|
|
|
async def ast_continue(self, arg):
|
|
"""Execute continue statement - return special class."""
|
|
return EvalContinue()
|
|
|
|
async def ast_return(self, arg):
|
|
"""Execute return statement - return special class."""
|
|
return EvalReturn(await self.aeval(arg.value) if arg.value else None)
|
|
|
|
async def ast_global(self, arg):
|
|
"""Execute global statement."""
|
|
if not self.curr_func:
|
|
raise SyntaxError("global statement outside function")
|
|
for var_name in arg.names:
|
|
self.curr_func.global_names.add(var_name)
|
|
|
|
async def ast_nonlocal(self, arg):
|
|
"""Execute nonlocal statement."""
|
|
if not self.curr_func:
|
|
raise SyntaxError("nonlocal statement outside function")
|
|
for var_name in arg.names:
|
|
self.curr_func.nonlocal_names.add(var_name)
|
|
|
|
async def recurse_assign(self, lhs, val):
|
|
"""Recursive assignment."""
|
|
if isinstance(lhs, ast.Tuple):
|
|
try:
|
|
vals = [*(iter(val))]
|
|
except Exception:
|
|
raise TypeError("cannot unpack non-iterable object") # pylint: disable=raise-missing-from
|
|
got_star = 0
|
|
for lhs_elt in lhs.elts:
|
|
if isinstance(lhs_elt, ast.Starred):
|
|
got_star = 1
|
|
break
|
|
if len(lhs.elts) > len(vals) + got_star:
|
|
if got_star:
|
|
err_msg = f"at least {len(lhs.elts) - got_star}"
|
|
else:
|
|
err_msg = f"{len(lhs.elts)}"
|
|
raise ValueError(f"too few values to unpack (expected {err_msg})")
|
|
if len(lhs.elts) < len(vals) and got_star == 0:
|
|
raise ValueError(f"too many values to unpack (expected {len(lhs.elts)})")
|
|
val_idx = 0
|
|
for lhs_elt in lhs.elts:
|
|
if isinstance(lhs_elt, ast.Starred):
|
|
star_len = len(vals) - len(lhs.elts) + 1
|
|
star_name = lhs_elt.value.id
|
|
await self.recurse_assign(
|
|
ast.Name(id=star_name, ctx=ast.Store()),
|
|
vals[val_idx : val_idx + star_len],
|
|
)
|
|
val_idx += star_len
|
|
else:
|
|
await self.recurse_assign(lhs_elt, vals[val_idx])
|
|
val_idx += 1
|
|
elif isinstance(lhs, ast.Subscript):
|
|
var = await self.aeval(lhs.value)
|
|
if isinstance(lhs.slice, ast.Index):
|
|
ind = await self.aeval(lhs.slice.value)
|
|
var[ind] = val
|
|
elif isinstance(lhs.slice, ast.Slice):
|
|
lower = await self.aeval(lhs.slice.lower) if lhs.slice.lower else None
|
|
upper = await self.aeval(lhs.slice.upper) if lhs.slice.upper else None
|
|
step = await self.aeval(lhs.slice.step) if lhs.slice.step else None
|
|
var[slice(lower, upper, step)] = val
|
|
else:
|
|
var[await self.aeval(lhs.slice)] = val
|
|
else:
|
|
var_name = await self.aeval(lhs)
|
|
if isinstance(var_name, EvalAttrSet):
|
|
var_name.setattr(val)
|
|
return
|
|
if not isinstance(var_name, str):
|
|
raise NotImplementedError(f"unknown lhs type {lhs} (got {var_name}) in assign")
|
|
dot_count = var_name.count(".")
|
|
if dot_count == 1:
|
|
State.set(var_name, val)
|
|
return
|
|
if dot_count == 2:
|
|
State.setattr(var_name, val)
|
|
return
|
|
if dot_count > 0:
|
|
raise NameError(
|
|
f"invalid name '{var_name}' (should be 'domain.entity' or 'domain.entity.attr')"
|
|
)
|
|
if self.curr_func and var_name in self.curr_func.global_names:
|
|
self.global_sym_table[var_name] = val
|
|
return
|
|
if var_name in self.sym_table and isinstance(self.sym_table[var_name], EvalLocalVar):
|
|
self.sym_table[var_name].set(val)
|
|
else:
|
|
self.sym_table[var_name] = val
|
|
|
|
async def ast_assign(self, arg):
|
|
"""Execute assignment statement."""
|
|
rhs = await self.aeval(arg.value)
|
|
for target in arg.targets:
|
|
await self.recurse_assign(target, rhs)
|
|
|
|
async def ast_augassign(self, arg):
|
|
"""Execute augmented assignment statement (lhs <BinOp>= value)."""
|
|
arg.target.ctx = ast.Load()
|
|
new_val = await self.aeval(ast.BinOp(left=arg.target, op=arg.op, right=arg.value))
|
|
arg.target.ctx = ast.Store()
|
|
await self.recurse_assign(arg.target, new_val)
|
|
|
|
async def ast_annassign(self, arg):
|
|
"""Execute type hint assignment statement (just ignore the type hint)."""
|
|
if arg.value is not None:
|
|
rhs = await self.aeval(arg.value)
|
|
await self.recurse_assign(arg.target, rhs)
|
|
|
|
async def ast_namedexpr(self, arg):
|
|
"""Execute named expression."""
|
|
val = await self.aeval(arg.value)
|
|
await self.recurse_assign(arg.target, val)
|
|
return val
|
|
|
|
async def ast_delete(self, arg):
|
|
"""Execute del statement."""
|
|
for arg1 in arg.targets:
|
|
if isinstance(arg1, ast.Subscript):
|
|
var = await self.aeval(arg1.value)
|
|
if isinstance(arg1.slice, ast.Index):
|
|
ind = await self.aeval(arg1.slice.value)
|
|
for elt in ind if isinstance(ind, list) else [ind]:
|
|
del var[elt]
|
|
elif isinstance(arg1.slice, ast.Slice):
|
|
lower, upper, step = None, None, None
|
|
if arg1.slice.lower:
|
|
lower = await self.aeval(arg1.slice.lower)
|
|
if arg1.slice.upper:
|
|
upper = await self.aeval(arg1.slice.upper)
|
|
if arg1.slice.step:
|
|
step = await self.aeval(arg1.slice.step)
|
|
del var[slice(lower, upper, step)]
|
|
else:
|
|
del var[await self.aeval(arg1.slice)]
|
|
elif isinstance(arg1, ast.Name):
|
|
if self.curr_func and arg1.id in self.curr_func.global_names:
|
|
if arg1.id in self.global_sym_table:
|
|
del self.global_sym_table[arg1.id]
|
|
elif arg1.id in self.sym_table:
|
|
if isinstance(self.sym_table[arg1.id], EvalLocalVar):
|
|
if self.sym_table[arg1.id].is_defined():
|
|
self.sym_table[arg1.id].set_undefined()
|
|
else:
|
|
raise NameError(f"name '{arg1.id}' is not defined")
|
|
else:
|
|
del self.sym_table[arg1.id]
|
|
else:
|
|
raise NameError(f"name '{arg1.id}' is not defined")
|
|
elif isinstance(arg1, ast.Attribute):
|
|
var_name = await self.ast_attribute_collapse(arg1, check_undef=False)
|
|
if not isinstance(var_name, str):
|
|
raise NameError("state name should be 'domain.entity' or 'domain.entity.attr'")
|
|
State.delete(var_name)
|
|
else:
|
|
raise NotImplementedError(f"unknown target type {arg1} in del")
|
|
|
|
async def ast_assert(self, arg):
|
|
"""Execute assert statement."""
|
|
if not await self.aeval(arg.test):
|
|
if arg.msg:
|
|
raise AssertionError(await self.aeval(arg.msg))
|
|
raise AssertionError
|
|
|
|
async def ast_attribute_collapse(self, arg, check_undef=True):
|
|
"""Combine dotted attributes to allow variable names to have dots."""
|
|
# collapse dotted names, eg:
|
|
# Attribute(value=Attribute(value=Name(id='i', ctx=Load()), attr='j', ctx=Load()), attr='k', ctx=Store())
|
|
name = arg.attr
|
|
val = arg.value
|
|
while isinstance(val, ast.Attribute):
|
|
name = val.attr + "." + name
|
|
val = val.value
|
|
if isinstance(val, ast.Name):
|
|
name = val.id + "." + name
|
|
# ensure the first portion of name is undefined
|
|
if check_undef and not isinstance(
|
|
await self.ast_name(ast.Name(id=val.id, ctx=ast.Load())), EvalName
|
|
):
|
|
return None
|
|
return name
|
|
return None
|
|
|
|
async def ast_attribute(self, arg):
|
|
"""Apply attributes."""
|
|
full_name = await self.ast_attribute_collapse(arg)
|
|
if full_name is not None:
|
|
if isinstance(arg.ctx, ast.Store):
|
|
return full_name
|
|
val = await self.ast_name(ast.Name(id=full_name, ctx=arg.ctx))
|
|
if not isinstance(val, EvalName):
|
|
return val
|
|
val = await self.aeval(arg.value)
|
|
if isinstance(arg.ctx, ast.Store):
|
|
return EvalAttrSet(val, arg.attr)
|
|
return getattr(val, arg.attr)
|
|
|
|
async def ast_name(self, arg):
|
|
"""Look up value of identifier on load, or returns name on set."""
|
|
if isinstance(arg.ctx, ast.Load):
|
|
#
|
|
# check other scopes if required by global declarations
|
|
#
|
|
if self.curr_func and arg.id in self.curr_func.global_names:
|
|
if arg.id in self.global_sym_table:
|
|
return self.global_sym_table[arg.id]
|
|
raise NameError(f"global name '{arg.id}' is not defined")
|
|
#
|
|
# now check in our current symbol table, and then some other places
|
|
#
|
|
if arg.id in self.sym_table:
|
|
if isinstance(self.sym_table[arg.id], EvalLocalVar):
|
|
return self.sym_table[arg.id].get()
|
|
return self.sym_table[arg.id]
|
|
if arg.id in self.local_sym_table:
|
|
return self.local_sym_table[arg.id]
|
|
if arg.id in self.global_sym_table:
|
|
if self.curr_func and arg.id in self.curr_func.local_names:
|
|
raise UnboundLocalError(f"local variable '{arg.id}' referenced before assignment")
|
|
return self.global_sym_table[arg.id]
|
|
if arg.id in BUILTIN_AST_FUNCS_FACTORY:
|
|
return BUILTIN_AST_FUNCS_FACTORY[arg.id](self)
|
|
if hasattr(builtins, arg.id) and arg.id not in BUILTIN_EXCLUDE and arg.id[0] != "_":
|
|
return getattr(builtins, arg.id)
|
|
if Function.get(arg.id):
|
|
return Function.get(arg.id)
|
|
num_dots = arg.id.count(".")
|
|
#
|
|
# any single-dot name could be a state variable
|
|
# a two-dot name for state.attr needs to exist
|
|
#
|
|
if num_dots == 1 or (num_dots == 2 and State.exist(arg.id)):
|
|
return State.get(arg.id)
|
|
#
|
|
# Couldn't find it, so return just the name wrapped in EvalName to
|
|
# distinguish from a string variable value. This is to support
|
|
# names with ".", which are joined by ast_attribute
|
|
#
|
|
return EvalName(arg.id)
|
|
return arg.id
|
|
|
|
async def ast_binop(self, arg):
|
|
"""Evaluate binary operators by calling function based on class."""
|
|
name = "ast_binop_" + arg.op.__class__.__name__.lower()
|
|
return await getattr(self, name, self.ast_not_implemented)(arg.left, arg.right)
|
|
|
|
async def ast_binop_add(self, arg0, arg1):
|
|
"""Evaluate binary operator: +."""
|
|
return (await self.aeval(arg0)) + (await self.aeval(arg1))
|
|
|
|
async def ast_binop_sub(self, arg0, arg1):
|
|
"""Evaluate binary operator: -."""
|
|
return (await self.aeval(arg0)) - (await self.aeval(arg1))
|
|
|
|
async def ast_binop_mult(self, arg0, arg1):
|
|
"""Evaluate binary operator: *."""
|
|
return (await self.aeval(arg0)) * (await self.aeval(arg1))
|
|
|
|
async def ast_binop_div(self, arg0, arg1):
|
|
"""Evaluate binary operator: /."""
|
|
return (await self.aeval(arg0)) / (await self.aeval(arg1))
|
|
|
|
async def ast_binop_mod(self, arg0, arg1):
|
|
"""Evaluate binary operator: %."""
|
|
return (await self.aeval(arg0)) % (await self.aeval(arg1))
|
|
|
|
async def ast_binop_pow(self, arg0, arg1):
|
|
"""Evaluate binary operator: **."""
|
|
return (await self.aeval(arg0)) ** (await self.aeval(arg1))
|
|
|
|
async def ast_binop_lshift(self, arg0, arg1):
|
|
"""Evaluate binary operator: <<."""
|
|
return (await self.aeval(arg0)) << (await self.aeval(arg1))
|
|
|
|
async def ast_binop_rshift(self, arg0, arg1):
|
|
"""Evaluate binary operator: >>."""
|
|
return (await self.aeval(arg0)) >> (await self.aeval(arg1))
|
|
|
|
async def ast_binop_bitor(self, arg0, arg1):
|
|
"""Evaluate binary operator: |."""
|
|
return (await self.aeval(arg0)) | (await self.aeval(arg1))
|
|
|
|
async def ast_binop_bitxor(self, arg0, arg1):
|
|
"""Evaluate binary operator: ^."""
|
|
return (await self.aeval(arg0)) ^ (await self.aeval(arg1))
|
|
|
|
async def ast_binop_bitand(self, arg0, arg1):
|
|
"""Evaluate binary operator: &."""
|
|
return (await self.aeval(arg0)) & (await self.aeval(arg1))
|
|
|
|
async def ast_binop_floordiv(self, arg0, arg1):
|
|
"""Evaluate binary operator: //."""
|
|
return (await self.aeval(arg0)) // (await self.aeval(arg1))
|
|
|
|
async def ast_unaryop(self, arg):
|
|
"""Evaluate unary operators by calling function based on class."""
|
|
name = "ast_unaryop_" + arg.op.__class__.__name__.lower()
|
|
return await getattr(self, name, self.ast_not_implemented)(arg.operand)
|
|
|
|
async def ast_unaryop_not(self, arg0):
|
|
"""Evaluate unary operator: not."""
|
|
return not (await self.aeval(arg0))
|
|
|
|
async def ast_unaryop_invert(self, arg0):
|
|
"""Evaluate unary operator: ~."""
|
|
return ~(await self.aeval(arg0))
|
|
|
|
async def ast_unaryop_uadd(self, arg0):
|
|
"""Evaluate unary operator: +."""
|
|
return await self.aeval(arg0)
|
|
|
|
async def ast_unaryop_usub(self, arg0):
|
|
"""Evaluate unary operator: -."""
|
|
return -(await self.aeval(arg0))
|
|
|
|
async def ast_compare(self, arg):
|
|
"""Evaluate comparison operators by calling function based on class."""
|
|
left = arg.left
|
|
for cmp_op, right in zip(arg.ops, arg.comparators):
|
|
name = "ast_cmpop_" + cmp_op.__class__.__name__.lower()
|
|
val = await getattr(self, name, self.ast_not_implemented)(left, right)
|
|
if not val:
|
|
return False
|
|
left = right
|
|
return True
|
|
|
|
async def ast_cmpop_eq(self, arg0, arg1):
|
|
"""Evaluate comparison operator: ==."""
|
|
return (await self.aeval(arg0)) == (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_noteq(self, arg0, arg1):
|
|
"""Evaluate comparison operator: !=."""
|
|
return (await self.aeval(arg0)) != (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_lt(self, arg0, arg1):
|
|
"""Evaluate comparison operator: <."""
|
|
return (await self.aeval(arg0)) < (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_lte(self, arg0, arg1):
|
|
"""Evaluate comparison operator: <=."""
|
|
return (await self.aeval(arg0)) <= (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_gt(self, arg0, arg1):
|
|
"""Evaluate comparison operator: >."""
|
|
return (await self.aeval(arg0)) > (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_gte(self, arg0, arg1):
|
|
"""Evaluate comparison operator: >=."""
|
|
return (await self.aeval(arg0)) >= (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_is(self, arg0, arg1):
|
|
"""Evaluate comparison operator: is."""
|
|
return (await self.aeval(arg0)) is (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_isnot(self, arg0, arg1):
|
|
"""Evaluate comparison operator: is not."""
|
|
return (await self.aeval(arg0)) is not (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_in(self, arg0, arg1):
|
|
"""Evaluate comparison operator: in."""
|
|
return (await self.aeval(arg0)) in (await self.aeval(arg1))
|
|
|
|
async def ast_cmpop_notin(self, arg0, arg1):
|
|
"""Evaluate comparison operator: not in."""
|
|
return (await self.aeval(arg0)) not in (await self.aeval(arg1))
|
|
|
|
async def ast_boolop(self, arg):
|
|
"""Evaluate boolean operators and and or."""
|
|
if isinstance(arg.op, ast.And):
|
|
val = True
|
|
for arg1 in arg.values:
|
|
val = await self.aeval(arg1)
|
|
if not val:
|
|
return val
|
|
return val
|
|
val = False
|
|
for arg1 in arg.values:
|
|
val = await self.aeval(arg1)
|
|
if val:
|
|
return val
|
|
return val
|
|
|
|
async def eval_elt_list(self, elts):
|
|
"""Evaluate and star list elements."""
|
|
val = []
|
|
for arg in elts:
|
|
if isinstance(arg, ast.Starred):
|
|
val += await self.aeval(arg.value)
|
|
else:
|
|
val.append(await self.aeval(arg))
|
|
return val
|
|
|
|
async def ast_list(self, arg):
|
|
"""Evaluate list."""
|
|
if isinstance(arg.ctx, ast.Load):
|
|
return await self.eval_elt_list(arg.elts)
|
|
|
|
async def loopvar_scope_save(self, generators):
|
|
"""Return current scope variables that match looping target vars."""
|
|
#
|
|
# looping variables are in their own implicit nested scope, so save/restore
|
|
# variables in the current scope with the same names
|
|
#
|
|
lvars = set()
|
|
for gen in generators:
|
|
await self.get_names(
|
|
ast.Assign(targets=[gen.target], value=ast.Constant(value=None)), local_names=lvars
|
|
)
|
|
return lvars, {var: self.sym_table[var] for var in lvars if var in self.sym_table}
|
|
|
|
async def loopvar_scope_restore(self, var_names, save_vars):
|
|
"""Restore current scope variables that match looping target vars."""
|
|
for var_name in var_names:
|
|
if var_name in save_vars:
|
|
self.sym_table[var_name] = save_vars[var_name]
|
|
else:
|
|
try:
|
|
del self.sym_table[var_name]
|
|
except KeyError:
|
|
# If the iterator was empty, the loop variables were never
|
|
# assigned to, so deleting them will fail.
|
|
pass
|
|
|
|
async def listcomp_loop(self, generators, elt):
|
|
"""Recursive list comprehension."""
|
|
out = []
|
|
gen = generators[0]
|
|
for loop_var in await self.aeval(gen.iter):
|
|
await self.recurse_assign(gen.target, loop_var)
|
|
for cond in gen.ifs:
|
|
if not await self.aeval(cond):
|
|
break
|
|
else:
|
|
if len(generators) == 1:
|
|
out.append(await self.aeval(elt))
|
|
else:
|
|
out += await self.listcomp_loop(generators[1:], elt)
|
|
return out
|
|
|
|
async def ast_listcomp(self, arg):
|
|
"""Evaluate list comprehension."""
|
|
target_vars, save_values = await self.loopvar_scope_save(arg.generators)
|
|
result = await self.listcomp_loop(arg.generators, arg.elt)
|
|
await self.loopvar_scope_restore(target_vars, save_values)
|
|
return result
|
|
|
|
async def ast_tuple(self, arg):
|
|
"""Evaluate Tuple."""
|
|
return tuple(await self.eval_elt_list(arg.elts))
|
|
|
|
async def ast_dict(self, arg):
|
|
"""Evaluate dict."""
|
|
val = {}
|
|
for key_ast, val_ast in zip(arg.keys, arg.values):
|
|
this_val = await self.aeval(val_ast)
|
|
if key_ast is None:
|
|
val.update(this_val)
|
|
else:
|
|
val[await self.aeval(key_ast)] = this_val
|
|
return val
|
|
|
|
async def dictcomp_loop(self, generators, key, value):
|
|
"""Recursive dict comprehension."""
|
|
out = {}
|
|
gen = generators[0]
|
|
for loop_var in await self.aeval(gen.iter):
|
|
await self.recurse_assign(gen.target, loop_var)
|
|
for cond in gen.ifs:
|
|
if not await self.aeval(cond):
|
|
break
|
|
else:
|
|
if len(generators) == 1:
|
|
#
|
|
# key is evaluated before value starting in 3.8
|
|
#
|
|
key_val = await self.aeval(key)
|
|
out[key_val] = await self.aeval(value)
|
|
else:
|
|
out.update(await self.dictcomp_loop(generators[1:], key, value))
|
|
return out
|
|
|
|
async def ast_dictcomp(self, arg):
|
|
"""Evaluate dict comprehension."""
|
|
target_vars, save_values = await self.loopvar_scope_save(arg.generators)
|
|
result = await self.dictcomp_loop(arg.generators, arg.key, arg.value)
|
|
await self.loopvar_scope_restore(target_vars, save_values)
|
|
return result
|
|
|
|
async def ast_set(self, arg):
|
|
"""Evaluate set."""
|
|
ret = set()
|
|
for elt in await self.eval_elt_list(arg.elts):
|
|
ret.add(elt)
|
|
return ret
|
|
|
|
async def setcomp_loop(self, generators, elt):
|
|
"""Recursive list comprehension."""
|
|
out = set()
|
|
gen = generators[0]
|
|
for loop_var in await self.aeval(gen.iter):
|
|
await self.recurse_assign(gen.target, loop_var)
|
|
for cond in gen.ifs:
|
|
if not await self.aeval(cond):
|
|
break
|
|
else:
|
|
if len(generators) == 1:
|
|
out.add(await self.aeval(elt))
|
|
else:
|
|
out.update(await self.setcomp_loop(generators[1:], elt))
|
|
return out
|
|
|
|
async def ast_setcomp(self, arg):
|
|
"""Evaluate set comprehension."""
|
|
target_vars, save_values = await self.loopvar_scope_save(arg.generators)
|
|
result = await self.setcomp_loop(arg.generators, arg.elt)
|
|
await self.loopvar_scope_restore(target_vars, save_values)
|
|
return result
|
|
|
|
async def ast_subscript(self, arg):
|
|
"""Evaluate subscript."""
|
|
var = await self.aeval(arg.value)
|
|
if isinstance(arg.ctx, ast.Load):
|
|
if isinstance(arg.slice, ast.Index):
|
|
return var[await self.aeval(arg.slice)]
|
|
if isinstance(arg.slice, ast.Slice):
|
|
lower = (await self.aeval(arg.slice.lower)) if arg.slice.lower else None
|
|
upper = (await self.aeval(arg.slice.upper)) if arg.slice.upper else None
|
|
step = (await self.aeval(arg.slice.step)) if arg.slice.step else None
|
|
return var[slice(lower, upper, step)]
|
|
return var[await self.aeval(arg.slice)]
|
|
return None
|
|
|
|
async def ast_index(self, arg):
|
|
"""Evaluate index."""
|
|
return await self.aeval(arg.value)
|
|
|
|
async def ast_slice(self, arg):
|
|
"""Evaluate slice."""
|
|
return await self.aeval(arg.value)
|
|
|
|
async def ast_call(self, arg):
|
|
"""Evaluate function call."""
|
|
func = await self.aeval(arg.func)
|
|
kwargs = {}
|
|
for kw_arg in arg.keywords:
|
|
if kw_arg.arg is None:
|
|
kwargs.update(await self.aeval(kw_arg.value))
|
|
else:
|
|
kwargs[kw_arg.arg] = await self.aeval(kw_arg.value)
|
|
args = await self.eval_elt_list(arg.args)
|
|
#
|
|
# try to deduce function name, although this only works in simple cases
|
|
#
|
|
func_name = None
|
|
if isinstance(arg.func, ast.Name):
|
|
func_name = arg.func.id
|
|
elif isinstance(arg.func, ast.Attribute):
|
|
func_name = arg.func.attr
|
|
if isinstance(func, EvalLocalVar):
|
|
func_name = func.get_name()
|
|
func = func.get()
|
|
return await self.call_func(func, func_name, *args, **kwargs)
|
|
|
|
async def call_func(self, func, func_name, *args, **kwargs):
|
|
"""Call a function with the given arguments."""
|
|
if func_name is None:
|
|
try:
|
|
if isinstance(func, (EvalFunc, EvalFuncVar)):
|
|
func_name = func.get_name()
|
|
else:
|
|
func_name = func.__name__
|
|
except Exception:
|
|
func_name = "<function>"
|
|
arg_str = ", ".join(['"' + elt + '"' if isinstance(elt, str) else str(elt) for elt in args])
|
|
_LOGGER.debug("%s: calling %s(%s, %s)", self.name, func_name, arg_str, kwargs)
|
|
if isinstance(func, (EvalFunc, EvalFuncVar)):
|
|
return await func.call(self, *args, **kwargs)
|
|
if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"):
|
|
inst = func()
|
|
#
|
|
# we use weak references when we bind the method calls to the instance inst;
|
|
# otherwise these self references cause the object to not be deleted until
|
|
# it is later garbage collected
|
|
#
|
|
inst_weak = weakref.ref(inst)
|
|
for name in dir(inst):
|
|
value = getattr(inst, name)
|
|
if type(value) is not EvalFuncVar:
|
|
continue
|
|
setattr(inst, name, EvalFuncVarClassInst(value.get_func(), value.get_ast_ctx(), inst_weak))
|
|
if getattr(func, "__init__evalfunc_wrap__") is not None:
|
|
#
|
|
# since our __init__ function is async, call the renamed one
|
|
#
|
|
await inst.__init__evalfunc_wrap__.call(self, *args, **kwargs)
|
|
return inst
|
|
if asyncio.iscoroutinefunction(func):
|
|
return await func(*args, **kwargs)
|
|
if callable(func):
|
|
if func == time.sleep: # pylint: disable=comparison-with-callable
|
|
_LOGGER.warning(
|
|
"%s line %s calls blocking time.sleep(); replaced with asyncio.sleep()",
|
|
self.filename,
|
|
self.lineno,
|
|
)
|
|
return await asyncio.sleep(*args, **kwargs)
|
|
return func(*args, **kwargs)
|
|
raise TypeError(f"'{func_name}' is not callable (got {func})")
|
|
|
|
async def ast_ifexp(self, arg):
|
|
"""Evaluate if expression."""
|
|
return await self.aeval(arg.body) if (await self.aeval(arg.test)) else await self.aeval(arg.orelse)
|
|
|
|
async def ast_num(self, arg):
|
|
"""Evaluate number."""
|
|
return arg.n
|
|
|
|
async def ast_str(self, arg):
|
|
"""Evaluate string."""
|
|
return arg.s
|
|
|
|
async def ast_nameconstant(self, arg):
|
|
"""Evaluate name constant."""
|
|
return arg.value
|
|
|
|
async def ast_constant(self, arg):
|
|
"""Evaluate constant."""
|
|
return arg.value
|
|
|
|
async def ast_joinedstr(self, arg):
|
|
"""Evaluate joined string."""
|
|
val = ""
|
|
for arg1 in arg.values:
|
|
this_val = await self.aeval(arg1)
|
|
val = val + str(this_val)
|
|
return val
|
|
|
|
async def ast_formattedvalue(self, arg):
|
|
"""Evaluate formatted value."""
|
|
val = await self.aeval(arg.value)
|
|
if arg.format_spec is not None:
|
|
fmt = await self.aeval(arg.format_spec)
|
|
return f"{val:{fmt}}"
|
|
return f"{val}"
|
|
|
|
async def ast_await(self, arg):
|
|
"""Evaluate await expr."""
|
|
coro = await self.aeval(arg.value)
|
|
if coro:
|
|
return await coro
|
|
return None
|
|
|
|
async def get_target_names(self, lhs):
|
|
"""Recursively find all the target names mentioned in the AST tree."""
|
|
names = set()
|
|
if isinstance(lhs, ast.Tuple):
|
|
for lhs_elt in lhs.elts:
|
|
if isinstance(lhs_elt, ast.Starred):
|
|
names.add(lhs_elt.value.id)
|
|
else:
|
|
names = names.union(await self.get_target_names(lhs_elt))
|
|
elif isinstance(lhs, ast.Attribute):
|
|
var_name = await self.ast_attribute_collapse(lhs, check_undef=False)
|
|
if isinstance(var_name, str):
|
|
names.add(var_name)
|
|
elif isinstance(lhs, ast.Name):
|
|
names.add(lhs.id)
|
|
return names
|
|
|
|
async def get_names_set(self, arg, names, nonlocal_names, global_names, local_names):
|
|
"""Recursively find all the names mentioned in the AST tree."""
|
|
|
|
cls_name = arg.__class__.__name__
|
|
if cls_name == "Attribute":
|
|
full_name = await self.ast_attribute_collapse(arg)
|
|
if full_name is not None:
|
|
names.add(full_name)
|
|
return
|
|
if cls_name == "Name":
|
|
names.add(arg.id)
|
|
return
|
|
if cls_name == "Nonlocal" and nonlocal_names is not None:
|
|
for var_name in arg.names:
|
|
nonlocal_names.add(var_name)
|
|
names.add(var_name)
|
|
return
|
|
if cls_name == "Global" and global_names is not None:
|
|
for var_name in arg.names:
|
|
global_names.add(var_name)
|
|
names.add(var_name)
|
|
return
|
|
if local_names is not None:
|
|
#
|
|
# find all the local variables by looking for assignments;
|
|
# also, don't recurse into function definitions
|
|
#
|
|
if cls_name == "Assign":
|
|
for target in arg.targets:
|
|
for name in await self.get_target_names(target):
|
|
local_names.add(name)
|
|
names.add(name)
|
|
elif cls_name in {"AugAssign", "For", "AsyncFor", "NamedExpr"}:
|
|
for name in await self.get_target_names(arg.target):
|
|
local_names.add(name)
|
|
names.add(name)
|
|
elif cls_name in {"With", "AsyncWith"}:
|
|
for item in arg.items:
|
|
if item.optional_vars:
|
|
for name in await self.get_target_names(item.optional_vars):
|
|
local_names.add(name)
|
|
names.add(name)
|
|
elif cls_name in {"ListComp", "DictComp", "SetComp"}:
|
|
target_vars, _ = await self.loopvar_scope_save(arg.generators)
|
|
for name in target_vars:
|
|
local_names.add(name)
|
|
elif cls_name == "Try":
|
|
for handler in arg.handlers:
|
|
if handler.name is not None:
|
|
local_names.add(handler.name)
|
|
names.add(handler.name)
|
|
elif cls_name == "Call":
|
|
await self.get_names_set(arg.func, names, nonlocal_names, global_names, local_names)
|
|
for this_arg in arg.args:
|
|
await self.get_names_set(this_arg, names, nonlocal_names, global_names, local_names)
|
|
for this_arg in arg.keywords or []:
|
|
await self.get_names_set(this_arg, names, nonlocal_names, global_names, local_names)
|
|
return
|
|
elif cls_name in {"FunctionDef", "ClassDef", "AsyncFunctionDef"}:
|
|
local_names.add(arg.name)
|
|
names.add(arg.name)
|
|
for dec in arg.decorator_list:
|
|
await self.get_names_set(dec, names, nonlocal_names, global_names, local_names)
|
|
#
|
|
# find unbound names from the body of the function or class
|
|
#
|
|
inner_global, inner_names, inner_local = set(), set(), set()
|
|
for child in arg.body:
|
|
await self.get_names_set(child, inner_names, None, inner_global, inner_local)
|
|
for name in inner_names:
|
|
if name not in inner_local and name not in inner_global:
|
|
names.add(name)
|
|
return
|
|
elif cls_name == "Delete":
|
|
for arg1 in arg.targets:
|
|
if isinstance(arg1, ast.Name):
|
|
local_names.add(arg1.id)
|
|
for child in ast.iter_child_nodes(arg):
|
|
await self.get_names_set(child, names, nonlocal_names, global_names, local_names)
|
|
|
|
async def get_names(self, this_ast=None, nonlocal_names=None, global_names=None, local_names=None):
|
|
"""Return set of all the names mentioned in our AST tree."""
|
|
names = set()
|
|
this_ast = this_ast or self.ast
|
|
if this_ast:
|
|
await self.get_names_set(this_ast, names, nonlocal_names, global_names, local_names)
|
|
return names
|
|
|
|
def parse(self, code_str, filename=None, mode="exec"):
|
|
"""Parse the code_str source code into an AST tree."""
|
|
self.exception = None
|
|
self.exception_obj = None
|
|
self.exception_long = None
|
|
self.ast = None
|
|
if filename is not None:
|
|
self.filename = filename
|
|
try:
|
|
if isinstance(code_str, list):
|
|
self.code_list = code_str
|
|
self.code_str = "\n".join(code_str)
|
|
elif isinstance(code_str, str):
|
|
self.code_str = code_str
|
|
self.code_list = code_str.split("\n")
|
|
else:
|
|
self.code_str = code_str
|
|
self.code_list = []
|
|
self.ast = ast.parse(self.code_str, filename=self.filename, mode=mode)
|
|
return True
|
|
except SyntaxError as err:
|
|
self.exception_obj = err
|
|
self.lineno = err.lineno
|
|
self.col_offset = err.offset - 1
|
|
self.exception = f"syntax error {err}"
|
|
if err.filename == self.filename:
|
|
self.exception_long = self.format_exc(err, self.lineno, self.col_offset)
|
|
else:
|
|
self.exception_long = self.format_exc(err, 1, self.col_offset, code_list=[err.text])
|
|
return False
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as err:
|
|
self.exception_obj = err
|
|
self.lineno = 1
|
|
self.col_offset = 0
|
|
self.exception = f"parsing error {err}"
|
|
self.exception_long = self.format_exc(err)
|
|
return False
|
|
|
|
def format_exc(self, exc, lineno=None, col_offset=None, short=False, code_list=None):
|
|
"""Format an multi-line exception message using lineno if available."""
|
|
if code_list is None:
|
|
code_list = self.code_list
|
|
if lineno is not None and lineno <= len(code_list):
|
|
if short:
|
|
mesg = f"In <{self.filename}> line {lineno}:\n"
|
|
mesg += " " + code_list[lineno - 1]
|
|
else:
|
|
mesg = f"Exception in <{self.filename}> line {lineno}:\n"
|
|
mesg += " " + code_list[lineno - 1] + "\n"
|
|
if col_offset is not None:
|
|
mesg += " " + " " * col_offset + "^\n"
|
|
mesg += f"{type(exc).__name__}: {exc}"
|
|
else:
|
|
mesg = f"Exception in <{self.filename}>:\n"
|
|
mesg += f"{type(exc).__name__}: {exc}"
|
|
#
|
|
# to get a more detailed traceback on exception (eg, when chasing an internal
|
|
# error), add an "import traceback" above, and uncomment this next line
|
|
#
|
|
# return mesg + "\n" + traceback.format_exc(-1)
|
|
return mesg
|
|
|
|
def get_exception(self):
|
|
"""Return the last exception str."""
|
|
return self.exception
|
|
|
|
def get_exception_obj(self):
|
|
"""Return the last exception object."""
|
|
return self.exception_obj
|
|
|
|
def get_exception_long(self):
|
|
"""Return the last exception in a longer str form."""
|
|
return self.exception_long
|
|
|
|
def set_local_sym_table(self, sym_table):
|
|
"""Set the local symbol table."""
|
|
self.local_sym_table = sym_table
|
|
|
|
def set_global_ctx(self, global_ctx):
|
|
"""Set the global context."""
|
|
self.global_ctx = global_ctx
|
|
if self.sym_table == self.global_sym_table:
|
|
self.global_sym_table = global_ctx.get_global_sym_table()
|
|
self.sym_table = self.global_sym_table
|
|
else:
|
|
self.global_sym_table = global_ctx.get_global_sym_table()
|
|
if len(self.sym_table_stack) > 0:
|
|
self.sym_table_stack[0] = self.global_sym_table
|
|
|
|
def get_global_ctx(self):
|
|
"""Return the global context."""
|
|
return self.global_ctx
|
|
|
|
def get_global_ctx_name(self):
|
|
"""Return the global context name."""
|
|
return self.global_ctx.get_name()
|
|
|
|
def set_logger_name(self, name):
|
|
"""Set the context's logger name."""
|
|
if self.logger:
|
|
for handler in self.logger_handlers:
|
|
self.logger.removeHandler(handler)
|
|
self.logger_name = name
|
|
self.logger = logging.getLogger(LOGGER_PATH + "." + name)
|
|
for handler in self.logger_handlers:
|
|
self.logger.addHandler(handler)
|
|
|
|
def get_logger_name(self):
|
|
"""Get the context's logger name."""
|
|
return self.logger_name
|
|
|
|
def get_logger(self):
|
|
"""Get the context's logger."""
|
|
return self.logger
|
|
|
|
def add_logger_handler(self, handler):
|
|
"""Add logger handler to this context."""
|
|
self.logger.addHandler(handler)
|
|
self.logger_handlers.add(handler)
|
|
|
|
def remove_logger_handler(self, handler):
|
|
"""Remove logger handler to this context."""
|
|
self.logger.removeHandler(handler)
|
|
self.logger_handlers.discard(handler)
|
|
|
|
def completions(self, root):
|
|
"""Return potential variable, function or attribute matches."""
|
|
words = set()
|
|
num_period = root.count(".")
|
|
if num_period >= 1:
|
|
last_period = root.rfind(".")
|
|
name = root[0:last_period]
|
|
attr_root = root[last_period + 1 :]
|
|
if name in self.global_sym_table:
|
|
var = self.global_sym_table[name]
|
|
try:
|
|
for attr in var.__dict__:
|
|
if attr.lower().startswith(attr_root) and (attr_root != "" or attr[0:1] != "_"):
|
|
words.add(f"{name}.{attr}")
|
|
except Exception:
|
|
pass
|
|
for keyw in set(keyword.kwlist) - {"yield"}:
|
|
if keyw.lower().startswith(root):
|
|
words.add(keyw)
|
|
sym_table = BUILTIN_AST_FUNCS_FACTORY.copy()
|
|
for name, value in builtins.__dict__.items():
|
|
if name[0] != "_" and name not in BUILTIN_EXCLUDE:
|
|
sym_table[name] = value
|
|
sym_table.update(self.global_sym_table.items())
|
|
for name, value in sym_table.items():
|
|
if name.lower().startswith(root):
|
|
words.add(name)
|
|
return words
|
|
|
|
async def eval(self, new_state_vars=None, merge_local=False):
|
|
"""Execute parsed code, with the optional state variables added to the scope."""
|
|
self.exception = None
|
|
self.exception_obj = None
|
|
self.exception_long = None
|
|
if new_state_vars:
|
|
if not merge_local:
|
|
self.local_sym_table = {}
|
|
self.local_sym_table.update(new_state_vars)
|
|
if self.ast:
|
|
try:
|
|
val = await self.aeval(self.ast)
|
|
if isinstance(val, EvalStopFlow):
|
|
return None
|
|
return val
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as err:
|
|
if self.exception_long is None:
|
|
self.exception_long = self.format_exc(err, self.lineno, self.col_offset)
|
|
return None
|
|
|
|
def dump(self, this_ast=None):
|
|
"""Dump the AST tree for debugging."""
|
|
return ast.dump(this_ast if this_ast else self.ast)
|