Skip to content

Commit

Permalink
fixed base64 decoding issue in mysql and postgres
Browse files Browse the repository at this point in the history
Signed-off-by: adarsh-jaiss <[email protected]>
  • Loading branch information
Adarsh-jaiss committed May 22, 2024
1 parent 341fd99 commit 1e94bce
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 57 deletions.
95 changes: 80 additions & 15 deletions cli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,61 @@ var (
query string
)

// QueryResult represents the result of a database query.
type QueryResultInterface interface {
GetColumns() []string
GetRows() interface{}
GetTime() float64
GetError() string
}

type QueryResult struct {
Columns []string `json:"columns"`
Rows [][]interface{} `json:"rows"`
Time float64 `json:"time"`
Error string `json:"error"`
}

func (q QueryResult) GetColumns() []string {
return q.Columns
}

func (q QueryResult) GetRows() interface{} {
return q.Rows
}

func (q QueryResult) GetTime() float64 {
return q.Time
}

func (q QueryResult) GetError() string {
return q.Error
}

type BigQueryResult struct {
Columns []string `json:"columns"`
Rows []map[string]interface{} `json:"rows"`
Time int64 `json:"time"`
Error string `json:"error"`
}

func (b BigQueryResult) GetColumns() []string {
return b.Columns
}

func (b BigQueryResult) GetRows() interface{} {
return b.Rows
}

func (b BigQueryResult) GetTime() float64 {
return float64(b.Time)
}

func (b BigQueryResult) GetError() string {
return b.Error
}

// QueryResult represents the result of a database query.

// Command for interacting with databases
var shellCmd = &cobra.Command{
Use: "shell",
Expand Down Expand Up @@ -133,26 +180,44 @@ func queryExecute(query string, db xrayTypes.ISQL) error {
return fmt.Errorf("error executing query result: %s", err)
}

var result QueryResult
err = json.Unmarshal(b, &result)
if err != nil {
return fmt.Errorf("error parsing query result: %s", err)
var result QueryResultInterface
if dbType == "bigquery" {
result = &BigQueryResult{}
} else {
result = &QueryResult{}
}

if len(result.Rows) == 0 {

return fmt.Errorf("no results found")
err = json.Unmarshal(b, result)
if err != nil {
return fmt.Errorf("error parsing query result: %s", err)
}

table := tablewriter.NewWriter(os.Stdout)
table.SetHeader(result.Columns)
for _, row := range result.Rows {
stringRow := make([]string, len(row))
for i, v := range row {
stringRow[i] = fmt.Sprintf("%v", v)
table.SetHeader(result.GetColumns()) // Assert the type of result and call GetColumns() instead of Columns
switch rows := result.GetRows().(type) {
case [][]interface{}:
for _, row := range rows {
stringRow := make([]string, len(row))
for i, v := range row {
stringRow[i] = fmt.Sprintf("%v", v)
}

table.Append(stringRow)
}
case []map[string]interface{}:
for _, rowMap := range rows {
var stringRow []string
for _, v := range rowMap {
stringRow = append(stringRow, fmt.Sprintf("%v", v))
}
table.Append(stringRow)
}
default:
return fmt.Errorf("unexpected type of rows: %T", rows)
}

table.Append(stringRow)
if table.NumLines() == 0 {
return fmt.Errorf("no results found")
}

// Print the table
Expand Down Expand Up @@ -195,4 +260,4 @@ func parseDbType(s string) xrayTypes.DbType {
default:
return xrayTypes.MySQL
}
}
}
38 changes: 12 additions & 26 deletions databases/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"os"
"strings"

"encoding/base64"

_ "github.com/go-sql-driver/mysql"
"github.com/thesaas-company/xray/config"
"github.com/thesaas-company/xray/types"
Expand Down Expand Up @@ -137,20 +135,21 @@ func (m *MySQL) Execute(query string) ([]byte, error) {
return nil, fmt.Errorf("error scanning row: %v", err)
}

// Decode base64 data
for i, val := range values {
strVal, ok := val.(string)
if ok && isBase64(strVal) {
// Redecode the value to get the decoded result
decoded, err := base64.StdEncoding.DecodeString(strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
}
values[i] = string(decoded)
// Convert the values to the appropriate types
stringRow := make([]interface{}, len(values))
for i, v := range values {
switch value := v.(type) {
case []byte:
stringRow[i] = string(value)
case string:
stringRow[i] = value
default:
stringRow[i] = fmt.Sprintf("%v", value)
}
}

results = append(results, values)
// Append the modified row to the results
results = append(results, stringRow)
}

// Check for errors from iterating over rows
Expand All @@ -171,19 +170,6 @@ func (m *MySQL) Execute(query string) ([]byte, error) {
return jsonData, nil
}

// isBase64 checks if a string is a valid base64 string.
func isBase64(s string) bool {
if len(s)%4 != 0 {
return false
}
// Try to decode the string
_, err := base64.StdEncoding.DecodeString(s)
// If decoding succeeds, err will be nil, and the function will return true
// If decoding fails, err will not be nil, and the function will return false
// Also we do not have access to decoded value, so we are not using it
return err == nil
}

// Tables retrieves the list of tables in the given database.
// It takes the database name as an argument and returns a list of table names.
func (m *MySQL) Tables(databaseName string) ([]string, error) {
Expand Down
40 changes: 24 additions & 16 deletions databases/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,29 @@ func (p *Postgres) Execute(query string) ([]byte, error) {
}

// Decode base64 data
for _, val := range values {
strVal, ok := val.(*string)
if ok && strVal != nil && isBase64(*strVal) {
decoded, err := base64.StdEncoding.DecodeString(*strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
stringRow := make([]interface{}, len(values))
for i, val := range values {
switch v := val.(type) {
case []byte:
strVal := string(v)
if isBase64(strVal) {
decoded, err := base64.StdEncoding.DecodeString(strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
}
stringRow[i] = string(decoded)
} else {
stringRow[i] = strVal
}
*strVal = string(decoded)
case string:
stringRow[i] = v
case nil:
stringRow[i] = nil
default:
stringRow[i] = fmt.Sprintf("%v", v)
}
}

results = append(results, values)
results = append(results, stringRow)
}

// Check for errors from iterating over rows
Expand All @@ -203,15 +214,12 @@ func (p *Postgres) Execute(query string) ([]byte, error) {
}

func isBase64(s string) bool {
if len(s)%4 != 0 {
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return false
}
// Try to decode the string
_, err := base64.StdEncoding.DecodeString(s)
// If decoding succeeds, err will be nil, and the function will return true
// If decoding fails, err will not be nil, and the function will return false
// Also we do not have access to decoded value, so we are not using it
return err == nil
encoded := base64.StdEncoding.EncodeToString(decoded)
return s == encoded
}

// Tables returns a list of all tables in the given database.
Expand Down

0 comments on commit 1e94bce

Please sign in to comment.