diff --git a/pkg/tutk/conn_dtls.go b/pkg/tutk/conn_dtls.go index 61e716ea..bdeb4dbd 100644 --- a/pkg/tutk/conn_dtls.go +++ b/pkg/tutk/conn_dtls.go @@ -226,8 +226,7 @@ func (c *DTLSConn) AVClientStart(timeout time.Duration) error { } func (c *DTLSConn) AVServStart() error { - adapter := NewChannelAdapter(c.ctx, iotcChannelBack, c.addr, c.WriteDTLS, c.serverBuf) - conn, err := NewDTLSServer(adapter, c.addr, c.psk) + conn, err := NewDTLSServer(c.ctx, iotcChannelBack, c.addr, c.WriteDTLS, c.serverBuf, c.psk) if err != nil { return fmt.Errorf("dtls: server handshake failed: %w", err) } @@ -564,10 +563,9 @@ func (c *DTLSConn) discoDoneCC51() error { } func (c *DTLSConn) connect() error { - adapter := NewChannelAdapter(c.ctx, iotcChannelMain, c.addr, c.WriteDTLS, c.clientBuf) - conn, err := NewDTLSClient(adapter, c.addr, c.psk) + conn, err := NewDTLSClient(c.ctx, iotcChannelMain, c.addr, c.WriteDTLS, c.clientBuf, c.psk) if err != nil { - return fmt.Errorf("dtls: client create failed: %w", err) + return fmt.Errorf("dtls: client handshake failed: %w", err) } c.mu.Lock() @@ -575,7 +573,7 @@ func (c *DTLSConn) connect() error { c.mu.Unlock() if c.verbose { - fmt.Printf("[DTLS] Client created for channel %d\n", iotcChannelMain) + fmt.Printf("[DTLS] Client handshake complete on channel %d\n", iotcChannelMain) } return nil diff --git a/pkg/tutk/dtls.go b/pkg/tutk/dtls.go index e807e96f..9088a664 100644 --- a/pkg/tutk/dtls.go +++ b/pkg/tutk/dtls.go @@ -9,18 +9,47 @@ import ( "github.com/pion/dtls/v3" ) -type DTLSConfig struct { - PSK []byte - Identity string - IsServer bool +func NewDTLSClient(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte) (*dtls.Conn, error) { + return dialDTLS(ctx, channel, addr, writeFn, readChan, psk, false) } -func NewDTLSClient(adapter net.PacketConn, addr net.Addr, psk []byte) (*dtls.Conn, error) { - return dtls.Client(adapter, addr, buildDTLSConfig(psk, false)) +func NewDTLSServer(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte) (*dtls.Conn, error) { + return dialDTLS(ctx, channel, addr, writeFn, readChan, psk, true) } -func NewDTLSServer(adapter net.PacketConn, addr net.Addr, psk []byte) (*dtls.Conn, error) { - return dtls.Server(adapter, addr, buildDTLSConfig(psk, true)) +func dialDTLS(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte, isServer bool) (*dtls.Conn, error) { + adapter := &channelAdapter{ + ctx: ctx, + channel: channel, + addr: addr, + writeFn: writeFn, + readChan: readChan, + } + + var conn *dtls.Conn + var err error + + if isServer { + conn, err = dtls.Server(adapter, addr, buildDTLSConfig(psk, true)) + } else { + conn, err = dtls.Client(adapter, addr, buildDTLSConfig(psk, false)) + } + if err != nil { + return nil, err + } + + timeout := 5 * time.Second + adapter.SetReadDeadline(time.Now().Add(timeout)) + hsCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + if err := conn.HandshakeContext(hsCtx); err != nil { + go conn.Close() + return nil, err + } + + adapter.SetReadDeadline(time.Time{}) + return conn, nil } func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config { @@ -45,7 +74,7 @@ func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config { return config } -type ChannelAdapter struct { +type channelAdapter struct { ctx context.Context channel uint8 writeFn func([]byte, uint8) error @@ -55,17 +84,7 @@ type ChannelAdapter struct { readDeadline time.Time } -func NewChannelAdapter(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte) *ChannelAdapter { - return &ChannelAdapter{ - ctx: ctx, - channel: channel, - addr: addr, - writeFn: writeFn, - readChan: readChan, - } -} - -func (a *ChannelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func (a *channelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) { a.mu.Lock() deadline := a.readDeadline a.mu.Unlock() @@ -97,28 +116,28 @@ func (a *ChannelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } } -func (a *ChannelAdapter) WriteTo(p []byte, _ net.Addr) (int, error) { +func (a *channelAdapter) WriteTo(p []byte, _ net.Addr) (int, error) { if err := a.writeFn(p, a.channel); err != nil { return 0, err } return len(p), nil } -func (a *ChannelAdapter) Close() error { return nil } -func (a *ChannelAdapter) LocalAddr() net.Addr { return &net.UDPAddr{} } -func (a *ChannelAdapter) SetDeadline(t time.Time) error { +func (a *channelAdapter) Close() error { return nil } +func (a *channelAdapter) LocalAddr() net.Addr { return &net.UDPAddr{} } +func (a *channelAdapter) SetDeadline(t time.Time) error { a.mu.Lock() a.readDeadline = t a.mu.Unlock() return nil } -func (a *ChannelAdapter) SetReadDeadline(t time.Time) error { +func (a *channelAdapter) SetReadDeadline(t time.Time) error { a.mu.Lock() a.readDeadline = t a.mu.Unlock() return nil } -func (a *ChannelAdapter) SetWriteDeadline(time.Time) error { return nil } +func (a *channelAdapter) SetWriteDeadline(time.Time) error { return nil } type timeoutError struct{}