Improve mDNS handler

This commit is contained in:
Alexey Khit
2023-07-23 17:07:12 +03:00
parent e94f338b77
commit 7005cd08f2
+78 -40
View File
@@ -2,6 +2,7 @@ package mdns
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -15,10 +16,18 @@ import (
const ServiceHAP = "_hap._tcp.local." // HomeKit Accessory Protocol const ServiceHAP = "_hap._tcp.local." // HomeKit Accessory Protocol
type ServiceEntry struct { type ServiceEntry struct {
Name string Name string `json:"name,omitempty"`
IP net.IP IP net.IP `json:"ip,omitempty"`
Port uint16 Port uint16 `json:"port,omitempty"`
Info map[string]string Info map[string]string `json:"info,omitempty"`
}
func (e *ServiceEntry) String() string {
b, err := json.Marshal(e)
if err != nil {
return err.Error()
}
return string(b)
} }
func (e *ServiceEntry) Complete() bool { func (e *ServiceEntry) Complete() bool {
@@ -195,7 +204,7 @@ func (b *Browser) ListenMulticastUDP() error {
func (b *Browser) Browse(onentry func(*ServiceEntry) bool) error { func (b *Browser) Browse(onentry func(*ServiceEntry) bool) error {
msg := &dns.Msg{ msg := &dns.Msg{
Question: []dns.Question{ Question: []dns.Question{
{b.Service, dns.TypePTR, dns.ClassINET}, {Name: b.Service, Qtype: dns.TypePTR, Qclass: dns.ClassINET},
}, },
} }
@@ -219,13 +228,12 @@ func (b *Browser) Browse(onentry func(*ServiceEntry) bool) error {
} }
}() }()
var skipPTR []string processed := map[string]struct{}{"": {}}
b2 := make([]byte, 1500) b2 := make([]byte, 1500)
loop:
for { for {
// in the Hass docker network can receive same msg from different address // in the Hass docker network can receive same msg from different address
n, _, err := b.Recv.ReadFrom(b2) n, addr, err := b.Recv.ReadFrom(b2)
if err != nil { if err != nil {
break break
} }
@@ -234,23 +242,21 @@ loop:
continue continue
} }
ptr := GetPTR(msg) ptr := GetPTR(msg, b.Service)
if !strings.HasSuffix(ptr, b.Service) { if _, ok := processed[ptr]; ok {
continue continue
} }
for _, s := range skipPTR { ip := addr.(*net.UDPAddr).IP
if s == ptr {
continue loop for _, entry := range NewServiceEntries(msg, ip) {
if onentry(entry) {
return nil
} }
} }
if entry := NewServiceEntry(msg); onentry(entry) { processed[ptr] = struct{}{}
break
}
skipPTR = append(skipPTR, ptr)
} }
return nil return nil
@@ -266,42 +272,74 @@ func (b *Browser) Close() error {
return nil return nil
} }
func GetPTR(msg *dns.Msg) string { func GetPTR(msg *dns.Msg, service string) string {
for _, rr := range msg.Answer { for _, record := range msg.Answer {
if rr, ok := rr.(*dns.PTR); ok { if ptr, ok := record.(*dns.PTR); ok && ptr.Hdr.Name == service {
return rr.Ptr return ptr.Ptr
} }
} }
return "" return ""
} }
func NewServiceEntry(msg *dns.Msg) *ServiceEntry { func NewServiceEntries(msg *dns.Msg, ip net.IP) (entries []*ServiceEntry) {
entry := &ServiceEntry{}
records := make([]dns.RR, 0, len(msg.Answer)+len(msg.Ns)+len(msg.Extra)) records := make([]dns.RR, 0, len(msg.Answer)+len(msg.Ns)+len(msg.Extra))
records = append(records, msg.Answer...) records = append(records, msg.Answer...)
records = append(records, msg.Ns...) records = append(records, msg.Ns...)
records = append(records, msg.Extra...) records = append(records, msg.Extra...)
// PTR ptr=SomeName._hap._tcp.local. hdr=_hap._tcp.local.
// TXT txt=... hdr=SomeName._hap._tcp.local.
// SRV target=SomeName.local. hdr=SomeName._hap._tcp.local.
// A a=192.168.1.123 hdr=SomeName.local.
for _, record := range records { for _, record := range records {
switch record := record.(type) { ptr, ok := record.(*dns.PTR)
case *dns.PTR: if !ok {
if i := strings.IndexByte(record.Ptr, '.'); i > 0 { continue
entry.Name = record.Ptr[:i] }
}
case *dns.A: entry := &ServiceEntry{}
entry.IP = record.A
case *dns.SRV: if i := strings.IndexByte(ptr.Ptr, '.'); i > 0 {
entry.Port = record.Port entry.Name = strings.ReplaceAll(ptr.Ptr[:i], `\ `, " ")
case *dns.TXT: }
entry.Info = make(map[string]string, len(record.Txt))
for _, txt := range record.Txt { var txt *dns.TXT
k, v, _ := strings.Cut(txt, "=") var srv *dns.SRV
entry.Info[k] = v var a *dns.A
for _, record = range records {
if txt, ok = record.(*dns.TXT); ok && txt.Hdr.Name == ptr.Ptr {
entry.Info = make(map[string]string, len(txt.Txt))
for _, s := range txt.Txt {
k, v, _ := strings.Cut(s, "=")
entry.Info[k] = v
}
break
} }
} }
for _, record = range records {
if srv, ok = record.(*dns.SRV); ok && srv.Hdr.Name == ptr.Ptr {
entry.Port = srv.Port
for _, record = range records {
if a, ok = record.(*dns.A); ok && a.Hdr.Name == srv.Target {
// device can send multiple IP addresses (ex. Homebridge)
// use first IP from the list or same IP from sender
if entry.IP == nil || ip.Equal(a.A) {
entry.IP = a.A
}
}
}
break
}
}
entries = append(entries, entry)
} }
return entry return
} }
func InterfacesIP4() ([]net.IP, error) { func InterfacesIP4() ([]net.IP, error) {