refactor dtls

This commit is contained in:
seydx
2026-01-17 23:28:03 +01:00
parent 9365fef7b3
commit d40f6064d9
2 changed files with 49 additions and 32 deletions
+4 -6
View File
@@ -226,8 +226,7 @@ func (c *DTLSConn) AVClientStart(timeout time.Duration) error {
} }
func (c *DTLSConn) AVServStart() error { func (c *DTLSConn) AVServStart() error {
adapter := NewChannelAdapter(c.ctx, iotcChannelBack, c.addr, c.WriteDTLS, c.serverBuf) conn, err := NewDTLSServer(c.ctx, iotcChannelBack, c.addr, c.WriteDTLS, c.serverBuf, c.psk)
conn, err := NewDTLSServer(adapter, c.addr, c.psk)
if err != nil { if err != nil {
return fmt.Errorf("dtls: server handshake failed: %w", err) return fmt.Errorf("dtls: server handshake failed: %w", err)
} }
@@ -564,10 +563,9 @@ func (c *DTLSConn) discoDoneCC51() error {
} }
func (c *DTLSConn) connect() error { func (c *DTLSConn) connect() error {
adapter := NewChannelAdapter(c.ctx, iotcChannelMain, c.addr, c.WriteDTLS, c.clientBuf) conn, err := NewDTLSClient(c.ctx, iotcChannelMain, c.addr, c.WriteDTLS, c.clientBuf, c.psk)
conn, err := NewDTLSClient(adapter, c.addr, c.psk)
if err != nil { if err != nil {
return fmt.Errorf("dtls: client create failed: %w", err) return fmt.Errorf("dtls: client handshake failed: %w", err)
} }
c.mu.Lock() c.mu.Lock()
@@ -575,7 +573,7 @@ func (c *DTLSConn) connect() error {
c.mu.Unlock() c.mu.Unlock()
if c.verbose { if c.verbose {
fmt.Printf("[DTLS] Client created for channel %d\n", iotcChannelMain) fmt.Printf("[DTLS] Client handshake complete on channel %d\n", iotcChannelMain)
} }
return nil return nil
+45 -26
View File
@@ -9,18 +9,47 @@ import (
"github.com/pion/dtls/v3" "github.com/pion/dtls/v3"
) )
type DTLSConfig struct { func NewDTLSClient(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte) (*dtls.Conn, error) {
PSK []byte return dialDTLS(ctx, channel, addr, writeFn, readChan, psk, false)
Identity string
IsServer bool
} }
func NewDTLSClient(adapter net.PacketConn, addr net.Addr, psk []byte) (*dtls.Conn, error) { func NewDTLSServer(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte) (*dtls.Conn, error) {
return dtls.Client(adapter, addr, buildDTLSConfig(psk, false)) return dialDTLS(ctx, channel, addr, writeFn, readChan, psk, true)
} }
func NewDTLSServer(adapter net.PacketConn, addr net.Addr, psk []byte) (*dtls.Conn, error) { func dialDTLS(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte, isServer bool) (*dtls.Conn, error) {
return dtls.Server(adapter, addr, buildDTLSConfig(psk, true)) adapter := &channelAdapter{
ctx: ctx,
channel: channel,
addr: addr,
writeFn: writeFn,
readChan: readChan,
}
var conn *dtls.Conn
var err error
if isServer {
conn, err = dtls.Server(adapter, addr, buildDTLSConfig(psk, true))
} else {
conn, err = dtls.Client(adapter, addr, buildDTLSConfig(psk, false))
}
if err != nil {
return nil, err
}
timeout := 5 * time.Second
adapter.SetReadDeadline(time.Now().Add(timeout))
hsCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if err := conn.HandshakeContext(hsCtx); err != nil {
go conn.Close()
return nil, err
}
adapter.SetReadDeadline(time.Time{})
return conn, nil
} }
func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config { func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config {
@@ -45,7 +74,7 @@ func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config {
return config return config
} }
type ChannelAdapter struct { type channelAdapter struct {
ctx context.Context ctx context.Context
channel uint8 channel uint8
writeFn func([]byte, uint8) error writeFn func([]byte, uint8) error
@@ -55,17 +84,7 @@ type ChannelAdapter struct {
readDeadline time.Time readDeadline time.Time
} }
func NewChannelAdapter(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte) *ChannelAdapter { func (a *channelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return &ChannelAdapter{
ctx: ctx,
channel: channel,
addr: addr,
writeFn: writeFn,
readChan: readChan,
}
}
func (a *ChannelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
a.mu.Lock() a.mu.Lock()
deadline := a.readDeadline deadline := a.readDeadline
a.mu.Unlock() a.mu.Unlock()
@@ -97,28 +116,28 @@ func (a *ChannelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
} }
func (a *ChannelAdapter) WriteTo(p []byte, _ net.Addr) (int, error) { func (a *channelAdapter) WriteTo(p []byte, _ net.Addr) (int, error) {
if err := a.writeFn(p, a.channel); err != nil { if err := a.writeFn(p, a.channel); err != nil {
return 0, err return 0, err
} }
return len(p), nil return len(p), nil
} }
func (a *ChannelAdapter) Close() error { return nil } func (a *channelAdapter) Close() error { return nil }
func (a *ChannelAdapter) LocalAddr() net.Addr { return &net.UDPAddr{} } func (a *channelAdapter) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (a *ChannelAdapter) SetDeadline(t time.Time) error { func (a *channelAdapter) SetDeadline(t time.Time) error {
a.mu.Lock() a.mu.Lock()
a.readDeadline = t a.readDeadline = t
a.mu.Unlock() a.mu.Unlock()
return nil return nil
} }
func (a *ChannelAdapter) SetReadDeadline(t time.Time) error { func (a *channelAdapter) SetReadDeadline(t time.Time) error {
a.mu.Lock() a.mu.Lock()
a.readDeadline = t a.readDeadline = t
a.mu.Unlock() a.mu.Unlock()
return nil return nil
} }
func (a *ChannelAdapter) SetWriteDeadline(time.Time) error { return nil } func (a *channelAdapter) SetWriteDeadline(time.Time) error { return nil }
type timeoutError struct{} type timeoutError struct{}