refactor dtls
This commit is contained in:
@@ -226,8 +226,7 @@ func (c *DTLSConn) AVClientStart(timeout time.Duration) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *DTLSConn) AVServStart() error {
|
func (c *DTLSConn) AVServStart() error {
|
||||||
adapter := NewChannelAdapter(c.ctx, iotcChannelBack, c.addr, c.WriteDTLS, c.serverBuf)
|
conn, err := NewDTLSServer(c.ctx, iotcChannelBack, c.addr, c.WriteDTLS, c.serverBuf, c.psk)
|
||||||
conn, err := NewDTLSServer(adapter, c.addr, c.psk)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dtls: server handshake failed: %w", err)
|
return fmt.Errorf("dtls: server handshake failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -564,10 +563,9 @@ func (c *DTLSConn) discoDoneCC51() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *DTLSConn) connect() error {
|
func (c *DTLSConn) connect() error {
|
||||||
adapter := NewChannelAdapter(c.ctx, iotcChannelMain, c.addr, c.WriteDTLS, c.clientBuf)
|
conn, err := NewDTLSClient(c.ctx, iotcChannelMain, c.addr, c.WriteDTLS, c.clientBuf, c.psk)
|
||||||
conn, err := NewDTLSClient(adapter, c.addr, c.psk)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dtls: client create failed: %w", err)
|
return fmt.Errorf("dtls: client handshake failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -575,7 +573,7 @@ func (c *DTLSConn) connect() error {
|
|||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if c.verbose {
|
if c.verbose {
|
||||||
fmt.Printf("[DTLS] Client created for channel %d\n", iotcChannelMain)
|
fmt.Printf("[DTLS] Client handshake complete on channel %d\n", iotcChannelMain)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
+45
-26
@@ -9,18 +9,47 @@ import (
|
|||||||
"github.com/pion/dtls/v3"
|
"github.com/pion/dtls/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DTLSConfig struct {
|
func NewDTLSClient(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte) (*dtls.Conn, error) {
|
||||||
PSK []byte
|
return dialDTLS(ctx, channel, addr, writeFn, readChan, psk, false)
|
||||||
Identity string
|
|
||||||
IsServer bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDTLSClient(adapter net.PacketConn, addr net.Addr, psk []byte) (*dtls.Conn, error) {
|
func NewDTLSServer(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte) (*dtls.Conn, error) {
|
||||||
return dtls.Client(adapter, addr, buildDTLSConfig(psk, false))
|
return dialDTLS(ctx, channel, addr, writeFn, readChan, psk, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDTLSServer(adapter net.PacketConn, addr net.Addr, psk []byte) (*dtls.Conn, error) {
|
func dialDTLS(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte, psk []byte, isServer bool) (*dtls.Conn, error) {
|
||||||
return dtls.Server(adapter, addr, buildDTLSConfig(psk, true))
|
adapter := &channelAdapter{
|
||||||
|
ctx: ctx,
|
||||||
|
channel: channel,
|
||||||
|
addr: addr,
|
||||||
|
writeFn: writeFn,
|
||||||
|
readChan: readChan,
|
||||||
|
}
|
||||||
|
|
||||||
|
var conn *dtls.Conn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if isServer {
|
||||||
|
conn, err = dtls.Server(adapter, addr, buildDTLSConfig(psk, true))
|
||||||
|
} else {
|
||||||
|
conn, err = dtls.Client(adapter, addr, buildDTLSConfig(psk, false))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := 5 * time.Second
|
||||||
|
adapter.SetReadDeadline(time.Now().Add(timeout))
|
||||||
|
hsCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := conn.HandshakeContext(hsCtx); err != nil {
|
||||||
|
go conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter.SetReadDeadline(time.Time{})
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config {
|
func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config {
|
||||||
@@ -45,7 +74,7 @@ func buildDTLSConfig(psk []byte, isServer bool) *dtls.Config {
|
|||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelAdapter struct {
|
type channelAdapter struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
channel uint8
|
channel uint8
|
||||||
writeFn func([]byte, uint8) error
|
writeFn func([]byte, uint8) error
|
||||||
@@ -55,17 +84,7 @@ type ChannelAdapter struct {
|
|||||||
readDeadline time.Time
|
readDeadline time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChannelAdapter(ctx context.Context, channel uint8, addr net.Addr, writeFn func([]byte, uint8) error, readChan chan []byte) *ChannelAdapter {
|
func (a *channelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
return &ChannelAdapter{
|
|
||||||
ctx: ctx,
|
|
||||||
channel: channel,
|
|
||||||
addr: addr,
|
|
||||||
writeFn: writeFn,
|
|
||||||
readChan: readChan,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *ChannelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
deadline := a.readDeadline
|
deadline := a.readDeadline
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
@@ -97,28 +116,28 @@ func (a *ChannelAdapter) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ChannelAdapter) WriteTo(p []byte, _ net.Addr) (int, error) {
|
func (a *channelAdapter) WriteTo(p []byte, _ net.Addr) (int, error) {
|
||||||
if err := a.writeFn(p, a.channel); err != nil {
|
if err := a.writeFn(p, a.channel); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ChannelAdapter) Close() error { return nil }
|
func (a *channelAdapter) Close() error { return nil }
|
||||||
func (a *ChannelAdapter) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
func (a *channelAdapter) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
||||||
func (a *ChannelAdapter) SetDeadline(t time.Time) error {
|
func (a *channelAdapter) SetDeadline(t time.Time) error {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
a.readDeadline = t
|
a.readDeadline = t
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (a *ChannelAdapter) SetReadDeadline(t time.Time) error {
|
func (a *channelAdapter) SetReadDeadline(t time.Time) error {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
a.readDeadline = t
|
a.readDeadline = t
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (a *ChannelAdapter) SetWriteDeadline(time.Time) error { return nil }
|
func (a *channelAdapter) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
type timeoutError struct{}
|
type timeoutError struct{}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user