Skip to content

Commit

Permalink
Fire OnTrack before reading first RTP
Browse files Browse the repository at this point in the history
Prior to this, we would wait for a single RTP packet to
figure out the codec which is not to spec.
  • Loading branch information
edaniels committed Jul 9, 2024
1 parent d9e2ce5 commit 12cb8ab
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 101 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ cover.out
examples/sfu-ws/cert.pem
examples/sfu-ws/key.pem
wasm_exec.js
*.DS_Store
16 changes: 1 addition & 15 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1228,21 +1228,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
return
}

go func(track *TrackRemote) {
b := make([]byte, pc.api.settingEngine.getReceiveMTU())
n, _, err := track.peek(b)
if err != nil {
pc.log.Warnf("Could not determine PayloadType for SSRC %d (%s)", track.SSRC(), err)
return
}

if err = track.checkAndUpdateTrack(b[:n]); err != nil {
pc.log.Warnf("Failed to set codec settings for track SSRC %d (%s)", track.SSRC(), err)
return
}

pc.onTrack(track, receiver)
}(t)
pc.onTrack(t, receiver)
}
}

Expand Down
47 changes: 16 additions & 31 deletions peerconnection_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"time"

"github.com/pion/ice/v2"
"github.com/pion/rtp"
"github.com/pion/transport/v2/test"
"github.com/pion/transport/v2/vnet"
"github.com/viamrobotics/webrtc/v3/internal/util"
Expand Down Expand Up @@ -1017,9 +1016,11 @@ func TestICERestart_Error_Handling(t *testing.T) {
}

type trackRecords struct {
mu sync.Mutex
trackIDs map[string]struct{}
receivedTrackIDs map[string]struct{}
mu sync.Mutex
trackIDs map[string]struct{}
receivedTrackIDs map[string]struct{}
onAllTracksReceived chan struct{}
onAllTracksReceivedOnce sync.Once
}

func (r *trackRecords) newTrack() (*TrackLocalStaticRTP, error) {
Expand All @@ -1036,6 +1037,11 @@ func (r *trackRecords) handleTrack(t *TrackRemote, _ *RTPReceiver) {
if _, exist := r.trackIDs[tID]; exist {
r.receivedTrackIDs[tID] = struct{}{}
}
if len(r.receivedTrackIDs) == len(r.trackIDs) {
r.onAllTracksReceivedOnce.Do(func() {
close(r.onAllTracksReceived)
})
}
}

func (r *trackRecords) remains() int {
Expand All @@ -1049,32 +1055,16 @@ func TestPeerConnection_MassiveTracks(t *testing.T) {
var (
api = NewAPI()
tRecs = &trackRecords{
trackIDs: make(map[string]struct{}),
receivedTrackIDs: make(map[string]struct{}),
trackIDs: make(map[string]struct{}),
receivedTrackIDs: make(map[string]struct{}),
onAllTracksReceived: make(chan struct{}),
}
tracks = []*TrackLocalStaticRTP{}
trackCount = 256
pingInterval = 1 * time.Second
noiseInterval = 100 * time.Microsecond
timeoutDuration = 20 * time.Second
rawPkt = []byte{
0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64,
0x27, 0x82, 0x00, 0x01, 0x00, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0x98, 0x36, 0xbe, 0x88, 0x9e,
}
samplePkt = &rtp.Packet{
Header: rtp.Header{
Marker: true,
Extension: false,
ExtensionProfile: 1,
Version: 2,
SequenceNumber: 27023,
Timestamp: 3653407706,
CSRC: []uint32{},
},
Payload: rawPkt[20:],
}
connected = make(chan struct{})
stopped = make(chan struct{})
connected = make(chan struct{})
stopped = make(chan struct{})
)
assert.NoError(t, api.mediaEngine.RegisterDefaultCodecs())
offerPC, answerPC, err := api.newPair(Configuration{})
Expand All @@ -1085,7 +1075,6 @@ func TestPeerConnection_MassiveTracks(t *testing.T) {
assert.NoError(t, err)
_, err = offerPC.AddTrack(track)
assert.NoError(t, err)
tracks = append(tracks, track)
}
answerPC.OnTrack(tRecs.handleTrack)
offerPC.OnICEConnectionStateChange(func(s ICEConnectionState) {
Expand All @@ -1107,12 +1096,8 @@ func TestPeerConnection_MassiveTracks(t *testing.T) {
}
}()
assert.NoError(t, signalPair(offerPC, answerPC))
// Send a RTP packets to each track to trigger track event after connected.
<-connected
time.Sleep(1 * time.Second)
for _, track := range tracks {
assert.NoError(t, track.WriteRTP(samplePkt))
}

// Ping trackRecords to see if any track event not received yet.
tooLong := time.After(timeoutDuration)
for {
Expand Down
52 changes: 38 additions & 14 deletions peerconnection_media_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/pion/interceptor"
mock_interceptor "github.com/pion/interceptor/pkg/mock"
"github.com/pion/logging"
"github.com/pion/randutil"
"github.com/pion/rtcp"
Expand Down Expand Up @@ -1055,10 +1056,34 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) {
panic(err)
}
registerSimulcastHeaderExtensions(m, RTPCodecTypeVideo)
assert.NoError(t, m.RegisterDefaultCodecs())
ir := &interceptor.Registry{}

trackReadDone := make(chan struct{})
ir.Add(&mock_interceptor.Factory{
NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
return &mock_interceptor.Interceptor{
BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
count := int64(0)
return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
if a == nil {
a = interceptor.Attributes{}
}
if atomic.AddInt64(&count, 1) > 5 {
// confirm read before sending any more packets for probing
<-trackReadDone
}
return reader.Read(b, a)
})
},
}, nil
},
})
assert.NoError(t, ConfigureSimulcastExtensionHeaders(m))

pcOffer, pcAnswer, err := NewAPI(WithSettingEngine(SettingEngine{
LoggerFactory: &undeclaredSsrcLoggerFactory{unhandledSimulcastError},
}), WithMediaEngine(m)).newPair(Configuration{})
}), WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{})
assert.NoError(t, err)

firstTrack, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "firstTrack", "firstTrack")
Expand Down Expand Up @@ -1103,26 +1128,24 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) {
time.Sleep(20 * time.Millisecond)
}

// establish undeclared SSRC (half number of probes)
for ; sequenceNumber <= 5; sequenceNumber++ {
sendRTPPacket()
}

assert.NoError(t, signalPair(pcOffer, pcAnswer))

trackRemoteChan := make(chan *TrackRemote, 1)
pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) {
trackRemoteChan <- trackRemote
})

trackRemote := func() *TrackRemote {
for {
select {
case t := <-trackRemoteChan:
return t
default:
sendRTPPacket()
}
}
assert.NoError(t, signalPair(pcOffer, pcAnswer))

trackRemote := <-trackRemoteChan

go func() {
_, _, err = trackRemote.Read(make([]byte, 1500))
assert.NoError(t, err)
close(trackReadDone)
}()

func() {
Expand All @@ -1136,8 +1159,7 @@ func TestPeerConnection_Simulcast_Probe(t *testing.T) {
}
}()

_, _, err = trackRemote.Read(make([]byte, 1500))
assert.NoError(t, err)
<-trackReadDone

closePairNow(t, pcOffer, pcAnswer)
})
Expand Down Expand Up @@ -1754,6 +1776,8 @@ func TestPeerConnection_Zero_PayloadType(t *testing.T) {
trackFired := make(chan struct{})

pcAnswer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
_, _, err = track.Read(make([]byte, 1500))
assert.NoError(t, err)
require.Equal(t, track.Codec().MimeType, MimeTypePCMU)
close(trackFired)
})
Expand Down
6 changes: 5 additions & 1 deletion track_local_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ func Test_TrackLocalStatic_PayloadType(t *testing.T) {
assert.NoError(t, err)

onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background())
offerer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
offerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
_, _, err = track.Read(make([]byte, 1500))
assert.NoError(t, err)
assert.Equal(t, track.PayloadType(), PayloadType(100))
assert.Equal(t, track.Codec().RTPCodecCapability.MimeType, "video/VP8")

Expand Down Expand Up @@ -284,6 +286,8 @@ func Test_TrackLocalStatic_Padding(t *testing.T) {
onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background())

offerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
_, _, err = track.Read(make([]byte, 1500))
assert.NoError(t, err)
assert.Equal(t, track.PayloadType(), PayloadType(100))
assert.Equal(t, track.Codec().RTPCodecCapability.MimeType, "video/VP8")

Expand Down
41 changes: 1 addition & 40 deletions track_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ type TrackRemote struct {
params RTPParameters
rid string

receiver *RTPReceiver
peeked []byte
peekedAttributes interceptor.Attributes
receiver *RTPReceiver
}

func newTrackRemote(kind RTPCodecType, ssrc, rtxSsrc SSRC, rid string, receiver *RTPReceiver) *TrackRemote {
Expand Down Expand Up @@ -107,26 +105,8 @@ func (t *TrackRemote) Codec() RTPCodecParameters {
func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) {
t.mu.RLock()
r := t.receiver
peeked := t.peeked != nil
t.mu.RUnlock()

if peeked {
t.mu.Lock()
data := t.peeked
attributes = t.peekedAttributes

t.peeked = nil
t.peekedAttributes = nil
t.mu.Unlock()
// someone else may have stolen our packet when we
// released the lock. Deal with it.
if data != nil {
n = copy(b, data)
err = t.checkAndUpdateTrack(b)
return
}
}

// If there's a separate RTX track and an RTX packet is available, return that
if rtxPacketReceived := r.readRTX(t); rtxPacketReceived != nil {
n = copy(b, rtxPacketReceived.pkt)
Expand Down Expand Up @@ -187,25 +167,6 @@ func (t *TrackRemote) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) {
return r, attributes, nil
}

// peek is like Read, but it doesn't discard the packet read
func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error) {
n, a, err = t.Read(b)
if err != nil {
return
}

t.mu.Lock()
// this might overwrite data if somebody peeked between the Read
// and us getting the lock. Oh well, we'll just drop a packet in
// that case.
data := make([]byte, n)
n = copy(data, b[:n])
t.peeked = data
t.peekedAttributes = a
t.mu.Unlock()
return
}

// SetReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever.
func (t *TrackRemote) SetReadDeadline(deadline time.Time) error {
return t.receiver.setRTPReadDeadline(deadline, t)
Expand Down

0 comments on commit 12cb8ab

Please sign in to comment.