922 lines
36 KiB
Python
922 lines
36 KiB
Python
"""Pyscript Jupyter kernel."""
|
|
|
|
#
|
|
# Based on simple_kernel.py by Doug Blank <doug.blank@gmail.com>
|
|
# https://github.com/dsblank/simple_kernel
|
|
# license: public domain
|
|
# Thanks Doug!
|
|
#
|
|
|
|
import asyncio
|
|
import datetime
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import logging.handlers
|
|
import re
|
|
from struct import pack, unpack
|
|
import traceback
|
|
import uuid
|
|
|
|
from .const import LOGGER_PATH
|
|
from .function import Function
|
|
from .global_ctx import GlobalContextMgr
|
|
from .state import State
|
|
|
|
_LOGGER = logging.getLogger(LOGGER_PATH + ".jupyter_kernel")
|
|
|
|
# Globals:
|
|
|
|
DELIM = b"<IDS|MSG>"
|
|
|
|
|
|
def msg_id():
|
|
"""Return a new uuid for message id."""
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
def str_to_bytes(string):
|
|
"""Encode a string in bytes."""
|
|
return string.encode("utf-8")
|
|
|
|
|
|
class KernelBufferingHandler(logging.handlers.BufferingHandler):
|
|
"""Memory-based handler for logging; send via stdout queue."""
|
|
|
|
def __init__(self, housekeep_q):
|
|
"""Initialize KernelBufferingHandler instance."""
|
|
super().__init__(0)
|
|
self.housekeep_q = housekeep_q
|
|
|
|
def flush(self):
|
|
"""Flush is a no-op."""
|
|
|
|
def shouldFlush(self, record):
|
|
"""Write the buffer to the housekeeping queue."""
|
|
try:
|
|
self.housekeep_q.put_nowait(["stdout", self.format(record)])
|
|
except asyncio.QueueFull:
|
|
_LOGGER.error("housekeep_q unexpectedly full")
|
|
|
|
|
|
################################################################
|
|
class ZmqSocket:
|
|
"""Defines a minimal implementation of a small subset of ZMQ."""
|
|
|
|
#
|
|
# This allows pyscript to work with Jupyter without the real zmq
|
|
# and pyzmq packages, which might not be available or easy to
|
|
# install on the wide set of HASS platforms.
|
|
#
|
|
def __init__(self, reader, writer, sock_type):
|
|
"""Initialize a ZMQ socket with the given type and reader/writer streams."""
|
|
self.writer = writer
|
|
self.reader = reader
|
|
self.type = sock_type
|
|
|
|
async def read_bytes(self, num_bytes):
|
|
"""Read bytes from ZMQ socket."""
|
|
data = b""
|
|
while len(data) < num_bytes:
|
|
new_data = await self.reader.read(num_bytes - len(data))
|
|
if len(new_data) == 0:
|
|
raise EOFError
|
|
data += new_data
|
|
return data
|
|
|
|
async def write_bytes(self, raw_msg):
|
|
"""Write bytes to ZMQ socket."""
|
|
self.writer.write(raw_msg)
|
|
await self.writer.drain()
|
|
|
|
async def handshake(self):
|
|
"""Do initial greeting handshake on a new ZMQ connection."""
|
|
await self.write_bytes(b"\xff\x00\x00\x00\x00\x00\x00\x00\x01\x7f")
|
|
_ = await self.read_bytes(10)
|
|
# _LOGGER.debug(f"handshake: got initial greeting {greeting}")
|
|
await self.write_bytes(b"\x03")
|
|
_ = await self.read_bytes(1)
|
|
await self.write_bytes(b"\x00" + "NULL".encode() + b"\x00" * 16 + b"\x00" + b"\x00" * 31)
|
|
_ = await self.read_bytes(53)
|
|
# _LOGGER.debug(f"handshake: got rest of greeting {greeting}")
|
|
params = [["Socket-Type", self.type]]
|
|
if self.type == "ROUTER":
|
|
params.append(["Identity", ""])
|
|
await self.send_cmd("READY", params)
|
|
|
|
async def recv(self, multipart=False):
|
|
"""Receive a message from ZMQ socket."""
|
|
parts = []
|
|
while 1:
|
|
cmd = (await self.read_bytes(1))[0]
|
|
if cmd & 0x2:
|
|
msg_len = unpack(">Q", await self.read_bytes(8))[0]
|
|
else:
|
|
msg_len = (await self.read_bytes(1))[0]
|
|
msg_body = await self.read_bytes(msg_len)
|
|
if cmd & 0x4:
|
|
# _LOGGER.debug(f"recv: got cmd {msg_body}")
|
|
cmd_len = msg_body[0]
|
|
cmd = msg_body[1 : cmd_len + 1]
|
|
msg_body = msg_body[cmd_len + 1 :]
|
|
params = []
|
|
while len(msg_body) > 0:
|
|
param_len = msg_body[0]
|
|
param = msg_body[1 : param_len + 1]
|
|
msg_body = msg_body[param_len + 1 :]
|
|
value_len = unpack(">L", msg_body[0:4])[0]
|
|
value = msg_body[4 : 4 + value_len]
|
|
msg_body = msg_body[4 + value_len :]
|
|
params.append([param, value])
|
|
# _LOGGER.debug(f"recv: got cmd={cmd}, params={params}")
|
|
else:
|
|
parts.append(msg_body)
|
|
if cmd in (0x0, 0x2):
|
|
# _LOGGER.debug(f"recv: got msg {parts}")
|
|
if not multipart:
|
|
return b"".join(parts)
|
|
|
|
return parts
|
|
|
|
async def recv_multipart(self):
|
|
"""Receive a multipart message from ZMQ socket."""
|
|
return await self.recv(multipart=True)
|
|
|
|
async def send_cmd(self, cmd, params):
|
|
"""Send a command over ZMQ socket."""
|
|
raw_msg = bytearray([len(cmd)]) + cmd.encode()
|
|
for param in params:
|
|
raw_msg += bytearray([len(param[0])]) + param[0].encode()
|
|
raw_msg += pack(">L", len(param[1])) + param[1].encode()
|
|
len_msg = len(raw_msg)
|
|
if len_msg <= 255:
|
|
raw_msg = bytearray([0x4, len_msg]) + raw_msg
|
|
else:
|
|
raw_msg = bytearray([0x6]) + pack(">Q", len_msg) + raw_msg
|
|
# _LOGGER.debug(f"send_cmd: sending {raw_msg}")
|
|
await self.write_bytes(raw_msg)
|
|
|
|
async def send(self, msg):
|
|
"""Send a message over ZMQ socket."""
|
|
len_msg = len(msg)
|
|
if len_msg <= 255:
|
|
raw_msg = bytearray([0x1, 0x0, 0x0, len_msg]) + msg
|
|
else:
|
|
raw_msg = bytearray([0x1, 0x0, 0x2]) + pack(">Q", len_msg) + msg
|
|
# _LOGGER.debug(f"send: sending {raw_msg}")
|
|
await self.write_bytes(raw_msg)
|
|
|
|
async def send_multipart(self, parts):
|
|
"""Send multipart messages over ZMQ socket."""
|
|
raw_msg = b""
|
|
for i, part in enumerate(parts):
|
|
len_part = len(part)
|
|
cmd = 0x1 if i < len(parts) - 1 else 0x0
|
|
if len_part <= 255:
|
|
raw_msg += bytearray([cmd, len_part]) + part
|
|
else:
|
|
raw_msg += bytearray([cmd + 2]) + pack(">Q", len_part) + part
|
|
# _LOGGER.debug(f"send_multipart: sending {raw_msg}")
|
|
await self.write_bytes(raw_msg)
|
|
|
|
def close(self):
|
|
"""Close the ZMQ socket."""
|
|
self.writer.close()
|
|
|
|
|
|
##########################################
|
|
class Kernel:
|
|
"""Define a Jupyter Kernel class."""
|
|
|
|
def __init__(self, config, ast_ctx, global_ctx, global_ctx_name):
|
|
"""Initialize a Kernel object, one instance per session."""
|
|
self.config = config.copy()
|
|
self.global_ctx = global_ctx
|
|
self.global_ctx_name = global_ctx_name
|
|
self.ast_ctx = ast_ctx
|
|
|
|
self.secure_key = str_to_bytes(self.config["key"])
|
|
self.no_connect_timeout = self.config.get("no_connect_timeout", 30)
|
|
self.signature_schemes = {"hmac-sha256": hashlib.sha256}
|
|
self.auth = hmac.HMAC(
|
|
self.secure_key,
|
|
digestmod=self.signature_schemes[self.config["signature_scheme"]],
|
|
)
|
|
self.execution_count = 1
|
|
self.engine_id = str(uuid.uuid4())
|
|
|
|
self.heartbeat_server = None
|
|
self.iopub_server = None
|
|
self.control_server = None
|
|
self.stdin_server = None
|
|
self.shell_server = None
|
|
|
|
self.heartbeat_port = None
|
|
self.iopub_port = None
|
|
self.control_port = None
|
|
self.stdin_port = None
|
|
self.shell_port = None
|
|
# this should probably be a configuration parameter
|
|
self.avail_port = 50321
|
|
|
|
# there can be multiple iopub subscribers, with corresponding tasks
|
|
self.iopub_socket = set()
|
|
|
|
self.tasks = {}
|
|
self.task_cnt = 0
|
|
self.task_cnt_max = 0
|
|
|
|
self.session_cleanup_callback = None
|
|
|
|
self.housekeep_q = asyncio.Queue(0)
|
|
|
|
self.parent_header = None
|
|
|
|
#
|
|
# we create a logging handler so that output from the log functions
|
|
# gets delivered back to Jupyter as stdout
|
|
#
|
|
self.console = KernelBufferingHandler(self.housekeep_q)
|
|
self.console.setLevel(logging.DEBUG)
|
|
# set a format which is just the message
|
|
formatter = logging.Formatter("%(message)s")
|
|
self.console.setFormatter(formatter)
|
|
|
|
# match alphanum or "." at end of line
|
|
self.completion_re = re.compile(r".*?([\w.]*)$", re.DOTALL)
|
|
|
|
# see if line ends in a ":", with optional whitespace and comment
|
|
# note: this doesn't detect if we are inside a quoted string...
|
|
self.colon_end_re = re.compile(r".*: *(#.*)?$")
|
|
|
|
def msg_sign(self, msg_lst):
|
|
"""Sign a message with a secure signature."""
|
|
auth_hmac = self.auth.copy()
|
|
for msg in msg_lst:
|
|
auth_hmac.update(msg)
|
|
return str_to_bytes(auth_hmac.hexdigest())
|
|
|
|
def deserialize_wire_msg(self, wire_msg):
|
|
"""Split the routing prefix and message frames from a message on the wire."""
|
|
delim_idx = wire_msg.index(DELIM)
|
|
identities = wire_msg[:delim_idx]
|
|
m_signature = wire_msg[delim_idx + 1]
|
|
msg_frames = wire_msg[delim_idx + 2 :]
|
|
|
|
def decode(msg):
|
|
return json.loads(msg.decode("utf-8"))
|
|
|
|
msg = {}
|
|
msg["header"] = decode(msg_frames[0])
|
|
msg["parent_header"] = decode(msg_frames[1])
|
|
msg["metadata"] = decode(msg_frames[2])
|
|
msg["content"] = decode(msg_frames[3])
|
|
check_sig = self.msg_sign(msg_frames)
|
|
if check_sig != m_signature:
|
|
_LOGGER.error(
|
|
"signature mismatch: check_sig=%s, m_signature=%s, wire_msg=%s",
|
|
check_sig,
|
|
m_signature,
|
|
wire_msg,
|
|
)
|
|
raise ValueError("Signatures do not match")
|
|
|
|
return identities, msg
|
|
|
|
def new_header(self, msg_type):
|
|
"""Make a new header."""
|
|
return {
|
|
"date": datetime.datetime.now().isoformat(),
|
|
"msg_id": msg_id(),
|
|
"username": "kernel",
|
|
"session": self.engine_id,
|
|
"msg_type": msg_type,
|
|
"version": "5.3",
|
|
}
|
|
|
|
async def send(
|
|
self,
|
|
stream,
|
|
msg_type,
|
|
content=None,
|
|
parent_header=None,
|
|
metadata=None,
|
|
identities=None,
|
|
):
|
|
"""Send message to the Jupyter client."""
|
|
header = self.new_header(msg_type)
|
|
|
|
def encode(msg):
|
|
return str_to_bytes(json.dumps(msg))
|
|
|
|
msg_lst = [
|
|
encode(header),
|
|
encode(parent_header if parent_header else {}),
|
|
encode(metadata if metadata else {}),
|
|
encode(content if content else {}),
|
|
]
|
|
signature = self.msg_sign(msg_lst)
|
|
parts = [DELIM, signature, msg_lst[0], msg_lst[1], msg_lst[2], msg_lst[3]]
|
|
if identities:
|
|
parts = identities + parts
|
|
if stream:
|
|
# _LOGGER.debug("send %s: %s", msg_type, parts)
|
|
for this_stream in stream if isinstance(stream, set) else {stream}:
|
|
await this_stream.send_multipart(parts)
|
|
|
|
async def shell_handler(self, shell_socket, wire_msg):
|
|
"""Handle shell messages."""
|
|
|
|
identities, msg = self.deserialize_wire_msg(wire_msg)
|
|
# _LOGGER.debug("shell received %s: %s", msg.get('header', {}).get('msg_type', 'UNKNOWN'), msg)
|
|
self.parent_header = msg["header"]
|
|
|
|
content = {
|
|
"execution_state": "busy",
|
|
}
|
|
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
|
|
|
|
if msg["header"]["msg_type"] == "execute_request":
|
|
|
|
content = {
|
|
"execution_count": self.execution_count,
|
|
"code": msg["content"]["code"],
|
|
}
|
|
await self.send(self.iopub_socket, "execute_input", content, parent_header=msg["header"])
|
|
result = None
|
|
|
|
code = msg["content"]["code"]
|
|
#
|
|
# replace VSCode initialization code, which depend on iPython % extensions
|
|
#
|
|
if code.startswith("%config "):
|
|
code = "None"
|
|
if code.startswith("_rwho_ls = %who_ls"):
|
|
code = "print([])"
|
|
|
|
self.global_ctx.set_auto_start(False)
|
|
self.ast_ctx.parse(code)
|
|
exc = self.ast_ctx.get_exception_obj()
|
|
if exc is None:
|
|
result = await self.ast_ctx.eval()
|
|
exc = self.ast_ctx.get_exception_obj()
|
|
await Function.waiter_sync()
|
|
self.global_ctx.set_auto_start(True)
|
|
self.global_ctx.start()
|
|
if exc:
|
|
traceback_mesg = self.ast_ctx.get_exception_long().split("\n")
|
|
|
|
metadata = {
|
|
"dependencies_met": True,
|
|
"engine": self.engine_id,
|
|
"status": "error",
|
|
"started": datetime.datetime.now().isoformat(),
|
|
}
|
|
content = {
|
|
"execution_count": self.execution_count,
|
|
"status": "error",
|
|
"ename": type(exc).__name__, # Exception name, as a string
|
|
"evalue": str(exc), # Exception value, as a string
|
|
"traceback": traceback_mesg,
|
|
}
|
|
_LOGGER.debug("Executing '%s' got exception: %s", code, content)
|
|
await self.send(
|
|
shell_socket,
|
|
"execute_reply",
|
|
content,
|
|
metadata=metadata,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
del content["execution_count"], content["status"]
|
|
await self.send(self.iopub_socket, "error", content, parent_header=msg["header"])
|
|
|
|
content = {
|
|
"execution_state": "idle",
|
|
}
|
|
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
|
|
if msg["content"].get("store_history", True):
|
|
self.execution_count += 1
|
|
return
|
|
|
|
# if True or isinstance(self.ast_ctx.ast, ast.Expr):
|
|
_LOGGER.debug("Executing: '%s' got result %s", code, result)
|
|
if result is not None:
|
|
content = {
|
|
"execution_count": self.execution_count,
|
|
"data": {"text/plain": repr(result)},
|
|
"metadata": {},
|
|
}
|
|
await self.send(
|
|
self.iopub_socket,
|
|
"execute_result",
|
|
content,
|
|
parent_header=msg["header"],
|
|
)
|
|
|
|
metadata = {
|
|
"dependencies_met": True,
|
|
"engine": self.engine_id,
|
|
"status": "ok",
|
|
"started": datetime.datetime.now().isoformat(),
|
|
}
|
|
content = {
|
|
"status": "ok",
|
|
"execution_count": self.execution_count,
|
|
"user_variables": {},
|
|
"payload": [],
|
|
"user_expressions": {},
|
|
}
|
|
await self.send(
|
|
shell_socket,
|
|
"execute_reply",
|
|
content,
|
|
metadata=metadata,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
if msg["content"].get("store_history", True):
|
|
self.execution_count += 1
|
|
|
|
#
|
|
# Make sure stdout gets sent before set report execution_state idle on iopub,
|
|
# otherwise VSCode doesn't display stdout. We do a handshake with the
|
|
# housekeep task to ensure any queued messages get processed.
|
|
#
|
|
handshake_q = asyncio.Queue(0)
|
|
await self.housekeep_q.put(["handshake", handshake_q, 0])
|
|
await handshake_q.get()
|
|
|
|
elif msg["header"]["msg_type"] == "kernel_info_request":
|
|
content = {
|
|
"protocol_version": "5.3",
|
|
"ipython_version": [1, 1, 0, ""],
|
|
"language_version": [0, 0, 1],
|
|
"language": "python",
|
|
"implementation": "python",
|
|
"implementation_version": "3.7",
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "1.0",
|
|
"mimetype": "",
|
|
"file_extension": ".py",
|
|
"codemirror_mode": "",
|
|
"nbconvert_exporter": "",
|
|
},
|
|
"banner": "",
|
|
}
|
|
await self.send(
|
|
shell_socket,
|
|
"kernel_info_reply",
|
|
content,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
|
|
elif msg["header"]["msg_type"] == "complete_request":
|
|
root = ""
|
|
words = set()
|
|
code = msg["content"]["code"]
|
|
posn = msg["content"]["cursor_pos"]
|
|
match = self.completion_re.match(code[0:posn].lower())
|
|
if match:
|
|
root = match[1].lower()
|
|
words = State.completions(root)
|
|
words = words.union(await Function.service_completions(root))
|
|
words = words.union(await Function.func_completions(root))
|
|
words = words.union(self.ast_ctx.completions(root))
|
|
# _LOGGER.debug(f"complete_request code={code}, posn={posn}, root={root}, words={words}")
|
|
content = {
|
|
"status": "ok",
|
|
"matches": sorted(list(words)),
|
|
"cursor_start": msg["content"]["cursor_pos"] - len(root),
|
|
"cursor_end": msg["content"]["cursor_pos"],
|
|
"metadata": {},
|
|
}
|
|
await self.send(
|
|
shell_socket,
|
|
"complete_reply",
|
|
content,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
|
|
elif msg["header"]["msg_type"] == "is_complete_request":
|
|
code = msg["content"]["code"]
|
|
self.ast_ctx.parse(code)
|
|
exc = self.ast_ctx.get_exception_obj()
|
|
|
|
# determine indent of last line
|
|
indent = 0
|
|
i = code.rfind("\n")
|
|
if i >= 0:
|
|
while i + 1 < len(code) and code[i + 1] == " ":
|
|
i += 1
|
|
indent += 1
|
|
if exc is None:
|
|
if indent == 0:
|
|
content = {
|
|
# One of 'complete', 'incomplete', 'invalid', 'unknown'
|
|
"status": "complete",
|
|
# If status is 'incomplete', indent should contain the characters to use
|
|
# to indent the next line. This is only a hint: frontends may ignore it
|
|
# and use their own autoindentation rules. For other statuses, this
|
|
# field does not exist.
|
|
# "indent": str,
|
|
}
|
|
else:
|
|
content = {
|
|
"status": "incomplete",
|
|
"indent": " " * indent,
|
|
}
|
|
else:
|
|
#
|
|
# if the syntax error is right at the end, then we label it incomplete,
|
|
# otherwise it's invalid
|
|
#
|
|
if "EOF while" in str(exc) or "expected an indented block" in str(exc):
|
|
# if error is at ":" then increase indent
|
|
if hasattr(exc, "lineno"):
|
|
line = code.split("\n")[exc.lineno - 1]
|
|
if self.colon_end_re.match(line):
|
|
indent += 4
|
|
content = {
|
|
"status": "incomplete",
|
|
"indent": " " * indent,
|
|
}
|
|
else:
|
|
content = {
|
|
"status": "invalid",
|
|
}
|
|
# _LOGGER.debug(f"is_complete_request code={code}, exc={exc}, content={content}")
|
|
await self.send(
|
|
shell_socket,
|
|
"is_complete_reply",
|
|
content,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
|
|
elif msg["header"]["msg_type"] == "comm_info_request":
|
|
content = {"comms": {}}
|
|
await self.send(
|
|
shell_socket,
|
|
"comm_info_reply",
|
|
content,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
|
|
elif msg["header"]["msg_type"] == "history_request":
|
|
content = {"history": []}
|
|
await self.send(
|
|
shell_socket,
|
|
"history_reply",
|
|
content,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
|
|
elif msg["header"]["msg_type"] in {"comm_open", "comm_msg", "comm_close"}:
|
|
# _LOGGER.debug(f"ignore {msg['header']['msg_type']} message ")
|
|
...
|
|
else:
|
|
_LOGGER.error("unknown msg_type: %s", msg["header"]["msg_type"])
|
|
|
|
content = {
|
|
"execution_state": "idle",
|
|
}
|
|
await self.send(self.iopub_socket, "status", content, parent_header=msg["header"])
|
|
|
|
async def control_listen(self, reader, writer):
|
|
"""Task that listens to control messages."""
|
|
try:
|
|
_LOGGER.debug("control_listen connected")
|
|
await self.housekeep_q.put(["register", "control", asyncio.current_task()])
|
|
control_socket = ZmqSocket(reader, writer, "ROUTER")
|
|
await control_socket.handshake()
|
|
while 1:
|
|
wire_msg = await control_socket.recv_multipart()
|
|
identities, msg = self.deserialize_wire_msg(wire_msg)
|
|
# _LOGGER.debug("control received %s: %s", msg.get('header', {}).get('msg_type', 'UNKNOWN'), msg)
|
|
if msg["header"]["msg_type"] == "shutdown_request":
|
|
content = {
|
|
"restart": False,
|
|
}
|
|
await self.send(
|
|
control_socket,
|
|
"shutdown_reply",
|
|
content,
|
|
parent_header=msg["header"],
|
|
identities=identities,
|
|
)
|
|
await self.housekeep_q.put(["shutdown"])
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except (EOFError, ConnectionResetError):
|
|
_LOGGER.debug("control_listen got eof")
|
|
await self.housekeep_q.put(["unregister", "control", asyncio.current_task()])
|
|
control_socket.close()
|
|
except Exception as err:
|
|
_LOGGER.error("control_listen exception %s", err)
|
|
await self.housekeep_q.put(["shutdown"])
|
|
|
|
async def stdin_listen(self, reader, writer):
|
|
"""Task that listens to stdin messages."""
|
|
try:
|
|
_LOGGER.debug("stdin_listen connected")
|
|
await self.housekeep_q.put(["register", "stdin", asyncio.current_task()])
|
|
stdin_socket = ZmqSocket(reader, writer, "ROUTER")
|
|
await stdin_socket.handshake()
|
|
while 1:
|
|
_ = await stdin_socket.recv_multipart()
|
|
# _LOGGER.debug("stdin_listen received %s", _)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except (EOFError, ConnectionResetError):
|
|
_LOGGER.debug("stdin_listen got eof")
|
|
await self.housekeep_q.put(["unregister", "stdin", asyncio.current_task()])
|
|
stdin_socket.close()
|
|
except Exception:
|
|
_LOGGER.error("stdin_listen exception %s", traceback.format_exc(-1))
|
|
await self.housekeep_q.put(["shutdown"])
|
|
|
|
async def shell_listen(self, reader, writer):
|
|
"""Task that listens to shell messages."""
|
|
try:
|
|
_LOGGER.debug("shell_listen connected")
|
|
await self.housekeep_q.put(["register", "shell", asyncio.current_task()])
|
|
shell_socket = ZmqSocket(reader, writer, "ROUTER")
|
|
await shell_socket.handshake()
|
|
while 1:
|
|
msg = await shell_socket.recv_multipart()
|
|
await self.shell_handler(shell_socket, msg)
|
|
except asyncio.CancelledError:
|
|
shell_socket.close()
|
|
raise
|
|
except (EOFError, ConnectionResetError):
|
|
_LOGGER.debug("shell_listen got eof")
|
|
await self.housekeep_q.put(["unregister", "shell", asyncio.current_task()])
|
|
shell_socket.close()
|
|
except Exception:
|
|
_LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
|
|
await self.housekeep_q.put(["shutdown"])
|
|
|
|
async def heartbeat_listen(self, reader, writer):
|
|
"""Task that listens and responds to heart beat messages."""
|
|
try:
|
|
_LOGGER.debug("heartbeat_listen connected")
|
|
await self.housekeep_q.put(["register", "heartbeat", asyncio.current_task()])
|
|
heartbeat_socket = ZmqSocket(reader, writer, "REP")
|
|
await heartbeat_socket.handshake()
|
|
while 1:
|
|
msg = await heartbeat_socket.recv()
|
|
# _LOGGER.debug("heartbeat_listen: got %s", msg)
|
|
await heartbeat_socket.send(msg)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except (EOFError, ConnectionResetError):
|
|
_LOGGER.debug("heartbeat_listen got eof")
|
|
await self.housekeep_q.put(["unregister", "heartbeat", asyncio.current_task()])
|
|
heartbeat_socket.close()
|
|
except Exception:
|
|
_LOGGER.error("heartbeat_listen exception: %s", traceback.format_exc(-1))
|
|
await self.housekeep_q.put(["shutdown"])
|
|
|
|
async def iopub_listen(self, reader, writer):
|
|
"""Task that listens to iopub messages."""
|
|
try:
|
|
_LOGGER.debug("iopub_listen connected")
|
|
await self.housekeep_q.put(["register", "iopub", asyncio.current_task()])
|
|
iopub_socket = ZmqSocket(reader, writer, "PUB")
|
|
await iopub_socket.handshake()
|
|
self.iopub_socket.add(iopub_socket)
|
|
while 1:
|
|
_ = await iopub_socket.recv_multipart()
|
|
# _LOGGER.debug("iopub received %s", _)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except (EOFError, ConnectionResetError):
|
|
await self.housekeep_q.put(["unregister", "iopub", asyncio.current_task()])
|
|
iopub_socket.close()
|
|
self.iopub_socket.discard(iopub_socket)
|
|
_LOGGER.debug("iopub_listen got eof")
|
|
except Exception:
|
|
_LOGGER.error("iopub_listen exception %s", traceback.format_exc(-1))
|
|
await self.housekeep_q.put(["shutdown"])
|
|
|
|
async def housekeep_run(self):
|
|
"""Housekeeping, including closing servers after startup, and doing orderly shutdown."""
|
|
while True:
|
|
try:
|
|
msg = await self.housekeep_q.get()
|
|
if msg[0] == "stdout":
|
|
content = {"name": "stdout", "text": msg[1] + "\n"}
|
|
if self.iopub_socket:
|
|
await self.send(
|
|
self.iopub_socket,
|
|
"stream",
|
|
content,
|
|
parent_header=self.parent_header,
|
|
identities=[b"stream.stdout"],
|
|
)
|
|
elif msg[0] == "handshake":
|
|
await msg[1].put(msg[2])
|
|
elif msg[0] == "register":
|
|
if msg[1] not in self.tasks:
|
|
self.tasks[msg[1]] = set()
|
|
self.tasks[msg[1]].add(msg[2])
|
|
self.task_cnt += 1
|
|
self.task_cnt_max = max(self.task_cnt_max, self.task_cnt)
|
|
#
|
|
# now a couple of things are connected, call the session_cleanup_callback
|
|
#
|
|
if self.task_cnt > 1 and self.session_cleanup_callback:
|
|
self.session_cleanup_callback()
|
|
self.session_cleanup_callback = None
|
|
elif msg[0] == "unregister":
|
|
if msg[1] in self.tasks:
|
|
self.tasks[msg[1]].discard(msg[2])
|
|
self.task_cnt -= 1
|
|
#
|
|
# if there are no connection tasks left, then shutdown the kernel
|
|
#
|
|
if self.task_cnt == 0 and self.task_cnt_max >= 4:
|
|
asyncio.create_task(self.session_shutdown())
|
|
await asyncio.sleep(10000)
|
|
elif msg[0] == "shutdown":
|
|
asyncio.create_task(self.session_shutdown())
|
|
return
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
_LOGGER.error("housekeep task exception: %s", traceback.format_exc(-1))
|
|
|
|
async def startup_timeout(self):
|
|
"""Shut down the session if nothing connects after 30 seconds."""
|
|
await self.housekeep_q.put(["register", "startup_timeout", asyncio.current_task()])
|
|
await asyncio.sleep(self.no_connect_timeout)
|
|
if self.task_cnt_max <= 1:
|
|
#
|
|
# nothing started other than us, so shut down the session
|
|
#
|
|
_LOGGER.error("No connections to session %s; shutting down", self.global_ctx_name)
|
|
if self.session_cleanup_callback:
|
|
self.session_cleanup_callback()
|
|
self.session_cleanup_callback = None
|
|
await self.housekeep_q.put(["shutdown"])
|
|
await self.housekeep_q.put(["unregister", "startup_timeout", asyncio.current_task()])
|
|
|
|
async def start_one_server(self, callback):
|
|
"""Start a server by finding an available port."""
|
|
first_port = self.avail_port
|
|
for _ in range(2048):
|
|
try:
|
|
server = await asyncio.start_server(callback, "0.0.0.0", self.avail_port)
|
|
return server, self.avail_port
|
|
except OSError:
|
|
self.avail_port += 1
|
|
_LOGGER.error(
|
|
"unable to find an available port from %d to %d",
|
|
first_port,
|
|
self.avail_port - 1,
|
|
)
|
|
return None, None
|
|
|
|
def get_ports(self):
|
|
"""Return a dict of the port numbers this kernel session is listening to."""
|
|
return {
|
|
"iopub_port": self.iopub_port,
|
|
"hb_port": self.heartbeat_port,
|
|
"control_port": self.control_port,
|
|
"stdin_port": self.stdin_port,
|
|
"shell_port": self.shell_port,
|
|
}
|
|
|
|
def set_session_cleanup_callback(self, callback):
|
|
"""Set a cleanup callback which is called right after the session has started."""
|
|
self.session_cleanup_callback = callback
|
|
|
|
async def session_start(self):
|
|
"""Start the kernel session."""
|
|
self.ast_ctx.add_logger_handler(self.console)
|
|
_LOGGER.info("Starting session %s", self.global_ctx_name)
|
|
|
|
self.tasks["housekeep"] = {asyncio.create_task(self.housekeep_run())}
|
|
self.tasks["startup_timeout"] = {asyncio.create_task(self.startup_timeout())}
|
|
|
|
self.iopub_server, self.iopub_port = await self.start_one_server(self.iopub_listen)
|
|
self.heartbeat_server, self.heartbeat_port = await self.start_one_server(self.heartbeat_listen)
|
|
self.control_server, self.control_port = await self.start_one_server(self.control_listen)
|
|
self.stdin_server, self.stdin_port = await self.start_one_server(self.stdin_listen)
|
|
self.shell_server, self.shell_port = await self.start_one_server(self.shell_listen)
|
|
|
|
#
|
|
# For debugging, can use the real ZMQ library instead on certain sockets; comment out
|
|
# the corresponding asyncio.start_server() call above if you enable the ZMQ-based
|
|
# functions here. You can then turn of verbosity level 4 (-vvvv) in hass_pyscript_kernel.py
|
|
# to see all the byte data in case you need to debug the simple ZMQ implementation here.
|
|
# The two most important zmq functions are shown below.
|
|
#
|
|
# import zmq
|
|
# import zmq.asyncio
|
|
#
|
|
# def zmq_bind(socket, connection, port):
|
|
# """Bind a socket."""
|
|
# if port <= 0:
|
|
# return socket.bind_to_random_port(connection)
|
|
# # _LOGGER.debug(f"binding to %s:%s" % (connection, port))
|
|
# socket.bind("%s:%s" % (connection, port))
|
|
# return port
|
|
#
|
|
# zmq_ctx = zmq.asyncio.Context()
|
|
#
|
|
# ##########################################
|
|
# # Shell using real ZMQ for debugging:
|
|
# async def shell_listen_zmq():
|
|
# """Task that listens to shell messages using ZMQ."""
|
|
# try:
|
|
# _LOGGER.debug("shell_listen_zmq connected")
|
|
# connection = self.config["transport"] + "://" + self.config["ip"]
|
|
# shell_socket = zmq_ctx.socket(zmq.ROUTER)
|
|
# self.shell_port = zmq_bind(shell_socket, connection, -1)
|
|
# _LOGGER.debug("shell_listen_zmq connected")
|
|
# while 1:
|
|
# msg = await shell_socket.recv_multipart()
|
|
# await self.shell_handler(shell_socket, msg)
|
|
# except asyncio.CancelledError:
|
|
# raise
|
|
# except Exception:
|
|
# _LOGGER.error("shell_listen exception %s", traceback.format_exc(-1))
|
|
# await self.housekeep_q.put(["shutdown"])
|
|
#
|
|
# ##########################################
|
|
# # IOPub using real ZMQ for debugging:
|
|
# # IOPub/Sub:
|
|
# async def iopub_listen_zmq():
|
|
# """Task that listens to iopub messages using ZMQ."""
|
|
# try:
|
|
# _LOGGER.debug("iopub_listen_zmq connected")
|
|
# connection = self.config["transport"] + "://" + self.config["ip"]
|
|
# iopub_socket = zmq_ctx.socket(zmq.PUB)
|
|
# self.iopub_port = zmq_bind(self.iopub_socket, connection, -1)
|
|
# self.iopub_socket.add(iopub_socket)
|
|
# while 1:
|
|
# wire_msg = await iopub_socket.recv_multipart()
|
|
# _LOGGER.debug("iopub received %s", wire_msg)
|
|
# except asyncio.CancelledError:
|
|
# raise
|
|
# except EOFError:
|
|
# await self.housekeep_q.put(["shutdown"])
|
|
# _LOGGER.debug("iopub_listen got eof")
|
|
# except Exception as err:
|
|
# _LOGGER.error("iopub_listen exception %s", err)
|
|
# await self.housekeep_q.put(["shutdown"])
|
|
#
|
|
# self.tasks["shell"] = {asyncio.create_task(shell_listen_zmq())}
|
|
# self.tasks["iopub"] = {asyncio.create_task(iopub_listen_zmq())}
|
|
#
|
|
|
|
async def session_shutdown(self):
|
|
"""Shutdown the kernel session."""
|
|
if not self.iopub_server:
|
|
# already shutdown, so quit
|
|
return
|
|
GlobalContextMgr.delete(self.global_ctx_name)
|
|
self.ast_ctx.remove_logger_handler(self.console)
|
|
# logging.getLogger("homeassistant.components.pyscript.func.").removeHandler(self.console)
|
|
_LOGGER.info("Shutting down session %s", self.global_ctx_name)
|
|
|
|
for server in [
|
|
self.heartbeat_server,
|
|
self.control_server,
|
|
self.stdin_server,
|
|
self.shell_server,
|
|
self.iopub_server,
|
|
]:
|
|
if server:
|
|
server.close()
|
|
self.heartbeat_server = None
|
|
self.iopub_server = None
|
|
self.control_server = None
|
|
self.stdin_server = None
|
|
self.shell_server = None
|
|
|
|
for task_set in self.tasks.values():
|
|
for task in task_set:
|
|
try:
|
|
task.cancel()
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self.tasks = []
|
|
|
|
for sock in self.iopub_socket:
|
|
try:
|
|
sock.close()
|
|
except Exception as err:
|
|
_LOGGER.error("iopub socket close exception: %s", err)
|
|
|
|
self.iopub_socket = set()
|