diff --git a/job.go b/job.go index 9b6f2002..0239bfbc 100644 --- a/job.go +++ b/job.go @@ -243,6 +243,49 @@ func (j *Job) updateConnections() { conn = strings.Replace(conn, "AUTHTOKEN", url.QueryEscape(token), 1) } + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "pg://") { + u, err := url.Parse(conn) + var filteredDBs []string + if err != nil { + level.Error(j.log).Log("msg", "Failed to parse URL", "url", conn, "err", err) + continue + } + if strings.Contains(u.Path, "include") || strings.Contains(u.Path, "exclude") { + if strings.Contains(u.Path, "include") && strings.Contains(u.Path, "exclude") { + level.Error(j.log).Log("msg", "You cannot use exclude and include:", "url", conn, "err", err) + return + } else { + extractedPath := u.Path //save pattern + u.Path = "/postgres" + dsn := u.String() + databases, err := listDatabases(dsn) + if err != nil { + level.Error(j.log).Log("msg", "Error listing databases", "url", conn, "err", err) + continue + } + filteredDBs, err = filterDatabases(databases, extractedPath) + if err != nil { + level.Error(j.log).Log("msg", "Error filtering databases", "url", conn, "err", err) + continue + } + + for _, db := range filteredDBs { + u.Path = "/" + db // Set the path to the filtered database name + newUserDSN := u.String() + j.conns = append(j.conns, &connection{ + conn: nil, + url: newUserDSN, + driver: u.Scheme, + host: u.Host, + database: db, + user: u.User.Username(), + }) + } + continue + } + } + } + u, err := url.Parse(conn) if err != nil { level.Error(j.log).Log("msg", "Failed to parse URL", "url", conn, "err", err) diff --git a/postgresql.go b/postgresql.go new file mode 100644 index 00000000..f116a7bc --- /dev/null +++ b/postgresql.go @@ -0,0 +1,80 @@ +package main + +import ( + "database/sql" + "fmt" + "regexp" + "strings" + + _ "github.com/lib/pq" +) + +const ( + INCLUDE_DBS = "/include:" + EXCLUDE_DBS = "/exclude:" +) + +func listDatabases(connStr string) ([]string, error) { + + db, err := sql.Open("postgres", connStr) + if err != nil { + return nil, err + } + defer db.Close() + + rows, err := db.Query("SELECT datname FROM pg_database WHERE datistemplate = false;") + if err != nil { + return nil, err + } + defer rows.Close() + + var databases []string + for rows.Next() { + var dbname string + if err := rows.Scan(&dbname); err != nil { + return nil, err + } + databases = append(databases, dbname) + } + + return databases, nil +} + +func filterDatabases(databases []string, pattern string) ([]string, error) { + var filtered []string + mode, dbs := parsePattern(pattern) + + // Split the dbs string into individual patterns + dbPatterns := strings.Split(dbs, ",") + + // Process each database name against the patterns + for _, dbname := range databases { + include := false + + for _, dbPattern := range dbPatterns { + matched, err := regexp.MatchString(dbPattern, dbname) + if err != nil { + return nil, fmt.Errorf("invalid pattern: %s", dbPattern) + } + if matched { + include = true + break + } + } + + if (mode == INCLUDE_DBS && include) || (mode == EXCLUDE_DBS && !include) { + filtered = append(filtered, dbname) + } + } + + return filtered, nil +} + +func parsePattern(pattern string) (mode string, dbs string) { + if strings.HasPrefix(pattern, INCLUDE_DBS) { + return INCLUDE_DBS, pattern[len(INCLUDE_DBS):] + } else if strings.HasPrefix(pattern, EXCLUDE_DBS) { + return EXCLUDE_DBS, pattern[len(EXCLUDE_DBS):] + } + return "", "" +}