diff --git a/NOTICE b/NOTICE index 664188d5..1e939d21 100644 --- a/NOTICE +++ b/NOTICE @@ -8,3 +8,7 @@ https://github.com/dropbox/godropbox/tree/master/database/sqlbuilder (BSD-3) This product contains a modified portion of 'snaker' which can be obtained at: https://github.com/serenize/snaker (MIT) + + +This product contains `FormatTimestamp` function from 'pq' which can be obtained at: +https://github.com/lib/pq (MIT) diff --git a/README.md b/README.md index 82acaded..56964300 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ convert database query result into desired arbitrary object structure. Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases. ![jet](https://github.com/go-jet/jet/wiki/image/jet.png) -Jet is the easiest and fastest way to write complex SQL queries and map database query result +Jet is the easiest and the fastest way to write complex SQL queries and map database query result into complex object composition. __It is not an ORM.__ ## Motivation @@ -46,7 +46,7 @@ https://medium.com/@go.jet/jet-5f3667efa0cc * UPDATE `(SET, WHERE)`, * DELETE `(WHERE, ORDER_BY, LIMIT)`, * LOCK `(READ, WRITE)` - 2) Auto-generated Data Model types - Go types mapped to database type (table or enum), used to store + 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store result of database queries. Can be combined to create desired query result destination. 3) Query execution with result mapping to arbitrary destination structure. @@ -88,12 +88,13 @@ jet -source=PostgreSQL -host=localhost -port=5432 -user=jetuser -password=jetpas ```sh Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable Retrieving schema information... - FOUND 15 table(s), 1 enum(s) -Destination directory: ./gen/jetdb/dvds -Cleaning up schema destination directory... + FOUND 15 table(s), 7 view(s), 1 enum(s) +Cleaning up destination directory... Generating table sql builder files... -Generating table model files... +Generating view sql builder files... Generating enum sql builder files... +Generating table model files... +Generating view model files... Generating enum model files... Done ``` @@ -102,9 +103,9 @@ be omitted (both databases doesn't have schema support). _*User has to have a permission to read information schema tables._ As command output suggest, Jet will: -- connect to postgres database and retrieve information about the _tables_ and _enums_ of `dvds` schema +- connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema - delete everything in schema destination folder - `./gen/jetdb/dvds`, -- and finally generate SQL Builder and Model files for each schema table and enum. +- and finally generate SQL Builder and Model files for each schema table, view and enum. Generated files folder structure will look like this: @@ -112,20 +113,24 @@ Generated files folder structure will look like this: |-- gen # -path | `-- jetdb # database name | `-- dvds # schema name -| |-- enum # sql builder folder for enums +| |-- enum # sql builder package for enums | | |-- mpaa_rating.go -| |-- table # sql builder folder for tables +| |-- table # sql builder package for tables | |-- actor.go | |-- address.go | |-- category.go | ... -| |-- model # model files for each table and enum +| |-- view # sql builder package for views +| |-- actor_info.go +| |-- film_list.go +| ... +| |-- model # data model types for each table, view and enum | | |-- actor.go | | |-- address.go | | |-- mpaa_rating.go | | ... ``` -Types from `table` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store +Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store results of the SQL queries. @@ -167,7 +172,8 @@ stmt := SELECT( Film.FilmID.ASC(), ) ``` -Package(dot) import is used so that statement would resemble as much as possible as native SQL. Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with +_Package(dot) import is used so that statement would resemble as much as possible as native SQL._ +Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns and can be compared only with integer columns and expressions. @@ -268,11 +274,12 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; #### Execute query and store result -Well formed SQL is just a first half the job. Lets see how can we make some sense of result set returned executing +Well formed SQL is just a first half of the job. Lets see how can we make some sense of result set returned executing above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest. -First we have to create desired structure to store query result set. -This is done be combining autogenerated model types or it can be done manually(see [wiki](https://github.com/go-jet/jet/wiki/Scan-to-arbitrary-destination) for more information). +First we have to create desired structure to store query result. +This is done be combining autogenerated model types or it can be done +manually(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)) for more information). Let's say this is our desired structure: ```go @@ -287,8 +294,8 @@ var dest []struct { } } ``` -Because one actor can act in multiple films, `Films` field is a slice, and because each film belongs to one language -`Langauge` field is just a single model struct. +`Films` field is a slice because one actor can act in multiple films, and because each film belongs to one language +`Langauge` field is just a single model struct. `Film` can belong to multiple categories. _*There is no limitation of how big or nested destination can be._ Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`. @@ -504,12 +511,14 @@ The biggest benefit is speed. Speed is improved in 3 major areas: ##### Speed of development -Writing SQL queries is much easier, because programmer has the help of SQL code completion and SQL type safety directly in Go. -Writing code is much faster and code is more robust. Automatic scan to arbitrary structure removes a lot of headache and -boilerplate code needed to structure database query result. +Writing SQL queries is faster and easier, because the developers have help of SQL code completion and SQL type safety directly from Go. +Automatic scan to arbitrary structure removes a lot of headache and boilerplate code needed to structure database query result. ##### Speed of execution +While ORM libraries can introduce significant performance penalties due to number of round-trips to the database, +Jet will always perform much better, because of the single database call. + Common web and database server usually are not on the same physical machine, and there is some latency between them. Latency can vary from 5ms to 50+ms. In majority of cases query executed on database is simple query lasting no more than 1ms. In those cases web server handler execution time is directly proportional to latency between server and database. @@ -521,14 +530,14 @@ With Jet, handler time lost on latency between server and database is constant. return result in one database call. Handler execution will be only proportional to the number of rows returned from database. ORM example replaced with jet will take just 30ms + 'result scan time' = 31ms (rough estimate). -With Jet you can even join the whole database and store the whole structured result in in one query call. +With Jet you can even join the whole database and store the whole structured result in one database call. This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40). The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. ##### How quickly bugs are found The most expensive bugs are the one on the production and the least expensive are those found during development. -With automatically generated type safe SQL not only queries are written faster but bugs are found sooner. +With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner. Lets return to quick start example, and take closer look at a line: ```go AND(Film.Length.GT(Int(180))), diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index b10906b3..0fe031ac 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -91,9 +91,7 @@ func jsonSave(path string, v interface{}) { err := ioutil.WriteFile(path, jsonText, 0644) - if err != nil { - panic(err) - } + panicOnError(err) } func printStatementInfo(stmt SelectStatement) { diff --git a/execution/execution.go b/execution/execution.go index 363b0f48..3281b033 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -771,7 +771,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { ret.subTypes = append(ret.subTypes, subType) } - } else if isPrimaryKey(field) { + } else if isPrimaryKey(field, parentField) { index := s.typeToColumnIndex(newTypeName, fieldName) if index < 0 { @@ -813,9 +813,7 @@ func (s *scanContext) rowElem(index int) interface{} { value, err := valuer.Value() - if err != nil { - panic(err) - } + utils.PanicOnError(err) return value } @@ -837,13 +835,45 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value { return newElem } -func isPrimaryKey(field reflect.StructField) bool { +func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool { + + if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite { + return isPrimaryKey + } sqlTag := field.Tag.Get("sql") return sqlTag == "primary_key" } +func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) { + if parentField == nil { + return + } + + sqlTag := parentField.Tag.Get("sql") + + if !strings.HasPrefix(sqlTag, "primary_key") { + return + } + + parts := strings.Split(sqlTag, "=") + + if len(parts) < 2 { + return + } + + primaryKeyColumns := strings.Split(parts[1], ",") + + for _, primaryKeyCol := range primaryKeyColumns { + if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) { + return true, true + } + } + + return true, false +} + func indirectType(reflectType reflect.Type) reflect.Type { if reflectType.Kind() != reflect.Ptr { return reflectType diff --git a/execution/internal/null_types.go b/execution/internal/null_types.go index c9fb45ed..5a39094a 100644 --- a/execution/internal/null_types.go +++ b/execution/internal/null_types.go @@ -62,7 +62,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) { nt.Time, nt.Valid = parseTime(v) return default: - return fmt.Errorf("can't scan time from %v", value) + return fmt.Errorf("can't scan time.Time from %v", value) } } diff --git a/execution/internal/null_types_test.go b/execution/internal/null_types_test.go new file mode 100644 index 00000000..70eb42ff --- /dev/null +++ b/execution/internal/null_types_test.go @@ -0,0 +1,147 @@ +package internal + +import ( + "fmt" + "gotest.tools/assert" + "testing" + "time" +) + +func TestNullByteArray(t *testing.T) { + var array NullByteArray + + assert.NilError(t, array.Scan(nil)) + assert.Equal(t, array.Valid, false) + + assert.NilError(t, array.Scan([]byte("bytea"))) + assert.Equal(t, array.Valid, true) + assert.Equal(t, string(array.ByteArray), string([]byte("bytea"))) + + assert.Error(t, array.Scan(12), "can't scan []byte from 12") +} + +func TestNullTime(t *testing.T) { + var array NullTime + + assert.NilError(t, array.Scan(nil)) + assert.Equal(t, array.Valid, false) + + time := time.Now() + assert.NilError(t, array.Scan(time)) + assert.Equal(t, array.Valid, true) + value, _ := array.Value() + assert.Equal(t, value, time) + + assert.NilError(t, array.Scan([]byte("13:10:11"))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") + + assert.NilError(t, array.Scan("13:10:11")) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") + + assert.Error(t, array.Scan(12), "can't scan time.Time from 12") +} + +func TestNullInt8(t *testing.T) { + var array NullInt8 + + assert.NilError(t, array.Scan(nil)) + assert.Equal(t, array.Valid, false) + + assert.NilError(t, array.Scan(int64(11))) + assert.Equal(t, array.Valid, true) + value, _ := array.Value() + assert.Equal(t, value, int8(11)) + + assert.Error(t, array.Scan("text"), "can't scan int8 from text") +} + +func TestNullInt16(t *testing.T) { + var array NullInt16 + + assert.NilError(t, array.Scan(nil)) + assert.Equal(t, array.Valid, false) + + assert.NilError(t, array.Scan(int64(11))) + assert.Equal(t, array.Valid, true) + value, _ := array.Value() + assert.Equal(t, value, int16(11)) + + assert.NilError(t, array.Scan(int16(20))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int16(20)) + + assert.NilError(t, array.Scan(int8(30))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int16(30)) + + assert.NilError(t, array.Scan(uint8(30))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int16(30)) + + assert.Error(t, array.Scan("text"), "can't scan int16 from text") +} + +func TestNullInt32(t *testing.T) { + var array NullInt32 + + assert.NilError(t, array.Scan(nil)) + assert.Equal(t, array.Valid, false) + + assert.NilError(t, array.Scan(int64(11))) + assert.Equal(t, array.Valid, true) + value, _ := array.Value() + assert.Equal(t, value, int32(11)) + + assert.NilError(t, array.Scan(int32(32))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int32(32)) + + assert.NilError(t, array.Scan(int16(20))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int32(20)) + + assert.NilError(t, array.Scan(uint16(16))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int32(16)) + + assert.NilError(t, array.Scan(int8(30))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int32(30)) + + assert.NilError(t, array.Scan(uint8(30))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, int32(30)) + + assert.Error(t, array.Scan("text"), "can't scan int32 from text") +} + +func TestNullFloat32(t *testing.T) { + var array NullFloat32 + + assert.NilError(t, array.Scan(nil)) + assert.Equal(t, array.Valid, false) + + assert.NilError(t, array.Scan(float64(64))) + assert.Equal(t, array.Valid, true) + value, _ := array.Value() + assert.Equal(t, value, float32(64)) + + assert.NilError(t, array.Scan(float32(32))) + assert.Equal(t, array.Valid, true) + value, _ = array.Value() + assert.Equal(t, value, float32(32)) + + assert.Error(t, array.Scan(12), "can't scan float32 from 12") +} diff --git a/generator/internal/metadata/column_meta_data.go b/generator/internal/metadata/column_meta_data.go index c1fdd100..69a16f73 100644 --- a/generator/internal/metadata/column_meta_data.go +++ b/generator/internal/metadata/column_meta_data.go @@ -142,13 +142,10 @@ func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string { return "" } -func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) { +func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData { rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() ret := []ColumnMetaData{} @@ -157,19 +154,13 @@ func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableN var name, isNullable, dataType, enumName string var isUnsigned bool err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned)) } err = rows.Err() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return ret, nil + return ret } diff --git a/generator/internal/metadata/dialect_query_set.go b/generator/internal/metadata/dialect_query_set.go index 6cc9834e..6c918257 100644 --- a/generator/internal/metadata/dialect_query_set.go +++ b/generator/internal/metadata/dialect_query_set.go @@ -11,5 +11,5 @@ type DialectQuerySet interface { ListOfColumnsQuery() string ListOfEnumsQuery() string - GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) + GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData } diff --git a/generator/internal/metadata/schema_meta_data.go b/generator/internal/metadata/schema_meta_data.go index 0a05d6a4..836745bc 100644 --- a/generator/internal/metadata/schema_meta_data.go +++ b/generator/internal/metadata/schema_meta_data.go @@ -3,41 +3,43 @@ package metadata import ( "database/sql" "fmt" + "github.com/go-jet/jet/internal/utils" ) // SchemaMetaData struct type SchemaMetaData struct { - TableInfos []MetaData - EnumInfos []MetaData + TablesMetaData []MetaData + ViewsMetaData []MetaData + EnumsMetaData []MetaData } -// GetSchemaInfo returns schema information from db connection. -func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) { - - schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName) +// IsEmpty returns true if schema info does not contain any table, views or enums metadata +func (s SchemaMetaData) IsEmpty() bool { + return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0 +} - if err != nil { - return - } +const ( + baseTable = "BASE TABLE" + view = "VIEW" +) - schemaInfo.EnumInfos, err = querySet.GetEnumsMetaData(db, schemaName) +// GetSchemaMetaData returns schema information from db connection. +func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) { - if err != nil { - return - } + schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable) + schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view) + schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName) - fmt.Println(" FOUND", len(schemaInfo.TableInfos), "table(s), ", len(schemaInfo.EnumInfos), "enum(s)") + fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),", + len(schemaInfo.EnumsMetaData), "enum(s)") return } -func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) { +func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData { - rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName) - - if err != nil { - return nil, err - } + rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType) + utils.PanicOnError(err) defer rows.Close() ret := []MetaData{} @@ -45,24 +47,15 @@ func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]M var tableName string err = rows.Scan(&tableName) - if err != nil { - return nil, err - } - - tableInfo, err := GetTableInfo(db, querySet, schemaName, tableName) + utils.PanicOnError(err) - if err != nil { - return nil, err - } + tableInfo := GetTableMetaData(db, querySet, schemaName, tableName) ret = append(ret, tableInfo) } err = rows.Err() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return ret, nil + return ret } diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go index c2f4b23b..cb738fa8 100644 --- a/generator/internal/metadata/table_meta_data.go +++ b/generator/internal/metadata/table_meta_data.go @@ -67,46 +67,32 @@ func (t TableMetaData) GoStructName() string { return utils.ToGoIdentifier(t.name) + "Table" } -// GetTableInfo returns table info metadata -func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) { +// GetTableMetaData returns table info metadata +func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { tableInfo.SchemaName = schemaName tableInfo.name = tableName - tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName) - if err != nil { - return - } - - tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName) - - if err != nil { - return - } + tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) + tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) return } -func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) { +func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool { rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) primaryKeyMap := map[string]bool{} for rows.Next() { primaryKey := "" err := rows.Scan(&primaryKey) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) primaryKeyMap[primaryKey] = true } - return primaryKeyMap, nil + return primaryKeyMap } diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index a16c1332..a076bb10 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -12,89 +12,65 @@ import ( ) // GenerateFiles generates Go files from tables and enums metadata -func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error { - if len(tables) == 0 && len(enums) == 0 { - return nil +func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) { + if schemaInfo.IsEmpty() { + return } fmt.Println("Destination directory:", destDir) fmt.Println("Cleaning up destination directory...") err := utils.CleanUpGeneratedFiles(destDir) + utils.PanicOnError(err) - if err != nil { - return err - } + generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) + generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) + generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) - fmt.Println("Generating table sql builder files...") - err = generate(destDir, "table", tableSQLBuilderTemplate, tables, dialect) + generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect) + generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect) + generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect) - if err != nil { - return err - } - - fmt.Println("Generating table model files...") - err = generate(destDir, "model", tableModelTemplate, tables, dialect) + fmt.Println("Done") +} - if err != nil { - return err +func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { + if len(metaData) == 0 { + return } + fmt.Printf("Generating %s sql builder files...\n", fileTypes) + generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect) +} - if len(enums) > 0 { - fmt.Println("Generating enum sql builder files...") - err = generate(destDir, "enum", enumSQLBuilderTemplate, enums, dialect) - - if err != nil { - return err - } - - fmt.Println("Generating enum model files...") - err = generate(destDir, "model", enumModelTemplate, enums, dialect) - - if err != nil { - return err - } +func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { + if len(metaData) == 0 { + return } - - fmt.Println("Done") - - return nil - + fmt.Printf("Generating %s model files...\n", fileTypes) + generateGoFiles(destDir, "model", modelTemplate, metaData, dialect) } -func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error { +func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) { modelDirPath := filepath.Join(dirPath, packageName) err := utils.EnsureDirPath(modelDirPath) - - if err != nil { - return err - } + utils.PanicOnError(err) autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect) - - if err != nil { - return err - } + utils.PanicOnError(err) for _, metaData := range metaDataList { - text, err := GenerateTemplate(template, metaData, dialect) - - if err != nil { - return err - } + text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName}) + utils.PanicOnError(err) err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...)) - - if err != nil { - return err - } + utils.PanicOnError(err) } - return nil + return } // GenerateTemplate generates template with template text and template data. -func GenerateTemplate(templateText string, templateData interface{}, dialect1 jet.Dialect) ([]byte, error) { +func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) { t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ "ToGoIdentifier": utils.ToGoIdentifier, @@ -102,7 +78,13 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect1 je return time.Now().Format(time.RFC850) }, "dialect": func() jet.Dialect { - return dialect1 + return dialect + }, + "param": func(name string) interface{} { + if len(params) > 0 { + return params[0][name] + } + return "" }, }).Parse(templateText) diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index c7c064f2..0a2a9c72 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -18,7 +18,7 @@ var tableSQLBuilderTemplate = ` {{- end}} {{- end}} -package table +package {{param "package"}} import ( "github.com/go-jet/jet/{{dialect.PackageName}}" diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index c8b01ee5..75405ea7 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -22,50 +22,34 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) error { - db, err := openConnection(dbConn) - if err != nil { - return err - } +func Generate(destDir string, dbConn DBConnection) (err error) { + defer utils.ErrorCatch(&err) + + db := openConnection(dbConn) defer utils.DBClose(db) fmt.Println("Retrieving database information...") // No schemas in MySQL - dbInfo, err := metadata.GetSchemaInfo(db, dbConn.DBName, &mySqlQuerySet{}) - - if err != nil { - return err - } + dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) genPath := path.Join(destDir, dbConn.DBName) - err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, mysql.Dialect) - - if err != nil { - return err - } + template.GenerateFiles(genPath, dbInfo, mysql.Dialect) return nil } -func openConnection(dbConn DBConnection) (*sql.DB, error) { +func openConnection(dbConn DBConnection) *sql.DB { var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) if dbConn.Params != "" { connectionString += "?" + dbConn.Params } - db, err := sql.Open("mysql", connectionString) - fmt.Println("Connecting to MySQL database: " + connectionString) - - if err != nil { - return nil, err - } + db, err := sql.Open("mysql", connectionString) + utils.PanicOnError(err) err = db.Ping() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return db, nil + return db } diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 20a01ac7..a1ad8ec8 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -3,6 +3,7 @@ package mysql import ( "database/sql" "github.com/go-jet/jet/generator/internal/metadata" + "github.com/go-jet/jet/internal/utils" "strings" ) @@ -13,7 +14,7 @@ func (m *mySqlQuerySet) ListOfTablesQuery() string { return ` SELECT table_name FROM INFORMATION_SCHEMA.tables -WHERE table_schema = ? and table_type = 'BASE TABLE'; +WHERE table_schema = ? and table_type = ?; ` } @@ -46,17 +47,14 @@ func (m *mySqlQuerySet) ListOfEnumsQuery() string { SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) FROM information_schema.columns as c INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) -WHERE c.table_schema = ? AND DATA_TYPE = 'enum' AND t.TABLE_TYPE = 'BASE TABLE'; +WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; ` } -func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { +func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() ret := []metadata.MetaData{} @@ -65,9 +63,7 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad var enumName string var enumValues string err = rows.Scan(&enumName, &enumValues) - if err != nil { - return nil, err - } + utils.PanicOnError(err) enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1) @@ -78,11 +74,8 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad } err = rows.Err() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return ret, nil + return ret } diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index ce2ea340..392a00be 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -25,31 +25,20 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) error { +func Generate(destDir string, dbConn DBConnection) (err error) { + defer utils.ErrorCatch(&err) db, err := openConnection(dbConn) + utils.PanicOnError(err) defer utils.DBClose(db) - if err != nil { - return err - } - fmt.Println("Retrieving schema information...") - schemaInfo, err := metadata.GetSchemaInfo(db, dbConn.SchemaName, &postgresQuerySet{}) - - if err != nil { - return err - } + schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) + template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) - err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, postgres.Dialect) - - if err != nil { - return err - } - - return nil + return } func openConnection(dbConn DBConnection) (*sql.DB, error) { diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index a4f1fdd4..ce4a0839 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "github.com/go-jet/jet/generator/internal/metadata" + "github.com/go-jet/jet/internal/utils" ) // postgresQuerySet is dialect query set for PostgreSQL @@ -12,7 +13,7 @@ func (p *postgresQuerySet) ListOfTablesQuery() string { return ` SELECT table_name FROM information_schema.tables -where table_schema = $1 and table_type = 'BASE TABLE'; +where table_schema = $1 and table_type = $2; ` } @@ -45,12 +46,9 @@ WHERE n.nspname = $1 ORDER BY n.nspname, t.typname, e.enumsortorder;` } -func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { +func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() enumsInfosMap := map[string][]string{} @@ -58,9 +56,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me var enumName string var enumValue string err = rows.Scan(&enumName, &enumValue) - if err != nil { - return nil, err - } + utils.PanicOnError(err) enumValues := enumsInfosMap[enumName] @@ -70,10 +66,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me } err = rows.Err() - - if err != nil { - return nil, err - } + utils.PanicOnError(err) ret := []metadata.MetaData{} @@ -84,5 +77,5 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me }) } - return ret, nil + return ret } diff --git a/internal/3rdparty/pq/format_timestamp.go b/internal/3rdparty/pq/format_timestamp.go new file mode 100644 index 00000000..9dcf541c --- /dev/null +++ b/internal/3rdparty/pq/format_timestamp.go @@ -0,0 +1,42 @@ +package pq + +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany + +import ( + "strconv" + "time" +) + +// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq +func FormatTimestamp(t time.Time) []byte { + // Need to send dates before 0001 A.D. with " BC" suffix, instead of the + // minus sign preferred by Go. + // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on + bc := false + if t.Year() <= 0 { + // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" + t = t.AddDate((-t.Year())*2+1, 0, 0) + bc = true + } + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) + + _, offset := t.Zone() + offset = offset % 60 + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) + } + + if bc { + b = append(b, " BC"...) + } + return b +} diff --git a/internal/3rdparty/pq/format_timestamp_test.go b/internal/3rdparty/pq/format_timestamp_test.go new file mode 100644 index 00000000..9cceba22 --- /dev/null +++ b/internal/3rdparty/pq/format_timestamp_test.go @@ -0,0 +1,39 @@ +package pq + +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany + +import ( + "testing" + "time" +) + +var formatTimeTests = []struct { + time time.Time + expected string +}{ + {time.Time{}, "0001-01-01 00:00:00Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"}, + + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"}, + {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"}, +} + +func TestFormatTs(t *testing.T) { + for i, tt := range formatTimeTests { + val := string(FormatTimestamp(tt.time)) + if val != tt.expected { + t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected) + } + } +} diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 5a3f1e93..738074bb 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -134,7 +134,8 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { // ClauseOrderBy struct type ClauseOrderBy struct { - List []OrderByClause + List []OrderByClause + SkipNewLine bool } // Serialize serializes clause into SQLBuilder @@ -143,7 +144,9 @@ func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) return } - out.NewLine() + if !o.SkipNewLine { + out.NewLine() + } out.WriteString("ORDER BY") out.IncreaseIdent() @@ -469,3 +472,37 @@ func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString(string(i.LockMode)) out.WriteString("MODE") } + +// WindowDefinition struct +type WindowDefinition struct { + Name string + Window Window +} + +// ClauseWindow struct +type ClauseWindow struct { + Definitions []WindowDefinition +} + +// Serialize serializes clause into SQLBuilder +func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { + if len(i.Definitions) == 0 { + return + } + + out.NewLine() + out.WriteString("WINDOW") + + for i, def := range i.Definitions { + if i > 0 { + out.WriteString(", ") + } + out.WriteString(def.Name) + out.WriteString("AS") + if def.Window == nil { + out.WriteString("()") + continue + } + def.Window.serialize(statementType, out) + } +} diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 2489b85c..91f200a0 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -81,68 +81,154 @@ func LOG(floatExpression FloatExpression) FloatExpression { // ----------------- Aggregate functions -------------------// // AVG is aggregate function used to calculate avg value from numeric expression -func AVG(numericExpression NumericExpression) FloatExpression { - return NewFloatFunc("AVG", numericExpression) +func AVG(numericExpression NumericExpression) floatWindowExpression { + return NewFloatWindowFunc("AVG", numericExpression) } // BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. -func BIT_AND(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("BIT_AND", integerExpression) +func BIT_AND(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("BIT_AND", integerExpression) } // BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none. -func BIT_OR(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("BIT_OR", integerExpression) +func BIT_OR(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("BIT_OR", integerExpression) } // BOOL_AND is aggregate function. Returns true if all input values are true, otherwise false -func BOOL_AND(boolExpression BoolExpression) BoolExpression { - return newBoolFunc("BOOL_AND", boolExpression) +func BOOL_AND(boolExpression BoolExpression) boolWindowExpression { + return newBoolWindowFunc("BOOL_AND", boolExpression) } // BOOL_OR is aggregate function. Returns true if at least one input value is true, otherwise false -func BOOL_OR(boolExpression BoolExpression) BoolExpression { - return newBoolFunc("BOOL_OR", boolExpression) +func BOOL_OR(boolExpression BoolExpression) boolWindowExpression { + return newBoolWindowFunc("BOOL_OR", boolExpression) } // COUNT is aggregate function. Returns number of input rows for which the value of expression is not null. -func COUNT(expression Expression) IntegerExpression { - return newIntegerFunc("COUNT", expression) +func COUNT(expression Expression) integerWindowExpression { + return newIntegerWindowFunc("COUNT", expression) } // EVERY is aggregate function. Returns true if all input values are true, otherwise false -func EVERY(boolExpression BoolExpression) BoolExpression { - return newBoolFunc("EVERY", boolExpression) +func EVERY(boolExpression BoolExpression) boolWindowExpression { + return newBoolWindowFunc("EVERY", boolExpression) } // MAXf is aggregate function. Returns maximum value of float expression across all input values -func MAXf(floatExpression FloatExpression) FloatExpression { - return NewFloatFunc("MAX", floatExpression) +func MAXf(floatExpression FloatExpression) floatWindowExpression { + return NewFloatWindowFunc("MAX", floatExpression) } // MAXi is aggregate function. Returns maximum value of int expression across all input values -func MAXi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("MAX", integerExpression) +func MAXi(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("MAX", integerExpression) } // MINf is aggregate function. Returns minimum value of float expression across all input values -func MINf(floatExpression FloatExpression) FloatExpression { - return NewFloatFunc("MIN", floatExpression) +func MINf(floatExpression FloatExpression) floatWindowExpression { + return NewFloatWindowFunc("MIN", floatExpression) } // MINi is aggregate function. Returns minimum value of int expression across all input values -func MINi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("MIN", integerExpression) +func MINi(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("MIN", integerExpression) } // SUMf is aggregate function. Returns sum of expression across all float expressions -func SUMf(floatExpression FloatExpression) FloatExpression { - return NewFloatFunc("SUM", floatExpression) +func SUMf(floatExpression FloatExpression) floatWindowExpression { + return NewFloatWindowFunc("SUM", floatExpression) } // SUMi is aggregate function. Returns sum of expression across all integer expression. -func SUMi(integerExpression IntegerExpression) IntegerExpression { - return newIntegerFunc("SUM", integerExpression) +func SUMi(integerExpression IntegerExpression) integerWindowExpression { + return newIntegerWindowFunc("SUM", integerExpression) +} + +// ----------------- Window functions -------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +func ROW_NUMBER() integerWindowExpression { + return newIntegerWindowFunc("ROW_NUMBER") +} + +// RANK of the current row with gaps; same as row_number of its first peer +func RANK() integerWindowExpression { + return newIntegerWindowFunc("RANK") +} + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +func DENSE_RANK() integerWindowExpression { + return newIntegerWindowFunc("DENSE_RANK") +} + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +func PERCENT_RANK() floatWindowExpression { + return NewFloatWindowFunc("PERCENT_RANK") +} + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +func CUME_DIST() floatWindowExpression { + return NewFloatWindowFunc("CUME_DIST") +} + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +func NTILE(numOfBuckets int64) integerWindowExpression { + return newIntegerWindowFunc("NTILE", FixedLiteral(numOfBuckets)) +} + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +func LAG(expr Expression, offsetAndDefault ...interface{}) windowExpression { + return leadLagImpl("LAG", expr, offsetAndDefault...) +} + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +func LEAD(expr Expression, offsetAndDefault ...interface{}) windowExpression { + return leadLagImpl("LEAD", expr, offsetAndDefault...) +} + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +func FIRST_VALUE(value Expression) windowExpression { + return newWindowFunc("FIRST_VALUE", value) +} + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +func LAST_VALUE(value Expression) windowExpression { + return newWindowFunc("LAST_VALUE", value) +} + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +func NTH_VALUE(value Expression, nth int64) windowExpression { + return newWindowFunc("NTH_VALUE", value, FixedLiteral(nth)) +} + +func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) windowExpression { + params := []Expression{expr} + + if len(offsetAndDefault) >= 2 { + offset, ok := offsetAndDefault[0].(int) + if !ok { + panic("jet: LAG offset should be an integer") + } + + var defaultValue Expression + + defaultValue, ok = offsetAndDefault[1].(Expression) + + if !ok { + defaultValue = literal(offsetAndDefault[1]) + } + + params = append(params, FixedLiteral(offset), defaultValue) + } + + return newWindowFunc(name, params...) } //------------ String functions ------------------// @@ -349,7 +435,7 @@ func TO_HEX(number IntegerExpression) StringExpression { // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression { if len(matchType) > 0 { - return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0])) + return newBoolFunc("REGEXP_LIKE", stringExp, pattern, FixedLiteral(matchType[0])) } return newBoolFunc("REGEXP_LIKE", stringExp, pattern) @@ -391,7 +477,7 @@ func CURRENT_TIME(precision ...int) TimezExpression { var timezFunc *timezFunc if len(precision) > 0 { - timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0])) + timezFunc = newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0])) } else { timezFunc = newTimezFunc("CURRENT_TIME") } @@ -406,7 +492,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression { var timestampzFunc *timestampzFunc if len(precision) > 0 { - timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0])) + timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0])) } else { timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") } @@ -421,7 +507,7 @@ func LOCALTIME(precision ...int) TimeExpression { var timeFunc *timeFunc if len(precision) > 0 { - timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0])) + timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0])) } else { timeFunc = newTimeFunc("LOCALTIME") } @@ -436,7 +522,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression { var timestampFunc *timestampFunc if len(precision) > 0 { - timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0])) + timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0])) } else { timestampFunc = NewTimestampFunc("LOCALTIMESTAMP") } @@ -504,6 +590,16 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } +// NewFloatWindowFunc creates new float function with name and expressions +func newWindowFunc(name string, expressions ...Expression) windowExpression { + + newFun := newFunc(name, expressions, nil) + windowExpr := newWindowExpression(newFun) + newFun.expressionInterfaceImpl.Parent = windowExpr + + return windowExpr +} + func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { serializeOverrideFunc := serializeOverride(f.expressions...) @@ -536,10 +632,23 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression { boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.boolInterfaceImpl.parent = boolFunc + boolFunc.expressionInterfaceImpl.Parent = boolFunc return boolFunc } +// NewFloatWindowFunc creates new float function with name and expressions +func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { + boolFunc := &boolFunc{} + + boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + intWindowFunc := newBoolWindowExpression(boolFunc) + boolFunc.boolInterfaceImpl.parent = intWindowFunc + boolFunc.expressionInterfaceImpl.Parent = intWindowFunc + + return intWindowFunc +} + type floatFunc struct { funcExpressionImpl floatInterfaceImpl @@ -555,6 +664,18 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression { return floatFunc } +// NewFloatWindowFunc creates new float function with name and expressions +func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { + floatFunc := &floatFunc{} + + floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatWindowFunc := newFloatWindowExpression(floatFunc) + floatFunc.floatInterfaceImpl.parent = floatWindowFunc + floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc + + return floatWindowFunc +} + type integerFunc struct { funcExpressionImpl integerInterfaceImpl @@ -569,6 +690,18 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { return floatFunc } +// NewFloatWindowFunc creates new float function with name and expressions +func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { + integerFunc := &integerFunc{} + + integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) + intWindowFunc := newIntegerWindowExpression(integerFunc) + integerFunc.integerInterfaceImpl.parent = intWindowFunc + integerFunc.expressionInterfaceImpl.Parent = intWindowFunc + + return intWindowFunc +} + type stringFunc struct { funcExpressionImpl stringInterfaceImpl diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 4851fa67..68fb4298 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -32,8 +32,8 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl return &exp } -// ConstLiteral is injected directly to SQL query, and does not appear in argument list. -func ConstLiteral(value interface{}) *literalExpressionImpl { +// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. +func FixedLiteral(value interface{}) *literalExpressionImpl { exp := literal(value) exp.constant = true diff --git a/internal/jet/serializer_test.go b/internal/jet/serializer_test.go index ad84b96d..6d2fd4a5 100644 --- a/internal/jet/serializer_test.go +++ b/internal/jet/serializer_test.go @@ -11,16 +11,9 @@ func TestArgToString(t *testing.T) { assert.Equal(t, argToString(true), "TRUE") assert.Equal(t, argToString(false), "FALSE") - assert.Equal(t, argToString(int8(-8)), "-8") - assert.Equal(t, argToString(int16(-16)), "-16") assert.Equal(t, argToString(int(-32)), "-32") assert.Equal(t, argToString(int32(-32)), "-32") assert.Equal(t, argToString(int64(-64)), "-64") - assert.Equal(t, argToString(uint8(8)), "8") - assert.Equal(t, argToString(uint16(16)), "16") - assert.Equal(t, argToString(uint(32)), "32") - assert.Equal(t, argToString(uint32(32)), "32") - assert.Equal(t, argToString(uint64(64)), "64") assert.Equal(t, argToString(float64(1.11)), "1.11") assert.Equal(t, argToString("john"), "'john'") @@ -31,5 +24,12 @@ func TestArgToString(t *testing.T) { time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") assert.NilError(t, err) assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") - assert.Equal(t, argToString(map[string]bool{}), "[Unsupported type]") + + func() { + defer func() { + assert.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter") + }() + + argToString(map[string]bool{}) + }() } diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index d16aec4b..4eaf6263 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -2,8 +2,11 @@ package jet import ( "bytes" + "fmt" + "github.com/go-jet/jet/internal/3rdparty/pq" "github.com/go-jet/jet/internal/utils" "github.com/google/uuid" + "reflect" "strconv" "strings" "time" @@ -139,28 +142,13 @@ func argToString(value interface{}) string { return "TRUE" } return "FALSE" - case int8: - return strconv.FormatInt(int64(bindVal), 10) case int: return strconv.FormatInt(int64(bindVal), 10) - case int16: - return strconv.FormatInt(int64(bindVal), 10) case int32: return strconv.FormatInt(int64(bindVal), 10) case int64: return strconv.FormatInt(bindVal, 10) - case uint8: - return strconv.FormatUint(uint64(bindVal), 10) - case uint: - return strconv.FormatUint(uint64(bindVal), 10) - case uint16: - return strconv.FormatUint(uint64(bindVal), 10) - case uint32: - return strconv.FormatUint(uint64(bindVal), 10) - case uint64: - return strconv.FormatUint(uint64(bindVal), 10) - case float32: return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) case float64: @@ -173,9 +161,9 @@ func argToString(value interface{}) string { case uuid.UUID: return stringQuote(bindVal.String()) case time.Time: - return stringQuote(string(utils.FormatTimestamp(bindVal))) + return stringQuote(string(pq.FormatTimestamp(bindVal))) default: - return "[Unsupported type]" + panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) } } diff --git a/internal/jet/table.go b/internal/jet/table.go index bf8285ff..43790007 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -19,15 +19,17 @@ type Table interface { } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable { +func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable { + + columnList := append([]ColumnExpression{column}, columns...) t := tableImpl{ schemaName: schemaName, name: name, - columnList: columns, + columnList: columnList, } - for _, c := range columns { + for _, c := range columnList { c.setTableName(name) } diff --git a/internal/jet/table_test.go b/internal/jet/table_test.go new file mode 100644 index 00000000..30182bc1 --- /dev/null +++ b/internal/jet/table_test.go @@ -0,0 +1,33 @@ +package jet + +import ( + "gotest.tools/assert" + "testing" +) + +func TestNewTable(t *testing.T) { + newTable := NewTable("schema", "table", IntegerColumn("intCol")) + + assert.Equal(t, newTable.SchemaName(), "schema") + assert.Equal(t, newTable.TableName(), "table") + + assert.Equal(t, len(newTable.columns()), 1) + assert.Equal(t, newTable.columns()[0].Name(), "intCol") +} + +func TestNewJoinTable(t *testing.T) { + newTable1 := NewTable("schema", "table", IntegerColumn("intCol1")) + newTable2 := NewTable("schema", "table2", IntegerColumn("intCol2")) + + joinTable := NewJoinTable(newTable1, newTable2, InnerJoin, IntegerColumn("intCol1").EQ(IntegerColumn("intCol2"))) + + assertClauseSerialize(t, joinTable, `schema.table +INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`) + + assert.Equal(t, joinTable.SchemaName(), "schema") + assert.Equal(t, joinTable.TableName(), "") + + assert.Equal(t, len(joinTable.columns()), 2) + assert.Equal(t, joinTable.columns()[0].Name(), "intCol1") + assert.Equal(t, joinTable.columns()[1].Name(), "intCol2") +} diff --git a/internal/jet/window_expression.go b/internal/jet/window_expression.go new file mode 100644 index 00000000..3e7f1c71 --- /dev/null +++ b/internal/jet/window_expression.go @@ -0,0 +1,146 @@ +package jet + +type commonWindowImpl struct { + expression Expression + window Window +} + +func (w *commonWindowImpl) over(window ...Window) { + if len(window) > 0 { + w.window = window[0] + } else { + w.window = newWindowImpl(nil) + } +} + +func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + w.expression.serialize(statement, out) + if w.window != nil { + out.WriteString("OVER") + w.window.serialize(statement, out) + } +} + +// -------------------------------------- + +type windowExpression interface { + Expression + OVER(window ...Window) Expression +} + +func newWindowExpression(Exp Expression) windowExpression { + newExp := &windowExpressionImpl{ + Expression: Exp, + } + + newExp.commonWindowImpl.expression = Exp + + return newExp +} + +type windowExpressionImpl struct { + Expression + commonWindowImpl +} + +func (f *windowExpressionImpl) OVER(window ...Window) Expression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} + +// ----------------------------------------------------- + +type floatWindowExpression interface { + FloatExpression + OVER(window ...Window) FloatExpression +} + +func newFloatWindowExpression(floatExp FloatExpression) floatWindowExpression { + newExp := &floatWindowExpressionImpl{ + FloatExpression: floatExp, + } + + newExp.commonWindowImpl.expression = floatExp + + return newExp +} + +type floatWindowExpressionImpl struct { + FloatExpression + commonWindowImpl +} + +func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} + +// ------------------------------------------------ + +type integerWindowExpression interface { + IntegerExpression + OVER(window ...Window) IntegerExpression +} + +func newIntegerWindowExpression(intExp IntegerExpression) integerWindowExpression { + newExp := &integerWindowExpressionImpl{ + IntegerExpression: intExp, + } + + newExp.commonWindowImpl.expression = intExp + + return newExp +} + +type integerWindowExpressionImpl struct { + IntegerExpression + commonWindowImpl +} + +func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} + +// ------------------------------------------------ + +type boolWindowExpression interface { + BoolExpression + OVER(window ...Window) BoolExpression +} + +func newBoolWindowExpression(boolExp BoolExpression) boolWindowExpression { + newExp := &boolWindowExpressionImpl{ + BoolExpression: boolExp, + } + + newExp.commonWindowImpl.expression = boolExp + + return newExp +} + +type boolWindowExpressionImpl struct { + BoolExpression + commonWindowImpl +} + +func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression { + f.commonWindowImpl.over(window...) + return f +} + +func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + f.commonWindowImpl.serialize(statement, out) +} diff --git a/internal/jet/window_func.go b/internal/jet/window_func.go new file mode 100644 index 00000000..7f4d1b72 --- /dev/null +++ b/internal/jet/window_func.go @@ -0,0 +1,186 @@ +package jet + +// Window interface +type Window interface { + Serializer + ORDER_BY(expr ...OrderByClause) Window + ROWS(start FrameExtent, end ...FrameExtent) Window + RANGE(start FrameExtent, end ...FrameExtent) Window + GROUPS(start FrameExtent, end ...FrameExtent) Window +} + +type windowImpl struct { + partitionBy []Expression + orderBy ClauseOrderBy + frameUnits string + start, end FrameExtent + + parent Window +} + +func newWindowImpl(parent Window) *windowImpl { + newWindow := &windowImpl{} + if parent == nil { + newWindow.parent = newWindow + } else { + newWindow.parent = parent + } + + return newWindow +} + +func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if !contains(options, noWrap) { + out.WriteByte('(') + } + + if w.partitionBy != nil { + out.WriteString("PARTITION BY") + + serializeExpressionList(statement, w.partitionBy, ", ", out) + } + w.orderBy.SkipNewLine = true + w.orderBy.Serialize(statement, out) + + if w.frameUnits != "" { + out.WriteString(w.frameUnits) + + if w.end == nil { + w.start.serialize(statement, out) + } else { + out.WriteString("BETWEEN") + w.start.serialize(statement, out) + out.WriteString("AND") + w.end.serialize(statement, out) + } + } + + if !contains(options, noWrap) { + out.WriteByte(')') + } +} + +func (w *windowImpl) ORDER_BY(exprs ...OrderByClause) Window { + w.orderBy.List = exprs + return w.parent +} + +func (w *windowImpl) ROWS(start FrameExtent, end ...FrameExtent) Window { + w.frameUnits = "ROWS" + w.setFrameRange(start, end...) + return w.parent +} + +func (w *windowImpl) RANGE(start FrameExtent, end ...FrameExtent) Window { + w.frameUnits = "RANGE" + w.setFrameRange(start, end...) + return w.parent +} + +func (w *windowImpl) GROUPS(start FrameExtent, end ...FrameExtent) Window { + w.frameUnits = "GROUPS" + w.setFrameRange(start, end...) + return w.parent +} + +func (w *windowImpl) setFrameRange(start FrameExtent, end ...FrameExtent) { + w.start = start + if len(end) > 0 { + w.end = end[0] + } +} + +// PARTITION_BY window function constructor +func PARTITION_BY(exp Expression, exprs ...Expression) Window { + funImpl := newWindowImpl(nil) + funImpl.partitionBy = append([]Expression{exp}, exprs...) + return funImpl +} + +// ORDER_BY window function constructor +func ORDER_BY(expr ...OrderByClause) Window { + funImpl := newWindowImpl(nil) + funImpl.orderBy.List = expr + return funImpl +} + +// ----------------------------------------------- + +// FrameExtent interface +type FrameExtent interface { + Serializer + isFrameExtent() +} + +// PRECEDING window frame clause +func PRECEDING(offset Serializer) FrameExtent { + return &frameExtentImpl{ + preceding: true, + offset: offset, + } +} + +// FOLLOWING window frame clause +func FOLLOWING(offset Serializer) FrameExtent { + return &frameExtentImpl{ + preceding: false, + offset: offset, + } +} + +type frameExtentImpl struct { + preceding bool + offset Serializer +} + +func (f *frameExtentImpl) isFrameExtent() {} + +func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if f == nil { + return + } + f.offset.serialize(statement, out) + + if f.preceding { + out.WriteString("PRECEDING") + } else { + out.WriteString("FOLLOWING") + } +} + +// ----------------------------------------------- + +// Window function keywords +var ( + UNBOUNDED = keywordClause("UNBOUNDED") + CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"} +) + +type frameExtentKeyword struct { + keywordClause +} + +func (f frameExtentKeyword) isFrameExtent() {} + +// ----------------------------------------------- + +// WindowName is used to specify window reference from WINDOW clause +func WindowName(name string) Window { + newWindow := &windowName{name: name} + newWindow.parent = newWindow + return newWindow +} + +type windowName struct { + windowImpl + name string +} + +func (w windowName) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteByte('(') + + out.WriteString(w.name) + w.windowImpl.serialize(statement, out, noWrap) + + out.WriteByte(')') +} diff --git a/internal/jet/window_func_test.go b/internal/jet/window_func_test.go new file mode 100644 index 00000000..74ae9e97 --- /dev/null +++ b/internal/jet/window_func_test.go @@ -0,0 +1,21 @@ +package jet + +import "testing" + +func TestFrameExtent(t *testing.T) { + assertClauseSerialize(t, PRECEDING(Int(2)), "$1 PRECEDING", int64(2)) + assertClauseSerialize(t, FOLLOWING(Int(4)), "$1 FOLLOWING", int64(4)) +} + +func TestWindowFunctions(t *testing.T) { + assertClauseSerialize(t, PARTITION_BY(table1Col1), "(PARTITION BY table1.col1)") + assertClauseSerialize(t, PARTITION_BY(table1Col3).ORDER_BY(table1Col1), "(PARTITION BY table1.col3 ORDER BY table1.col1)") + assertClauseSerialize(t, ORDER_BY(table1Col1), "(ORDER BY table1.col1)") + assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1))), "(ORDER BY table1.col1 ROWS $1 PRECEDING)", int64(1)) + assertClauseSerialize(t, ORDER_BY(table1Col1).ROWS(PRECEDING(Int(1)), FOLLOWING(Int(33))), + "(ORDER BY table1.col1 ROWS BETWEEN $1 PRECEDING AND $2 FOLLOWING)", int64(1), int64(33)) + assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + "(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)") + assertClauseSerialize(t, ORDER_BY(table1Col1).RANGE(PRECEDING(UNBOUNDED), CURRENT_ROW), + "(ORDER BY table1.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)") +} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index d5dbbd66..714442c3 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/go-jet/jet/execution" "github.com/go-jet/jet/internal/jet" + "github.com/go-jet/jet/internal/utils" "gotest.tools/assert" "io/ioutil" "os" @@ -60,9 +61,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) err := ioutil.WriteFile(filePath, jsonText, 0644) - if err != nil { - panic(err) - } + utils.PanicOnError(err) } // AssertJSONFile check if data json representation is the same as json at testRelativePath @@ -159,3 +158,31 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest stmt.Query(db, dest) } + +// AssertFileContent check if file content at filePath contains expectedContent text. +func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { + enumFileData, err := ioutil.ReadFile(filePath) + + assert.NilError(t, err) + + beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) + + //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") + + assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) +} + +// AssertFileNamesEqual check if all filesInfos are contained in fileNames +func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { + assert.Equal(t, len(fileInfos), len(fileNames)) + + fileNamesMap := map[string]bool{} + + for _, fileInfo := range fileInfos { + fileNamesMap[fileInfo.Name()] = true + } + + for _, fileName := range fileNames { + assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") + } +} diff --git a/internal/testutils/time_utils.go b/internal/testutils/time_utils.go index 8348a6cf..5c628028 100644 --- a/internal/testutils/time_utils.go +++ b/internal/testutils/time_utils.go @@ -1,6 +1,7 @@ package testutils import ( + "github.com/go-jet/jet/internal/utils" "strings" "time" ) @@ -9,9 +10,7 @@ import ( func Date(t string) *time.Time { newTime, err := time.Parse("2006-01-02", t) - if err != nil { - panic(err) - } + utils.PanicOnError(err) return &newTime } @@ -27,9 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time { newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") - if err != nil { - panic(err) - } + utils.PanicOnError(err) return &newTime } @@ -38,9 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time { func TimeWithoutTimeZone(t string) *time.Time { newTime, err := time.Parse("15:04:05", t) - if err != nil { - panic(err) - } + utils.PanicOnError(err) return &newTime } @@ -49,9 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time { func TimeWithTimeZone(t string) *time.Time { newTimez, err := time.Parse("15:04:05 -0700", t) - if err != nil { - panic(err) - } + utils.PanicOnError(err) return &newTimez } @@ -67,9 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time { newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) - if err != nil { - panic(err) - } + utils.PanicOnError(err) return &newTime } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 9694cbb2..a091973e 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,14 +2,13 @@ package utils import ( "database/sql" + "fmt" "github.com/go-jet/jet/internal/3rdparty/snaker" "go/format" "os" "path/filepath" "reflect" - "strconv" "strings" - "time" ) // ToGoIdentifier converts database to Go identifier. @@ -104,44 +103,11 @@ func DirExists(path string) (bool, error) { func replaceInvalidChars(str string) string { str = strings.Replace(str, " ", "_", -1) str = strings.Replace(str, "-", "_", -1) + str = strings.Replace(str, ".", "_", -1) return str } -// FormatTimestamp formats t into Postgres' text format for timestamps. From: github.com/lib/pq -func FormatTimestamp(t time.Time) []byte { - // Need to send dates before 0001 A.D. with " BC" suffix, instead of the - // minus sign preferred by Go. - // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on - bc := false - if t.Year() <= 0 { - // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" - t = t.AddDate((-t.Year())*2+1, 0, 0) - bc = true - } - b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) - - _, offset := t.Zone() - offset = offset % 60 - if offset != 0 { - // RFC3339Nano already printed the minus sign - if offset < 0 { - offset = -offset - } - - b = append(b, ':') - if offset < 10 { - b = append(b, '0') - } - b = strconv.AppendInt(b, int64(offset), 10) - } - - if bc { - b = append(b, " BC"...) - } - return b -} - // IsNil check if v is nil func IsNil(v interface{}) bool { return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) @@ -174,3 +140,27 @@ func MustBeInitializedPtr(val interface{}, errorStr string) { panic(errorStr) } } + +// PanicOnError panics if err is not nil +func PanicOnError(err error) { + if err != nil { + panic(err) + } +} + +// ErrorCatch is used in defer to recover from panics and to set err +func ErrorCatch(err *error) { + recovered := recover() + + if recovered == nil { + return + } + + recoveredErr, isError := recovered.(error) + + if isError { + *err = recoveredErr + } else { + *err = fmt.Errorf("%v", recovered) + } +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index a41a2250..8ee2d497 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "gotest.tools/assert" "testing" ) @@ -23,3 +24,27 @@ func TestToGoIdentifier(t *testing.T) { assert.Equal(t, ToGoIdentifier("My Table"), "MyTable") assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") } + +func TestErrorCatchErr(t *testing.T) { + var err error + + func() { + defer ErrorCatch(&err) + + panic(fmt.Errorf("newError")) + }() + + assert.Error(t, err, "newError") +} + +func TestErrorCatchNonErr(t *testing.T) { + var err error + + func() { + defer ErrorCatch(&err) + + panic(11) + }() + + assert.Error(t, err, "11") +} diff --git a/mysql/dialect.go b/mysql/dialect.go index d93081cd..45509a7a 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -10,13 +10,13 @@ var Dialect = newDialect() func newDialect() jet.Dialect { operatorSerializeOverrides := map[string]jet.SerializeOverride{} - operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysql_REGEXP_LIKE_operator - operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysql_NOT_REGEXP_LIKE_operator - operatorSerializeOverrides["IS DISTINCT FROM"] = mysql_IS_DISTINCT_FROM - operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysql_IS_NOT_DISTINCT_FROM - operatorSerializeOverrides["/"] = mysql_DIVISION - operatorSerializeOverrides["#"] = mysql_BIT_XOR - operatorSerializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator + operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator + operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator + operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM + operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM + operatorSerializeOverrides["/"] = mysqlDivision + operatorSerializeOverrides["#"] = mysqlBitXor + operatorSerializeOverrides[jet.StringConcatOperator] = mysqlCONCAToperator mySQLDialectParams := jet.DialectParams{ Name: "MySQL", @@ -32,7 +32,7 @@ func newDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } -func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator XOR") @@ -49,7 +49,7 @@ func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator CONCAT") @@ -66,7 +66,7 @@ func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator DIV") @@ -90,7 +90,7 @@ func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -102,15 +102,15 @@ func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc } } -func mysql_IS_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { out.WriteString("NOT(") - mysql_IS_NOT_DISTINCT_FROM(expressions...)(statement, out, options...) + mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...) out.WriteString(")") } } -func mysql_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -136,7 +136,7 @@ func mysql_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc } } -func mysql_NOT_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") diff --git a/mysql/functions.go b/mysql/functions.go index 2c911eae..0064d9d3 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -85,6 +85,47 @@ var SUMi = jet.SUMi // SUMf is aggregate function. Returns sum of float expression. var SUMf = jet.SUMf +// -------------------- Window functions -----------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +var ROW_NUMBER = jet.ROW_NUMBER + +// RANK of the current row with gaps; same as row_number of its first peer +var RANK = jet.RANK + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +var DENSE_RANK = jet.DENSE_RANK + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +var PERCENT_RANK = jet.PERCENT_RANK + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +var CUME_DIST = jet.CUME_DIST + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +var NTILE = jet.NTILE + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LAG = jet.LAG + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LEAD = jet.LEAD + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +var FIRST_VALUE = jet.FIRST_VALUE + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +var LAST_VALUE = jet.LAST_VALUE + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +var NTH_VALUE = jet.NTH_VALUE + //--------------------- String functions ------------------// // BIT_LENGTH returns number of bits in string expression @@ -181,7 +222,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampExpression { // NOW returns current datetime func NOW(fsp ...int) DateTimeExpression { if len(fsp) > 0 { - return jet.NewTimestampFunc("NOW", jet.ConstLiteral(int64(fsp[0]))) + return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0]))) } return jet.NewTimestampFunc("NOW") } diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 5622f3e8..33478881 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -1,6 +1,8 @@ package mysql -import "github.com/go-jet/jet/internal/jet" +import ( + "github.com/go-jet/jet/internal/jet" +) // RowLock is interface for SELECT statement row lock types type RowLock = jet.RowLock @@ -11,6 +13,27 @@ var ( SHARE = jet.NewRowLock("SHARE") ) +// Window function clauses +var ( + PARTITION_BY = jet.PARTITION_BY + ORDER_BY = jet.ORDER_BY + UNBOUNDED = jet.UNBOUNDED + CURRENT_ROW = jet.CURRENT_ROW +) + +// PRECEDING window frame clause +func PRECEDING(offset interface{}) jet.FrameExtent { + return jet.PRECEDING(toJetFrameOffset(offset)) +} + +// FOLLOWING window frame clause +func FOLLOWING(offset interface{}) jet.FrameExtent { + return jet.FOLLOWING(toJetFrameOffset(offset)) +} + +// Window is used to specify window reference from WINDOW clause +var Window = jet.WindowName + // SelectStatement is interface for MySQL SELECT statement type SelectStatement interface { Statement @@ -22,6 +45,7 @@ type SelectStatement interface { WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement + WINDOW(name string) windowExpand ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement @@ -42,7 +66,7 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) newSelect.Select.Projections = toJetProjectionList(projections) @@ -66,6 +90,7 @@ type selectStatementImpl struct { Where jet.ClauseWhere GroupBy jet.ClauseGroupBy Having jet.ClauseHaving + Window jet.ClauseWindow OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit Offset jet.ClauseOffset @@ -98,6 +123,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem return s } +func (s *selectStatementImpl) WINDOW(name string) windowExpand { + s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{selectStatement: s} +} + func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { s.OrderBy.List = orderByClauses return s @@ -126,3 +156,31 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { func (s *selectStatementImpl) AsTable(alias string) SelectTable { return newSelectTable(s, alias) } + +//----------------------------------------------------- + +type windowExpand struct { + selectStatement *selectStatementImpl +} + +func (w windowExpand) AS(window ...jet.Window) SelectStatement { + if len(window) == 0 { + return w.selectStatement + } + windowsDefinition := w.selectStatement.Window.Definitions + windowsDefinition[len(windowsDefinition)-1].Window = window[0] + return w.selectStatement +} + +func toJetFrameOffset(offset interface{}) jet.Serializer { + if offset == UNBOUNDED { + return jet.UNBOUNDED + } + + // check for interval expression + //if exp, ok := offset.(Expression); ok { + // return exp + //} + + return jet.FixedLiteral(offset) +} diff --git a/mysql/table.go b/mysql/table.go index a4cf0421..6d414a23 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectU } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { +func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - SerializerTable: jet.NewTable(schemaName, name, columns...), + SerializerTable: jet.NewTable(schemaName, name, column, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/postgres/dialect.go b/postgres/dialect.go index 0fbfd9ea..114e5a64 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -11,8 +11,8 @@ var Dialect = newDialect() func newDialect() jet.Dialect { operatorSerializeOverrides := map[string]jet.SerializeOverride{} - operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgres_REGEXP_LIKE_operator - operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgres_NOT_REGEXP_LIKE_operator + operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgresREGEXPLIKEoperator + operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgresNOTREGEXPLIKEoperator operatorSerializeOverrides["CAST"] = postgresCAST dialectParams := jet.DialectParams{ @@ -54,7 +54,7 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { } } -func postgres_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { +func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -80,7 +80,7 @@ func postgres_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeF } } -func postgres_NOT_REGEXP_LIKE_operator(expressions ...jet.Expression) jet.SerializeFunc { +func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") diff --git a/postgres/functions.go b/postgres/functions.go index 18a637f6..4f657ca7 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -87,6 +87,47 @@ var SUMf = jet.SUMf // SUMi is aggregate function. Returns sum of expression across all integer expression. var SUMi = jet.SUMi +// -------------------- Window functions -----------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +var ROW_NUMBER = jet.ROW_NUMBER + +// RANK of the current row with gaps; same as row_number of its first peer +var RANK = jet.RANK + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +var DENSE_RANK = jet.DENSE_RANK + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +var PERCENT_RANK = jet.PERCENT_RANK + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +var CUME_DIST = jet.CUME_DIST + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +var NTILE = jet.NTILE + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LAG = jet.LAG + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LEAD = jet.LEAD + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +var FIRST_VALUE = jet.FIRST_VALUE + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +var LAST_VALUE = jet.LAST_VALUE + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +var NTH_VALUE = jet.NTH_VALUE + //--------------------- String functions ------------------// // BIT_LENGTH returns number of bits in string expression diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 34225e6a..e4aeb371 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -1,6 +1,9 @@ package postgres -import "github.com/go-jet/jet/internal/jet" +import ( + "github.com/go-jet/jet/internal/jet" + "math" +) // RowLock is interface for SELECT statement row lock types type RowLock = jet.RowLock @@ -13,6 +16,27 @@ var ( KEY_SHARE = jet.NewRowLock("KEY SHARE") ) +// Window function clauses +var ( + PARTITION_BY = jet.PARTITION_BY + ORDER_BY = jet.ORDER_BY + UNBOUNDED = int64(math.MaxInt64) + CURRENT_ROW = jet.CURRENT_ROW +) + +// PRECEDING window frame clause +func PRECEDING(offset int64) jet.FrameExtent { + return jet.PRECEDING(toJetFrameOffset(offset)) +} + +// FOLLOWING window frame clause +func FOLLOWING(offset int64) jet.FrameExtent { + return jet.FOLLOWING(toJetFrameOffset(offset)) +} + +// Window definition reference +var Window = jet.WindowName + // SelectStatement is interface for PostgreSQL SELECT statement type SelectStatement interface { Statement @@ -24,6 +48,7 @@ type SelectStatement interface { WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement + WINDOW(name string) windowExpand ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement @@ -47,15 +72,9 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.Limit, &newSelect.Offset, &newSelect.For) - // statementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - // &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, - // &newSelect.Limit, &newSelect.Offset, &newSelect.For) - // - //newSelect.expressionStatementImpl.expressionInterfaceImpl.Parent = newSelect - newSelect.Select.Projections = toJetProjectionList(projections) newSelect.From.Table = table newSelect.Limit.Count = -1 @@ -75,6 +94,7 @@ type selectStatementImpl struct { Where jet.ClauseWhere GroupBy jet.ClauseGroupBy Having jet.ClauseHaving + Window jet.ClauseWindow OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit Offset jet.ClauseOffset @@ -106,6 +126,11 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem return s } +func (s *selectStatementImpl) WINDOW(name string) windowExpand { + s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{selectStatement: s} +} + func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { s.OrderBy.List = orderByClauses return s @@ -129,3 +154,25 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { func (s *selectStatementImpl) AsTable(alias string) SelectTable { return newSelectTable(s, alias) } + +//----------------------------------------------------- + +type windowExpand struct { + selectStatement *selectStatementImpl +} + +func (w windowExpand) AS(window ...jet.Window) SelectStatement { + if len(window) == 0 { + return w.selectStatement + } + windowsDefinition := w.selectStatement.Window.Definitions + windowsDefinition[len(windowsDefinition)-1].Window = window[0] + return w.selectStatement +} + +func toJetFrameOffset(offset int64) jet.Serializer { + if offset == UNBOUNDED { + return jet.UNBOUNDED + } + return jet.FixedLiteral(offset) +} diff --git a/postgres/table.go b/postgres/table.go index bc2f5c23..dc0b266d 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -109,10 +109,10 @@ type tableImpl struct { } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { +func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - SerializerTable: jet.NewTable(schemaName, name, columns...), + SerializerTable: jet.NewTable(schemaName, name, column, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/tests/init/init.go b/tests/init/init.go index a054e314..22698b4c 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/generator/postgres" + "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/tests/dbconfig" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -60,7 +61,7 @@ func initMySQLDB() { cmd.Stdout = os.Stdout err := cmd.Run() - panicOnError(err) + utils.PanicOnError(err) err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ Host: dbconfig.MySqLHost, @@ -70,7 +71,7 @@ func initMySQLDB() { DBName: dbName, }) - panicOnError(err) + utils.PanicOnError(err) } } @@ -104,22 +105,16 @@ func initPostgresDB() { SchemaName: schemaName, SslMode: "disable", }) - panicOnError(err) + utils.PanicOnError(err) } } func execFile(db *sql.DB, sqlFilePath string) { testSampleSql, err := ioutil.ReadFile(sqlFilePath) - panicOnError(err) + utils.PanicOnError(err) _, err = db.Exec(string(testSampleSql)) - panicOnError(err) -} - -func panicOnError(err error) { - if err != nil { - panic(err) - } + utils.PanicOnError(err) } func printOnError(err error) { diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 1c93609b..952ea90d 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -5,6 +5,7 @@ import ( "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" + "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" "github.com/google/uuid" "time" @@ -36,6 +37,23 @@ func TestAllTypes(t *testing.T) { testutils.AssertJSON(t, dest, allTypesJson) } +func TestAllTypesViewSelect(t *testing.T) { + + type AllTypesView model.AllTypes + + dest := []AllTypesView{} + + err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + + if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert + return + } + + testutils.AssertJSON(t, dest, allTypesJson) +} + func TestUUID(t *testing.T) { query := AllTypes. diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index a1912141..e3dab153 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -1,8 +1,8 @@ package mysql import ( - "bytes" "github.com/go-jet/jet/generator/mysql" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" "gotest.tools/assert" "io/ioutil" @@ -15,22 +15,22 @@ const genTestDirRoot = "./.gentestdata3" const genTestDir3 = "./.gentestdata3/mysql" func TestGenerator(t *testing.T) { - err := os.RemoveAll(genTestDir3) - assert.NilError(t, err) - err = mysql.Generate(genTestDir3, mysql.DBConnection{ - Host: dbconfig.MySqLHost, - Port: dbconfig.MySQLPort, - User: dbconfig.MySQLUser, - Password: dbconfig.MySQLPassword, - DBName: "dvds", - }) + for i := 0; i < 3; i++ { + err := mysql.Generate(genTestDir3, mysql.DBConnection{ + Host: dbconfig.MySqLHost, + Port: dbconfig.MySQLPort, + User: dbconfig.MySQLUser, + Password: dbconfig.MySQLPassword, + DBName: "dvds", + }) - assert.NilError(t, err) + assert.NilError(t, err) - assertGeneratedFiles(t) + assertGeneratedFiles(t) + } - err = os.RemoveAll(genTestDirRoot) + err := os.RemoveAll(genTestDirRoot) assert.NilError(t, err) } @@ -63,53 +63,40 @@ func assertGeneratedFiles(t *testing.T) { tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") assert.NilError(t, err) - assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", - "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go", "rental.go", "staff.go", "store.go") - assertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") assert.NilError(t, err) - assertFileNameEqual(t, enumFiles, "film_rating.go") - assertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) + testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") + testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") assert.NilError(t, err) - assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", - "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", - "payment.go", "rental.go", "staff.go", "store.go", "film_rating.go") + testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", + "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go", + "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") - assertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) -} - -func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { - enumFileData, err := ioutil.ReadFile(filePath) - - assert.NilError(t, err) - - beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - - //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - - assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) -} - -func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { - - fileNamesMap := map[string]bool{} - - for _, fileInfo := range fileInfos { - fileNamesMap[fileInfo.Name()] = true - } - - for _, fileName := range fileNames { - assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") - } + testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) } var mpaaRatingEnumFile = ` @@ -200,3 +187,57 @@ type Actor struct { LastUpdate time.Time } ` + +var actorInfoSQLBuilerFile = ` +package view + +import ( + "github.com/go-jet/jet/mysql" +) + +var ActorInfo = newActorInfoTable() + +type ActorInfoTable struct { + mysql.Table + + //Columns + ActorID mysql.ColumnInteger + FirstName mysql.ColumnString + LastName mysql.ColumnString + FilmInfo mysql.ColumnString + + AllColumns mysql.IColumnList + MutableColumns mysql.IColumnList +} + +// creates new ActorInfoTable with assigned alias +func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { + aliasTable := newActorInfoTable() + + aliasTable.Table.AS(alias) + + return aliasTable +} + +func newActorInfoTable() *ActorInfoTable { + var ( + ActorIDColumn = mysql.IntegerColumn("actor_id") + FirstNameColumn = mysql.StringColumn("first_name") + LastNameColumn = mysql.StringColumn("last_name") + FilmInfoColumn = mysql.StringColumn("film_info") + ) + + return &ActorInfoTable{ + Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + + //Columns + ActorID: ActorIDColumn, + FirstName: FirstNameColumn, + LastName: LastNameColumn, + FilmInfo: FilmInfoColumn, + + AllColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + MutableColumns: mysql.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + } +} +` diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 1e9b73b2..0e20ae2e 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -1,11 +1,13 @@ package mysql import ( + "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" + "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view" "gotest.tools/assert" "testing" @@ -15,7 +17,7 @@ func TestSelect_ScanToStruct(t *testing.T) { query := Actor. SELECT(Actor.AllColumns). DISTINCT(). - WHERE(Actor.ActorID.EQ(Int(1))) + WHERE(Actor.ActorID.EQ(Int(2))) testutils.AssertStatementSql(t, query, ` SELECT DISTINCT actor.actor_id AS "actor.actor_id", @@ -24,20 +26,20 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id", actor.last_update AS "actor.last_update" FROM dvds.actor WHERE actor.actor_id = ?; -`, int64(1)) +`, int64(2)) actor := model.Actor{} err := query.Query(db, &actor) assert.NilError(t, err) - assert.DeepEqual(t, actor, actor1) + assert.DeepEqual(t, actor, actor2) } -var actor1 = model.Actor{ - ActorID: 1, - FirstName: "PENELOPE", - LastName: "GUINESS", +var actor2 = model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), } @@ -61,7 +63,7 @@ ORDER BY actor.actor_id; assert.NilError(t, err) assert.Equal(t, len(dest), 200) - assert.DeepEqual(t, dest[0], actor1) + assert.DeepEqual(t, dest[1], actor2) //testutils.PrintJson(dest) //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") @@ -527,3 +529,172 @@ LOCK IN SHARE MODE; err := query.Query(db, &struct{}{}) assert.NilError(t, err) } + +func TestWindowFunction(t *testing.T) { + + if sourceIsMariaDB() { + return + } + + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (PARTITION BY payment.customer_id), + MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC), + MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING), + MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + ROW_NUMBER() OVER (ORDER BY payment.payment_date), + RANK() OVER (ORDER BY payment.payment_date), + DENSE_RANK() OVER (ORDER BY payment.payment_date), + CUME_DIST() OVER (ORDER BY payment.payment_date), + NTILE(11) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date) +FROM dvds.payment +WHERE payment.payment_id < ? +GROUP BY payment.amount, payment.customer_id, payment.payment_date; +` + query := Payment. + SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)), + MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())), + MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))), + MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)), + RANK().OVER(ORDER_BY(Payment.PaymentDate)), + DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)), + CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)), + NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)), + ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). + WHERE(Payment.PaymentID.LT(Int(10))) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestWindowClause(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (w1), + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +FROM dvds.payment +WHERE payment.payment_id < ? +WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) +ORDER BY payment.customer_id; +` + query := Payment.SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(Window("w1")), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + ). + WHERE(Payment.PaymentID.LT(Int(10))). + WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY(Payment.CustomerID) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) + + err := query.Query(db, &struct{}{}) + + assert.NilError(t, err) +} + +func TestSimpleView(t *testing.T) { + query := SELECT( + view.ActorInfo.AllColumns, + ). + FROM(view.ActorInfo). + ORDER_BY(view.ActorInfo.ActorID). + LIMIT(10) + + type ActorInfo struct { + ActorID int + FirstName string + LastName string + FilmInfo string + } + + var dest []ActorInfo + + err := query.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 10) + testutils.AssertJSON(t, dest[1:2], ` +[ + { + "ActorID": 2, + "FirstName": "NICK", + "LastName": "WAHLBERG", + "FilmInfo": "Action: BULL SHAWSHANK; Animation: FIGHT JAWBREAKER; Children: JERSEY SASSY; Classics: DRACULA CRYSTAL, GILBERT PELICAN; Comedy: MALLRATS UNITED, RUSHMORE MERMAID; Documentary: ADAPTATION HOLES; Drama: WARDROBE PHANTOM; Family: APACHE DIVINE, CHISUM BEHAVIOR, INDIAN LOVE, MAGUIRE APACHE; Foreign: BABY HALL, HAPPINESS UNITED; Games: ROOF CHAMPION; Music: LUCKY FLYING; New: DESTINY SATURDAY, FLASH WARS, JEKYLL FROGMEN, MASK PEACH; Sci-Fi: CHAINSAW UPTOWN, GOODFELLAS SALUTE; Travel: LIAISONS SWEET, SMILE EARRING" + } +] +`) +} + +func TestJoinViewWithTable(t *testing.T) { + query := SELECT( + view.CustomerList.AllColumns, + Rental.AllColumns, + ). + FROM(view.CustomerList. + INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)), + ). + ORDER_BY(view.CustomerList.ID). + WHERE(view.CustomerList.ID.LT_EQ(Int(2))) + + var dest []struct { + model.CustomerList `sql:"primary_key=ID"` + Rentals []model.Rental + } + + err := query.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 2) + assert.Equal(t, len(dest[0].Rentals), 32) + assert.Equal(t, len(dest[1].Rentals), 27) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 2b91c7ad..a6131732 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -6,6 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" + "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" "github.com/google/uuid" "gotest.tools/assert" @@ -23,6 +24,19 @@ func TestAllTypesSelect(t *testing.T) { assert.DeepEqual(t, dest[1], allTypesRow1) } +func TestAllTypesViewSelect(t *testing.T) { + + type AllTypesView model.AllTypes + + dest := []AllTypesView{} + + err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) + assert.NilError(t, err) + + assert.DeepEqual(t, dest[0], AllTypesView(allTypesRow0)) + assert.DeepEqual(t, dest[1], AllTypesView(allTypesRow1)) +} + func TestAllTypesInsertModel(t *testing.T) { query := AllTypes.INSERT(AllTypes.AllColumns). MODEL(allTypesRow0). @@ -31,8 +45,8 @@ func TestAllTypesInsertModel(t *testing.T) { dest := []model.AllTypes{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.Equal(t, len(dest), 2) assert.DeepEqual(t, dest[0], allTypesRow0) assert.DeepEqual(t, dest[1], allTypesRow1) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index d73a850a..5eecba06 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -1,8 +1,8 @@ package postgres import ( - "bytes" "github.com/go-jet/jet/generator/postgres" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" "gotest.tools/assert" "io/ioutil" @@ -71,26 +71,25 @@ func TestCmdGenerator(t *testing.T) { func TestGenerator(t *testing.T) { - err := os.RemoveAll(genTestDir2) - assert.NilError(t, err) + for i := 0; i < 3; i++ { + err := postgres.Generate(genTestDir2, postgres.DBConnection{ + Host: dbconfig.Host, + Port: dbconfig.Port, + User: dbconfig.User, + Password: dbconfig.Password, + SslMode: "disable", + Params: "", - err = postgres.Generate(genTestDir2, postgres.DBConnection{ - Host: dbconfig.Host, - Port: dbconfig.Port, - User: dbconfig.User, - Password: dbconfig.Password, - SslMode: "disable", - Params: "", + DBName: dbconfig.DBName, + SchemaName: "dvds", + }) - DBName: dbconfig.DBName, - SchemaName: "dvds", - }) - - assert.NilError(t, err) + assert.NilError(t, err) - assertGeneratedFiles(t) + assertGeneratedFiles(t) + } - err = os.RemoveAll(genTestDir2) + err := os.RemoveAll(genTestDir2) assert.NilError(t, err) } @@ -99,53 +98,39 @@ func assertGeneratedFiles(t *testing.T) { tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") assert.NilError(t, err) - assertFileNameEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "payment.go", "rental.go", "staff.go", "store.go") - assertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") assert.NilError(t, err) - assertFileNameEqual(t, enumFiles, "mpaa_rating.go") - assertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) + testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go") + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") assert.NilError(t, err) - assertFileNameEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", - "payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go") - - assertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile) -} + "payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go", + "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") -func assertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { - enumFileData, err := ioutil.ReadFile(filePath) - - assert.NilError(t, err) - - beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - - //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - - assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) -} - -func assertFileNameEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { - - fileNamesMap := map[string]bool{} - - for _, fileInfo := range fileInfos { - fileNamesMap[fileInfo.Name()] = true - } - - for _, fileName := range fileNames { - assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") - } + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile) } var mpaaRatingEnumFile = ` @@ -236,3 +221,57 @@ type Actor struct { LastUpdate time.Time } ` + +var actorInfoSQLBuilderFile = ` +package view + +import ( + "github.com/go-jet/jet/postgres" +) + +var ActorInfo = newActorInfoTable() + +type ActorInfoTable struct { + postgres.Table + + //Columns + ActorID postgres.ColumnInteger + FirstName postgres.ColumnString + LastName postgres.ColumnString + FilmInfo postgres.ColumnString + + AllColumns postgres.IColumnList + MutableColumns postgres.IColumnList +} + +// creates new ActorInfoTable with assigned alias +func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { + aliasTable := newActorInfoTable() + + aliasTable.Table.AS(alias) + + return aliasTable +} + +func newActorInfoTable() *ActorInfoTable { + var ( + ActorIDColumn = postgres.IntegerColumn("actor_id") + FirstNameColumn = postgres.StringColumn("first_name") + LastNameColumn = postgres.StringColumn("last_name") + FilmInfoColumn = postgres.StringColumn("film_info") + ) + + return &ActorInfoTable{ + Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + + //Columns + ActorID: ActorIDColumn, + FirstName: FirstNameColumn, + LastName: LastNameColumn, + FilmInfo: FilmInfoColumn, + + AllColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + MutableColumns: postgres.ColumnList(ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + } +} +` diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 03e56728..15ec4fd1 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" + "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view" "gotest.tools/assert" "testing" "time" @@ -19,15 +20,15 @@ SELECT DISTINCT actor.actor_id AS "actor.actor_id", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor -WHERE actor.actor_id = 1; +WHERE actor.actor_id = 2; ` query := Actor. SELECT(Actor.AllColumns). DISTINCT(). - WHERE(Actor.ActorID.EQ(Int(1))) + WHERE(Actor.ActorID.EQ(Int(2))) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2)) actor := model.Actor{} err := query.Query(db, &actor) @@ -35,9 +36,9 @@ WHERE actor.actor_id = 1; assert.NilError(t, err) expectedActor := model.Actor{ - ActorID: 1, - FirstName: "Penelope", - LastName: "Guiness", + ActorID: 2, + FirstName: "Nick", + LastName: "Wahlberg", LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), } @@ -1615,3 +1616,169 @@ SELECT true, err := query.Query(db, &struct{}{}) assert.NilError(t, err) } + +func TestWindowFunction(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (PARTITION BY payment.customer_id), + MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC), + MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING), + MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + ROW_NUMBER() OVER (ORDER BY payment.payment_date), + RANK() OVER (ORDER BY payment.payment_date), + DENSE_RANK() OVER (ORDER BY payment.payment_date), + CUME_DIST() OVER (ORDER BY payment.payment_date), + NTILE(11) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, $1) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, $2) OVER (ORDER BY payment.payment_date), + FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date) +FROM dvds.payment +WHERE payment.payment_id < $3 +GROUP BY payment.amount, payment.customer_id, payment.payment_date; +` + query := Payment. + SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)), + MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())), + MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))), + MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)), + RANK().OVER(ORDER_BY(Payment.PaymentDate)), + DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)), + CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)), + NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)), + ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). + WHERE(Payment.PaymentID.LT(Int(10))) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestWindowClause(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (w1), + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +FROM dvds.payment +WHERE payment.payment_id < $1 +WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) +ORDER BY payment.customer_id; +` + query := Payment.SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(Window("w1")), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + ). + WHERE(Payment.PaymentID.LT(Int(10))). + WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY(Payment.CustomerID) + + fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) + + err := query.Query(db, &struct{}{}) + + assert.NilError(t, err) +} + +func TestSimpleView(t *testing.T) { + query := SELECT( + view.ActorInfo.AllColumns, + ). + FROM(view.ActorInfo). + ORDER_BY(view.ActorInfo.ActorID). + LIMIT(10) + + type ActorInfo struct { + ActorID int + FirstName string + LastName string + FilmInfo string + } + + var dest []ActorInfo + + err := query.Query(db, &dest) + assert.NilError(t, err) + + testutils.AssertJSON(t, dest[1:2], ` +[ + { + "ActorID": 2, + "FirstName": "Nick", + "LastName": "Wahlberg", + "FilmInfo": "Action: Bull Shawshank, Animation: Fight Jawbreaker, Children: Jersey Sassy, Classics: Dracula Crystal, Gilbert Pelican, Comedy: Mallrats United, Rushmore Mermaid, Documentary: Adaptation Holes, Drama: Wardrobe Phantom, Family: Apache Divine, Chisum Behavior, Indian Love, Maguire Apache, Foreign: Baby Hall, Happiness United, Games: Roof Champion, Music: Lucky Flying, New: Destiny Saturday, Flash Wars, Jekyll Frogmen, Mask Peach, Sci-Fi: Chainsaw Uptown, Goodfellas Salute, Travel: Liaisons Sweet, Smile Earring" + } +] +`) + +} + +func TestJoinViewWithTable(t *testing.T) { + query := SELECT( + view.CustomerList.AllColumns, + Rental.AllColumns, + ). + FROM(view.CustomerList. + INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)), + ). + ORDER_BY(view.CustomerList.ID). + WHERE(view.CustomerList.ID.LT_EQ(Int(2))) + + var dest []struct { + model.CustomerList `sql:"primary_key=ID"` + Rentals []model.Rental + } + + fmt.Println(query.DebugSql()) + + err := query.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 2) + assert.Equal(t, len(dest[0].Rentals), 32) + assert.Equal(t, len(dest[1].Rentals), 27) +} diff --git a/tests/testdata b/tests/testdata index 7f3f3cc2..1f6bd8bb 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 7f3f3cc26ce34324f3699d6b422376671b827490 +Subproject commit 1f6bd8bb86458019fa43b1e2cd7ae9488a7ac9a4