Skip to content

Commit

Permalink
Feat/pid filter tool (#19)
Browse files Browse the repository at this point in the history
* feat: add pid filter
* fix: read packet by packet
* fix: rename parameters
  • Loading branch information
Wkkkkk authored Feb 2, 2024
1 parent ac336db commit 1b2c5c3
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 26 deletions.
72 changes: 72 additions & 0 deletions cmd/mp2ts-pidfilter/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"context"
"flag"
"fmt"
"io"
"log"
"os"
"strings"

"github.com/Eyevinn/mp2ts-tools/internal"
)

var usg = `Usage of %s:
%s filters out some chosen pids from the ts packet.
Drop nothing and list all PIDs if empty pids list is specified (by default).
However, PAT(0) and PMT must not be dropped.
`

func parseOptions() internal.Options {
opts := internal.Options{ShowStreamInfo: true, Indent: true, FilterPids: true}
flag.StringVar(&opts.PidsToDrop, "drop", "", "pids to drop in the PMT (split by space), e.g. \"256 257\"")
flag.StringVar(&opts.OutPutTo, "output", "", "save the TS packets into the given file (filepath) or stdout (-)")
flag.BoolVar(&opts.Indent, "indent", true, "indent JSON output")
flag.BoolVar(&opts.Version, "version", false, "print version")

flag.Usage = func() {
parts := strings.Split(os.Args[0], "/")
name := parts[len(parts)-1]
fmt.Fprintf(os.Stderr, usg, name, name)
fmt.Fprintf(os.Stderr, "\nRun as: %s [options] file.ts (- for stdin) with options:\n\n", name)
flag.PrintDefaults()
}

flag.Parse()
return opts
}

func filter(ctx context.Context, w io.Writer, f io.Reader, o internal.Options) error {
outPutToFile := o.OutPutTo != "-"
var textOutput io.Writer
var tsOutput io.Writer
// If we output to ts files, print analysis to stdout
if outPutToFile {
// Remove existing output file
if err := internal.RemoveFileIfExists(o.OutPutTo); err != nil {
return err
}
file, err := internal.OpenFileAndAppend(o.OutPutTo)
if err != nil {
return err
}
tsOutput = file
textOutput = w
defer file.Close()
} else { // If we output to stdout, print analysis to stderr
tsOutput = w
textOutput = os.Stderr
}

return internal.FilterPids(ctx, textOutput, tsOutput, f, o)
}

func main() {
o, inFile := internal.ParseParams(parseOptions)
err := internal.Execute(os.Stdout, o, inFile, filter)
if err != nil {
log.Fatal(err)
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/Eyevinn/mp4ff v0.42.0
github.com/asticode/go-astits v1.13.0
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA=
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
Expand Down
149 changes: 123 additions & 26 deletions internal/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/Comcast/gots/v2/psi"
"github.com/Comcast/gots/v2/scte35"
"github.com/asticode/go-astits"
slices "golang.org/x/exp/slices"
)

func ParseAll(ctx context.Context, w io.Writer, f io.Reader, o Options) error {
Expand Down Expand Up @@ -150,18 +151,7 @@ dataLoop:
if pmtPID < 0 && d.PMT != nil {
// Loop through elementary streams
for _, es := range d.PMT.ElementaryStreams {
var streamInfo *ElementaryStreamInfo
switch es.StreamType {
case astits.StreamTypeH264Video:
streamInfo = &ElementaryStreamInfo{PID: es.ElementaryPID, Codec: "AVC", Type: "video"}
case astits.StreamTypeAACAudio:
streamInfo = &ElementaryStreamInfo{PID: es.ElementaryPID, Codec: "AAC", Type: "audio"}
case astits.StreamTypeH265Video:
streamInfo = &ElementaryStreamInfo{PID: es.ElementaryPID, Codec: "HEVC", Type: "video"}
case astits.StreamTypeSCTE35:
streamInfo = &ElementaryStreamInfo{PID: es.ElementaryPID, Codec: "SCTE35", Type: "cue"}
}

streamInfo := ParseAstitsElementaryStreamInfo(es)
if streamInfo != nil {
jp.Print(streamInfo, o.ShowStreamInfo)
}
Expand Down Expand Up @@ -214,21 +204,12 @@ func ParseSCTE35(ctx context.Context, w io.Writer, f io.Reader, o Options) error
scte35PIDs := make(map[int]bool)
for _, pmt := range pmts {
for _, es := range pmt.ElementaryStreams() {
pid := uint16(es.ElementaryPid())
var streamInfo *ElementaryStreamInfo
switch es.StreamType() {
case psi.PmtStreamTypeMpeg4VideoH264:
streamInfo = &ElementaryStreamInfo{PID: pid, Codec: "AVC", Type: "video"}
case psi.PmtStreamTypeAac:
streamInfo = &ElementaryStreamInfo{PID: pid, Codec: "AAC", Type: "audio"}
case psi.PmtStreamTypeMpeg4VideoH265:
streamInfo = &ElementaryStreamInfo{PID: pid, Codec: "HEVC", Type: "video"}
case psi.PmtStreamTypeScte35:
streamInfo = &ElementaryStreamInfo{PID: pid, Codec: "SCTE35", Type: "cue"}
scte35PIDs[es.ElementaryPid()] = true
}

streamInfo := ParseElementaryStreamInfo(es)
if streamInfo != nil {
if streamInfo.Codec == "SCTE35" {
scte35PIDs[es.ElementaryPid()] = true
}

jp.Print(streamInfo, o.ShowStreamInfo)
}
}
Expand Down Expand Up @@ -261,3 +242,119 @@ func ParseSCTE35(ctx context.Context, w io.Writer, f io.Reader, o Options) error

return jp.Error()
}

func FilterPids(ctx context.Context, textWriter io.Writer, tsWriter io.Writer, f io.Reader, o Options) error {
pidsToDrop := ParsePidsFromString(o.PidsToDrop)
if slices.Contains(pidsToDrop, 0) {
return fmt.Errorf("filtering out PAT is not allowed")
}

reader := bufio.NewReader(f)
_, err := packet.Sync(reader)
if err != nil {
return fmt.Errorf("syncing with reader %w", err)
}

jp := &JsonPrinter{W: textWriter, Indent: o.Indent}
statistics := PidFilterStatistics{PidsToDrop: pidsToDrop, TotalPackets: 0, FilteredPackets: 0, PacketsBeforePAT: 0}

var pkt packet.Packet
var pat psi.PAT
foundPAT := false
hasShownStreamInfo := false
// Skip packets until PAT
for {
// Read packet
if _, err := io.ReadFull(reader, pkt[:]); err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
}
return fmt.Errorf("reading Packet %w", err)
}
if packet.IsPat(&pkt) {
// Found first PAT packet
foundPAT = true
}

// Count PAT packet and non-PMT packets
statistics.TotalPackets = statistics.TotalPackets + 1
if !foundPAT {
// packets before PAT
statistics.PacketsBeforePAT = statistics.PacketsBeforePAT + 1
continue
}

if packet.IsPat(&pkt) {
// Parse PAT packet
pat, err = ParsePacketToPAT(&pkt)
if err != nil {
return err
}

// Save PAT packet
if err = WritePacket(&pkt, tsWriter); err != nil {
return err
}

// Handle PMT packet(s)
pm := pat.ProgramMap()
for _, pid := range pm {
if slices.Contains(pidsToDrop, pid) {
return fmt.Errorf("filtering out PMT is not allowed")
}

packets, pmt, err := ReadPMTPackets(reader, pid)
if err != nil {
return err
}
// Count PMT packets
statistics.TotalPackets = statistics.TotalPackets + uint32(len(packets))

// 1. Print stream info only once
if o.ShowStreamInfo && !hasShownStreamInfo {
for _, es := range pmt.ElementaryStreams() {
streamInfo := ParseElementaryStreamInfo(es)
if streamInfo != nil {
jp.Print(streamInfo, true)
}
}
hasShownStreamInfo = true
}

// 2. Drop pids if exist
isFilteringOutPids := IsTwoSlicesOverlapping(pmt.Pids(), pidsToDrop)
pkts := []*packet.Packet{}
for i := range packets {
pkts = append(pkts, &packets[i])
}
if isFilteringOutPids {
pidsToKeep := GetDifferenceOfTwoSlices(pmt.Pids(), pidsToDrop)
pkts, err = psi.FilterPMTPacketsToPids(pkts, pidsToKeep)
if err != nil {
return fmt.Errorf("filtering pids %w", err)
}

statistics.FilteredPackets = statistics.FilteredPackets + uint32(len(pkts))
}

// 3. Save PMT packets
for _, p := range pkts {
if err = WritePacket(p, tsWriter); err != nil {
return err
}
}
}

// Move on to next packet
continue
}

// Save non-PAT/PMT packets
if err = WritePacket(&pkt, tsWriter); err != nil {
return err
}
}

jp.PrintFilter(statistics, true)
return nil
}
19 changes: 19 additions & 0 deletions internal/statistics.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package internal

type PidFilterStatistics struct {
PidsToDrop []int `json:"pidsToDrop"`
TotalPackets uint32 `json:"total"`
FilteredPackets uint32 `json:"filtered"`
PacketsBeforePAT uint32 `json:"packetBeforePAT"`
Percentage float32 `json:"percentage"`
}

type StreamStatistics struct {
Type string `json:"streamType"`
Pid uint16 `json:"pid"`
Expand All @@ -17,6 +25,17 @@ type StreamStatistics struct {
Errors []string `json:"errors,omitempty"`
}

func (p *JsonPrinter) PrintFilter(s PidFilterStatistics, show bool) {
if s.TotalPackets == 0 {
s.Percentage = 0
} else {
s.Percentage = (float32(s.PacketsBeforePAT) + float32(s.FilteredPackets)) / float32(s.TotalPackets)
}

// print statistics
p.Print(s, show)
}

func (p *JsonPrinter) PrintStatistics(s StreamStatistics, show bool) {
// fmt.Fprintf(p.w, "Print statistics for PID: %d\n", s.Pid)
s.calculateFrameRate(TimeScale)
Expand Down
Loading

0 comments on commit 1b2c5c3

Please sign in to comment.