diff --git a/internal/expr/expr.go b/internal/expr/expr.go index a6d1f972..8fd6c9c2 100644 --- a/internal/expr/expr.go +++ b/internal/expr/expr.go @@ -12,7 +12,7 @@ func Init() { log := app.GetLogger("expr") streams.RedirectFunc("expr", func(url string) (string, error) { - v, err := expr.Run(url[5:]) + v, err := expr.Eval(url[5:], nil) if err != nil { return "", err } diff --git a/internal/wyoming/wyoming.go b/internal/wyoming/wyoming.go index aa76eab7..1849da3a 100644 --- a/internal/wyoming/wyoming.go +++ b/internal/wyoming/wyoming.go @@ -16,11 +16,12 @@ func Init() { // server var cfg struct { Mod map[string]struct { - Listen string `yaml:"listen"` - Name string `yaml:"name"` - Mode string `yaml:"mode"` - WakeURI string `yaml:"wake_uri"` - VADThreshold float32 `yaml:"vad_threshold"` + Listen string `yaml:"listen"` + Name string `yaml:"name"` + Mode string `yaml:"mode"` + Event map[string]string `yaml:"event"` + WakeURI string `yaml:"wake_uri"` + VADThreshold float32 `yaml:"vad_threshold"` } `yaml:"wyoming"` } app.LoadConfig(&cfg) @@ -40,6 +41,7 @@ func Init() { srv := &wyoming.Server{ Name: cfg.Name, + Event: cfg.Event, VADThreshold: int16(1000 * cfg.VADThreshold), // 1.0 => 1000 WakeURI: cfg.WakeURI, MicHandler: func(cons core.Consumer) error { @@ -60,6 +62,9 @@ func Init() { Trace: func(format string, v ...any) { log.Trace().Msgf("[wyoming] "+format, v...) }, + Error: func(format string, v ...any) { + log.Error().Msgf("[wyoming] "+format, v...) + }, } go serve(srv, cfg.Mode, cfg.Listen) } @@ -70,7 +75,7 @@ var log zerolog.Logger func serve(srv *wyoming.Server, mode, address string) { ln, err := net.Listen("tcp", address) if err != nil { - log.Warn().Msgf("[wyoming] listen error: %s", err) + log.Warn().Err(err).Msgf("[wyoming] listen") } for { diff --git a/pkg/expr/expr.go b/pkg/expr/expr.go index e2ed0ca6..4a8a663c 100644 --- a/pkg/expr/expr.go +++ b/pkg/expr/expr.go @@ -10,6 +10,7 @@ import ( "github.com/AlexxIT/go2rtc/pkg/tcp" "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" ) func newRequest(method, url string, headers map[string]any, body string) (*http.Request, error) { @@ -112,11 +113,19 @@ var Options = []expr.Option{ ), } -func Run(input string) (any, error) { - program, err := expr.Compile(input, Options...) +func Compile(input string) (*vm.Program, error) { + return expr.Compile(input, Options...) +} + +func Eval(input string, env any) (any, error) { + program, err := Compile(input) if err != nil { return nil, err } - return expr.Run(program, nil) + return expr.Run(program, env) +} + +func Run(program *vm.Program, env any) (any, error) { + return vm.Run(program, env) } diff --git a/pkg/expr/expr_test.go b/pkg/expr/expr_test.go index 14e75b2a..096afcdc 100644 --- a/pkg/expr/expr_test.go +++ b/pkg/expr/expr_test.go @@ -7,11 +7,11 @@ import ( ) func TestMatchHost(t *testing.T) { - v, err := Run(` + v, err := Eval(` let url = "rtsp://user:pass@192.168.1.123/cam/realmonitor?..."; let host = match(url, "//[^/]+")[0][2:]; host -`) +`, nil) require.Nil(t, err) require.Equal(t, "user:pass@192.168.1.123", v) } diff --git a/pkg/wyoming/expr.go b/pkg/wyoming/expr.go new file mode 100644 index 00000000..1b184cc3 --- /dev/null +++ b/pkg/wyoming/expr.go @@ -0,0 +1,131 @@ +package wyoming + +import ( + "fmt" + "time" + + "github.com/AlexxIT/go2rtc/pkg/expr" + "golang.org/x/net/context" +) + +type env struct { + *satellite + Type string + Data string +} + +func (s *satellite) handleEvent(evt *Event) { + switch evt.Type { + case "describe": + // {"asr": [], "tts": [], "handle": [], "intent": [], "wake": [], "satellite": {"name": "my satellite", "attribution": {"name": "", "url": ""}, "installed": true, "description": "my satellite", "version": "1.4.1", "area": null, "snd_format": null}} + data := fmt.Sprintf(`{"satellite":{"name":%q,"attribution":{"name":"go2rtc","url":"https://github.com/AlexxIT/go2rtc"},"installed":true}}`, s.srv.Name) + s.WriteEvent("info", data) + case "run-satellite": + s.Detect() + case "pause-satellite": + s.Stop() + case "detect": // WAKE_WORD_START {"names": null} + case "detection": // WAKE_WORD_END {"name": "ok_nabu_v0.1", "timestamp": 17580, "speaker": null} + case "transcribe": // STT_START {"language": "en"} + case "voice-started": // STT_VAD_START {"timestamp": 1160} + case "voice-stopped": // STT_VAD_END {"timestamp": 2470} + s.Pause() + case "transcript": // STT_END {"text": "how are you"} + case "synthesize": // TTS_START {"text": "Sorry, I couldn't understand that", "voice": {"language": "en"}} + case "audio-start": // TTS_END {"rate": 22050, "width": 2, "channels": 1, "timestamp": 0} + case "audio-chunk": // {"rate": 22050, "width": 2, "channels": 1, "timestamp": 0} + case "audio-stop": // {"timestamp": 2.880000000000002} + // run async because PlayAudio takes some time + go func() { + s.PlayAudio() + s.WriteEvent("played") + s.Detect() + }() + case "error": + s.Detect() + case "internal-run": + s.WriteEvent("run-pipeline", `{"start_stage":"wake","end_stage":"tts"}`) + s.Stream() + case "internal-detection": + s.WriteEvent("run-pipeline", `{"start_stage":"asr","end_stage":"tts"}`) + s.Stream() + } +} + +func (s *satellite) handleScript(evt *Event) { + var script string + if s.srv.Event != nil { + script = s.srv.Event[evt.Type] + } + + s.srv.Trace("event=%s data=%s payload size=%d", evt.Type, evt.Data, len(evt.Payload)) + + if script == "" { + s.handleEvent(evt) + return + } + + // run async because script can have sleeps + go func() { + e := &env{satellite: s, Type: evt.Type, Data: evt.Data} + if res, err := expr.Eval(script, e); err != nil { + s.srv.Trace("event=%s expr error=%s", evt.Type, err) + s.handleEvent(evt) + } else { + s.srv.Trace("event=%s expr result=%v", evt.Type, res) + } + }() +} + +func (s *satellite) Detect() bool { + return s.setMicState(stateWaitVAD) +} + +func (s *satellite) Stream() bool { + return s.setMicState(stateActive) +} + +func (s *satellite) Pause() bool { + return s.setMicState(stateIdle) +} + +func (s *satellite) Stop() bool { + s.micStop() + return true +} + +func (s *satellite) WriteEvent(args ...string) bool { + if len(args) == 0 { + return false + } + evt := &Event{Type: args[0]} + if len(args) > 1 { + evt.Data = args[1] + } + if err := s.api.WriteEvent(evt); err != nil { + return false + } + return true +} + +func (s *satellite) PlayAudio() bool { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + prod := newSndProducer(s.sndAudio, cancel) + if err := s.srv.SndHandler(prod); err != nil { + return false + } else { + <-ctx.Done() + return true + } +} + +func (e *env) Sleep(s string) bool { + d, err := time.ParseDuration(s) + if err != nil { + return false + } + time.Sleep(d) + return true +} diff --git a/pkg/wyoming/satellite.go b/pkg/wyoming/satellite.go index 0d1ea3e8..7bc990d0 100644 --- a/pkg/wyoming/satellite.go +++ b/pkg/wyoming/satellite.go @@ -1,7 +1,6 @@ package wyoming import ( - "errors" "fmt" "net" "sync" @@ -14,7 +13,8 @@ import ( ) type Server struct { - Name string + Name string + Event map[string]string VADThreshold int16 WakeURI string @@ -23,6 +23,7 @@ type Server struct { SndHandler func(prod core.Producer) error Trace func(format string, v ...any) + Error func(format string, v ...any) } func (s *Server) Serve(l net.Listener) error { @@ -41,66 +42,49 @@ func (s *Server) Handle(conn net.Conn) error { sat := newSatellite(api, s) defer sat.Close() - var snd []byte - for { evt, err := api.ReadEvent() if err != nil { return err } - s.Trace("event: %s data: %s payload: %d", evt.Type, evt.Data, len(evt.Payload)) - switch evt.Type { case "ping": // {"text": null} _ = api.WriteEvent(&Event{Type: "pong", Data: evt.Data}) - case "describe": - // {"asr": [], "tts": [], "handle": [], "intent": [], "wake": [], "satellite": {"name": "my satellite", "attribution": {"name": "", "url": ""}, "installed": true, "description": "my satellite", "version": "1.4.1", "area": null, "snd_format": null}} - data := fmt.Sprintf(`{"satellite":{"name":%q,"attribution":{"name":"go2rtc","url":"https://github.com/AlexxIT/go2rtc"},"installed":true}}`, s.Name) - _ = api.WriteEvent(&Event{Type: "info", Data: data}) - case "run-satellite": - if err = sat.run(); err != nil { - return err - } - case "pause-satellite": - sat.pause() - case "detect": // WAKE_WORD_START {"names": null} - case "detection": // WAKE_WORD_END {"name": "ok_nabu_v0.1", "timestamp": 17580, "speaker": null} - case "transcribe": // STT_START {"language": "en"} - case "voice-started": // STT_VAD_START {"timestamp": 1160} - case "voice-stopped": // STT_VAD_END {"timestamp": 2470} - sat.idle() - case "transcript": // STT_END {"text": "how are you"} - case "synthesize": // TTS_START {"text": "Sorry, I couldn't understand that", "voice": {"language": "en"}} case "audio-start": // TTS_END {"rate": 22050, "width": 2, "channels": 1, "timestamp": 0} - snd = snd[:0] + sat.sndAudio = sat.sndAudio[:0] case "audio-chunk": // {"rate": 22050, "width": 2, "channels": 1, "timestamp": 0} - snd = append(snd, evt.Payload...) - case "audio-stop": // {"timestamp": 2.880000000000002} - sat.respond(snd) - case "error": - sat.start() + sat.sndAudio = append(sat.sndAudio, evt.Payload...) + } + + if s.Event == nil || s.Event[evt.Type] == "" { + sat.handleEvent(evt) + } else { + // run async because there may be sleeps + go sat.handleScript(evt) } } } -// states like Home Assistant +// states like http.ConnState const ( - stateUnavailable = iota - stateIdle - stateWaitVAD // aka wait VAD - stateWaitWakeWord - stateStreaming + stateError = -2 + stateClosed = -1 + stateNew = 0 + stateIdle = 1 + stateWaitVAD = 2 // aka wait VAD + stateWaitWakeWord = 3 + stateActive = 4 ) type satellite struct { api *API srv *Server - state uint8 - mu sync.Mutex - - timestamp int + micState int8 + micTS int + micMu sync.Mutex + sndAudio []byte mic *micConsumer wake *WakeWord @@ -112,35 +96,41 @@ func newSatellite(api *API, srv *Server) *satellite { } func (s *satellite) Close() error { - s.pause() + s.Stop() return s.api.Close() } -func (s *satellite) run() error { - s.mu.Lock() - defer s.mu.Unlock() +const wakeTimeout = 5 * 2 * 16000 // 5 seconds - if s.state != stateUnavailable { - return errors.New("wyoming: wrong satellite state") +func (s *satellite) setMicState(state int8) bool { + s.micMu.Lock() + defer s.micMu.Unlock() + + if s.micState == stateNew { + s.mic = newMicConsumer(s.onMicChunk) + s.mic.RemoteAddr = s.api.conn.RemoteAddr().String() + if err := s.srv.MicHandler(s.mic); err != nil { + s.micState = stateError + s.srv.Error("can't get mic: %w", err) + _ = s.api.Close() + } else { + s.micState = stateIdle + } } - s.mic = newMicConsumer(s.onMicChunk) - s.mic.RemoteAddr = s.api.conn.RemoteAddr().String() - - if err := s.srv.MicHandler(s.mic); err != nil { - return err + if s.micState < stateIdle { + return false } - s.state = stateIdle - go s.start() - - return nil + s.micState = state + s.micTS = 0 + return true } -func (s *satellite) pause() { - s.mu.Lock() +func (s *satellite) micStop() { + s.micMu.Lock() - s.state = stateUnavailable + s.micState = stateClosed if s.mic != nil { _ = s.mic.Stop() s.mic = nil @@ -150,40 +140,18 @@ func (s *satellite) pause() { s.wake = nil } - s.mu.Unlock() + s.micMu.Unlock() } -func (s *satellite) start() { - s.mu.Lock() - - if s.state != stateUnavailable { - s.state = stateWaitVAD - } - - s.mu.Unlock() -} - -func (s *satellite) idle() { - s.mu.Lock() - - if s.state != stateUnavailable { - s.state = stateIdle - } - - s.mu.Unlock() -} - -const wakeTimeout = 5 * 2 * 16000 // 5 seconds - func (s *satellite) onMicChunk(chunk []byte) { - s.mu.Lock() - defer s.mu.Unlock() + s.micMu.Lock() + defer s.micMu.Unlock() - if s.state == stateIdle { + if s.micState == stateIdle { return } - if s.state == stateWaitVAD { + if s.micState == stateWaitVAD { // tests show that values over 1000 are most likely speech if s.srv.VADThreshold == 0 || s16le.PeaksRMS(chunk) > s.srv.VADThreshold { if s.wake == nil && s.srv.WakeURI != "" { @@ -191,62 +159,41 @@ func (s *satellite) onMicChunk(chunk []byte) { } if s.wake == nil { // some problems with wake word - redirect to HA - evt := &Event{ - Type: "run-pipeline", - Data: `{"start_stage":"wake","end_stage":"tts","restart_on_end":false}`, - } - if err := s.api.WriteEvent(evt); err != nil { - return - } - s.state = stateStreaming + s.micState = stateIdle + go s.handleScript(&Event{Type: "internal-run"}) } else { - s.state = stateWaitWakeWord + s.micState = stateWaitWakeWord } - s.timestamp = 0 + s.micTS = 0 } } - if s.state == stateWaitWakeWord { + if s.micState == stateWaitWakeWord { if s.wake.Detection != "" { // check if wake word detected - evt := &Event{ - Type: "run-pipeline", - Data: `{"start_stage":"asr","end_stage":"tts","restart_on_end":false}`, - } - _ = s.api.WriteEvent(evt) - s.state = stateStreaming - s.timestamp = 0 + s.micState = stateIdle + go s.handleScript(&Event{Type: "internal-detection", Data: `{"name":"` + s.wake.Detection + `"}`}) } else if err := s.wake.WriteChunk(chunk); err != nil { // wake word service failed - s.state = stateWaitVAD + s.micState = stateWaitVAD _ = s.wake.Close() s.wake = nil - } else if s.timestamp > wakeTimeout { + } else if s.micTS > wakeTimeout { // wake word detection timeout - s.state = stateWaitVAD + s.micState = stateWaitVAD } } else if s.wake != nil { _ = s.wake.Close() s.wake = nil } - if s.state == stateStreaming { - data := fmt.Sprintf(`{"rate":16000,"width":2,"channels":1,"timestamp":%d}`, s.timestamp) + if s.micState == stateActive { + data := fmt.Sprintf(`{"rate":16000,"width":2,"channels":1,"timestamp":%d}`, s.micTS) evt := &Event{Type: "audio-chunk", Data: data, Payload: chunk} _ = s.api.WriteEvent(evt) } - s.timestamp += len(chunk) / 2 -} - -func (s *satellite) respond(data []byte) { - prod := newSndProducer(data, func() { - _ = s.api.WriteEvent(&Event{Type: "played"}) - s.start() - }) - if err := s.srv.SndHandler(prod); err != nil { - prod.onClose() - } + s.micTS += len(chunk) / 2 } type micConsumer struct { @@ -373,7 +320,7 @@ func (s *sndProducer) Start() error { s.Recv += chunkBytes s.data = s.data[chunkBytes:] - pts += 10 * time.Millisecond + pts += 20 * time.Millisecond seq++ }