diff --git a/codecs/av1/obu/leb128.go b/codecs/av1/obu/leb128.go index 38ce090..f5fcbf6 100644 --- a/codecs/av1/obu/leb128.go +++ b/codecs/av1/obu/leb128.go @@ -67,3 +67,19 @@ func ReadLeb128(in []byte) (uint, uint, error) { return 0, 0, ErrFailedToReadLEB128 } + +// WriteToLeb128 writes a uint to a LEB128 encoded byte slice. +func WriteToLeb128(in uint) []byte { + b := make([]byte, 10) + + for i := 0; i < len(b); i++ { + b[i] = byte(in & 0x7f) + in >>= 7 + if in == 0 { + return b[:i+1] + } + b[i] |= 0x80 + } + + return b // unreachable +} diff --git a/codecs/av1/obu/leb128_test.go b/codecs/av1/obu/leb128_test.go index f92fff4..2b2336a 100644 --- a/codecs/av1/obu/leb128_test.go +++ b/codecs/av1/obu/leb128_test.go @@ -4,7 +4,10 @@ package obu import ( + "encoding/hex" "errors" + "fmt" + "math" "testing" ) @@ -40,3 +43,33 @@ func TestReadLeb128(t *testing.T) { t.Fatal("ReadLeb128 on a buffer with all MSB set should fail") } } + +func TestWriteToLeb128(t *testing.T) { + type testVector struct { + value uint + leb128 string + } + testVectors := []testVector{ + {150, "9601"}, + {240, "f001"}, + {400, "9003"}, + {720, "d005"}, + {1200, "b009"}, + {999999, "bf843d"}, + {0, "00"}, + {math.MaxUint32, "ffffffff0f"}, + } + + runTest := func(t *testing.T, v testVector) { + b := WriteToLeb128(v.value) + if v.leb128 != hex.EncodeToString(b) { + t.Errorf("Expected %s, got %s", v.leb128, hex.EncodeToString(b)) + } + } + + for _, v := range testVectors { + t.Run(fmt.Sprintf("encode %d", v.value), func(t *testing.T) { + runTest(t, v) + }) + } +} diff --git a/vlaextension.go b/vlaextension.go new file mode 100644 index 0000000..e10820a --- /dev/null +++ b/vlaextension.go @@ -0,0 +1,360 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtp + +import ( + "encoding/binary" + "errors" + "fmt" + "strings" + + "github.com/pion/rtp/codecs/av1/obu" +) + +var ( + ErrVLATooShort = errors.New("VLA payload too short") // ErrVLATooShort is returned when payload is too short + ErrVLAInvalidStreamCount = errors.New("invalid RTP stream count in VLA") // ErrVLAInvalidStreamCount is returned when RTP stream count is invalid + ErrVLAInvalidStreamID = errors.New("invalid RTP stream ID in VLA") // ErrVLAInvalidStreamID is returned when RTP stream ID is invalid + ErrVLAInvalidSpatialID = errors.New("invalid spatial ID in VLA") // ErrVLAInvalidSpatialID is returned when spatial ID is invalid + ErrVLADuplicateSpatialID = errors.New("duplicate spatial ID in VLA") // ErrVLADuplicateSpatialID is returned when spatial ID is invalid + ErrVLAInvalidTemporalLayer = errors.New("invalid temporal layer in VLA") // ErrVLAInvalidTemporalLayer is returned when temporal layer is invalid +) + +// SpatialLayer is a spatial layer in VLA. +type SpatialLayer struct { + RTPStreamID int + SpatialID int + TargetBitrates []int // target bitrates per temporal layer + + // Following members are valid only when HasResolutionAndFramerate is true + Width int + Height int + Framerate int +} + +// VLA is a Video Layer Allocation (VLA) extension. +// See https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00 +type VLA struct { + RTPStreamID int // 0-origin RTP stream ID (RID) this allocation is sent on (0..3) + RTPStreamCount int // Number of RTP streams (1..4) + ActiveSpatialLayer []SpatialLayer + HasResolutionAndFramerate bool +} + +type vlaMarshalingContext struct { + slMBs [4]uint8 + sls [4][4]*SpatialLayer + commonSLBM uint8 + encodedTargetBitrates [][]byte + requiredLen int +} + +func (v VLA) preprocessForMashaling(ctx *vlaMarshalingContext) error { + for i := 0; i < len(v.ActiveSpatialLayer); i++ { + sl := v.ActiveSpatialLayer[i] + if sl.RTPStreamID < 0 || sl.RTPStreamID >= v.RTPStreamCount { + return fmt.Errorf("invalid RTP streamID %d:%w", sl.RTPStreamID, ErrVLAInvalidStreamID) + } + if sl.SpatialID < 0 || sl.SpatialID >= 4 { + return fmt.Errorf("invalid spatial ID %d: %w", sl.SpatialID, ErrVLAInvalidSpatialID) + } + if len(sl.TargetBitrates) == 0 || len(sl.TargetBitrates) > 4 { + return fmt.Errorf("invalid temporal layer count %d: %w", len(sl.TargetBitrates), ErrVLAInvalidTemporalLayer) + } + ctx.slMBs[sl.RTPStreamID] |= 1 << sl.SpatialID + if ctx.sls[sl.RTPStreamID][sl.SpatialID] != nil { + return fmt.Errorf("duplicate spatial layer: %w", ErrVLADuplicateSpatialID) + } + ctx.sls[sl.RTPStreamID][sl.SpatialID] = &sl + } + return nil +} + +func (v VLA) encodeTargetBitrates(ctx *vlaMarshalingContext) { + for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ { + for spatialID := 0; spatialID < 4; spatialID++ { + if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil { + for _, kbps := range sl.TargetBitrates { + leb128 := obu.WriteToLeb128(uint(kbps)) + ctx.encodedTargetBitrates = append(ctx.encodedTargetBitrates, leb128) + ctx.requiredLen += len(leb128) + } + } + } + } +} + +func (v VLA) analyzeVLAForMarshaling() (*vlaMarshalingContext, error) { + // Validate RTPStreamCount + if v.RTPStreamCount <= 0 || v.RTPStreamCount > 4 { + return nil, ErrVLAInvalidStreamCount + } + // Validate RTPStreamID + if v.RTPStreamID < 0 || v.RTPStreamID >= v.RTPStreamCount { + return nil, ErrVLAInvalidStreamID + } + + ctx := &vlaMarshalingContext{} + err := v.preprocessForMashaling(ctx) + if err != nil { + return nil, err + } + + ctx.commonSLBM = commonSLBMValues(ctx.slMBs[:]) + + // RID, NS, sl_bm fields + if ctx.commonSLBM != 0 { + ctx.requiredLen = 1 + } else { + ctx.requiredLen = 3 + } + + // #tl fields + ctx.requiredLen += (len(v.ActiveSpatialLayer)-1)/4 + 1 + + v.encodeTargetBitrates(ctx) + + if v.HasResolutionAndFramerate { + ctx.requiredLen += len(v.ActiveSpatialLayer) * 5 + } + + return ctx, nil +} + +// Marshal encodes VLA into a byte slice. +func (v VLA) Marshal() ([]byte, error) { + ctx, err := v.analyzeVLAForMarshaling() + if err != nil { + return nil, err + } + + payload := make([]byte, ctx.requiredLen) + offset := 0 + + // RID, NS, sl_bm fields + payload[offset] = byte(v.RTPStreamID<<6) | byte(v.RTPStreamCount-1)<<4 | ctx.commonSLBM + + if ctx.commonSLBM == 0 { + offset++ + for streamID := 0; streamID < v.RTPStreamCount; streamID++ { + if streamID%2 == 0 { + payload[offset+streamID/2] |= ctx.slMBs[streamID] << 4 + } else { + payload[offset+streamID/2] |= ctx.slMBs[streamID] + } + } + offset += (v.RTPStreamCount - 1) / 2 + } + + // #tl fields + offset++ + var temporalLayerIndex int + for rtpStreamID := 0; rtpStreamID < v.RTPStreamCount; rtpStreamID++ { + for spatialID := 0; spatialID < 4; spatialID++ { + if sl := ctx.sls[rtpStreamID][spatialID]; sl != nil { + if temporalLayerIndex >= 4 { + temporalLayerIndex = 0 + offset++ + } + payload[offset] |= byte(len(sl.TargetBitrates)-1) << (2 * (3 - temporalLayerIndex)) + temporalLayerIndex++ + } + } + } + + // Target bitrate fields + offset++ + for _, encodedKbps := range ctx.encodedTargetBitrates { + encodedSize := len(encodedKbps) + copy(payload[offset:], encodedKbps) + offset += encodedSize + } + + // Resolution & framerate fields + if v.HasResolutionAndFramerate { + for _, sl := range v.ActiveSpatialLayer { + binary.BigEndian.PutUint16(payload[offset+0:], uint16(sl.Width-1)) + binary.BigEndian.PutUint16(payload[offset+2:], uint16(sl.Height-1)) + payload[offset+4] = byte(sl.Framerate) + offset += 5 + } + } + + return payload, nil +} + +func commonSLBMValues(slMBs []uint8) uint8 { + var common uint8 + for i := 0; i < len(slMBs); i++ { + if slMBs[i] == 0 { + continue + } + if common == 0 { + common = slMBs[i] + continue + } + if slMBs[i] != common { + return 0 + } + } + return common +} + +type vlaUnmarshalingContext struct { + payload []byte + offset int + slBMField uint8 + slBMs [4]uint8 +} + +func (ctx *vlaUnmarshalingContext) checkRemainingLen(requiredLen int) bool { + return len(ctx.payload)-ctx.offset >= requiredLen +} + +func (v *VLA) unmarshalSpatialLayers(ctx *vlaUnmarshalingContext) error { + if !ctx.checkRemainingLen(1) { + return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort) + } + v.RTPStreamID = int(ctx.payload[ctx.offset] >> 6 & 0b11) + v.RTPStreamCount = int(ctx.payload[ctx.offset]>>4&0b11) + 1 + + // sl_bm fields + ctx.slBMField = ctx.payload[ctx.offset] & 0b1111 + ctx.offset++ + + if ctx.slBMField != 0 { + for streamID := 0; streamID < v.RTPStreamCount; streamID++ { + ctx.slBMs[streamID] = ctx.slBMField + } + } else { + if !ctx.checkRemainingLen((v.RTPStreamCount-1)/2 + 1) { + return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort) + } + // slX_bm fields + for streamID := 0; streamID < v.RTPStreamCount; streamID++ { + var bm uint8 + if streamID%2 == 0 { + bm = ctx.payload[ctx.offset+streamID/2] >> 4 & 0b1111 + } else { + bm = ctx.payload[ctx.offset+streamID/2] & 0b1111 + } + ctx.slBMs[streamID] = bm + } + ctx.offset += 1 + (v.RTPStreamCount-1)/2 + } + + return nil +} + +func (v *VLA) unmarshalTemporalLayers(ctx *vlaUnmarshalingContext) error { + if !ctx.checkRemainingLen(1) { + return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort) + } + + var temporalLayerIndex int + for streamID := 0; streamID < v.RTPStreamCount; streamID++ { + for spatialID := 0; spatialID < 4; spatialID++ { + if ctx.slBMs[streamID]&(1<= 4 { + temporalLayerIndex = 0 + ctx.offset++ + if !ctx.checkRemainingLen(1) { + return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort) + } + } + tlCount := int(ctx.payload[ctx.offset]>>(2*(3-temporalLayerIndex))&0b11) + 1 + temporalLayerIndex++ + sl := SpatialLayer{ + RTPStreamID: streamID, + SpatialID: spatialID, + TargetBitrates: make([]int, tlCount), + } + v.ActiveSpatialLayer = append(v.ActiveSpatialLayer, sl) + } + } + ctx.offset++ + + // target bitrates + for i, sl := range v.ActiveSpatialLayer { + for j := range sl.TargetBitrates { + kbps, n, err := obu.ReadLeb128(ctx.payload[ctx.offset:]) + if err != nil { + return err + } + if !ctx.checkRemainingLen(int(n)) { + return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort) + } + v.ActiveSpatialLayer[i].TargetBitrates[j] = int(kbps) + ctx.offset += int(n) + } + } + + return nil +} + +func (v *VLA) unmarshalResolutionAndFramerate(ctx *vlaUnmarshalingContext) error { + if !ctx.checkRemainingLen(len(v.ActiveSpatialLayer) * 5) { + return fmt.Errorf("failed to unmarshal VLA (offset=%d): %w", ctx.offset, ErrVLATooShort) + } + + v.HasResolutionAndFramerate = true + + for i := range v.ActiveSpatialLayer { + v.ActiveSpatialLayer[i].Width = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+0:])) + 1 + v.ActiveSpatialLayer[i].Height = int(binary.BigEndian.Uint16(ctx.payload[ctx.offset+2:])) + 1 + v.ActiveSpatialLayer[i].Framerate = int(ctx.payload[ctx.offset+4]) + ctx.offset += 5 + } + + return nil +} + +// Unmarshal decodes VLA from a byte slice. +func (v *VLA) Unmarshal(payload []byte) (int, error) { + ctx := &vlaUnmarshalingContext{ + payload: payload, + } + + err := v.unmarshalSpatialLayers(ctx) + if err != nil { + return ctx.offset, err + } + + // #tl fields (build the list ActiveSpatialLayer at the same time) + err = v.unmarshalTemporalLayers(ctx) + if err != nil { + return ctx.offset, err + } + + if len(ctx.payload) == ctx.offset { + return ctx.offset, nil + } + + // resolution & framerate (optional) + err = v.unmarshalResolutionAndFramerate(ctx) + if err != nil { + return ctx.offset, err + } + + return ctx.offset, nil +} + +// String makes VLA printable. +func (v VLA) String() string { + out := fmt.Sprintf("RID:%d,RTPStreamCount:%d", v.RTPStreamID, v.RTPStreamCount) + var slOut []string + for _, sl := range v.ActiveSpatialLayer { + out2 := fmt.Sprintf("RTPStreamID:%d", sl.RTPStreamID) + out2 += fmt.Sprintf(",TargetBitrates:%v", sl.TargetBitrates) + if v.HasResolutionAndFramerate { + out2 += fmt.Sprintf(",Resolution:(%d,%d)", sl.Width, sl.Height) + out2 += fmt.Sprintf(",Framerate:%d", sl.Framerate) + } + slOut = append(slOut, out2) + } + out += fmt.Sprintf(",ActiveSpatialLayers:{%s}", strings.Join(slOut, ",")) + return out +} diff --git a/vlaextension_test.go b/vlaextension_test.go new file mode 100644 index 0000000..b9b8066 --- /dev/null +++ b/vlaextension_test.go @@ -0,0 +1,532 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtp + +import ( + "bytes" + "encoding/hex" + "errors" + "reflect" + "testing" +) + +func TestVLAMarshal(t *testing.T) { + requireNoError := func(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } + } + + t.Run("3 streams no resolution and framerate", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 3, + ActiveSpatialLayer: []SpatialLayer{ + { + RTPStreamID: 0, + SpatialID: 0, + TargetBitrates: []int{150}, + }, + { + RTPStreamID: 1, + SpatialID: 0, + TargetBitrates: []int{240, 400}, + }, + { + RTPStreamID: 2, + SpatialID: 0, + TargetBitrates: []int{720, 1200}, + }, + }, + } + + bytesActual, err := vla.Marshal() + requireNoError(t, err) + bytesExpected, err := hex.DecodeString("21149601f0019003d005b009") + requireNoError(t, err) + if !bytes.Equal(bytesExpected, bytesActual) { + t.Fatalf("expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual)) + } + }) + + t.Run("3 streams with resolution and framerate", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 2, + RTPStreamCount: 3, + ActiveSpatialLayer: []SpatialLayer{ + { + RTPStreamID: 0, + SpatialID: 0, + TargetBitrates: []int{150}, + Width: 320, + Height: 180, + Framerate: 30, + }, + { + RTPStreamID: 1, + SpatialID: 0, + TargetBitrates: []int{240, 400}, + Width: 640, + Height: 360, + Framerate: 30, + }, + { + RTPStreamID: 2, + SpatialID: 0, + TargetBitrates: []int{720, 1200}, + Width: 1280, + Height: 720, + Framerate: 30, + }, + }, + HasResolutionAndFramerate: true, + } + + bytesActual, err := vla.Marshal() + requireNoError(t, err) + bytesExpected, err := hex.DecodeString("a1149601f0019003d005b009013f00b31e027f01671e04ff02cf1e") + requireNoError(t, err) + if !bytes.Equal(bytesExpected, bytesActual) { + t.Fatalf("expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual)) + } + }) + + t.Run("Negative RTPStreamCount", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: -1, + ActiveSpatialLayer: []SpatialLayer{}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidStreamCount) { + t.Fatal("expected ErrVLAInvalidRTPStreamCount") + } + }) + + t.Run("RTPStreamCount too large", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 5, + ActiveSpatialLayer: []SpatialLayer{{}, {}, {}, {}, {}}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidStreamCount) { + t.Fatal("expected ErrVLAInvalidRTPStreamCount") + } + }) + + t.Run("Negative RTPStreamID", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: -1, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{}}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidStreamID) { + t.Fatalf("expected ErrVLAInvalidRTPStreamID, actual %v", err) + } + }) + + t.Run("RTPStreamID to large", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 1, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{}}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidStreamID) { + t.Fatalf("expected ErrVLAInvalidRTPStreamID: %v", err) + } + }) + + t.Run("Invalid stream ID in the spatial layer", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: -1, + }}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidStreamID) { + t.Fatalf("expected ErrVLAInvalidStreamID: %v", err) + } + vla = &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: 1, + }}, + } + _, err = vla.Marshal() + if !errors.Is(err, ErrVLAInvalidStreamID) { + t.Fatalf("expected ErrVLAInvalidStreamID: %v", err) + } + }) + + t.Run("Invalid spatial ID in the spatial layer", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: 0, + SpatialID: -1, + }}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidSpatialID) { + t.Fatalf("expected ErrVLAInvalidSpatialID: %v", err) + } + vla = &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: 0, + SpatialID: 5, + }}, + } + _, err = vla.Marshal() + if !errors.Is(err, ErrVLAInvalidSpatialID) { + t.Fatalf("expected ErrVLAInvalidSpatialID: %v", err) + } + }) + + t.Run("Invalid temporal layer in the spatial layer", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: 0, + SpatialID: 0, + TargetBitrates: []int{}, + }}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLAInvalidTemporalLayer) { + t.Fatalf("expected ErrVLAInvalidTemporalLayer: %v", err) + } + vla = &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: 0, + SpatialID: 0, + TargetBitrates: []int{100, 200, 300, 400, 500}, + }}, + } + _, err = vla.Marshal() + if !errors.Is(err, ErrVLAInvalidTemporalLayer) { + t.Fatalf("expected ErrVLAInvalidTemporalLayer: %v", err) + } + }) + + t.Run("Duplicate spatial ID in the spatial layer", func(t *testing.T) { + vla := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 1, + ActiveSpatialLayer: []SpatialLayer{{ + RTPStreamID: 0, + SpatialID: 0, + TargetBitrates: []int{100}, + }, { + RTPStreamID: 0, + SpatialID: 0, + TargetBitrates: []int{200}, + }}, + } + _, err := vla.Marshal() + if !errors.Is(err, ErrVLADuplicateSpatialID) { + t.Fatalf("expected ErrVLADuplicateSpatialID: %v", err) + } + }) +} + +func TestVLAUnmarshal(t *testing.T) { + requireEqualInt := func(t *testing.T, expected, actual int) { + if expected != actual { + t.Fatalf("expected %d, actual %d", expected, actual) + } + } + requireNoError := func(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } + } + requireTrue := func(t *testing.T, val bool) { + if !val { + t.Fatal("expected true") + } + } + requireFalse := func(t *testing.T, val bool) { + if val { + t.Fatal("expected false") + } + } + + t.Run("3 streams no resolution and framerate", func(t *testing.T) { + // two layer ("low", "high") + b, err := hex.DecodeString("21149601f0019003d005b009") + requireNoError(t, err) + if err != nil { + t.Fatal("failed to decode input data") + } + + vla := &VLA{} + n, err := vla.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + + requireEqualInt(t, 0, vla.RTPStreamID) + requireEqualInt(t, 3, vla.RTPStreamCount) + requireEqualInt(t, 3, len(vla.ActiveSpatialLayer)) + + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + requireEqualInt(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + requireEqualInt(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + requireFalse(t, vla.HasResolutionAndFramerate) + + requireEqualInt(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[2].SpatialID) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates)) + requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0]) + requireEqualInt(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1]) + }) + + t.Run("3 streams with resolution and framerate", func(t *testing.T) { + b, err := hex.DecodeString("a1149601f0019003d005b009013f00b31e027f01671e04ff02cf1e") + requireNoError(t, err) + + vla := &VLA{} + n, err := vla.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + + requireEqualInt(t, 2, vla.RTPStreamID) + requireEqualInt(t, 3, vla.RTPStreamCount) + + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + requireEqualInt(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + requireEqualInt(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + requireEqualInt(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[2].SpatialID) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates)) + requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0]) + requireEqualInt(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1]) + + requireTrue(t, vla.HasResolutionAndFramerate) + + requireEqualInt(t, 320, vla.ActiveSpatialLayer[0].Width) + requireEqualInt(t, 180, vla.ActiveSpatialLayer[0].Height) + requireEqualInt(t, 30, vla.ActiveSpatialLayer[0].Framerate) + requireEqualInt(t, 640, vla.ActiveSpatialLayer[1].Width) + requireEqualInt(t, 360, vla.ActiveSpatialLayer[1].Height) + requireEqualInt(t, 30, vla.ActiveSpatialLayer[1].Framerate) + requireEqualInt(t, 1280, vla.ActiveSpatialLayer[2].Width) + requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].Height) + requireEqualInt(t, 30, vla.ActiveSpatialLayer[2].Framerate) + }) + + t.Run("2 streams", func(t *testing.T) { + // two layer ("low", "high") + b, err := hex.DecodeString("1110c801d005b009") + requireNoError(t, err) + + vla := &VLA{} + n, err := vla.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + + requireEqualInt(t, 0, vla.RTPStreamID) + requireEqualInt(t, 2, vla.RTPStreamCount) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer)) + + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + requireEqualInt(t, 200, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + requireEqualInt(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + requireFalse(t, vla.HasResolutionAndFramerate) + }) + + t.Run("3 streams mid paused with resolution and framerate", func(t *testing.T) { + b, err := hex.DecodeString("601010109601d005b009013f00b31e04ff02cf1e") + requireNoError(t, err) + + vla := &VLA{} + n, err := vla.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + + requireEqualInt(t, 1, vla.RTPStreamID) + requireEqualInt(t, 3, vla.RTPStreamCount) + + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + requireEqualInt(t, 2, vla.ActiveSpatialLayer[1].RTPStreamID) + requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + requireEqualInt(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + requireTrue(t, vla.HasResolutionAndFramerate) + + requireEqualInt(t, 320, vla.ActiveSpatialLayer[0].Width) + requireEqualInt(t, 180, vla.ActiveSpatialLayer[0].Height) + requireEqualInt(t, 30, vla.ActiveSpatialLayer[0].Framerate) + requireEqualInt(t, 1280, vla.ActiveSpatialLayer[1].Width) + requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].Height) + requireEqualInt(t, 30, vla.ActiveSpatialLayer[1].Framerate) + }) + + t.Run("extra 1", func(t *testing.T) { + b, err := hex.DecodeString("a0001040ac02f403") + requireNoError(t, err) + + vla := &VLA{} + n, err := vla.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + }) + + t.Run("extra 2", func(t *testing.T) { + b, err := hex.DecodeString("a00010409405cc08") + requireNoError(t, err) + + vla := &VLA{} + n, err := vla.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + }) +} + +func TestVLAMarshalThenUnmarshal(t *testing.T) { + requireEqualInt := func(t *testing.T, expected, actual int) { + if expected != actual { + t.Fatalf("expected %d, actual %d", expected, actual) + } + } + requireNoError := func(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } + } + + t.Run("multiple spatial layers", func(t *testing.T) { + var spatialLayers []SpatialLayer + for streamID := 0; streamID < 3; streamID++ { + for spatialID := 0; spatialID < 4; spatialID++ { + spatialLayers = append(spatialLayers, SpatialLayer{ + RTPStreamID: streamID, + SpatialID: spatialID, + TargetBitrates: []int{150, 200}, + Width: 320, + Height: 180, + Framerate: 30, + }) + } + } + + vla0 := &VLA{ + RTPStreamID: 2, + RTPStreamCount: 3, + ActiveSpatialLayer: spatialLayers, + HasResolutionAndFramerate: true, + } + + b, err := vla0.Marshal() + requireNoError(t, err) + + vla1 := &VLA{} + n, err := vla1.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + + if !reflect.DeepEqual(vla0, vla1) { + t.Fatalf("expected %v, actual %v", vla0, vla1) + } + }) + + t.Run("different spatial layer bitmasks", func(t *testing.T) { + var spatialLayers []SpatialLayer + for streamID := 0; streamID < 4; streamID++ { + for spatialID := 0; spatialID < streamID+1; spatialID++ { + spatialLayers = append(spatialLayers, SpatialLayer{ + RTPStreamID: streamID, + SpatialID: spatialID, + TargetBitrates: []int{150, 200}, + Width: 320, + Height: 180, + Framerate: 30, + }) + } + } + + vla0 := &VLA{ + RTPStreamID: 0, + RTPStreamCount: 4, + ActiveSpatialLayer: spatialLayers, + HasResolutionAndFramerate: true, + } + + b, err := vla0.Marshal() + requireNoError(t, err) + if b[0]&0x0f != 0 { + t.Error("expects sl_bm to be 0") + } + if b[1] != 0x13 { + t.Error("expects sl0_bm,sl1_bm to be b0001,b0011") + } + if b[2] != 0x7f { + t.Error("expects sl1_bm,sl2_bm to be b0111,b1111") + } + t.Logf("b: %s", hex.EncodeToString(b)) + + vla1 := &VLA{} + n, err := vla1.Unmarshal(b) + requireNoError(t, err) + requireEqualInt(t, len(b), n) + + if !reflect.DeepEqual(vla0, vla1) { + t.Fatalf("expected %v, actual %v", vla0, vla1) + } + }) +} + +func FuzzVLAUnmarshal(f *testing.F) { + f.Add([]byte{0}) + f.Add([]byte("70")) + + f.Fuzz(func(t *testing.T, data []byte) { + vla := &VLA{} + _, err := vla.Unmarshal(data) + if err != nil { + t.Skip() // If the function returns an error, we skip the test case + } + }) +}