From f0bc5e012c094805cffdd536117431426d33244e Mon Sep 17 00:00:00 2001 From: "Eduardo R. Golding" Date: Mon, 18 Mar 2024 15:21:49 +0100 Subject: [PATCH] Allow using goImport as source for moqs --- internal/registry/registry.go | 12 +++ main.go | 144 ++++++++++++++++++++++++++++------ 2 files changed, 132 insertions(+), 24 deletions(-) diff --git a/internal/registry/registry.go b/internal/registry/registry.go index a237cdc..1f843b7 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -92,6 +92,10 @@ func (r *Registry) AddImport(pkg *types.Package) *Package { return nil } + if strings.Contains(pkg.Path(), "moq_force_local_interface_mode_") { + path = pkg.Imports()[0].Path() // moq_force_local_interface_mode only have 1 import dependency and is the one we are looking for + } + if imprt, ok := r.imports[path]; ok { return imprt } @@ -111,6 +115,14 @@ func (r *Registry) AddImport(pkg *types.Package) *Package { func (r Registry) Imports() []*Package { imports := make([]*Package, 0, len(r.imports)) for _, imprt := range r.imports { + if strings.Contains(imprt.Path(), "moq_force_local_interface_mode_") { + imports = append(imports, &Package{ + pkg: imprt.pkg.Imports()[0], // moq_force_local_interface_mode only have 1 dependency + Alias: imprt.pkg.Imports()[0].Name(), + }) + + continue + } imports = append(imports, imprt) } sort.Slice(imports, func(i, j int) bool { diff --git a/main.go b/main.go index 89adb3d..9a614bb 100644 --- a/main.go +++ b/main.go @@ -6,14 +6,14 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" + "os/exec" "path/filepath" + "text/template" "github.com/matryer/moq/pkg/moq" ) -// Version is the command version, injected at build time. var Version string = "dev" type userFlags struct { @@ -64,38 +64,38 @@ func main() { } func run(flags userFlags) error { - if len(flags.args) < 2 { + var ( + buf bytes.Buffer + out io.Writer = os.Stdout + srcDir string + ) + + if NotEnoughArguments(flags) { return errors.New("not enough arguments") } - if flags.remove && flags.outFile != "" { - if err := os.Remove(flags.outFile); err != nil { - if !errors.Is(err, os.ErrNotExist) { - return err - } + if ShouldRemoveFile(flags) { + if err := RemoveFile(flags.outFile); err != nil && !os.IsNotExist(err) { + return err } } - var buf bytes.Buffer - var out io.Writer = os.Stdout - if flags.outFile != "" { + if ShouldWriteToFile(flags) { out = &buf } - srcDir, args := flags.args[0], flags.args[1:] - m, err := moq.New(moq.Config{ - SrcDir: srcDir, - PkgName: flags.pkgName, - Formatter: flags.formatter, - StubImpl: flags.stubImpl, - SkipEnsure: flags.skipEnsure, - WithResets: flags.withResets, - }) + srcDir, cleanUp, err := getSourceDirectory(flags) if err != nil { return err } + defer cleanUp() - if err = m.Mock(out, args...); err != nil { + m, err := CreateMoq(flags, srcDir) + if err != nil { + return err + } + + if err := m.Mock(out, flags.args[1:]...); err != nil { return err } @@ -103,11 +103,107 @@ func run(flags userFlags) error { return nil } - // create the file - err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750) + if err := CreateOutputFile(flags.outFile, buf.Bytes()); err != nil { + return err + } + + return nil +} + +func getSourceDirectory(flags userFlags) (string, func(), error) { + if DirectoryExists(flags.args[0]) { + return flags.args[0], func() {}, nil + } + + cmd := exec.Command("go", "list", flag.Args()[0]) + output, err := cmd.CombinedOutput() + if err != nil { + return "", func() {}, fmt.Errorf("%s", output) + } + + pwd, err := os.Getwd() + if err != nil { + return "", func() {}, err + } + + tempDir, err := os.MkdirTemp(pwd, "moq_force_local_interface_mode_") + if err != nil { + return "", func() {}, err + } + + if err := GenerateMoqForceLocalInterface(flags, tempDir); err != nil { + return "", func() {}, err + } + + return tempDir, func() { os.RemoveAll(tempDir) }, nil + +} + +func NotEnoughArguments(flags userFlags) bool { + return len(flags.args) < 2 +} + +func ShouldRemoveFile(flags userFlags) bool { + return flags.remove && flags.outFile != "" +} + +func RemoveFile(filePath string) error { + return os.Remove(filePath) +} + +func ShouldWriteToFile(flags userFlags) bool { + return flags.outFile != "" +} + +func DirectoryExists(directoryPath string) bool { + _, err := os.Stat(directoryPath) + return err == nil +} + +func GenerateMoqForceLocalInterface(flags userFlags, tempDir string) error { + tmpl, err := template.New("force_local_interface").Parse(moqForceLocalInterface) + if err != nil { + return err + } + + var buf bytes.Buffer + + err = tmpl.Execute(&buf, map[string]interface{}{ + "SrcPkgQualifier": filepath.Base(flags.args[0]), + "Import": flags.args[0], + "InterfaceName": flags.args[1], + }) if err != nil { return err } - return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600) + return os.WriteFile(filepath.Join(tempDir, "moq_force_local_interface.go"), buf.Bytes(), 0600) } + +func CreateMoq(flags userFlags, srcDir string) (*moq.Mocker, error) { + return moq.New(moq.Config{ + SrcDir: srcDir, + PkgName: flags.pkgName, + Formatter: flags.formatter, + StubImpl: flags.stubImpl, + SkipEnsure: flags.skipEnsure, + WithResets: flags.withResets, + }) +} + +func CreateOutputFile(filePath string, data []byte) error { + if err := os.MkdirAll(filepath.Dir(filePath), 0750); err != nil { + return err + } + return os.WriteFile(filePath, data, 0600) +} + +const moqForceLocalInterface = `// Code generated by moq; DO NOT EDIT +// github.com/matryer/moq + +package {{.SrcPkgQualifier}} + +import "{{.Import}}" + +type {{.InterfaceName}} {{$.SrcPkgQualifier}}.{{.InterfaceName}} +`