Skip to content

Commit

Permalink
support recursive collection of source files for provider using fs.FS
Browse files Browse the repository at this point in the history
  • Loading branch information
fmarmol committed Feb 27, 2024
1 parent 76946cc commit cd26a9e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 62 deletions.
2 changes: 1 addition & 1 deletion provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func newProvider(
// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
// we should make it optional.
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions, cfg.recursive)
if err != nil {
return nil, err
}
Expand Down
144 changes: 97 additions & 47 deletions provider_collect.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,54 @@ type fileSources struct {
goSources []Source
}

func checkFile(fullpath string, strict bool, excludePaths map[string]bool, excludeVersions map[int64]bool, versionToBaseLookup map[int64]string) (Source, bool, error) {
base := filepath.Base(fullpath)
if strings.HasSuffix(base, "_test.go") {
return Source{}, false, nil
}
if excludePaths[base] {
// TODO(mf): log this?
return Source{}, false, nil
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := NumericComponent(base)
if err != nil {
if strict {
return Source{}, false, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
return Source{}, false, nil
}
if excludeVersions[version] {
// TODO: log this?
return Source{}, false, nil
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return Source{}, false, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing,
base,
)
}
source := Source{Path: fullpath, Version: version}
switch filepath.Ext(base) {
case ".sql":
source.Type = TypeSQL
case ".go":
source.Type = TypeGo
default:
// Should never happen since we already filtered out all other file types.
return Source{}, false, fmt.Errorf("invalid file extension: %q", base)
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
return source, true, nil
}

// collectFilesystemSources scans the file system for migration files that have a numeric prefix
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
// be nil, in which case an empty fileSources is returned.
Expand All @@ -29,6 +77,7 @@ func collectFilesystemSources(
strict bool,
excludePaths map[string]bool,
excludeVersions map[int64]bool,
recursive bool,
) (*fileSources, error) {
if fsys == nil {
return new(fileSources), nil
Expand All @@ -39,65 +88,66 @@ func collectFilesystemSources(
"*.sql",
"*.go",
} {
files, err := fs.Glob(fsys, pattern)
files, err := func() ([]string, error) {
if recursive {
var files []string
err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
subFs, err := fs.Sub(fsys, path)
if err != nil {
return err
}
dirFiles, err := fs.Glob(subFs, pattern)
for _, file := range dirFiles {
files = append(files, filepath.Join(path, file))
}
}
return nil
})
if err != nil {
return nil, err
}
return files, nil
} else {
files, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
}
return files, nil
}
}()
if err != nil {
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
return nil, err
}
for _, fullpath := range files {
base := filepath.Base(fullpath)
if strings.HasSuffix(base, "_test.go") {
continue
}
if excludePaths[base] {
// TODO(mf): log this?
continue
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := NumericComponent(base)
source, isValid, err := checkFile(
fullpath,
strict,
excludePaths,
excludeVersions,
versionToBaseLookup,
)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
continue
return nil, err
}
if excludeVersions[version] {
// TODO: log this?
if !isValid {
continue
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing,
base,
)
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, Source{
Type: TypeSQL,
Path: fullpath,
Version: version,
})
case ".go":
sources.goSources = append(sources.goSources, Source{
Type: TypeGo,
Path: fullpath,
Version: version,
})
switch source.Type {
case TypeSQL:
sources.sqlSources = append(sources.sqlSources, source)
case TypeGo:
sources.goSources = append(sources.goSources, source)
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("invalid file extension: %q", base)
return nil, errors.New("unreachable")
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
}
}
return sources, nil

}

func newSQLMigration(source Source) *Migration {
Expand Down
48 changes: 34 additions & 14 deletions provider_collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ import (
func TestCollectFileSources(t *testing.T) {
t.Parallel()
t.Run("nil_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(nil, false, nil, nil)
sources, err := collectFilesystemSources(nil, false, nil, nil, false)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("noop_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil, false)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("empty_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
Expand All @@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) {
"00000_foo.sql": sqlMapFile,
}
// strict disable - should not error
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
// strict enabled - should error
_, err = collectFilesystemSources(mapFS, true, nil, nil)
_, err = collectFilesystemSources(mapFS, true, nil, nil, false)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration version must be greater than zero")
})
t.Run("collect", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
Expand Down Expand Up @@ -77,6 +77,7 @@ func TestCollectFileSources(t *testing.T) {
"00110_qux.sql": true,
},
nil,
false,
)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
Expand All @@ -97,7 +98,7 @@ func TestCollectFileSources(t *testing.T) {
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
_, err = collectFilesystemSources(fsys, true, nil, nil)
_, err = collectFilesystemSources(fsys, true, nil, nil, false)
check.HasError(t, err)
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
})
Expand All @@ -109,7 +110,7 @@ func TestCollectFileSources(t *testing.T) {
"4_qux.sql": sqlMapFile,
"5_foo_test.go": {Data: []byte(`package goose_test`)},
}
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
Expand All @@ -124,7 +125,7 @@ func TestCollectFileSources(t *testing.T) {
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
}
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1)
Expand All @@ -143,7 +144,8 @@ func TestCollectFileSources(t *testing.T) {
"001_foo.sql": sqlMapFile,
"01_bar.sql": sqlMapFile,
}
_, err := collectFilesystemSources(mapFS, false, nil, nil)

_, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
})
Expand All @@ -159,7 +161,7 @@ func TestCollectFileSources(t *testing.T) {
t.Helper()
f, err := fs.Sub(mapFS, dirpath)
check.NoError(t, err)
got, err := collectFilesystemSources(f, false, nil, nil)
got, err := collectFilesystemSources(f, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(got.sqlSources), len(sqlSources))
check.Number(t, len(got.goSources), 0)
Expand All @@ -180,6 +182,24 @@ func TestCollectFileSources(t *testing.T) {
})
assertDirpath("dir3", nil)
})
t.Run("recursive", func(t *testing.T) {
mapFS := fstest.MapFS{
"876_a.sql": sqlMapFile,
"dir1/101_a.sql": sqlMapFile,
"dir1/102_b.sql": sqlMapFile,
"dir1/103_c.sql": sqlMapFile,
"dir2/201_a.sql": sqlMapFile,
"dir2/dir3/301_a.sql": sqlMapFile,
}
sources, err := collectFilesystemSources(mapFS, false, nil, nil, true)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 6)
check.Equal(t, sources.sqlSources[0].Path, "876_a.sql")
check.Equal(t, sources.sqlSources[1].Path, "dir1/101_a.sql")
check.Equal(t, sources.sqlSources[2].Path, "dir1/102_b.sql")
check.Equal(t, sources.sqlSources[3].Path, "dir1/103_c.sql")
check.Equal(t, sources.sqlSources[4].Path, "dir2/201_a.sql")
})
}

func TestMerge(t *testing.T) {
Expand All @@ -195,7 +215,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 1)
check.Equal(t, len(sources.goSources), 2)
Expand Down Expand Up @@ -243,7 +263,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*Migration{
Expand All @@ -267,7 +287,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*Migration{
Expand Down
8 changes: 8 additions & 0 deletions provider_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ func WithDisableVersioning(b bool) ProviderOption {
})
}

func WithRecursive(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.recursive = b
return nil
})
}

type config struct {
store database.Store

Expand All @@ -184,6 +191,7 @@ type config struct {
disableVersioning bool
allowMissing bool
disableGlobalRegistry bool
recursive bool

// Let's not expose the Logger just yet. Ideally we consolidate on the std lib slog package
// added in go1.21 and then expose that (if that's even necessary). For now, just use the std
Expand Down

0 comments on commit cd26a9e

Please sign in to comment.