From 711cbe2a6312b3ccefd1390aa95686e29a132bd0 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 12:05:46 +0400 Subject: [PATCH 01/10] refactor: add set Signed-off-by: knqyf263 --- pkg/set/set.go | 37 +++ pkg/set/unsafe.go | 100 +++++++ pkg/set/unsafe_test.go | 583 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 720 insertions(+) create mode 100644 pkg/set/set.go create mode 100644 pkg/set/unsafe.go create mode 100644 pkg/set/unsafe_test.go diff --git a/pkg/set/set.go b/pkg/set/set.go new file mode 100644 index 000000000000..e1e985c27217 --- /dev/null +++ b/pkg/set/set.go @@ -0,0 +1,37 @@ +package set + +// Set defines the interface for set operations +type Set[T comparable] interface { + // Add adds an item to the set + Add(item T) + + // Append adds multiple items to the set and returns the new size + Append(val ...T) int + + // Remove removes an item from the set + Remove(item T) + + // Contains checks if an item exists in the set + Contains(item T) bool + + // Size returns the number of items in the set + Size() int + + // Clear removes all items from the set + Clear() + + // Clone returns a new set with a copy of all items + Clone() Set[T] + + // Items returns all items in the set as a slice + Items() []T + + // Union returns a new set containing all items from both sets + Union(other Set[T]) Set[T] + + // Intersection returns a new set containing items present in both sets + Intersection(other Set[T]) Set[T] + + // Difference returns a new set containing items present in this set but not in the other + Difference(other Set[T]) Set[T] +} diff --git a/pkg/set/unsafe.go b/pkg/set/unsafe.go new file mode 100644 index 000000000000..e538c81fac71 --- /dev/null +++ b/pkg/set/unsafe.go @@ -0,0 +1,100 @@ +package set + +import "maps" + +// unsafeSet represents a non-thread-safe set implementation +// WARNING: This implementation is not thread-safe +type unsafeSet[T comparable] map[T]struct{} + +// New creates a new empty non-thread-safe set with optional initial values +func New[T comparable](values ...T) Set[T] { + s := make(unsafeSet[T]) + for _, v := range values { + s[v] = struct{}{} + } + return s +} + +// Add adds an item to the set +func (s unsafeSet[T]) Add(item T) { + s[item] = struct{}{} +} + +// Append adds multiple items to the set and returns the new size +func (s unsafeSet[T]) Append(val ...T) int { + for _, item := range val { + s[item] = struct{}{} + } + return len(s) +} + +// Remove removes an item from the set +func (s unsafeSet[T]) Remove(item T) { + delete(s, item) +} + +// Contains checks if an item exists in the set +func (s unsafeSet[T]) Contains(item T) bool { + _, exists := s[item] + return exists +} + +// Size returns the number of items in the set +func (s unsafeSet[T]) Size() int { + return len(s) +} + +// Clear removes all items from the set +func (s unsafeSet[T]) Clear() { + for k := range s { + delete(s, k) + } +} + +// Clone returns a new set with a copy of all items +func (s unsafeSet[T]) Clone() Set[T] { + return maps.Clone(s) +} + +// Items returns all items in the set as a slice +func (s unsafeSet[T]) Items() []T { + items := make([]T, 0, len(s)) + for item := range s { + items = append(items, item) + } + return items +} + +// Union returns a new set containing all items from both sets +func (s unsafeSet[T]) Union(other Set[T]) Set[T] { + result := make(unsafeSet[T]) + for k := range s { + result[k] = struct{}{} + } + for _, item := range other.Items() { + result[item] = struct{}{} + } + return result +} + +// Intersection returns a new set containing items present in both sets +func (s unsafeSet[T]) Intersection(other Set[T]) Set[T] { + result := make(unsafeSet[T]) + for k := range s { + if other.Contains(k) { + result[k] = struct{}{} + } + } + return result +} + +// Difference returns a new set containing items present in this set but not in the other +func (s unsafeSet[T]) Difference(other Set[T]) Set[T] { + result := make(unsafeSet[T]) + for k := range s { + if !other.Contains(k) { + result[k] = struct{}{} + } + } + return result +} diff --git a/pkg/set/unsafe_test.go b/pkg/set/unsafe_test.go new file mode 100644 index 000000000000..4c19e75c0ce6 --- /dev/null +++ b/pkg/set/unsafe_test.go @@ -0,0 +1,583 @@ +package set_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aquasecurity/trivy/pkg/set" +) + +func Test_New(t *testing.T) { + tests := []struct { + name string + values []int + wantSize int + wantAll bool + desc string + }{ + { + name: "new empty set", + values: []int{}, + wantSize: 0, + wantAll: true, + desc: "should create empty set when no values provided", + }, + { + name: "new set with single value", + values: []int{1}, + wantSize: 1, + wantAll: true, + desc: "should create set with single value", + }, + { + name: "new set with multiple values", + values: []int{ + 1, + 2, + 3, + }, + wantSize: 3, + wantAll: true, + desc: "should create set with multiple values", + }, + { + name: "new set with duplicate values", + values: []int{ + 1, + 2, + 2, + 3, + 3, + 3, + }, + wantSize: 3, + wantAll: true, + desc: "should create set with unique values only", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := set.New(tt.values...) + assert.Equal(t, tt.wantSize, s.Size(), "unexpected set size") + }) + } +} + +func Test_unsafeSet_Add(t *testing.T) { + // Define custom type for struct test cases + type custom struct { + id int + name string + } + + tests := []struct { + name string + prepare func(s set.Set[any]) + input any + wantSize int + }{ + { + name: "add integer", + prepare: nil, + input: 1, + wantSize: 1, + }, + { + name: "add duplicate integer", + prepare: func(s set.Set[any]) { + s.Add(1) + }, + input: 1, + wantSize: 1, + }, + { + name: "add string", + prepare: nil, + input: "test", + wantSize: 1, + }, + { + name: "add empty string", + prepare: nil, + input: "", + wantSize: 1, + }, + { + name: "add custom struct", + prepare: nil, + input: custom{ + id: 1, + name: "test1", + }, + wantSize: 1, + }, + { + name: "add nil pointer", + prepare: nil, + input: (*int)(nil), + wantSize: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := set.New[any]() + if tt.prepare != nil { + tt.prepare(s) + } + s.Add(tt.input) + + got := s.Size() + assert.Equal(t, tt.wantSize, got, "unexpected set size") + assert.True(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input) + }) + } +} + +func Test_unsafeSet_Append(t *testing.T) { + tests := []struct { + name string + prepare func(s set.Set[int]) + input []int + wantSize int + }{ + { + name: "append to empty set", + prepare: nil, + input: []int{ + 1, + 2, + 3, + }, + wantSize: 3, + }, + { + name: "append with duplicates", + prepare: func(s set.Set[int]) { + s.Add(1) + }, + input: []int{ + 1, + 2, + 1, + 3, + 2, + }, + wantSize: 3, + }, + { + name: "append empty slice", + prepare: func(s set.Set[int]) { + s.Add(1) + }, + input: []int{}, + wantSize: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := set.New[int]() + if tt.prepare != nil { + tt.prepare(s) + } + got := s.Append(tt.input...) + + assert.Equal(t, tt.wantSize, got, "unexpected returned size") + assert.Equal(t, tt.wantSize, s.Size(), "unexpected actual size") + + for _, item := range tt.input { + assert.True(t, s.Contains(item), "set should contain appended item: %v", item) + } + }) + } +} + +func Test_unsafeSet_Remove(t *testing.T) { + tests := []struct { + name string + prepare func(s set.Set[int]) + input int + wantSize int + }{ + { + name: "remove existing element", + prepare: func(s set.Set[int]) { + s.Add(1) + }, + input: 1, + wantSize: 0, + }, + { + name: "remove non-existing element", + prepare: func(s set.Set[int]) { + s.Add(1) + }, + input: 2, + wantSize: 1, + }, + { + name: "remove from empty set", + prepare: nil, + input: 1, + wantSize: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := set.New[int]() + if tt.prepare != nil { + tt.prepare(s) + } + s.Remove(tt.input) + + got := s.Size() + assert.Equal(t, tt.wantSize, got, "unexpected set size") + assert.False(t, s.Contains(tt.input), "unexpected contains result for value: %v", tt.input) + }) + } +} + +func Test_unsafeSet_Clear(t *testing.T) { + tests := []struct { + name string + prepare func(s set.Set[int]) + }{ + { + name: "clear non-empty set", + prepare: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + s.Add(3) + }, + }, + { + name: "clear empty set", + prepare: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := set.New[int]() + if tt.prepare != nil { + tt.prepare(s) + } + s.Clear() + + got := s.Size() + assert.Zero(t, got, "unexpected set size") + assert.Equal(t, 0, len(s.Items()), "items should be empty") + }) + } +} + +func Test_unsafeSet_Clone(t *testing.T) { + t.Run("empty set", func(t *testing.T) { + original := set.New[string]() + cloned := original.Clone() + + assert.Equal(t, 0, cloned.Size(), "cloned set should be empty") + + // Verify independence + original.Add("test") + assert.False(t, cloned.Contains("test"), "cloned set should not be affected by original") + }) + + t.Run("basic types", func(t *testing.T) { + original := set.New[any](1, "test", true) + cloned := original.Clone() + + assert.Equal(t, original.Size(), cloned.Size(), "sizes should match") + assert.True(t, cloned.Contains(1), "should contain integer") + assert.True(t, cloned.Contains("test"), "should contain string") + assert.True(t, cloned.Contains(true), "should contain boolean") + + // Verify independence + original.Add("new") + assert.False(t, cloned.Contains("new"), "cloned set should not be affected by original") + cloned.Add("another") + assert.False(t, original.Contains("another"), "original set should not be affected by clone") + }) + + // Test nil pointer + t.Run("nil pointer", func(t *testing.T) { + original := set.New[*int]() + original.Add(nil) + + cloned := original.Clone() + + assert.Equal(t, original.Size(), cloned.Size(), "sizes should match") + assert.True(t, cloned.Contains((*int)(nil)), "should contain nil pointer") + }) +} + +func Test_unsafeSet_Items(t *testing.T) { + tests := []struct { + name string + prepare func(s set.Set[int]) + want []int + }{ + { + name: "get items from non-empty set", + prepare: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + s.Add(3) + }, + want: []int{ + 1, + 2, + 3, + }, + }, + { + name: "get items from empty set", + prepare: nil, + want: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := set.New[int]() + if tt.prepare != nil { + tt.prepare(s) + } + got := s.Items() + + assert.ElementsMatch(t, tt.want, got, "unexpected items in set") + }) + } +} + +func Test_unsafeSet_Union(t *testing.T) { + tests := []struct { + name string + prepare1 func(s set.Set[int]) + prepare2 func(s set.Set[int]) + want []int + }{ + { + name: "union of non-overlapping sets", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + prepare2: func(s set.Set[int]) { + s.Add(3) + s.Add(4) + }, + want: []int{ + 1, + 2, + 3, + 4, + }, + }, + { + name: "union of overlapping sets", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + s.Add(3) + }, + prepare2: func(s set.Set[int]) { + s.Add(2) + s.Add(3) + s.Add(4) + }, + want: []int{ + 1, + 2, + 3, + 4, + }, + }, + { + name: "union with empty set", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + prepare2: nil, + want: []int{ + 1, + 2, + }, + }, + { + name: "union of empty sets", + prepare1: nil, + prepare2: nil, + want: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s1 := set.New[int]() + s2 := set.New[int]() + + if tt.prepare1 != nil { + tt.prepare1(s1) + } + if tt.prepare2 != nil { + tt.prepare2(s2) + } + + result := s1.Union(s2) + got := result.Items() + + assert.ElementsMatch(t, tt.want, got, "unexpected union result") + }) + } +} + +func Test_unsafeSet_Intersection(t *testing.T) { + tests := []struct { + name string + prepare1 func(s set.Set[int]) + prepare2 func(s set.Set[int]) + want []int + }{ + { + name: "intersection of overlapping sets", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + s.Add(3) + }, + prepare2: func(s set.Set[int]) { + s.Add(2) + s.Add(3) + s.Add(4) + }, + want: []int{ + 2, + 3, + }, + }, + { + name: "intersection of non-overlapping sets", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + prepare2: func(s set.Set[int]) { + s.Add(3) + s.Add(4) + }, + want: []int{}, + }, + { + name: "intersection with empty set", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + prepare2: nil, + want: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s1 := set.New[int]() + s2 := set.New[int]() + + if tt.prepare1 != nil { + tt.prepare1(s1) + } + if tt.prepare2 != nil { + tt.prepare2(s2) + } + + result := s1.Intersection(s2) + got := result.Items() + + assert.ElementsMatch(t, tt.want, got, "unexpected intersection result") + }) + } +} + +func Test_unsafeSet_Difference(t *testing.T) { + tests := []struct { + name string + prepare1 func(s set.Set[int]) + prepare2 func(s set.Set[int]) + want []int + }{ + { + name: "difference of overlapping sets", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + s.Add(3) + }, + prepare2: func(s set.Set[int]) { + s.Add(2) + s.Add(3) + s.Add(4) + }, + want: []int{1}, + }, + { + name: "difference with non-overlapping set", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + prepare2: func(s set.Set[int]) { + s.Add(3) + s.Add(4) + }, + want: []int{ + 1, + 2, + }, + }, + { + name: "difference with empty set", + prepare1: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + prepare2: nil, + want: []int{ + 1, + 2, + }, + }, + { + name: "difference of empty set", + prepare1: nil, + prepare2: func(s set.Set[int]) { + s.Add(1) + s.Add(2) + }, + want: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s1 := set.New[int]() + s2 := set.New[int]() + if tt.prepare1 != nil { + tt.prepare1(s1) + } + if tt.prepare2 != nil { + tt.prepare2(s2) + } + + result := s1.Difference(s2) + got := result.Items() + + assert.ElementsMatch(t, tt.want, got, "unexpected difference result") + }) + } +} From 06809900a186336d1b86887657eb99e9c1c99742 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 12:23:00 +0400 Subject: [PATCH 02/10] refactor(iac/rego): use set Signed-off-by: knqyf263 --- pkg/iac/rego/load.go | 9 +++------ pkg/iac/rego/options.go | 4 +--- pkg/iac/rego/scanner.go | 20 +++++++++----------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/pkg/iac/rego/load.go b/pkg/iac/rego/load.go index 7356384a3327..29df31b2fa60 100644 --- a/pkg/iac/rego/load.go +++ b/pkg/iac/rego/load.go @@ -12,16 +12,13 @@ import ( "github.com/samber/lo" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" ) -var builtinNamespaces = map[string]struct{}{ - "builtin": {}, - "defsec": {}, - "appshield": {}, -} +var builtinNamespaces = set.New("builtin", "defsec", "appshield") func BuiltinNamespaces() []string { - return lo.Keys(builtinNamespaces) + return builtinNamespaces.Items() } func IsBuiltinNamespace(namespace string) bool { diff --git a/pkg/iac/rego/options.go b/pkg/iac/rego/options.go index 79a1b951746d..5b8df5f4affd 100644 --- a/pkg/iac/rego/options.go +++ b/pkg/iac/rego/options.go @@ -69,9 +69,7 @@ func WithDataDirs(paths ...string) options.ScannerOption { func WithPolicyNamespaces(namespaces ...string) options.ScannerOption { return func(s options.ConfigurableScanner) { if ss, ok := s.(*Scanner); ok { - for _, namespace := range namespaces { - ss.ruleNamespaces[namespace] = struct{}{} - } + ss.ruleNamespaces.Append(namespaces...) } } } diff --git a/pkg/iac/rego/scanner.go b/pkg/iac/rego/scanner.go index c46e580ac3bd..fe942507300c 100644 --- a/pkg/iac/rego/scanner.go +++ b/pkg/iac/rego/scanner.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/fs" - "maps" "strings" "github.com/open-policy-agent/opa/ast" @@ -22,6 +21,7 @@ import ( "github.com/aquasecurity/trivy/pkg/iac/scanners/options" "github.com/aquasecurity/trivy/pkg/iac/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" ) var checkTypesWithSubtype = map[types.Source]struct{}{ @@ -32,19 +32,19 @@ var checkTypesWithSubtype = map[types.Source]struct{}{ var supportedProviders = makeSupportedProviders() -func makeSupportedProviders() map[string]struct{} { - m := make(map[string]struct{}) +func makeSupportedProviders() set.Set[string] { + m := set.New[string]() for _, p := range providers.AllProviders() { - m[string(p)] = struct{}{} + m.Add(string(p)) } - m["kind"] = struct{}{} // kubernetes + m.Add("kind") // kubernetes return m } var _ options.ConfigurableScanner = (*Scanner)(nil) type Scanner struct { - ruleNamespaces map[string]struct{} + ruleNamespaces set.Set[string] policies map[string]*ast.Module store storage.Store runtimeValues *ast.Term @@ -103,15 +103,13 @@ func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner { s := &Scanner{ regoErrorLimit: ast.CompileErrorLimitDefault, sourceType: source, - ruleNamespaces: make(map[string]struct{}), + ruleNamespaces: builtinNamespaces.Clone(), runtimeValues: addRuntimeValues(), logger: log.WithPrefix("rego"), customSchemas: make(map[string][]byte), disabledCheckIDs: make(map[string]struct{}), } - maps.Copy(s.ruleNamespaces, builtinNamespaces) - for _, opt := range opts { opt(s) } @@ -198,7 +196,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results, namespace := getModuleNamespace(module) topLevel := strings.Split(namespace, ".")[0] - if _, ok := s.ruleNamespaces[topLevel]; !ok { + if !s.ruleNamespaces.Contains(topLevel) { continue } @@ -290,7 +288,7 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool { for _, input := range inputs { if ii, ok := input.Contents.(map[string]any); ok { for provider := range ii { - if _, exists := supportedProviders[provider]; !exists { + if !supportedProviders.Contains(provider) { continue } From 81c8539d124ec630b60d3d409bca7aef9f39cab8 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 21:45:23 +0400 Subject: [PATCH 03/10] refactor(pom): use set Signed-off-by: knqyf263 --- pkg/dependency/parser/java/pom/artifact.go | 3 ++- pkg/dependency/parser/java/pom/parse.go | 24 +++++++++++++--------- pkg/dependency/parser/java/pom/pom.go | 8 ++++---- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/pkg/dependency/parser/java/pom/artifact.go b/pkg/dependency/parser/java/pom/artifact.go index 8fc20b236b51..00f4843f0b9d 100644 --- a/pkg/dependency/parser/java/pom/artifact.go +++ b/pkg/dependency/parser/java/pom/artifact.go @@ -12,6 +12,7 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/version/doc" ) @@ -30,7 +31,7 @@ type artifact struct { Version version Licenses []string - Exclusions map[string]struct{} + Exclusions set.Set[string] Module bool Relationship ftypes.Relationship diff --git a/pkg/dependency/parser/java/pom/parse.go b/pkg/dependency/parser/java/pom/parse.go index 9f9afc35d99a..4d76ca3e86ae 100644 --- a/pkg/dependency/parser/java/pom/parse.go +++ b/pkg/dependency/parser/java/pom/parse.go @@ -22,6 +22,7 @@ import ( "github.com/aquasecurity/trivy/pkg/dependency/parser/utils" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" xio "github.com/aquasecurity/trivy/pkg/x/io" ) @@ -118,11 +119,11 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc rootArt := root.artifact() rootArt.Relationship = ftypes.RelationshipRoot - return p.parseRoot(rootArt, make(map[string]struct{})) + return p.parseRoot(rootArt, set.New[string]()) } // nolint: gocyclo -func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ftypes.Package, []ftypes.Dependency, error) { +func (p *Parser) parseRoot(root artifact, uniqModules set.Set[string]) ([]ftypes.Package, []ftypes.Dependency, error) { // Prepare a queue for dependencies queue := newArtifactQueue() @@ -145,10 +146,10 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft // Modules should be handled separately so that they can have independent dependencies. // It means multi-module allows for duplicate dependencies. if art.Module { - if _, ok := uniqModules[art.String()]; ok { + if uniqModules.Contains(art.String()) { continue } - uniqModules[art.String()] = struct{}{} + uniqModules.Append(art.String()) modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules) if err != nil { @@ -251,7 +252,7 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft // `mvn` shows modules separately from the root package and does not show module nesting. // So we can add all modules as dependencies of root package. if art.Relationship == ftypes.RelationshipRoot { - dependsOn = append(dependsOn, lo.Keys(uniqModules)...) + dependsOn = append(dependsOn, uniqModules.Items()...) } sort.Strings(dependsOn) @@ -340,7 +341,7 @@ type analysisResult struct { } type analysisOptions struct { - exclusions map[string]struct{} + exclusions set.Set[string] depManagement []pomDependency // from the root POM } @@ -348,6 +349,9 @@ func (p *Parser) analyze(pom *pom, opts analysisOptions) (analysisResult, error) if pom.nil() { return analysisResult{}, nil } + if opts.exclusions == nil { + opts.exclusions = set.New[string]() + } // Update remoteRepositories pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers) p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...)) @@ -492,19 +496,19 @@ func (p *Parser) mergeDependencies(child, parent []pomDependency) []pomDependenc }) } -func (p *Parser) filterDependencies(artifacts []artifact, exclusions map[string]struct{}) []artifact { +func (p *Parser) filterDependencies(artifacts []artifact, exclusions set.Set[string]) []artifact { return lo.Filter(artifacts, func(art artifact, _ int) bool { return !excludeDep(exclusions, art) }) } -func excludeDep(exclusions map[string]struct{}, art artifact) bool { - if _, ok := exclusions[art.Name()]; ok { +func excludeDep(exclusions set.Set[string], art artifact) bool { + if exclusions.Contains(art.Name()) { return true } // Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies // https://maven.apache.org/pom.html#exclusions - for exlusion := range exclusions { + for exlusion := range exclusions.Iter() { // exclusion format - ":" e := strings.Split(exlusion, ":") if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") { diff --git a/pkg/dependency/parser/java/pom/pom.go b/pkg/dependency/parser/java/pom/pom.go index 853dd2beb281..695a9feb950a 100644 --- a/pkg/dependency/parser/java/pom/pom.go +++ b/pkg/dependency/parser/java/pom/pom.go @@ -4,7 +4,6 @@ import ( "encoding/xml" "fmt" "io" - "maps" "net/url" "reflect" "strings" @@ -15,6 +14,7 @@ import ( "github.com/aquasecurity/trivy/pkg/dependency/parser/utils" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/x/slices" ) @@ -287,12 +287,12 @@ func (d pomDependency) ToArtifact(opts analysisOptions) artifact { // To avoid shadow adding exclusions to top pom's, // we need to initialize a new map for each new artifact // See `exclusions in child` test for more information - exclusions := make(map[string]struct{}) + exclusions := set.New[string]() if opts.exclusions != nil { - exclusions = maps.Clone(opts.exclusions) + exclusions = opts.exclusions.Clone() } for _, e := range d.Exclusions.Exclusion { - exclusions[fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)] = struct{}{} + exclusions.Append(fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)) } var locations ftypes.Locations From c317b52043852ca4fcc26f1749d23792590aad4e Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 21:45:35 +0400 Subject: [PATCH 04/10] refactor(iac): use set Signed-off-by: knqyf263 --- pkg/iac/rego/embed.go | 5 +++-- pkg/iac/rego/load.go | 11 ++++------- pkg/iac/rego/options.go | 4 +--- pkg/iac/rego/scanner.go | 10 +++++----- pkg/iac/rules/register.go | 7 ++++--- pkg/iac/scanners/terraform/scanner.go | 5 +++-- pkg/iac/types/fskey_test.go | 12 ++++++------ 7 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pkg/iac/rego/embed.go b/pkg/iac/rego/embed.go index c5416f50c0cc..71110d7fd7b7 100644 --- a/pkg/iac/rego/embed.go +++ b/pkg/iac/rego/embed.go @@ -13,6 +13,7 @@ import ( checks "github.com/aquasecurity/trivy-checks" "github.com/aquasecurity/trivy/pkg/iac/rules" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" ) var LoadAndRegister = sync.OnceFunc(func() { @@ -49,7 +50,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) { } retriever := NewMetadataRetriever(compiler) - regoCheckIDs := make(map[string]struct{}) + regoCheckIDs := set.New[string]() for _, module := range modules { metadata, err := retriever.RetrieveMetadata(ctx, module) @@ -66,7 +67,7 @@ func RegisterRegoRules(modules map[string]*ast.Module) { } if !metadata.Deprecated { - regoCheckIDs[metadata.AVDID] = struct{}{} + regoCheckIDs.Append(metadata.AVDID) } rules.Register(metadata.ToRule()) diff --git a/pkg/iac/rego/load.go b/pkg/iac/rego/load.go index 29df31b2fa60..fc9d5cda58d8 100644 --- a/pkg/iac/rego/load.go +++ b/pkg/iac/rego/load.go @@ -119,15 +119,12 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error { } // gather namespaces - uniq := make(map[string]struct{}) + uniq := set.New[string]() for _, module := range s.policies { namespace := getModuleNamespace(module) - uniq[namespace] = struct{}{} - } - var namespaces []string - for namespace := range uniq { - namespaces = append(namespaces, namespace) + uniq.Add(namespace) } + namespaces := uniq.Items() dataFS := srcFS if s.dataFS != nil { @@ -293,7 +290,7 @@ func (s *Scanner) filterModules(retriever *MetadataRetriever) error { } if IsBuiltinNamespace(getModuleNamespace(module)) { - if _, disabled := s.disabledCheckIDs[meta.ID]; disabled { // ignore builtin disabled checks + if s.disabledCheckIDs.Contains(meta.ID) { // ignore builtin disabled checks continue } } diff --git a/pkg/iac/rego/options.go b/pkg/iac/rego/options.go index 5b8df5f4affd..31026b2f3b57 100644 --- a/pkg/iac/rego/options.go +++ b/pkg/iac/rego/options.go @@ -110,9 +110,7 @@ func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption { func WithDisabledCheckIDs(ids ...string) options.ScannerOption { return func(s options.ConfigurableScanner) { if ss, ok := s.(*Scanner); ok { - for _, id := range ids { - ss.disabledCheckIDs[id] = struct{}{} - } + ss.disabledCheckIDs.Append(ids...) } } } diff --git a/pkg/iac/rego/scanner.go b/pkg/iac/rego/scanner.go index fe942507300c..8f62b6096a26 100644 --- a/pkg/iac/rego/scanner.go +++ b/pkg/iac/rego/scanner.go @@ -70,7 +70,7 @@ type Scanner struct { embeddedChecks map[string]*ast.Module customSchemas map[string][]byte - disabledCheckIDs map[string]struct{} + disabledCheckIDs set.Set[string] } func (s *Scanner) trace(heading string, input any) { @@ -107,7 +107,7 @@ func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner { runtimeValues: addRuntimeValues(), logger: log.WithPrefix("rego"), customSchemas: make(map[string][]byte), - disabledCheckIDs: make(map[string]struct{}), + disabledCheckIDs: set.New[string](), } for _, opt := range opts { @@ -225,15 +225,15 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results, continue } - usedRules := make(map[string]struct{}) + usedRules := set.New[string]() // all rules for _, rule := range module.Rules { ruleName := rule.Head.Name.String() - if _, ok := usedRules[ruleName]; ok { + if usedRules.Contains(ruleName) { continue } - usedRules[ruleName] = struct{}{} + usedRules.Add(ruleName) if isEnforcedRule(ruleName) { ruleResults, err := s.applyRule(ctx, namespace, ruleName, inputs) if err != nil { diff --git a/pkg/iac/rules/register.go b/pkg/iac/rules/register.go index e07268255417..207502672c51 100755 --- a/pkg/iac/rules/register.go +++ b/pkg/iac/rules/register.go @@ -10,6 +10,7 @@ import ( "github.com/aquasecurity/trivy/pkg/iac/scan" dftypes "github.com/aquasecurity/trivy/pkg/iac/types" ruleTypes "github.com/aquasecurity/trivy/pkg/iac/types/rules" + "github.com/aquasecurity/trivy/pkg/set" ) type registry struct { @@ -74,14 +75,14 @@ func (r *registry) getFrameworkRules(fw ...framework.Framework) []ruleTypes.Regi if len(fw) == 0 { fw = []framework.Framework{framework.Default} } - unique := make(map[int]struct{}) + unique := set.New[int]() for _, f := range fw { for _, rule := range r.frameworks[f] { - if _, ok := unique[rule.Number]; ok { + if unique.Contains(rule.Number) { continue } registered = append(registered, rule) - unique[rule.Number] = struct{}{} + unique.Append(rule.Number) } } return registered diff --git a/pkg/iac/scanners/terraform/scanner.go b/pkg/iac/scanners/terraform/scanner.go index 9ddb2f3ef861..c9aed6f76b73 100644 --- a/pkg/iac/scanners/terraform/scanner.go +++ b/pkg/iac/scanners/terraform/scanner.go @@ -19,6 +19,7 @@ import ( "github.com/aquasecurity/trivy/pkg/iac/terraform" "github.com/aquasecurity/trivy/pkg/iac/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" ) var _ scanners.FSScanner = (*Scanner)(nil) @@ -31,7 +32,7 @@ type Scanner struct { options []options.ScannerOption parserOpt []parser.Option executorOpt []executor.Option - dirs map[string]struct{} + dirs set.Set[string] forceAllDirs bool regoScanner *rego.Scanner execLock sync.RWMutex @@ -55,7 +56,7 @@ func (s *Scanner) AddExecutorOptions(opts ...executor.Option) { func New(opts ...options.ScannerOption) *Scanner { s := &Scanner{ - dirs: make(map[string]struct{}), + dirs: set.New[string](), options: opts, logger: log.WithPrefix("terraform scanner"), } diff --git a/pkg/iac/types/fskey_test.go b/pkg/iac/types/fskey_test.go index 37be8fce4f0d..de91b32a0405 100644 --- a/pkg/iac/types/fskey_test.go +++ b/pkg/iac/types/fskey_test.go @@ -7,6 +7,8 @@ import ( "github.com/liamg/memoryfs" "github.com/stretchr/testify/assert" + + "github.com/aquasecurity/trivy/pkg/set" ) func Test_FSKey(t *testing.T) { @@ -18,22 +20,20 @@ func Test_FSKey(t *testing.T) { memoryfs.New(), } - keys := make(map[string]struct{}) + keys := set.New[string]() t.Run("uniqueness", func(t *testing.T) { for _, system := range systems { key := CreateFSKey(system) - _, ok := keys[key] - assert.False(t, ok, "filesystem keys should be unique") - keys[key] = struct{}{} + assert.False(t, keys.Contains(key), "filesystem keys should be unique") + keys.Append(key) } }) t.Run("reproducible", func(t *testing.T) { for _, system := range systems { key := CreateFSKey(system) - _, ok := keys[key] - assert.True(t, ok, "filesystem keys should be reproducible") + assert.True(t, keys.Contains(key), "filesystem keys should be reproducible") } }) } From c5d972ce9ba0865228754ea99b5dcf8689be03db Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 21:45:53 +0400 Subject: [PATCH 05/10] feat: add Iter Signed-off-by: knqyf263 --- pkg/set/set.go | 5 +++++ pkg/set/unsafe.go | 17 +++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pkg/set/set.go b/pkg/set/set.go index e1e985c27217..878cc750a8f6 100644 --- a/pkg/set/set.go +++ b/pkg/set/set.go @@ -1,5 +1,7 @@ package set +import "iter" + // Set defines the interface for set operations type Set[T comparable] interface { // Add adds an item to the set @@ -26,6 +28,9 @@ type Set[T comparable] interface { // Items returns all items in the set as a slice Items() []T + // Iter returns an iterator over the set + Iter() iter.Seq[T] + // Union returns a new set containing all items from both sets Union(other Set[T]) Set[T] diff --git a/pkg/set/unsafe.go b/pkg/set/unsafe.go index e538c81fac71..7bfb3d9499c1 100644 --- a/pkg/set/unsafe.go +++ b/pkg/set/unsafe.go @@ -1,6 +1,10 @@ package set -import "maps" +import ( + "iter" + "maps" + "slices" +) // unsafeSet represents a non-thread-safe set implementation // WARNING: This implementation is not thread-safe @@ -58,11 +62,12 @@ func (s unsafeSet[T]) Clone() Set[T] { // Items returns all items in the set as a slice func (s unsafeSet[T]) Items() []T { - items := make([]T, 0, len(s)) - for item := range s { - items = append(items, item) - } - return items + return slices.Collect(s.Iter()) +} + +// Iter returns an iterator over the set +func (s unsafeSet[T]) Iter() iter.Seq[T] { + return maps.Keys(s) } // Union returns a new set containing all items from both sets From 425ff1601e12b6d73c0b9c792de9080f9b3c39ea Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 22:13:33 +0400 Subject: [PATCH 06/10] refactor(parser): use set Signed-off-by: knqyf263 --- pkg/dependency/parser/java/pom/parse.go | 6 +++--- pkg/dependency/parser/nodejs/npm/parse.go | 15 ++++++++------- pkg/dependency/parser/nodejs/pnpm/parse.go | 9 +++++---- pkg/dependency/parser/nuget/lock/parse.go | 2 +- pkg/dependency/parser/python/uv/parse.go | 21 +++++++++++---------- pkg/dependency/parser/utils/utils.go | 13 ------------- 6 files changed, 28 insertions(+), 38 deletions(-) diff --git a/pkg/dependency/parser/java/pom/parse.go b/pkg/dependency/parser/java/pom/parse.go index 4d76ca3e86ae..2ce4fb4e3936 100644 --- a/pkg/dependency/parser/java/pom/parse.go +++ b/pkg/dependency/parser/java/pom/parse.go @@ -412,16 +412,16 @@ func (p *Parser) resolveParent(pom *pom) error { } func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency { - uniq := make(map[string]struct{}) + uniq := set.New[string]() var depManagement []pomDependency // The preceding argument takes precedence. for _, dm := range depManagements { for _, dep := range dm { - if _, ok := uniq[dep.Name()]; ok { + if uniq.Contains(dep.Name()) { continue } depManagement = append(depManagement, dep) - uniq[dep.Name()] = struct{}{} + uniq.Append(dep.Name()) } } return depManagement diff --git a/pkg/dependency/parser/nodejs/npm/parse.go b/pkg/dependency/parser/nodejs/npm/parse.go index 4956f9b1cd10..eae16752ce97 100644 --- a/pkg/dependency/parser/nodejs/npm/parse.go +++ b/pkg/dependency/parser/nodejs/npm/parse.go @@ -17,6 +17,7 @@ import ( "github.com/aquasecurity/trivy/pkg/dependency/parser/utils" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" xio "github.com/aquasecurity/trivy/pkg/x/io" ) @@ -91,7 +92,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype // https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages p.resolveLinks(packages) - directDeps := make(map[string]struct{}) + directDeps := set.New[string]() for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) { pkgPath := joinPaths(nodeModulesDir, name) if _, ok := packages[pkgPath]; !ok { @@ -101,7 +102,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype } // Store the package paths of direct dependencies // e.g. node_modules/body-parser - directDeps[pkgPath] = struct{}{} + directDeps.Append(pkgPath) } for pkgPath, pkg := range packages { @@ -366,13 +367,13 @@ func (p *Parser) pkgNameFromPath(pkgPath string) string { func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency { var uniqDeps ftypes.Dependencies - unique := make(map[string]struct{}) + unique := set.New[string]() for _, dep := range deps { sort.Strings(dep.DependsOn) depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ",")) - if _, ok := unique[depKey]; !ok { - unique[depKey] = struct{}{} + if !unique.Contains(depKey) { + unique.Append(depKey) uniqDeps = append(uniqDeps, dep) } } @@ -381,11 +382,11 @@ func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency { return uniqDeps } -func isIndirectPkg(pkgPath string, directDeps map[string]struct{}) bool { +func isIndirectPkg(pkgPath string, directDeps set.Set[string]) bool { // A project can contain 2 different versions of the same dependency. // e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi` // direct dependencies always have root path (`node_modules/`) - if _, ok := directDeps[pkgPath]; ok { + if directDeps.Contains(pkgPath) { return false } return true diff --git a/pkg/dependency/parser/nodejs/pnpm/parse.go b/pkg/dependency/parser/nodejs/pnpm/parse.go index 6f85411f40d0..0817fd48899f 100644 --- a/pkg/dependency/parser/nodejs/pnpm/parse.go +++ b/pkg/dependency/parser/nodejs/pnpm/parse.go @@ -14,6 +14,7 @@ import ( "github.com/aquasecurity/trivy/pkg/dependency" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" xio "github.com/aquasecurity/trivy/pkg/x/io" ) @@ -215,7 +216,7 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen } } - visited := make(map[string]struct{}) + visited := set.New[string]() // Overwrite the `Dev` field for dev deps and their child dependencies. for _, pkg := range resolvedPkgs { if !pkg.Dev { @@ -227,8 +228,8 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen } // markRootPkgs sets `Dev` to false for non dev dependency. -func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited map[string]struct{}) { - if _, ok := visited[id]; ok { +func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited set.Set[string]) { + if visited.Contains(id) { return } pkg, ok := pkgs[id] @@ -238,7 +239,7 @@ func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps ma pkg.Dev = false pkgs[id] = pkg - visited[id] = struct{}{} + visited.Append(id) // Update child deps for _, depID := range deps[id].DependsOn { diff --git a/pkg/dependency/parser/nuget/lock/parse.go b/pkg/dependency/parser/nuget/lock/parse.go index 7852680f5749..812a04515850 100644 --- a/pkg/dependency/parser/nuget/lock/parse.go +++ b/pkg/dependency/parser/nuget/lock/parse.go @@ -76,7 +76,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc } if savedDependsOn, ok := depsMap[depId]; ok { - dependsOn = utils.UniqueStrings(append(dependsOn, savedDependsOn...)) + dependsOn = lo.Uniq(append(dependsOn, savedDependsOn...)) } if len(dependsOn) > 0 { diff --git a/pkg/dependency/parser/python/uv/parse.go b/pkg/dependency/parser/python/uv/parse.go index 6c200c313cee..c50a9d4ebded 100644 --- a/pkg/dependency/parser/python/uv/parse.go +++ b/pkg/dependency/parser/python/uv/parse.go @@ -9,6 +9,7 @@ import ( "github.com/aquasecurity/trivy/pkg/dependency" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/set" xio "github.com/aquasecurity/trivy/pkg/x/io" ) @@ -22,25 +23,25 @@ func (l Lock) packages() map[string]Package { }) } -func (l Lock) directDeps(root Package) map[string]struct{} { - deps := make(map[string]struct{}) +func (l Lock) directDeps(root Package) set.Set[string] { + deps := set.New[string]() for _, dep := range root.Dependencies { - deps[dep.Name] = struct{}{} + deps.Append(dep.Name) } return deps } -func prodDeps(root Package, packages map[string]Package) map[string]struct{} { - visited := make(map[string]struct{}) +func prodDeps(root Package, packages map[string]Package) set.Set[string] { + visited := set.New[string]() walkPackageDeps(root, packages, visited) return visited } -func walkPackageDeps(pkg Package, packages map[string]Package, visited map[string]struct{}) { - if _, ok := visited[pkg.Name]; ok { +func walkPackageDeps(pkg Package, packages map[string]Package, visited set.Set[string]) { + if visited.Contains(pkg.Name) { return } - visited[pkg.Name] = struct{}{} + visited.Append(pkg.Name) for _, dep := range pkg.Dependencies { depPkg, exists := packages[dep.Name] if !exists { @@ -119,7 +120,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc ) for _, pkg := range lock.Packages { - if _, ok := prodDeps[pkg.Name]; !ok { + if !prodDeps.Contains(pkg.Name) { continue } @@ -127,7 +128,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc relationship := ftypes.RelationshipIndirect if pkg.isRoot() { relationship = ftypes.RelationshipRoot - } else if _, ok := directDeps[pkg.Name]; ok { + } else if directDeps.Contains(pkg.Name) { relationship = ftypes.RelationshipDirect } diff --git a/pkg/dependency/parser/utils/utils.go b/pkg/dependency/parser/utils/utils.go index ce2aff36976b..36afd9025310 100644 --- a/pkg/dependency/parser/utils/utils.go +++ b/pkg/dependency/parser/utils/utils.go @@ -10,19 +10,6 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ) -func UniqueStrings(ss []string) []string { - var results []string - uniq := make(map[string]struct{}) - for _, s := range ss { - if _, ok := uniq[s]; ok { - continue - } - results = append(results, s) - uniq[s] = struct{}{} - } - return results -} - func UniquePackages(pkgs []ftypes.Package) []ftypes.Package { if len(pkgs) == 0 { return nil From b70ee8769755c3960564d71d4144a8839243eac4 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 22:13:49 +0400 Subject: [PATCH 07/10] refactor: use set Signed-off-by: knqyf263 --- pkg/compliance/spec/compliance.go | 8 +- pkg/fanal/analyzer/imgconf/apk/apk.go | 20 +++-- pkg/fanal/analyzer/imgconf/apk/apk_test.go | 87 +++++----------------- pkg/fanal/analyzer/pkg/apk/apk.go | 7 +- pkg/fanal/image/daemon/image.go | 4 +- pkg/fanal/utils/utils.go | 8 -- pkg/iac/rego/result.go | 4 +- pkg/iac/rego/scanner.go | 19 ++--- pkg/licensing/classifier.go | 7 +- pkg/remote/remote_test.go | 9 ++- pkg/report/table/vulnerability.go | 36 ++++----- pkg/scanner/langpkg/scan.go | 9 ++- pkg/scanner/local/scan.go | 7 +- 13 files changed, 83 insertions(+), 142 deletions(-) diff --git a/pkg/compliance/spec/compliance.go b/pkg/compliance/spec/compliance.go index 70355eaa926f..0d9b4cde810c 100644 --- a/pkg/compliance/spec/compliance.go +++ b/pkg/compliance/spec/compliance.go @@ -6,13 +6,13 @@ import ( "path/filepath" "strings" - "github.com/samber/lo" "golang.org/x/xerrors" "gopkg.in/yaml.v3" sp "github.com/aquasecurity/trivy-checks/pkg/spec" iacTypes "github.com/aquasecurity/trivy/pkg/iac/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/types" ) @@ -31,17 +31,17 @@ const ( // Scanners reads spec control and determines the scanners by check ID prefix func (cs *ComplianceSpec) Scanners() (types.Scanners, error) { - scannerTypes := make(map[types.Scanner]struct{}) + scannerTypes := set.New[types.Scanner]() for _, control := range cs.Spec.Controls { for _, check := range control.Checks { scannerType := scannerByCheckID(check.ID) if scannerType == types.UnknownScanner { return nil, xerrors.Errorf("unsupported check ID: %s", check.ID) } - scannerTypes[scannerType] = struct{}{} + scannerTypes.Append(scannerType) } } - return lo.Keys(scannerTypes), nil + return scannerTypes.Items(), nil } // CheckIDs return list of compliance check IDs diff --git a/pkg/fanal/analyzer/imgconf/apk/apk.go b/pkg/fanal/analyzer/imgconf/apk/apk.go index 04aa244313c1..a43e4838221e 100644 --- a/pkg/fanal/analyzer/imgconf/apk/apk.go +++ b/pkg/fanal/analyzer/imgconf/apk/apk.go @@ -19,6 +19,7 @@ import ( "github.com/aquasecurity/trivy/pkg/dependency" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/set" ) const ( @@ -179,33 +180,30 @@ func (a alpineCmdAnalyzer) parseCommand(command string, envs map[string]string) return pkgs } func (a alpineCmdAnalyzer) resolveDependencies(apkIndexArchive *apkIndex, originalPkgs []string) (pkgs []string) { - uniqPkgs := make(map[string]struct{}) + uniqPkgs := set.New[string]() for _, pkgName := range originalPkgs { - if _, ok := uniqPkgs[pkgName]; ok { + if uniqPkgs.Contains(pkgName) { continue } - seenPkgs := make(map[string]struct{}) + seenPkgs := set.New[string]() for _, p := range a.resolveDependency(apkIndexArchive, pkgName, seenPkgs) { - uniqPkgs[p] = struct{}{} + uniqPkgs.Append(p) } } - for pkg := range uniqPkgs { - pkgs = append(pkgs, pkg) - } - return pkgs + return uniqPkgs.Items() } func (a alpineCmdAnalyzer) resolveDependency(apkIndexArchive *apkIndex, pkgName string, - seenPkgs map[string]struct{}) (pkgNames []string) { + seenPkgs set.Set[string]) (pkgNames []string) { pkg, ok := apkIndexArchive.Package[pkgName] if !ok { return nil } - if _, ok = seenPkgs[pkgName]; ok { + if seenPkgs.Contains(pkgName) { return nil } - seenPkgs[pkgName] = struct{}{} + seenPkgs.Append(pkgName) pkgNames = append(pkgNames, pkgName) for _, dependency := range pkg.Dependencies { diff --git a/pkg/fanal/analyzer/imgconf/apk/apk_test.go b/pkg/fanal/analyzer/imgconf/apk/apk_test.go index 93da80f87e0b..78fce4cba09d 100644 --- a/pkg/fanal/analyzer/imgconf/apk/apk_test.go +++ b/pkg/fanal/analyzer/imgconf/apk/apk_test.go @@ -19,6 +19,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/analyzer" "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/set" ) var ( @@ -1508,86 +1509,41 @@ func TestResolveDependency(t *testing.T) { var tests = map[string]struct { pkgName string apkIndexArchivePath string - expected map[string]struct{} + want set.Set[string] }{ "low": { pkgName: "libblkid", apkIndexArchivePath: "testdata/history_v3.9.json", - expected: map[string]struct{}{ - "libblkid": {}, - "libuuid": {}, - "musl": {}, - }, + want: set.New("libblkid", "libuuid", "musl"), }, "medium": { pkgName: "libgcab", apkIndexArchivePath: "testdata/history_v3.9.json", - expected: map[string]struct{}{ - "busybox": {}, - "libblkid": {}, - "libuuid": {}, - "musl": {}, - "libmount": {}, - "pcre": {}, - "glib": {}, - "libgcab": {}, - "libintl": {}, - "zlib": {}, - "libffi": {}, - }, + want: set.New("busybox", "libblkid", "libuuid", "musl", "libmount", "pcre", "glib", + "libgcab", "libintl", "zlib", "libffi"), }, "high": { pkgName: "postgresql", apkIndexArchivePath: "testdata/history_v3.9.json", - expected: map[string]struct{}{ - "busybox": {}, - "ncurses-terminfo-base": {}, - "ncurses-terminfo": {}, - "libedit": {}, - "db": {}, - "libsasl": {}, - "libldap": {}, - "libpq": {}, - "postgresql-client": {}, - "tzdata": {}, - "libxml2": {}, - "postgresql": {}, - "musl": {}, - "libcrypto1.1": {}, - "libssl1.1": {}, - "ncurses-libs": {}, - "zlib": {}, - }, + want: set.New("busybox", "ncurses-terminfo-base", "ncurses-terminfo", "libedit", "db", "libsasl", + "libldap", "libpq", "postgresql-client", "tzdata", "libxml2", "postgresql", "musl", "libcrypto1.1", + "libssl1.1", "ncurses-libs", "zlib"), }, "package alias": { pkgName: "sqlite-dev", apkIndexArchivePath: "testdata/history_v3.9.json", - expected: map[string]struct{}{ - "sqlite-dev": {}, - "sqlite-libs": {}, - "pkgconf": {}, // pkgconfig => pkgconf - "musl": {}, - }, + want: set.New( + "sqlite-dev", + "sqlite-libs", + "pkgconf", // pkgconfig => pkgconf + "musl", + ), }, "circular dependencies": { pkgName: "nodejs", apkIndexArchivePath: "testdata/history_v3.7.json", - expected: map[string]struct{}{ - "busybox": {}, - "c-ares": {}, - "ca-certificates": {}, - "http-parser": {}, - "libcrypto1.0": {}, - "libgcc": {}, - "libressl2.6-libcrypto": {}, - "libssl1.0": {}, - "libstdc++": {}, - "libuv": {}, - "musl": {}, - "nodejs": {}, - "nodejs-npm": {}, - "zlib": {}, - }, + want: set.New("busybox", "c-ares", "ca-certificates", "http-parser", "libcrypto1.0", "libgcc", + "libressl2.6-libcrypto", "libssl1.0", "libstdc++", "libuv", "musl", "nodejs", "nodejs-npm", "zlib"), }, } analyzer := alpineCmdAnalyzer{} @@ -1600,15 +1556,10 @@ func TestResolveDependency(t *testing.T) { if err = json.NewDecoder(f).Decode(&apkIndexArchive); err != nil { t.Fatalf("unexpected error: %s", err) } - circularDependencyCheck := make(map[string]struct{}) + circularDependencyCheck := set.New[string]() pkgs := analyzer.resolveDependency(apkIndexArchive, v.pkgName, circularDependencyCheck) - actual := make(map[string]struct{}) - for _, pkg := range pkgs { - actual[pkg] = struct{}{} - } - if !reflect.DeepEqual(v.expected, actual) { - t.Errorf("[%s]\n%s", testName, pretty.Compare(v.expected, actual)) - } + got := set.New(pkgs...) + assert.Equal(t, v.want, got, testName) } } diff --git a/pkg/fanal/analyzer/pkg/apk/apk.go b/pkg/fanal/analyzer/pkg/apk/apk.go index 962398600fc5..69cc72eb5d3f 100644 --- a/pkg/fanal/analyzer/pkg/apk/apk.go +++ b/pkg/fanal/analyzer/pkg/apk/apk.go @@ -20,6 +20,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/licensing" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" ) func init() { @@ -185,13 +186,13 @@ func (a alpinePkgAnalyzer) consolidateDependencies(pkgs []types.Package, provide } func (a alpinePkgAnalyzer) uniquePkgs(pkgs []types.Package) (uniqPkgs []types.Package) { - uniq := make(map[string]struct{}) + uniq := set.New[string]() for _, pkg := range pkgs { - if _, ok := uniq[pkg.Name]; ok { + if uniq.Contains(pkg.Name) { continue } uniqPkgs = append(uniqPkgs, pkg) - uniq[pkg.Name] = struct{}{} + uniq.Append(pkg.Name) } return uniqPkgs } diff --git a/pkg/fanal/image/daemon/image.go b/pkg/fanal/image/daemon/image.go index d3787cc0abc7..3409e278d1c7 100644 --- a/pkg/fanal/image/daemon/image.go +++ b/pkg/fanal/image/daemon/image.go @@ -227,8 +227,8 @@ func (img *image) imageConfig(config *container.Config) v1.Config { if len(config.ExposedPorts) > 0 { c.ExposedPorts = make(map[string]struct{}) - for port := range c.ExposedPorts { - c.ExposedPorts[port] = struct{}{} + for port := range config.ExposedPorts { + c.ExposedPorts[port.Port()] = struct{}{} } } diff --git a/pkg/fanal/utils/utils.go b/pkg/fanal/utils/utils.go index 463c8dd1f255..4e82a9d01ab7 100644 --- a/pkg/fanal/utils/utils.go +++ b/pkg/fanal/utils/utils.go @@ -56,14 +56,6 @@ func IsGzip(f *bufio.Reader) bool { return buf[0] == 0x1F && buf[1] == 0x8B && buf[2] == 0x8 } -func Keys(m map[string]struct{}) []string { - var keys []string - for k := range m { - keys = append(keys, k) - } - return keys -} - func IsExecutable(fileInfo os.FileInfo) bool { // For Windows if filepath.Ext(fileInfo.Name()) == ".exe" { diff --git a/pkg/iac/rego/result.go b/pkg/iac/rego/result.go index c4045705249c..87723a11ee3d 100644 --- a/pkg/iac/rego/result.go +++ b/pkg/iac/rego/result.go @@ -121,7 +121,7 @@ func parseLineNumber(raw any) int { return n } -func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results { +func (s *Scanner) convertResults(resultSet rego.ResultSet, input Input, namespace, rule string, traces []string) scan.Results { var results scan.Results offset := 0 @@ -136,7 +136,7 @@ func (s *Scanner) convertResults(set rego.ResultSet, input Input, namespace, rul } } } - for _, result := range set { + for _, result := range resultSet { for _, expression := range result.Expressions { values, ok := expression.Value.([]any) if !ok { diff --git a/pkg/iac/rego/scanner.go b/pkg/iac/rego/scanner.go index 8f62b6096a26..1bf3a550de93 100644 --- a/pkg/iac/rego/scanner.go +++ b/pkg/iac/rego/scanner.go @@ -24,11 +24,7 @@ import ( "github.com/aquasecurity/trivy/pkg/set" ) -var checkTypesWithSubtype = map[types.Source]struct{}{ - types.SourceCloud: {}, - types.SourceDefsec: {}, - types.SourceKubernetes: {}, -} +var checkTypesWithSubtype = set.New[types.Source](types.SourceCloud, types.SourceDefsec, types.SourceKubernetes) var supportedProviders = makeSupportedProviders() @@ -145,7 +141,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d } instance := rego.New(regoOptions...) - set, err := instance.Eval(ctx) + resultSet, err := instance.Eval(ctx) if err != nil { return nil, nil, err } @@ -163,7 +159,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d traces = strings.Split(traceBuffer.String(), "\n") } } - return set, traces, nil + return resultSet, traces, nil } type Input struct { @@ -255,8 +251,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results, } func isPolicyWithSubtype(sourceType types.Source) bool { - _, exists := checkTypesWithSubtype[sourceType] - return exists + return checkTypesWithSubtype.Contains(sourceType) } func checkSubtype(ii map[string]any, provider string, subTypes []SubType) bool { @@ -327,12 +322,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs continue } - set, traces, err := s.runQuery(ctx, qualified, parsedInput, false) + resultSet, traces, err := s.runQuery(ctx, qualified, parsedInput, false) if err != nil { return nil, err } - s.trace("RESULTSET", set) - ruleResults := s.convertResults(set, input, namespace, rule, traces) + s.trace("RESULTSET", resultSet) + ruleResults := s.convertResults(resultSet, input, namespace, rule, traces) if len(ruleResults) == 0 { // It passed because we didn't find anything wrong (NOT because it didn't exist) var result regoResult result.FS = input.FS diff --git a/pkg/licensing/classifier.go b/pkg/licensing/classifier.go index 74f825f303a7..5f230624bcc9 100644 --- a/pkg/licensing/classifier.go +++ b/pkg/licensing/classifier.go @@ -12,6 +12,7 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" ) var ( @@ -43,7 +44,7 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic var findings types.LicenseFindings var matchType types.LicenseType - seen := make(map[string]struct{}) + seen := set.New[string]() // cf.Match is not thread safe m.Lock() @@ -57,11 +58,11 @@ func Classify(filePath string, r io.Reader, confidenceLevel float64) (*types.Lic if match.Confidence <= confidenceLevel { continue } - if _, ok := seen[match.Name]; ok { + if seen.Contains(match.Name) { continue } - seen[match.Name] = struct{}{} + seen.Append(match.Name) switch match.MatchType { case "Header": diff --git a/pkg/remote/remote_test.go b/pkg/remote/remote_test.go index fb64deb5d4c4..27ea8079153b 100644 --- a/pkg/remote/remote_test.go +++ b/pkg/remote/remote_test.go @@ -20,6 +20,7 @@ import ( "github.com/aquasecurity/testdocker/registry" "github.com/aquasecurity/testdocker/tarfile" "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/version/app" ) @@ -216,13 +217,13 @@ type userAgentsTrackingHandler struct { hr http.Handler mu sync.Mutex - agents map[string]struct{} + agents set.Set[string] } func newUserAgentsTrackingHandler(hr http.Handler) *userAgentsTrackingHandler { return &userAgentsTrackingHandler{ hr: hr, - agents: make(map[string]struct{}), + agents: set.New[string](), } } @@ -230,7 +231,7 @@ func (uh *userAgentsTrackingHandler) ServeHTTP(rw http.ResponseWriter, r *http.R for _, agent := range r.Header["User-Agent"] { // Skip test framework user agent if agent != "Go-http-client/1.1" { - uh.agents[agent] = struct{}{} + uh.agents.Append(agent) } } uh.hr.ServeHTTP(rw, r) @@ -271,7 +272,7 @@ func TestUserAgents(t *testing.T) { require.NoError(t, err) require.Len(t, tracker.agents, 1) - _, ok := tracker.agents[fmt.Sprintf("trivy/%s go-containerregistry", app.Version())] + ok := tracker.agents.Contains(fmt.Sprintf("trivy/%s go-containerregistry", app.Version())) require.True(t, ok, `user-agent header equals to "trivy/dev go-containerregistry"`) } diff --git a/pkg/report/table/vulnerability.go b/pkg/report/table/vulnerability.go index 450349ed5fcb..489004033a73 100644 --- a/pkg/report/table/vulnerability.go +++ b/pkg/report/table/vulnerability.go @@ -19,6 +19,7 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/version/doc" ) @@ -279,7 +280,7 @@ Dependency Origin Tree (Reversed) topLvlID := tml.Sprintf("%s, (%s)", vulnPkg.ID, strings.Join(summaries, ", ")) branch := root.AddBranch(topLvlID) - addParents(branch, vulnPkg, parents, ancestors, map[string]struct{}{vulnPkg.ID: {}}, 1) + addParents(branch, vulnPkg, parents, ancestors, set.New(vulnPkg.ID), 1) } r.printf(root.String()) @@ -291,17 +292,17 @@ func (r *vulnerabilityRenderer) printf(format string, args ...any) { } func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string, - seen map[string]struct{}, depth int) { + seen set.Set[string], depth int) { if pkg.Relationship == ftypes.RelationshipDirect { return } - roots := make(map[string]struct{}) + roots := set.New[string]() for _, parent := range parentMap[pkg.ID] { - if _, ok := seen[parent.ID]; ok { + if seen.Contains(parent.ID) { continue } - seen[parent.ID] = struct{}{} // to avoid infinite loops + seen.Append(parent.ID) // to avoid infinite loops if depth == 1 && parent.Relationship == ftypes.RelationshipDirect { topItem.AddBranch(parent.ID) @@ -309,15 +310,14 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string // We omit intermediate dependencies and show only direct dependencies // as this could make the dependency tree huge. for _, ancestor := range ancestors[parent.ID] { - roots[ancestor] = struct{}{} + roots.Append(ancestor) } } } // Omitted - rootIDs := lo.Filter(lo.Keys(roots), func(pkgID string, _ int) bool { - _, ok := seen[pkgID] - return !ok + rootIDs := lo.Filter(roots.Items(), func(pkgID string, _ int) bool { + return !seen.Contains(pkgID) }) sort.Strings(rootIDs) if len(rootIDs) > 0 { @@ -331,21 +331,21 @@ func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string { ancestors := make(map[string][]string) for _, pkg := range pkgs { - ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, make(map[string]struct{})) + ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, set.New[string]()) } return ancestors } -func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[string]struct{}) []string { - ancestors := make(map[string]struct{}) - seen[pkgID] = struct{}{} +func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen set.Set[string]) []string { + ancestors := set.New[string]() + seen.Append(pkgID) for _, parent := range parentMap[pkgID] { - if _, ok := seen[parent.ID]; ok { + if seen.Contains(parent.ID) { continue } switch { case parent.Relationship == ftypes.RelationshipDirect: - ancestors[parent.ID] = struct{}{} + ancestors.Append(parent.ID) case len(parentMap[parent.ID]) == 0: // Some package managers, such as "package-lock.json" v1, can retrieve package dependencies but not relationships. // We try to guess direct dependencies in this case. A dependency with no parents must be a direct dependency. @@ -358,14 +358,14 @@ func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen map[s // // Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency // as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency. - ancestors[parent.ID] = struct{}{} + ancestors.Append(parent.ID) default: for _, a := range findAncestor(parent.ID, parentMap, seen) { - ancestors[a] = struct{}{} + ancestors.Append(a) } } } - return lo.Keys(ancestors) + return ancestors.Items() } var jarExtensions = []string{ diff --git a/pkg/scanner/langpkg/scan.go b/pkg/scanner/langpkg/scan.go index df6068b3ddd3..565c3b2c07c1 100644 --- a/pkg/scanner/langpkg/scan.go +++ b/pkg/scanner/langpkg/scan.go @@ -9,6 +9,7 @@ import ( "github.com/aquasecurity/trivy/pkg/detector/library" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/types" ) @@ -41,7 +42,7 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types. } var results types.Results - printedTypes := make(map[ftypes.LangType]struct{}) + printedTypes := set.New[ftypes.LangType]() for _, app := range apps { if len(app.Packages) == 0 { continue @@ -76,13 +77,13 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types. return results, nil } -func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes map[ftypes.LangType]struct{}) ( +func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes set.Set[ftypes.LangType]) ( []types.DetectedVulnerability, error) { // Prevent the same log messages from being displayed many times for the same type. - if _, ok := printedTypes[app.Type]; !ok { + if !printedTypes.Contains(app.Type) { log.InfoContext(ctx, "Detecting vulnerabilities...") - printedTypes[app.Type] = struct{}{} + printedTypes.Append(app.Type) } log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.FilePath(app.FilePath)) diff --git a/pkg/scanner/local/scan.go b/pkg/scanner/local/scan.go index 7fd36fbb2643..1d117999ede1 100644 --- a/pkg/scanner/local/scan.go +++ b/pkg/scanner/local/scan.go @@ -24,6 +24,7 @@ import ( "github.com/aquasecurity/trivy/pkg/scanner/langpkg" "github.com/aquasecurity/trivy/pkg/scanner/ospkg" "github.com/aquasecurity/trivy/pkg/scanner/post" + "github.com/aquasecurity/trivy/pkg/set" "github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/vulnerability" @@ -458,12 +459,12 @@ func mergePkgs(pkgs, pkgsFromCommands []ftypes.Package, options types.ScanOption } // pkg has priority over pkgsFromCommands - uniqPkgs := make(map[string]struct{}) + uniqPkgs := set.New[string]() for _, pkg := range pkgs { - uniqPkgs[pkg.Name] = struct{}{} + uniqPkgs.Append(pkg.Name) } for _, pkg := range pkgsFromCommands { - if _, ok := uniqPkgs[pkg.Name]; ok { + if uniqPkgs.Contains(pkg.Name) { continue } pkgs = append(pkgs, pkg) From 360e6e0e0dc7b799a2d335325f1fb0536023fb71 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 22:16:02 +0400 Subject: [PATCH 08/10] refactor(set): remove Add() Signed-off-by: knqyf263 --- pkg/iac/rego/load.go | 2 +- pkg/iac/rego/scanner.go | 6 +-- pkg/set/set.go | 3 -- pkg/set/unsafe.go | 5 -- pkg/set/unsafe_test.go | 108 ++++++++++++++++++++-------------------- 5 files changed, 58 insertions(+), 66 deletions(-) diff --git a/pkg/iac/rego/load.go b/pkg/iac/rego/load.go index fc9d5cda58d8..dba0eda4917d 100644 --- a/pkg/iac/rego/load.go +++ b/pkg/iac/rego/load.go @@ -122,7 +122,7 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error { uniq := set.New[string]() for _, module := range s.policies { namespace := getModuleNamespace(module) - uniq.Add(namespace) + uniq.Append(namespace) } namespaces := uniq.Items() diff --git a/pkg/iac/rego/scanner.go b/pkg/iac/rego/scanner.go index 1bf3a550de93..70bb98dd6bf4 100644 --- a/pkg/iac/rego/scanner.go +++ b/pkg/iac/rego/scanner.go @@ -31,9 +31,9 @@ var supportedProviders = makeSupportedProviders() func makeSupportedProviders() set.Set[string] { m := set.New[string]() for _, p := range providers.AllProviders() { - m.Add(string(p)) + m.Append(string(p)) } - m.Add("kind") // kubernetes + m.Append("kind") // kubernetes return m } @@ -229,7 +229,7 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results, if usedRules.Contains(ruleName) { continue } - usedRules.Add(ruleName) + usedRules.Append(ruleName) if isEnforcedRule(ruleName) { ruleResults, err := s.applyRule(ctx, namespace, ruleName, inputs) if err != nil { diff --git a/pkg/set/set.go b/pkg/set/set.go index 878cc750a8f6..7dadcd3b414f 100644 --- a/pkg/set/set.go +++ b/pkg/set/set.go @@ -4,9 +4,6 @@ import "iter" // Set defines the interface for set operations type Set[T comparable] interface { - // Add adds an item to the set - Add(item T) - // Append adds multiple items to the set and returns the new size Append(val ...T) int diff --git a/pkg/set/unsafe.go b/pkg/set/unsafe.go index 7bfb3d9499c1..2dd8bd3bbb35 100644 --- a/pkg/set/unsafe.go +++ b/pkg/set/unsafe.go @@ -19,11 +19,6 @@ func New[T comparable](values ...T) Set[T] { return s } -// Add adds an item to the set -func (s unsafeSet[T]) Add(item T) { - s[item] = struct{}{} -} - // Append adds multiple items to the set and returns the new size func (s unsafeSet[T]) Append(val ...T) int { for _, item := range val { diff --git a/pkg/set/unsafe_test.go b/pkg/set/unsafe_test.go index 4c19e75c0ce6..b3bdfa353fe7 100644 --- a/pkg/set/unsafe_test.go +++ b/pkg/set/unsafe_test.go @@ -87,7 +87,7 @@ func Test_unsafeSet_Add(t *testing.T) { { name: "add duplicate integer", prepare: func(s set.Set[any]) { - s.Add(1) + s.Append(1) }, input: 1, wantSize: 1, @@ -127,7 +127,7 @@ func Test_unsafeSet_Add(t *testing.T) { if tt.prepare != nil { tt.prepare(s) } - s.Add(tt.input) + s.Append(tt.input) got := s.Size() assert.Equal(t, tt.wantSize, got, "unexpected set size") @@ -156,7 +156,7 @@ func Test_unsafeSet_Append(t *testing.T) { { name: "append with duplicates", prepare: func(s set.Set[int]) { - s.Add(1) + s.Append(1) }, input: []int{ 1, @@ -170,7 +170,7 @@ func Test_unsafeSet_Append(t *testing.T) { { name: "append empty slice", prepare: func(s set.Set[int]) { - s.Add(1) + s.Append(1) }, input: []int{}, wantSize: 1, @@ -205,7 +205,7 @@ func Test_unsafeSet_Remove(t *testing.T) { { name: "remove existing element", prepare: func(s set.Set[int]) { - s.Add(1) + s.Append(1) }, input: 1, wantSize: 0, @@ -213,7 +213,7 @@ func Test_unsafeSet_Remove(t *testing.T) { { name: "remove non-existing element", prepare: func(s set.Set[int]) { - s.Add(1) + s.Append(1) }, input: 2, wantSize: 1, @@ -249,9 +249,9 @@ func Test_unsafeSet_Clear(t *testing.T) { { name: "clear non-empty set", prepare: func(s set.Set[int]) { - s.Add(1) - s.Add(2) - s.Add(3) + s.Append(1) + s.Append(2) + s.Append(3) }, }, { @@ -283,7 +283,7 @@ func Test_unsafeSet_Clone(t *testing.T) { assert.Equal(t, 0, cloned.Size(), "cloned set should be empty") // Verify independence - original.Add("test") + original.Append("test") assert.False(t, cloned.Contains("test"), "cloned set should not be affected by original") }) @@ -297,16 +297,16 @@ func Test_unsafeSet_Clone(t *testing.T) { assert.True(t, cloned.Contains(true), "should contain boolean") // Verify independence - original.Add("new") + original.Append("new") assert.False(t, cloned.Contains("new"), "cloned set should not be affected by original") - cloned.Add("another") + cloned.Append("another") assert.False(t, original.Contains("another"), "original set should not be affected by clone") }) // Test nil pointer t.Run("nil pointer", func(t *testing.T) { original := set.New[*int]() - original.Add(nil) + original.Append(nil) cloned := original.Clone() @@ -324,9 +324,9 @@ func Test_unsafeSet_Items(t *testing.T) { { name: "get items from non-empty set", prepare: func(s set.Set[int]) { - s.Add(1) - s.Add(2) - s.Add(3) + s.Append(1) + s.Append(2) + s.Append(3) }, want: []int{ 1, @@ -364,12 +364,12 @@ func Test_unsafeSet_Union(t *testing.T) { { name: "union of non-overlapping sets", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, prepare2: func(s set.Set[int]) { - s.Add(3) - s.Add(4) + s.Append(3) + s.Append(4) }, want: []int{ 1, @@ -381,14 +381,14 @@ func Test_unsafeSet_Union(t *testing.T) { { name: "union of overlapping sets", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) - s.Add(3) + s.Append(1) + s.Append(2) + s.Append(3) }, prepare2: func(s set.Set[int]) { - s.Add(2) - s.Add(3) - s.Add(4) + s.Append(2) + s.Append(3) + s.Append(4) }, want: []int{ 1, @@ -400,8 +400,8 @@ func Test_unsafeSet_Union(t *testing.T) { { name: "union with empty set", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, prepare2: nil, want: []int{ @@ -447,14 +447,14 @@ func Test_unsafeSet_Intersection(t *testing.T) { { name: "intersection of overlapping sets", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) - s.Add(3) + s.Append(1) + s.Append(2) + s.Append(3) }, prepare2: func(s set.Set[int]) { - s.Add(2) - s.Add(3) - s.Add(4) + s.Append(2) + s.Append(3) + s.Append(4) }, want: []int{ 2, @@ -464,20 +464,20 @@ func Test_unsafeSet_Intersection(t *testing.T) { { name: "intersection of non-overlapping sets", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, prepare2: func(s set.Set[int]) { - s.Add(3) - s.Add(4) + s.Append(3) + s.Append(4) }, want: []int{}, }, { name: "intersection with empty set", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, prepare2: nil, want: []int{}, @@ -514,26 +514,26 @@ func Test_unsafeSet_Difference(t *testing.T) { { name: "difference of overlapping sets", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) - s.Add(3) + s.Append(1) + s.Append(2) + s.Append(3) }, prepare2: func(s set.Set[int]) { - s.Add(2) - s.Add(3) - s.Add(4) + s.Append(2) + s.Append(3) + s.Append(4) }, want: []int{1}, }, { name: "difference with non-overlapping set", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, prepare2: func(s set.Set[int]) { - s.Add(3) - s.Add(4) + s.Append(3) + s.Append(4) }, want: []int{ 1, @@ -543,8 +543,8 @@ func Test_unsafeSet_Difference(t *testing.T) { { name: "difference with empty set", prepare1: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, prepare2: nil, want: []int{ @@ -556,8 +556,8 @@ func Test_unsafeSet_Difference(t *testing.T) { name: "difference of empty set", prepare1: nil, prepare2: func(s set.Set[int]) { - s.Add(1) - s.Add(2) + s.Append(1) + s.Append(2) }, want: []int{}, }, From 91834d9a4269700386425e1c650124619278c9fa Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 22:37:10 +0400 Subject: [PATCH 09/10] chore(lint): add new rule for set Signed-off-by: knqyf263 --- misc/lint/rules.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/misc/lint/rules.go b/misc/lint/rules.go index 4f1b098535ec..591b31d1ef82 100644 --- a/misc/lint/rules.go +++ b/misc/lint/rules.go @@ -30,3 +30,8 @@ func errorsJoin(m dsl.Matcher) { m.Match(`errors.Join($*args)`). Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.") } + +func mapSet(m dsl.Matcher) { + m.Match(`map[$x]struct{}`). + Report("use github.com/aquasecurity/trivy/pkg/set.Set instead of map.") +} From 1885abcfc85e9e797ab6bf1a1d4fbbb66a23e7d0 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Fri, 20 Dec 2024 22:37:29 +0400 Subject: [PATCH 10/10] fix: lint errors Signed-off-by: knqyf263 --- pkg/fanal/image/daemon/image.go | 2 +- pkg/set/unsafe.go | 2 +- pkg/set/unsafe_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/fanal/image/daemon/image.go b/pkg/fanal/image/daemon/image.go index 3409e278d1c7..d650851ad242 100644 --- a/pkg/fanal/image/daemon/image.go +++ b/pkg/fanal/image/daemon/image.go @@ -226,7 +226,7 @@ func (img *image) imageConfig(config *container.Config) v1.Config { } if len(config.ExposedPorts) > 0 { - c.ExposedPorts = make(map[string]struct{}) + c.ExposedPorts = make(map[string]struct{}) //nolint: gocritic for port := range config.ExposedPorts { c.ExposedPorts[port.Port()] = struct{}{} } diff --git a/pkg/set/unsafe.go b/pkg/set/unsafe.go index 2dd8bd3bbb35..261492045715 100644 --- a/pkg/set/unsafe.go +++ b/pkg/set/unsafe.go @@ -8,7 +8,7 @@ import ( // unsafeSet represents a non-thread-safe set implementation // WARNING: This implementation is not thread-safe -type unsafeSet[T comparable] map[T]struct{} +type unsafeSet[T comparable] map[T]struct{} //nolint: gocritic // New creates a new empty non-thread-safe set with optional initial values func New[T comparable](values ...T) Set[T] { diff --git a/pkg/set/unsafe_test.go b/pkg/set/unsafe_test.go index b3bdfa353fe7..8f42c01666ea 100644 --- a/pkg/set/unsafe_test.go +++ b/pkg/set/unsafe_test.go @@ -270,7 +270,7 @@ func Test_unsafeSet_Clear(t *testing.T) { got := s.Size() assert.Zero(t, got, "unexpected set size") - assert.Equal(t, 0, len(s.Items()), "items should be empty") + assert.Empty(t, s.Items(), "items should be empty") }) } }