diff --git a/pkg/pcm/pcm.go b/pkg/pcm/pcm.go new file mode 100644 index 00000000..717a1450 --- /dev/null +++ b/pkg/pcm/pcm.go @@ -0,0 +1,116 @@ +package pcm + +import ( + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/pion/rtp" +) + +func Resample(codec *core.Codec, sampleRate uint32, handler core.HandlerFunc) core.HandlerFunc { + n := float32(codec.ClockRate) / float32(sampleRate) + + switch codec.Name { + case core.CodecPCMA: + return DownsampleByte(PCMAtoPCM, PCMtoPCMA, n, handler) + case core.CodecPCMU: + return DownsampleByte(PCMUtoPCM, PCMtoPCMU, n, handler) + case core.CodecPCM: + if n == 1 { + return ResamplePCM(PCMtoPCMA, handler) + } + return DownsamplePCM(PCMtoPCMA, n, handler) + } + + panic(core.Caller()) +} + +func DownsampleByte( + toPCM func(byte) int16, fromPCM func(int16) byte, n float32, handler core.HandlerFunc, +) core.HandlerFunc { + var sampleN, sampleSum float32 + var ts uint32 + + return func(packet *rtp.Packet) { + samples := len(packet.Payload) + newLen := uint32((float32(samples) + sampleN) / n) + + oldSamples := packet.Payload + newSamples := make([]byte, newLen) + + var i int + for _, sample := range oldSamples { + sampleSum += float32(toPCM(sample)) + if sampleN++; sampleN >= n { + newSamples[i] = fromPCM(int16(sampleSum / n)) + i++ + + sampleSum = 0 + sampleN -= n + } + } + + ts += newLen + + clone := *packet + clone.Payload = newSamples + clone.Timestamp = ts + handler(&clone) + } +} + +func ResamplePCM(fromPCM func(int16) byte, handler core.HandlerFunc) core.HandlerFunc { + var ts uint32 + + return func(packet *rtp.Packet) { + len1 := len(packet.Payload) + len2 := len1 / 2 + + oldSamples := packet.Payload + newSamples := make([]byte, len2) + + var i2 int + for i1 := 0; i1 < len1; i1 += 2 { + sample := int16(uint16(oldSamples[i1])<<8 | uint16(oldSamples[i1+1])) + newSamples[i2] = fromPCM(sample) + i2++ + } + + ts += uint32(len2) + + clone := *packet + clone.Payload = newSamples + clone.Timestamp = ts + handler(&clone) + } +} + +func DownsamplePCM(fromPCM func(int16) byte, n float32, handler core.HandlerFunc) core.HandlerFunc { + var sampleN, sampleSum float32 + var ts uint32 + + return func(packet *rtp.Packet) { + samples := len(packet.Payload) / 2 + newLen := uint32((float32(samples) + sampleN) / n) + + oldSamples := packet.Payload + newSamples := make([]byte, newLen) + + var i2 int + for i1 := 0; i1 < len(packet.Payload); i1 += 2 { + sampleSum += float32(int16(uint16(oldSamples[i1])<<8 | uint16(oldSamples[i1+1]))) + if sampleN++; sampleN >= n { + newSamples[i2] = fromPCM(int16(sampleSum / n)) + i2++ + + sampleSum = 0 + sampleN -= n + } + } + + ts += newLen + + clone := *packet + clone.Payload = newSamples + clone.Timestamp = ts + handler(&clone) + } +} diff --git a/pkg/webrtc/consumer.go b/pkg/webrtc/consumer.go index 0a278924..b25cb7e3 100644 --- a/pkg/webrtc/consumer.go +++ b/pkg/webrtc/consumer.go @@ -5,11 +5,12 @@ import ( "github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/h264" "github.com/AlexxIT/go2rtc/pkg/h265" + "github.com/AlexxIT/go2rtc/pkg/pcm" "github.com/pion/rtp" ) func (c *Conn) GetMedias() []*core.Media { - return c.medias + return WithResampling(c.medias) } func (c *Conn) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiver) error { @@ -31,15 +32,16 @@ func (c *Conn) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiv } localTrack := c.getTranseiver(media.ID).Sender().Track().(*Track) + payloadType := codec.PayloadType - sender := core.NewSender(media, track.Codec) + sender := core.NewSender(media, codec) sender.Handler = func(packet *rtp.Packet) { c.send += packet.MarshalSize() //important to send with remote PayloadType - _ = localTrack.WriteRTP(codec.PayloadType, packet) + _ = localTrack.WriteRTP(payloadType, packet) } - switch codec.Name { + switch track.Codec.Name { case core.CodecH264: sender.Handler = h264.RTPPay(1200, sender.Handler) if track.Codec.IsRTP() { @@ -55,6 +57,15 @@ func (c *Conn) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiv if track.Codec.IsRTP() { sender.Handler = h265.RTPDepay(track.Codec, sender.Handler) } + + case core.CodecPCMA, core.CodecPCMU, core.CodecPCM: + if codec.ClockRate == 0 { + if codec.Name == core.CodecPCM { + codec.Name = core.CodecPCMA + } + codec.ClockRate = 8000 + sender.Handler = pcm.Resample(track.Codec, 8000, sender.Handler) + } } sender.HandleRTP(track) diff --git a/pkg/webrtc/helpers.go b/pkg/webrtc/helpers.go index b6e36ee6..b92e72ee 100644 --- a/pkg/webrtc/helpers.go +++ b/pkg/webrtc/helpers.go @@ -52,6 +52,53 @@ func UnmarshalMedias(descriptions []*sdp.MediaDescription) (medias []*core.Media return } +func WithResampling(medias []*core.Media) []*core.Media { + for _, media := range medias { + if media.Kind != core.KindAudio || media.Direction != core.DirectionSendonly { + continue + } + + var pcma, pcmu, pcm *core.Codec + + for _, codec := range media.Codecs { + switch codec.Name { + case core.CodecPCMA: + if codec.ClockRate != 0 { + pcma = codec + } else { + pcma = nil + } + case core.CodecPCMU: + if codec.ClockRate != 0 { + pcmu = codec + } else { + pcmu = nil + } + case core.CodecPCM: + pcm = codec + } + } + + if pcma != nil { + pcma = pcma.Clone() + pcma.ClockRate = 0 // reset clock rate so will match any + media.Codecs = append(media.Codecs, pcma) + } + if pcmu != nil { + pcmu = pcmu.Clone() + pcmu.ClockRate = 0 + media.Codecs = append(media.Codecs, pcmu) + } + if pcma != nil && pcm == nil { + pcm = pcma.Clone() + pcm.Name = core.CodecPCM + media.Codecs = append(media.Codecs, pcm) + } + } + + return medias +} + func NewCandidate(network, address string) (string, error) { i := strings.LastIndexByte(address, ':') if i < 0 {