Skip to content

Commit

Permalink
Add flag to retrieve N most recent migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
cmpickle authored and Cameron Pickle committed Jun 22, 2021
1 parent f4a495f commit ee61e25
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Options:
migrations table name (default "goose_db_version")
-h print help
-v enable verbose mode
-n retrieve the N most recent migrations
-version
print version
Expand Down
4 changes: 4 additions & 0 deletions cmd/goose/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var (
version = flags.Bool("version", false, "print version")
certfile = flags.String("certfile", "", "file path to root CA's certificates in pem format (only support on mysql)")
sequential = flags.Bool("s", false, "use sequential numbering for new migrations")
recent = flags.Int("n", 0, "retrieve the N most recent migrations")
)

func main() {
Expand All @@ -34,6 +35,9 @@ func main() {
if *sequential {
goose.SetSequential(true)
}
if *recent > 0 {
goose.SetRecentLimit(*recent)
}
goose.SetTableName(*table)

args := flags.Args()
Expand Down
28 changes: 28 additions & 0 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ var (
MaxVersion int64 = 9223372036854775807 // max(int64)

registeredGoMigrations = map[int64]*Migration{}

// recent with value of 0 is no limit
recent = 0
)

// Migrations slice.
Expand Down Expand Up @@ -78,6 +81,20 @@ func (ms Migrations) Last() (*Migration, error) {
return ms[len(ms)-1], nil
}

// Last gets the last migration.
func (ms Migrations) LastN(n int) (Migrations, error) {
if len(ms) == 0 {
return nil, ErrNoNextVersion
}

start := len(ms) - n - 1
if start < 0 {
start = 0
}

return ms[start : len(ms)-1], nil
}

// Versioned gets versioned migrations.
func (ms Migrations) versioned() (Migrations, error) {
var migrations Migrations
Expand Down Expand Up @@ -199,6 +216,13 @@ func CollectMigrations(dirpath string, current, target int64) (Migrations, error
}
}

if recent > 0 {
migrations, err = migrations.LastN(recent)
if err != nil {
return nil, err
}
}

migrations = sortAndConnectMigrations(migrations)

return migrations, nil
Expand Down Expand Up @@ -317,3 +341,7 @@ func GetDBVersion(db *sql.DB) (int64, error) {

return version, nil
}

func SetRecentLimit(r int) {
recent = r
}

0 comments on commit ee61e25

Please sign in to comment.