Rewrite mDNS processing

This commit is contained in:
Alexey Khit
2023-07-11 00:44:27 +03:00
parent 4ea5a22eda
commit 73bf96e123
8 changed files with 194 additions and 141 deletions
+22 -17
View File
@@ -8,7 +8,7 @@ import (
"errors"
"fmt"
"github.com/AlexxIT/go2rtc/pkg/core"
"github.com/AlexxIT/go2rtc/pkg/hap/mdns"
"github.com/AlexxIT/go2rtc/pkg/mdns"
"github.com/brutella/hap"
"github.com/brutella/hap/chacha20poly1305"
"github.com/brutella/hap/curve25519"
@@ -61,28 +61,29 @@ func NewConn(rawURL string) (*Conn, error) {
}
func Pair(deviceID, pin string) (*Conn, error) {
entry := mdns.GetEntry(deviceID)
if entry == nil {
var addr string
var mfi bool
_ = mdns.Discovery(mdns.ServiceHAP, func(entry *mdns.ServiceEntry) bool {
if entry.Complete() && entry.Info["id"] == deviceID {
addr = entry.Addr()
mfi = entry.Info["ff"] == "1"
return true
}
return false
})
if addr == "" {
return nil, errors.New("can't find device via mDNS")
}
c := &Conn{
DeviceAddress: fmt.Sprintf("%s:%d", entry.AddrV4.String(), entry.Port),
DeviceAddress: addr,
DeviceID: deviceID,
ClientID: GenerateUUID(),
ClientPrivate: GenerateKey(),
}
var mfi bool
for _, field := range entry.InfoFields {
if field[:2] == "ff" {
if field[3] == '1' {
mfi = true
}
break
}
}
return c, c.Pair(mfi, pin)
}
@@ -106,9 +107,13 @@ func (c *Conn) DialAndServe() error {
func (c *Conn) Dial() error {
// update device host before dial
if host := mdns.GetAddress(c.DeviceID); host != "" {
c.DeviceAddress = host
}
_ = mdns.Discovery(mdns.ServiceHAP, func(entry *mdns.ServiceEntry) bool {
if entry.Complete() && entry.Info["id"] == c.DeviceID {
c.DeviceAddress = entry.Addr()
return true
}
return false
})
var err error
c.conn, err = net.DialTimeout("tcp", c.DeviceAddress, time.Second*5)
-42
View File
@@ -1,42 +0,0 @@
package mdns
import (
"fmt"
"github.com/hashicorp/mdns"
"strings"
)
const Suffix = "._hap._tcp.local."
func GetAll() chan *mdns.ServiceEntry {
entries := make(chan *mdns.ServiceEntry)
params := &mdns.QueryParam{
Service: "_hap._tcp", Entries: entries, DisableIPv6: true,
}
go func() {
_ = mdns.Query(params)
close(entries)
}()
return entries
}
func GetAddress(deviceID string) string {
for entry := range GetAll() {
if strings.Contains(entry.Info, deviceID) {
return fmt.Sprintf("%s:%d", entry.AddrV4.String(), entry.Port)
}
}
return ""
}
func GetEntry(deviceID string) *mdns.ServiceEntry {
for entry := range GetAll() {
if strings.Contains(entry.Info, deviceID) {
return entry
}
}
return nil
}
-53
View File
@@ -1,53 +0,0 @@
package mdns
import (
"github.com/hashicorp/mdns"
"net"
)
const HostHeaderTail = "._hap._tcp.local"
func NewServer(name string, port int, ips []net.IP, txt []string) (*mdns.Server, error) {
if ips == nil || ips[0] == nil {
ips = LocalIPs()
}
// important to set hostName manually with any value and `.local.` tail
// important to set ips manually
service, _ := mdns.NewMDNSService(
name, "_hap._tcp", "", name+".local.", port, ips, txt,
)
return mdns.NewServer(&mdns.Config{Zone: service})
}
func LocalIPs() []net.IP {
ifaces, err := net.Interfaces()
if err != nil {
return nil
}
var ips []net.IP
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
var addrs []net.Addr
if addrs, err = iface.Addrs(); err != nil {
continue
}
for _, addr := range addrs {
switch addr := addr.(type) {
case *net.IPNet:
ips = append(ips, addr.IP)
case *net.IPAddr:
ips = append(ips, addr.IP)
}
}
}
return ips
}
+139
View File
@@ -0,0 +1,139 @@
package mdns
import (
"fmt"
"github.com/miekg/dns"
"net"
"strings"
"time"
)
const ServiceHAP = "_hap._tcp.local." // HomeKit Accessory Protocol
const requestTimeout = time.Millisecond * 505
const responseTimeout = time.Second * 2
type ServiceEntry struct {
Name string
IP net.IP
Port uint16
Info map[string]string
}
func (e *ServiceEntry) Complete() bool {
return e.IP != nil && e.Port > 0 && e.Info != nil
}
func (e *ServiceEntry) Addr() string {
return fmt.Sprintf("%s:%d", e.IP, e.Port)
}
func Discovery(service string, onentry func(*ServiceEntry) bool) error {
addr := &net.UDPAddr{
IP: net.IP{224, 0, 0, 251},
Port: 5353,
}
conn, err := net.ListenMulticastUDP("udp4", nil, addr)
if err != nil {
return err
}
defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(responseTimeout)); err != nil {
return err
}
msg := &dns.Msg{
Question: []dns.Question{
{service, dns.TypePTR, dns.ClassINET},
},
}
b1, err := msg.Pack()
if err != nil {
return err
}
go func() {
for {
if _, err := conn.WriteToUDP(b1, addr); err != nil {
return
}
time.Sleep(requestTimeout)
}
}()
var skipIPs []net.IP
b2 := make([]byte, 1500)
loop:
for {
n, addr, err := conn.ReadFromUDP(b2)
if err != nil {
break
}
for _, ip := range skipIPs {
if ip.Equal(addr.IP) {
continue loop
}
}
if err = msg.Unpack(b2[:n]); err != nil {
continue
}
if !EqualService(msg, service) {
continue
}
if entry := NewServiceEntry(msg); onentry(entry) {
break
}
skipIPs = append(skipIPs, addr.IP)
}
return nil
}
func EqualService(msg *dns.Msg, service string) bool {
for _, rr := range msg.Answer {
if rr, ok := rr.(*dns.PTR); ok {
return strings.HasSuffix(rr.Ptr, service)
}
}
return false
}
func NewServiceEntry(msg *dns.Msg) *ServiceEntry {
entry := &ServiceEntry{}
records := make([]dns.RR, 0, len(msg.Answer)+len(msg.Ns)+len(msg.Extra))
records = append(records, msg.Answer...)
records = append(records, msg.Ns...)
records = append(records, msg.Extra...)
for _, record := range records {
switch record := record.(type) {
case *dns.PTR:
if i := strings.IndexByte(record.Ptr, '.'); i > 0 {
entry.Name = record.Ptr[:i]
}
case *dns.A:
entry.IP = record.A
case *dns.SRV:
entry.Port = record.Port
case *dns.TXT:
entry.Info = make(map[string]string, len(record.Txt))
for _, txt := range record.Txt {
k, v, _ := strings.Cut(txt, "=")
entry.Info[k] = v
}
}
}
return entry
}
+16
View File
@@ -0,0 +1,16 @@
package mdns
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestDiscovery(t *testing.T) {
onentry := func(entry *ServiceEntry) bool {
return true
}
err := Discovery(ServiceHAP, onentry)
//err := Discovery("_ewelink._tcp.local.", time.Second, onentry)
// err := Discovery("_googlecast._tcp.local.", time.Second, onentry)
require.Nil(t, err)
}