Skip to content

Commit

Permalink
Allow using goImport as source for moqs
Browse files Browse the repository at this point in the history
  • Loading branch information
K4L1Ma committed Mar 19, 2024
1 parent 0bf2e8a commit f0bc5e0
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 24 deletions.
12 changes: 12 additions & 0 deletions internal/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func (r *Registry) AddImport(pkg *types.Package) *Package {
return nil
}

if strings.Contains(pkg.Path(), "moq_force_local_interface_mode_") {
path = pkg.Imports()[0].Path() // moq_force_local_interface_mode only have 1 import dependency and is the one we are looking for
}

if imprt, ok := r.imports[path]; ok {
return imprt
}
Expand All @@ -111,6 +115,14 @@ func (r *Registry) AddImport(pkg *types.Package) *Package {
func (r Registry) Imports() []*Package {
imports := make([]*Package, 0, len(r.imports))
for _, imprt := range r.imports {
if strings.Contains(imprt.Path(), "moq_force_local_interface_mode_") {
imports = append(imports, &Package{
pkg: imprt.pkg.Imports()[0], // moq_force_local_interface_mode only have 1 dependency
Alias: imprt.pkg.Imports()[0].Name(),
})

continue
}
imports = append(imports, imprt)
}
sort.Slice(imports, func(i, j int) bool {
Expand Down
144 changes: 120 additions & 24 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"text/template"

"github.com/matryer/moq/pkg/moq"
)

// Version is the command version, injected at build time.
var Version string = "dev"

type userFlags struct {
Expand Down Expand Up @@ -64,50 +64,146 @@ func main() {
}

func run(flags userFlags) error {
if len(flags.args) < 2 {
var (
buf bytes.Buffer
out io.Writer = os.Stdout
srcDir string
)

if NotEnoughArguments(flags) {
return errors.New("not enough arguments")
}

if flags.remove && flags.outFile != "" {
if err := os.Remove(flags.outFile); err != nil {
if !errors.Is(err, os.ErrNotExist) {
return err
}
if ShouldRemoveFile(flags) {
if err := RemoveFile(flags.outFile); err != nil && !os.IsNotExist(err) {
return err
}
}

var buf bytes.Buffer
var out io.Writer = os.Stdout
if flags.outFile != "" {
if ShouldWriteToFile(flags) {
out = &buf
}

srcDir, args := flags.args[0], flags.args[1:]
m, err := moq.New(moq.Config{
SrcDir: srcDir,
PkgName: flags.pkgName,
Formatter: flags.formatter,
StubImpl: flags.stubImpl,
SkipEnsure: flags.skipEnsure,
WithResets: flags.withResets,
})
srcDir, cleanUp, err := getSourceDirectory(flags)
if err != nil {
return err
}
defer cleanUp()

if err = m.Mock(out, args...); err != nil {
m, err := CreateMoq(flags, srcDir)
if err != nil {
return err
}

if err := m.Mock(out, flags.args[1:]...); err != nil {
return err
}

if flags.outFile == "" {
return nil
}

// create the file
err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750)
if err := CreateOutputFile(flags.outFile, buf.Bytes()); err != nil {
return err
}

return nil
}

func getSourceDirectory(flags userFlags) (string, func(), error) {
if DirectoryExists(flags.args[0]) {
return flags.args[0], func() {}, nil
}

cmd := exec.Command("go", "list", flag.Args()[0])
output, err := cmd.CombinedOutput()
if err != nil {
return "", func() {}, fmt.Errorf("%s", output)
}

pwd, err := os.Getwd()
if err != nil {
return "", func() {}, err
}

tempDir, err := os.MkdirTemp(pwd, "moq_force_local_interface_mode_")
if err != nil {
return "", func() {}, err
}

if err := GenerateMoqForceLocalInterface(flags, tempDir); err != nil {
return "", func() {}, err
}

return tempDir, func() { os.RemoveAll(tempDir) }, nil

}

func NotEnoughArguments(flags userFlags) bool {
return len(flags.args) < 2
}

func ShouldRemoveFile(flags userFlags) bool {
return flags.remove && flags.outFile != ""
}

func RemoveFile(filePath string) error {
return os.Remove(filePath)
}

func ShouldWriteToFile(flags userFlags) bool {
return flags.outFile != ""
}

func DirectoryExists(directoryPath string) bool {
_, err := os.Stat(directoryPath)
return err == nil
}

func GenerateMoqForceLocalInterface(flags userFlags, tempDir string) error {
tmpl, err := template.New("force_local_interface").Parse(moqForceLocalInterface)
if err != nil {
return err
}

var buf bytes.Buffer

err = tmpl.Execute(&buf, map[string]interface{}{
"SrcPkgQualifier": filepath.Base(flags.args[0]),
"Import": flags.args[0],
"InterfaceName": flags.args[1],
})
if err != nil {
return err
}

return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600)
return os.WriteFile(filepath.Join(tempDir, "moq_force_local_interface.go"), buf.Bytes(), 0600)
}

func CreateMoq(flags userFlags, srcDir string) (*moq.Mocker, error) {
return moq.New(moq.Config{
SrcDir: srcDir,
PkgName: flags.pkgName,
Formatter: flags.formatter,
StubImpl: flags.stubImpl,
SkipEnsure: flags.skipEnsure,
WithResets: flags.withResets,
})
}

func CreateOutputFile(filePath string, data []byte) error {
if err := os.MkdirAll(filepath.Dir(filePath), 0750); err != nil {
return err
}
return os.WriteFile(filePath, data, 0600)
}

const moqForceLocalInterface = `// Code generated by moq; DO NOT EDIT
// github.com/matryer/moq
package {{.SrcPkgQualifier}}
import "{{.Import}}"
type {{.InterfaceName}} {{$.SrcPkgQualifier}}.{{.InterfaceName}}
`

0 comments on commit f0bc5e0

Please sign in to comment.