Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2.8]Fix controller generation to add test session to generated controllers #237

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
//go:generate go run pkg/codegen/generator/cleanup/main.go -mod vendor
//go:generate go run pkg/codegen/main.go -mod vendor
//go:generate rm -rf vendor
//go:generate go fmt ./...

package main
252 changes: 243 additions & 9 deletions pkg/codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@ package main

import (
"bytes"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"

"go/ast"
"go/parser"
"go/printer"
"go/token"

fleet "github.com/rancher/fleet/pkg/apis/fleet.cattle.io/v1alpha1"
"github.com/rancher/norman/types"
catalogv1 "github.com/rancher/rancher/pkg/apis/catalog.cattle.io/v1"
Expand All @@ -24,12 +32,23 @@ import (
capi "sigs.k8s.io/cluster-api/api/v1beta1"
)

// main initializes the code generation for controllers and clients.
// It generates clients for various API groups using controllergen.Run.
// It also generates clients for specific schemas using generator.GenerateClient.
// Then, it calls replaceClientBasePackages to replace imports in the generated clients.
// Finally, it replaces imports and adds controller test session for generated
func main() {
err := os.Unsetenv("GOPATH")
if err != nil {
return
}

generatedControllerPaths := map[string]string{
"AppsControllerPath": "./pkg/generated/controllers/apps",
"CoreControllerPath": "./pkg/generated/controllers/core",
"ManagementControllerPath": "./pkg/generated/controllers/management.cattle.io",
}

controllergen.Run(args.Options{
OutputPackage: "github.com/rancher/shepherd/pkg/generated",
Boilerplate: "pkg/codegen/boilerplate.go.txt",
Expand Down Expand Up @@ -135,9 +154,16 @@ func main() {
panic(err)
}

// Comment out this function to avoid replacing the imports in the management controllers
if err := replaceManagementControllerImports(); err != nil {
panic(err)
// Loop through all generated controller paths and replace imports
// and add test session
for _, path := range generatedControllerPaths {
if err := replaceImports(path); err != nil {
panic(err)
}

if err := addControllerTestSession(path); err != nil {
panic(err)
}
}
}

Expand Down Expand Up @@ -167,9 +193,29 @@ func replaceClientBasePackages() error {
})
}

// NOTE: Comment out this function to avoid replacing the imports in the management controllers
func replaceManagementControllerImports() error {
return filepath.Walk("./pkg/generated/controllers/management.cattle.io", func(path string, info os.FileInfo, err error) error {
// Walk through the generated controllers and add test session
// to necessary functions and structs
func addControllerTestSession(root string) error {
err := filepath.Walk(root, processInterfaceFile)
if err != nil {
return err
}
return nil
}

// replaceImports walks through the specified directory and replaces certain imports in Go files.
// It replaces the import "github.com/rancher/wrangler/v3/pkg/generic" with "github.com/rancher/shepherd/pkg/wrangler/pkg/generic"
// in all files ending with ".go".
// It also replaces specific function calls in files starting with "factory", and "interface".
// The replaced function calls are:
// - "New(c.ControllerFactory())" with "New(c.ControllerFactory(), c.Opts.TS)"
// - "controller.NewSharedControllerFactoryWithAgent(userAgent, c.ControllerFactory())" with "controller.NewSharedControllerFactoryWithAgent(userAgent, c.ControllerFactory()), c.Opts.TS"
// - "controller.SharedControllerFactory)" with "controller.SharedControllerFactory, ts *session.Session)"
// - "g.controllerFactory)" with "g.controllerFactory, g.ts)"
// - "v.controllerFactory)" with "v.controllerFactory, v.ts)"
// The function returns an error if there was a problem reading or writing files.
func replaceImports(dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
Expand All @@ -186,19 +232,39 @@ func replaceManagementControllerImports() error {
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}
}

if strings.HasPrefix(info.Name(), "factory") {
input, err := os.ReadFile(path)
if err != nil {
return err
}
replacement = bytes.Replace(input, []byte("New(c.ControllerFactory())"), []byte("New(c.ControllerFactory(), c.Opts.TS)"), -1)
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}
}

if strings.HasPrefix(info.Name(), "factory") {
input, err := os.ReadFile(path)
if err != nil {
return err
}
replacement = bytes.Replace(input, []byte("c.ControllerFactory())"), []byte("c.ControllerFactory(), c.Opts.TS)"), -1)
replacement = bytes.Replace(input, []byte("New(c.ControllerFactory())"), []byte("New(c.ControllerFactory(), c.Opts.TS)"), -1)
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}
}

if strings.HasPrefix(info.Name(), "factory") {
input, err := os.ReadFile(path)
if err != nil {
return err
}
replacement = bytes.Replace(input, []byte("controller.NewSharedControllerFactoryWithAgent(userAgent, c.ControllerFactory())"), []byte("controller.NewSharedControllerFactoryWithAgent(userAgent, c.ControllerFactory()), c.Opts.TS"), -1)
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}
}

if strings.HasPrefix(info.Name(), "interface") {
Expand All @@ -210,17 +276,185 @@ func replaceManagementControllerImports() error {
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}
input, err = os.ReadFile(path)
if err != nil {
return err
}
replacement = bytes.Replace(input, []byte("g.controllerFactory)"), []byte("g.controllerFactory, g.ts)"), -1)
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}

input, err = os.ReadFile(path)
if err != nil {
return err
}
replacement = bytes.Replace(input, []byte("v.controllerFactory)"), []byte("v.controllerFactory, v.ts)"), -1)
if err = os.WriteFile(path, replacement, 0666); err != nil {
return err
}
}

return nil
})
}

// Check if import already exists
func addImport(fset *token.FileSet, filename string, importPath string) error {
if importPath == "" {
return errors.New("empty import path")
}

node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
return err
}

// Check if session import is there and do nothing
for _, i := range node.Imports {
if i.Path.Value == importPath {
println("Import already included in file:", filename)
return nil
}
}

// Create a new import spec
newImport := &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: importPath,
},
}

// Insert the new import spec in the right place
found := false
for _, decl := range node.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
genDecl.Specs = append(genDecl.Specs, newImport)
found = true
break
}
}

// If no import declaration was found, create a new one
if !found {
node.Decls = append([]ast.Decl{
&ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
newImport,
},
},
}, node.Decls...)
}

var buf bytes.Buffer
err = printer.Fprint(&buf, fset, node)
if err != nil {
return err
}

err = os.WriteFile(filename, buf.Bytes(), 0666)
if err != nil {
return err
}
return nil
}

// processInterfaceFile processes the specified file and adds import, new struct line,
// and new function line to the specified blocks within the file.
func processInterfaceFile(path string, info os.FileInfo, err error) error {
const importPath = `"github.com/rancher/shepherd/pkg/session"`
const newStructLine = "\tts *session.Session"
const newFuncLine = "\t\tts: ts,"

if !info.IsDir() && strings.HasSuffix(info.Name(), "interface.go") {
fset := token.NewFileSet()
err := addImport(fset, path, importPath)
if err != nil {
return err
}

err = appendNewlineToBlockInFile(path, "group", newStructLine)
if err != nil {
return err
}

err = appendNewlineToBlockInFile(path, "&group", newFuncLine)
if err != nil {
return err
}

err = appendNewlineToBlockInFile(path, "version", newStructLine)
if err != nil {
return err
}

err = appendNewlineToBlockInFile(path, "&version", newFuncLine)
if err != nil {
return err
}
}
return nil
}

// appendNewlineToBlockInFile takes a path to a file, a code block(struct,return) in a file to update
// and the string to insert in a new line within the block
func appendNewlineToBlockInFile(filePath, blockName, newLine string) error {
// Read the file contents
content, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("error reading file: %v", err)
}

// Convert content to string and split into lines
input := string(content)
lines := strings.Split(input, "\n")
blockStart := -1
blockEnd := -1
braceCount := 0

// Create a regex for matching function declaration
var re *regexp.Regexp
re = regexp.MustCompile(fmt.Sprintf(`(?m)(type\s+%s\s+struct|\s*return\s*%s)\s*{\s*`, blockName, blockName))

// Find the start and end of the specified function
for i, line := range lines {
if blockStart == -1 {
if re.MatchString(line) {
blockStart = i
braceCount = 1
}
} else {
braceCount += strings.Count(line, "{") - strings.Count(line, "}")
if braceCount == 0 {
blockEnd = i
break
}
}
}

// If the target is found, insert the new line before the closing brace
if blockStart != -1 && blockEnd != -1 {
// Find the last non-empty line before the closing brace
insertPos := blockEnd
for i := blockEnd - 1; i > blockStart; i-- {
if strings.TrimSpace(lines[i]) != "" {
insertPos = i + 1
break
}
}

// Insert the new line
lines = append(lines[:insertPos], append([]string{newLine}, lines[insertPos:]...)...)

// Join the lines back together
modifiedContent := strings.Join(lines, "\n")

// Write the modified content back to the file
err = os.WriteFile(filePath, []byte(modifiedContent), 0644)
if err != nil {
return fmt.Errorf("error writing to file: %v", err)
}
}

return nil
}
2 changes: 1 addition & 1 deletion pkg/generated/clientset/versioned/clientset.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/generated/clientset/versioned/fake/doc.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/generated/clientset/versioned/fake/register.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/generated/clientset/versioned/scheme/doc.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/generated/clientset/versioned/scheme/register.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading