Rewrite mDNS processing
This commit is contained in:
+22
-17
@@ -8,7 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/AlexxIT/go2rtc/pkg/core"
|
||||
"github.com/AlexxIT/go2rtc/pkg/hap/mdns"
|
||||
"github.com/AlexxIT/go2rtc/pkg/mdns"
|
||||
"github.com/brutella/hap"
|
||||
"github.com/brutella/hap/chacha20poly1305"
|
||||
"github.com/brutella/hap/curve25519"
|
||||
@@ -61,28 +61,29 @@ func NewConn(rawURL string) (*Conn, error) {
|
||||
}
|
||||
|
||||
func Pair(deviceID, pin string) (*Conn, error) {
|
||||
entry := mdns.GetEntry(deviceID)
|
||||
if entry == nil {
|
||||
var addr string
|
||||
var mfi bool
|
||||
|
||||
_ = mdns.Discovery(mdns.ServiceHAP, func(entry *mdns.ServiceEntry) bool {
|
||||
if entry.Complete() && entry.Info["id"] == deviceID {
|
||||
addr = entry.Addr()
|
||||
mfi = entry.Info["ff"] == "1"
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
if addr == "" {
|
||||
return nil, errors.New("can't find device via mDNS")
|
||||
}
|
||||
|
||||
c := &Conn{
|
||||
DeviceAddress: fmt.Sprintf("%s:%d", entry.AddrV4.String(), entry.Port),
|
||||
DeviceAddress: addr,
|
||||
DeviceID: deviceID,
|
||||
ClientID: GenerateUUID(),
|
||||
ClientPrivate: GenerateKey(),
|
||||
}
|
||||
|
||||
var mfi bool
|
||||
for _, field := range entry.InfoFields {
|
||||
if field[:2] == "ff" {
|
||||
if field[3] == '1' {
|
||||
mfi = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return c, c.Pair(mfi, pin)
|
||||
}
|
||||
|
||||
@@ -106,9 +107,13 @@ func (c *Conn) DialAndServe() error {
|
||||
|
||||
func (c *Conn) Dial() error {
|
||||
// update device host before dial
|
||||
if host := mdns.GetAddress(c.DeviceID); host != "" {
|
||||
c.DeviceAddress = host
|
||||
}
|
||||
_ = mdns.Discovery(mdns.ServiceHAP, func(entry *mdns.ServiceEntry) bool {
|
||||
if entry.Complete() && entry.Info["id"] == c.DeviceID {
|
||||
c.DeviceAddress = entry.Addr()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
var err error
|
||||
c.conn, err = net.DialTimeout("tcp", c.DeviceAddress, time.Second*5)
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package mdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/hashicorp/mdns"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const Suffix = "._hap._tcp.local."
|
||||
|
||||
func GetAll() chan *mdns.ServiceEntry {
|
||||
entries := make(chan *mdns.ServiceEntry)
|
||||
params := &mdns.QueryParam{
|
||||
Service: "_hap._tcp", Entries: entries, DisableIPv6: true,
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = mdns.Query(params)
|
||||
close(entries)
|
||||
}()
|
||||
|
||||
return entries
|
||||
}
|
||||
|
||||
func GetAddress(deviceID string) string {
|
||||
for entry := range GetAll() {
|
||||
if strings.Contains(entry.Info, deviceID) {
|
||||
return fmt.Sprintf("%s:%d", entry.AddrV4.String(), entry.Port)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func GetEntry(deviceID string) *mdns.ServiceEntry {
|
||||
for entry := range GetAll() {
|
||||
if strings.Contains(entry.Info, deviceID) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package mdns
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/mdns"
|
||||
"net"
|
||||
)
|
||||
|
||||
const HostHeaderTail = "._hap._tcp.local"
|
||||
|
||||
func NewServer(name string, port int, ips []net.IP, txt []string) (*mdns.Server, error) {
|
||||
if ips == nil || ips[0] == nil {
|
||||
ips = LocalIPs()
|
||||
}
|
||||
|
||||
// important to set hostName manually with any value and `.local.` tail
|
||||
// important to set ips manually
|
||||
service, _ := mdns.NewMDNSService(
|
||||
name, "_hap._tcp", "", name+".local.", port, ips, txt,
|
||||
)
|
||||
|
||||
return mdns.NewServer(&mdns.Config{Zone: service})
|
||||
}
|
||||
|
||||
func LocalIPs() []net.IP {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ips []net.IP
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue // interface down
|
||||
}
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
continue // loopback interface
|
||||
}
|
||||
|
||||
var addrs []net.Addr
|
||||
if addrs, err = iface.Addrs(); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
switch addr := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ips = append(ips, addr.IP)
|
||||
case *net.IPAddr:
|
||||
ips = append(ips, addr.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ips
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package mdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const ServiceHAP = "_hap._tcp.local." // HomeKit Accessory Protocol
|
||||
|
||||
const requestTimeout = time.Millisecond * 505
|
||||
const responseTimeout = time.Second * 2
|
||||
|
||||
type ServiceEntry struct {
|
||||
Name string
|
||||
IP net.IP
|
||||
Port uint16
|
||||
Info map[string]string
|
||||
}
|
||||
|
||||
func (e *ServiceEntry) Complete() bool {
|
||||
return e.IP != nil && e.Port > 0 && e.Info != nil
|
||||
}
|
||||
|
||||
func (e *ServiceEntry) Addr() string {
|
||||
return fmt.Sprintf("%s:%d", e.IP, e.Port)
|
||||
}
|
||||
|
||||
func Discovery(service string, onentry func(*ServiceEntry) bool) error {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.IP{224, 0, 0, 251},
|
||||
Port: 5353,
|
||||
}
|
||||
|
||||
conn, err := net.ListenMulticastUDP("udp4", nil, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
if err = conn.SetDeadline(time.Now().Add(responseTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := &dns.Msg{
|
||||
Question: []dns.Question{
|
||||
{service, dns.TypePTR, dns.ClassINET},
|
||||
},
|
||||
}
|
||||
|
||||
b1, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if _, err := conn.WriteToUDP(b1, addr); err != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(requestTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
var skipIPs []net.IP
|
||||
|
||||
b2 := make([]byte, 1500)
|
||||
loop:
|
||||
for {
|
||||
n, addr, err := conn.ReadFromUDP(b2)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
for _, ip := range skipIPs {
|
||||
if ip.Equal(addr.IP) {
|
||||
continue loop
|
||||
}
|
||||
}
|
||||
|
||||
if err = msg.Unpack(b2[:n]); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if !EqualService(msg, service) {
|
||||
continue
|
||||
}
|
||||
|
||||
if entry := NewServiceEntry(msg); onentry(entry) {
|
||||
break
|
||||
}
|
||||
|
||||
skipIPs = append(skipIPs, addr.IP)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func EqualService(msg *dns.Msg, service string) bool {
|
||||
for _, rr := range msg.Answer {
|
||||
if rr, ok := rr.(*dns.PTR); ok {
|
||||
return strings.HasSuffix(rr.Ptr, service)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func NewServiceEntry(msg *dns.Msg) *ServiceEntry {
|
||||
entry := &ServiceEntry{}
|
||||
|
||||
records := make([]dns.RR, 0, len(msg.Answer)+len(msg.Ns)+len(msg.Extra))
|
||||
records = append(records, msg.Answer...)
|
||||
records = append(records, msg.Ns...)
|
||||
records = append(records, msg.Extra...)
|
||||
for _, record := range records {
|
||||
switch record := record.(type) {
|
||||
case *dns.PTR:
|
||||
if i := strings.IndexByte(record.Ptr, '.'); i > 0 {
|
||||
entry.Name = record.Ptr[:i]
|
||||
}
|
||||
case *dns.A:
|
||||
entry.IP = record.A
|
||||
case *dns.SRV:
|
||||
entry.Port = record.Port
|
||||
case *dns.TXT:
|
||||
entry.Info = make(map[string]string, len(record.Txt))
|
||||
for _, txt := range record.Txt {
|
||||
k, v, _ := strings.Cut(txt, "=")
|
||||
entry.Info[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package mdns
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDiscovery(t *testing.T) {
|
||||
onentry := func(entry *ServiceEntry) bool {
|
||||
return true
|
||||
}
|
||||
err := Discovery(ServiceHAP, onentry)
|
||||
//err := Discovery("_ewelink._tcp.local.", time.Second, onentry)
|
||||
// err := Discovery("_googlecast._tcp.local.", time.Second, onentry)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user