diff --git a/.circleci/config.yml b/.circleci/config.yml index 1571843a..f83283cc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -33,6 +33,13 @@ jobs: MYSQL_USER: jet MYSQL_PASSWORD: jet + - image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 + command: ['start-single-node', '--insecure'] + environment: + COCKROACH_USER: jet + COCKROACH_PASSWORD: jet + COCKROACH_DATABASE: jetdb + environment: # environment variables for the build itself TEST_RESULTS: /tmp/test-results # path to where test results will be saved @@ -82,7 +89,18 @@ jobs: echo -n . sleep 1 done - echo Failed waiting for MySQL && exit 1 + echo Failed waiting for MySQL && exit 1 + + - run: + name: Waiting for Cockroach to be ready + command: | + for i in `seq 1 10`; + do + nc -z localhost 26257 && echo Success && exit 0 + echo -n . + sleep 1 + done + echo Failed waiting for Cockroach && exit 1 - run: name: Install MySQL CLI; @@ -122,8 +140,9 @@ jobs: -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \ -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - # run mariaDB tests. No need to collect coverage, because coverage is already included with mysql tests + # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ + - run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/ - save_cache: key: go-mod-v4-{{ checksum "go.sum" }} diff --git a/README.md b/README.md index a6288f30..e1f808f0 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Jet is a complete solution for efficient and high performance database access, consisting of type-safe SQL builder with code generation and automatic query result data mapping. -Jet currently supports `PostgreSQL`, `MySQL`, `MariaDB` and `SQLite`. Future releases will add support for additional databases. +Jet currently supports `PostgreSQL`, `MySQL`, `CockroachDB`, `MariaDB` and `SQLite`. Future releases will add support for additional databases. ![jet](https://github.com/go-jet/jet/wiki/image/jet.png) Jet is the easiest, and the fastest way to write complex type-safe SQL queries as a Go code and map database query result @@ -62,40 +62,37 @@ $ go get -u github.com/go-jet/jet/v2 Jet generator can be installed in one of the following ways: -1) (Go1.16+) Install jet generator using go install: - ```sh - go install github.com/go-jet/jet/v2/cmd/jet@latest - ``` - -2) Install jet generator to GOPATH/bin folder: - ```sh - cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet - ``` +- (Go1.16+) Install jet generator using go install: +```sh +go install github.com/go-jet/jet/v2/cmd/jet@latest +``` +*Jet generator is installed to the directory named by the GOBIN environment variable, +which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.* -3) Install jet generator into specific folder: - ```sh - git clone https://github.com/go-jet/jet.git - cd jet && go build -o dir_path ./cmd/jet - ``` -*Make sure that the destination folder is added to the PATH environment variable.* +- Install jet generator to specific folder: +```sh +git clone https://github.com/go-jet/jet.git +cd jet && go build -o dir_path ./cmd/jet +``` +*Make sure `dir_path` folder is added to the PATH environment variable.* ### Quick Start For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in [./tests/testdata/init/postgres/dvds.sql](https://github.com/go-jet/jet-test-data/blob/master/init/postgres/dvds.sql). -Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png). +Schema diagram of interest can be found [here](./examples/quick-start/diagram.png). #### Generate SQL Builder and Model types -To generate jet SQL Builder and Data Model types from postgres database, we need to call `jet` generator with postgres -connection parameters and root destination folder path for generated files. +To generate jet SQL Builder and Data Model types from running postgres database, we need to call `jet` generator with postgres +connection parameters and destination folder path. Assuming we are running local postgres database, with user `user`, user password `pass`, database `jetdb` and schema `dvds` we will use this command: ```sh -jet -dsn=postgresql://user:pass@localhost:5432/jetdb -schema=dvds -path=./.gen +jet -dsn=postgresql://user:pass@localhost:5432/jetdb?sslmode=disable -schema=dvds -path=./.gen ``` ```sh -Connecting to postgres database: postgresql://user:pass@localhost:5432/jetdb +Connecting to postgres database: postgresql://user:pass@localhost:5432/jetdb?sslmode=disable Retrieving schema information... FOUND 15 table(s), 7 view(s), 1 enum(s) Cleaning up destination directory... @@ -107,9 +104,10 @@ Generating view model files... Generating enum model files... Done ``` -Procedure is similar for MySQL, MariaDB and SQLite. For instance: +Procedure is similar for MySQL, CockroachDB, MariaDB and SQLite. For example: ```sh jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./gen +jet -dsn=postgres://user:pass@localhost:26257/jetdb?sslmode=disable -schema=dvds -path=./.gen #cockroachdb jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./gen # source flag can be omitted if data source appears in dsn jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./gen jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./gen # sqlite database assumed for 'file' data sources @@ -168,7 +166,7 @@ and _film category_ is not 'Action'. stmt := SELECT( Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, // or just Actor.AllColumns Film.AllColumns, - Language.AllColumns.Except(Language.LastUpdate), + Language.AllColumns.Except(Language.LastUpdate), // all language columns except last_update Category.AllColumns, ).FROM( Actor. @@ -186,7 +184,7 @@ stmt := SELECT( Film.FilmID.ASC(), ) ``` -_Package(dot) import is used, so the statements would resemble as much as possible as native SQL._ +_Package(dot) import is used, so the statements look as close as possible to the native SQL._ Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns and can be compared only with integer columns and expressions. @@ -245,7 +243,7 @@ __How to get debug SQL from statement?__ ```go debugSql := stmt.DebugSql() ``` -debugSql - this query string can be copy-pasted to sql editor and executed. __It is not intended to be used in production, only for the purpose of debugging!!!__ +debugSql - this query string can be copy-pasted to sql editor and executed. __It is not intended to be used in production. For debug purposes only!!!__
Click to see debug sql @@ -295,8 +293,8 @@ First we have to create desired structure to store query result. This is done be combining autogenerated model types, or it can be done by combining custom model types(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)#custom-model-types) for more information). -It's possible to overwrite default jet generator behavior, and all the aspects of generated model and SQLBuilder types can be -tailor-made([wiki](https://github.com/go-jet/jet/wiki/Generator#generator-customization)). +_Note that it's possible to overwrite default jet generator behavior. All the aspects of generated model and SQLBuilder types can be +tailor-made([wiki](https://github.com/go-jet/jet/wiki/Generator#generator-customization))._ Let's say this is our desired structure made of autogenerated types: ```go @@ -315,14 +313,14 @@ var dest []struct { `Langauge` field is just a single model struct. `Film` can belong to multiple categories. _*There is no limitation of how big or nested destination can be._ -Now lets execute above statement on open database connection (or transaction) db and store result into `dest`. +Now let's execute above statement on open database connection (or transaction) db and store result into `dest`. ```go err := stmt.Query(db, &dest) handleError(err) ``` -__And thats it.__ +__And that's it.__ `dest` now contains the list of all actors(with list of films acted, where each film has information about language and list of belonging categories) that acted in films longer than 180 minutes, film language is 'English' and film category is not 'Action'. @@ -528,7 +526,7 @@ The biggest benefit is speed. Speed is being improved in 3 major areas: ##### Speed of development -Writing SQL queries is faster and easier as the developers have help of SQL code completion and SQL type safety directly from Go. +Writing SQL queries is faster and easier, as developers will have help of SQL code completion and SQL type safety directly from Go code. Automatic scan to arbitrary structure removes a lot of headache and boilerplate code needed to structure database query result. ##### Speed of execution @@ -539,14 +537,14 @@ Thus handler time lost on latency between server and database can be constant. H only to the query complexity and the number of rows returned from database. With Jet, it is even possible to join the whole database and store the whole structured result in one database call. -This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40). -The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. +This is exactly what is being done in one of the tests: [TestJoinEverything](https://github.com/go-jet/jet/blob/6706f4b228f51cf810129f57ba90bbdb60b85fe7/tests/postgres/chinook_db_test.go#L187). +The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.5s. ##### How quickly bugs are found The most expensive bugs are the one discovered on the production, and the least expensive are those found during development. With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner. -Lets return to quick start example, and take closer look at a line: +Let's return to quick start example, and take closer look at a line: ```go AND(Film.Length.GT(Int(180))), ``` @@ -573,6 +571,8 @@ To run the tests, additional dependencies are required: - `github.com/stretchr/testify` - `github.com/google/go-cmp` - `github.com/jackc/pgx/v4` +- `github.com/shopspring/decimal` +- `github.com/volatiletech/null/v8` ## Versioning diff --git a/cmd/jet/main.go b/cmd/jet/main.go index 7e8e2999..530d02a5 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -3,6 +3,9 @@ package main import ( "flag" "fmt" + "os" + "strings" + "github.com/go-jet/jet/v2/generator/metadata" sqlitegen "github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/generator/template" @@ -11,8 +14,6 @@ import ( "github.com/go-jet/jet/v2/mysql" postgres2 "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/sqlite" - "os" - "strings" mysqlgen "github.com/go-jet/jet/v2/generator/mysql" postgresgen "github.com/go-jet/jet/v2/generator/postgres" @@ -42,7 +43,7 @@ var ( ) func init() { - flag.StringVar(&source, "source", "", "Database system name (postgres, mysql, mariadb or sqlite)") + flag.StringVar(&source, "source", "", "Database system name (postgres, mysql, cockroachdb, mariadb or sqlite)") flag.StringVar(&dsn, "dsn", "", `Data source name. Unified format for connecting to database. PostgreSQL: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING @@ -59,7 +60,7 @@ func init() { flag.StringVar(&user, "user", "", "Database user. Used only if dsn is not set.") flag.StringVar(&password, "password", "", "The user’s password. Used only if dsn is not set.") flag.StringVar(&dbName, "dbname", "", "Database name. Used only if dsn is not set.") - flag.StringVar(&schemaName, "schema", "public", `Database schema name. Used only if dsn is not set. (default "public")(PostgreSQL only)`) + flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public")(PostgreSQL only)`) flag.StringVar(¶ms, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.") flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`) flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore`) @@ -70,33 +71,7 @@ func init() { } func main() { - - flag.Usage = func() { - fmt.Println("Jet generator 2.7.0") - fmt.Println() - fmt.Println("Usage:") - - order := []string{ - "source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode", - "path", - "ignore-tables", "ignore-views", "ignore-enums", - } - for _, name := range order { - flagEntry := flag.CommandLine.Lookup(name) - fmt.Printf(" -%s\n", flagEntry.Name) - fmt.Printf("\t%s\n", flagEntry.Usage) - } - - fmt.Println() - fmt.Println(`Example command: - - $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds -path=./gen - $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen - $ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen - $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen - `) - } - + flag.Usage = usage flag.Parse() if dsn == "" && (source == "" || host == "" || port == 0 || user == "" || dbName == "") { @@ -111,11 +86,14 @@ func main() { var err error switch source { - case "postgresql", "postgres": + case "postgresql", "postgres", "cockroachdb", "cockroach": + generatorTemplate := genTemplate(postgres2.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList) + if dsn != "" { - err = postgresgen.GenerateDSN(dsn, schemaName, destDir) + err = postgresgen.GenerateDSN(dsn, schemaName, destDir, generatorTemplate) break } + dbConn := postgresgen.DBConnection{ Host: host, Port: port, @@ -131,14 +109,17 @@ func main() { err = postgresgen.Generate( destDir, dbConn, - genTemplate(postgres2.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList), + generatorTemplate, ) case "mysql", "mysqlx", "mariadb": + generatorTemplate := genTemplate(mysql.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList) + if dsn != "" { - err = mysqlgen.GenerateDSN(dsn, destDir) + err = mysqlgen.GenerateDSN(dsn, destDir, generatorTemplate) break } + dbConn := mysqlgen.DBConnection{ Host: host, Port: port, @@ -151,12 +132,13 @@ func main() { err = mysqlgen.Generate( destDir, dbConn, - genTemplate(mysql.Dialect, ignoreTablesList, ignoreViewsList, ignoreEnumsList), + generatorTemplate, ) case "sqlite": if dsn == "" { printErrorAndExit("ERROR: required -dsn flag missing.") } + err = sqlitegen.GenerateDSN( dsn, destDir, @@ -176,6 +158,34 @@ func main() { } } +func usage() { + fmt.Println("Jet generator 2.8.0") + fmt.Println() + fmt.Println("Usage:") + + order := []string{ + "source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode", + "path", + "ignore-tables", "ignore-views", "ignore-enums", + } + + for _, name := range order { + flagEntry := flag.CommandLine.Lookup(name) + fmt.Printf(" -%s\n", flagEntry.Name) + fmt.Printf("\t%s\n", flagEntry.Usage) + } + + fmt.Println() + fmt.Println(`Example command: + + $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb?sslmode=disable -schema=dvds -path=./gen + $ jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=dvds -path=./gen #cockroachdb + $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen + $ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen + $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen + `) +} + func printErrorAndExit(error string) { fmt.Println("\n", error) fmt.Println() diff --git a/examples/quick-start/.gen/jetdb/dvds/table/actor.go b/examples/quick-start/.gen/jetdb/dvds/table/actor.go index bbeb7e8f..37e0f85f 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/actor.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/actor.go @@ -42,6 +42,16 @@ func (a ActorTable) FromSchema(schemaName string) *ActorTable { return newActorTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorTable with assigned table prefix +func (a ActorTable) WithPrefix(prefix string) *ActorTable { + return newActorTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorTable with assigned table suffix +func (a ActorTable) WithSuffix(suffix string) *ActorTable { + return newActorTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorTable(schemaName, tableName, alias string) *ActorTable { return &ActorTable{ actorTable: newActorTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/table/category.go b/examples/quick-start/.gen/jetdb/dvds/table/category.go index 563938d9..87beb46d 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/category.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/category.go @@ -41,6 +41,16 @@ func (a CategoryTable) FromSchema(schemaName string) *CategoryTable { return newCategoryTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new CategoryTable with assigned table prefix +func (a CategoryTable) WithPrefix(prefix string) *CategoryTable { + return newCategoryTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new CategoryTable with assigned table suffix +func (a CategoryTable) WithSuffix(suffix string) *CategoryTable { + return newCategoryTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newCategoryTable(schemaName, tableName, alias string) *CategoryTable { return &CategoryTable{ categoryTable: newCategoryTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/table/film.go b/examples/quick-start/.gen/jetdb/dvds/table/film.go index 65db900a..46364083 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/film.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/film.go @@ -51,6 +51,16 @@ func (a FilmTable) FromSchema(schemaName string) *FilmTable { return newFilmTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new FilmTable with assigned table prefix +func (a FilmTable) WithPrefix(prefix string) *FilmTable { + return newFilmTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new FilmTable with assigned table suffix +func (a FilmTable) WithSuffix(suffix string) *FilmTable { + return newFilmTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newFilmTable(schemaName, tableName, alias string) *FilmTable { return &FilmTable{ filmTable: newFilmTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go b/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go index 30c3ad35..7fe60c38 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go @@ -41,6 +41,16 @@ func (a FilmActorTable) FromSchema(schemaName string) *FilmActorTable { return newFilmActorTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new FilmActorTable with assigned table prefix +func (a FilmActorTable) WithPrefix(prefix string) *FilmActorTable { + return newFilmActorTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new FilmActorTable with assigned table suffix +func (a FilmActorTable) WithSuffix(suffix string) *FilmActorTable { + return newFilmActorTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newFilmActorTable(schemaName, tableName, alias string) *FilmActorTable { return &FilmActorTable{ filmActorTable: newFilmActorTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/table/film_category.go b/examples/quick-start/.gen/jetdb/dvds/table/film_category.go index 83681777..fb3bfb65 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/film_category.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/film_category.go @@ -41,6 +41,16 @@ func (a FilmCategoryTable) FromSchema(schemaName string) *FilmCategoryTable { return newFilmCategoryTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new FilmCategoryTable with assigned table prefix +func (a FilmCategoryTable) WithPrefix(prefix string) *FilmCategoryTable { + return newFilmCategoryTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new FilmCategoryTable with assigned table suffix +func (a FilmCategoryTable) WithSuffix(suffix string) *FilmCategoryTable { + return newFilmCategoryTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newFilmCategoryTable(schemaName, tableName, alias string) *FilmCategoryTable { return &FilmCategoryTable{ filmCategoryTable: newFilmCategoryTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/table/language.go b/examples/quick-start/.gen/jetdb/dvds/table/language.go index 5bddeebb..24f037d7 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/language.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/language.go @@ -41,6 +41,16 @@ func (a LanguageTable) FromSchema(schemaName string) *LanguageTable { return newLanguageTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new LanguageTable with assigned table prefix +func (a LanguageTable) WithPrefix(prefix string) *LanguageTable { + return newLanguageTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new LanguageTable with assigned table suffix +func (a LanguageTable) WithSuffix(suffix string) *LanguageTable { + return newLanguageTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newLanguageTable(schemaName, tableName, alias string) *LanguageTable { return &LanguageTable{ languageTable: newLanguageTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go b/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go index 5bfa25d3..23f81aa2 100644 --- a/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go +++ b/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go @@ -42,6 +42,16 @@ func (a ActorInfoTable) FromSchema(schemaName string) *ActorInfoTable { return newActorInfoTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorInfoTable with assigned table prefix +func (a ActorInfoTable) WithPrefix(prefix string) *ActorInfoTable { + return newActorInfoTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorInfoTable with assigned table suffix +func (a ActorInfoTable) WithSuffix(suffix string) *ActorInfoTable { + return newActorInfoTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorInfoTable(schemaName, tableName, alias string) *ActorInfoTable { return &ActorInfoTable{ actorInfoTable: newActorInfoTableImpl(schemaName, tableName, alias), diff --git a/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go b/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go index c03a5ef5..cdf14cae 100644 --- a/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go +++ b/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go @@ -47,6 +47,16 @@ func (a CustomerListTable) FromSchema(schemaName string) *CustomerListTable { return newCustomerListTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new CustomerListTable with assigned table prefix +func (a CustomerListTable) WithPrefix(prefix string) *CustomerListTable { + return newCustomerListTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new CustomerListTable with assigned table suffix +func (a CustomerListTable) WithSuffix(suffix string) *CustomerListTable { + return newCustomerListTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newCustomerListTable(schemaName, tableName, alias string) *CustomerListTable { return &CustomerListTable{ customerListTable: newCustomerListTableImpl(schemaName, tableName, alias), diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 93e6ffb5..da485055 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -35,7 +35,9 @@ WITH primaryKeys AS ( SELECT column_name FROM information_schema.key_column_usage AS c LEFT JOIN information_schema.table_constraints AS t - ON t.constraint_name = c.constraint_name + ON t.constraint_name = c.constraint_name AND + c.table_schema = t.table_schema AND + c.table_name = t.table_name WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' ) SELECT column_name as "column.Name", diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index e3020cea..d1f12603 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -49,6 +49,16 @@ func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) {{tableTemplat return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new {{tableTemplate.TypeName}} with assigned table prefix +func (a {{tableTemplate.TypeName}}) WithPrefix(prefix string) {{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new {{tableTemplate.TypeName}} with assigned table suffix +func (a {{tableTemplate.TypeName}}) WithSuffix(suffix string) {{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableTemplate.TypeName}} { var ( {{- range $i, $c := .Columns}} @@ -119,6 +129,16 @@ func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) *{{tableTempla return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new {{tableTemplate.TypeName}} with assigned table prefix +func (a {{tableTemplate.TypeName}}) WithPrefix(prefix string) *{{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new {{tableTemplate.TypeName}} with assigned table suffix +func (a {{tableTemplate.TypeName}}) WithSuffix(suffix string) *{{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) *{{tableTemplate.TypeName}} { return &{{tableTemplate.TypeName}}{ {{structImplName}}: new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias), diff --git a/go.mod b/go.mod index caa7fea1..f15e634b 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,18 @@ go 1.11 require ( github.com/go-sql-driver/mysql v1.5.0 - github.com/google/go-cmp v0.5.7 //tests github.com/google/uuid v1.1.1 github.com/jackc/pgconn v1.12.0 - github.com/jackc/pgx/v4 v4.16.0 //tests github.com/lib/pq v1.10.5 github.com/mattn/go-sqlite3 v1.14.8 - github.com/pkg/profile v1.5.0 //tests - github.com/shopspring/decimal v1.2.0 // tests - github.com/stretchr/testify v1.7.0 // tests +) + +// test dependencies +require ( + github.com/google/go-cmp v0.5.8 + github.com/jackc/pgx/v4 v4.16.0 + github.com/pkg/profile v1.6.0 + github.com/shopspring/decimal v1.3.1 + github.com/stretchr/testify v1.7.0 + github.com/volatiletech/null/v8 v8.1.2 ) diff --git a/go.sum b/go.sum index c681d6de..cb1e8e55 100644 --- a/go.sum +++ b/go.sum @@ -9,15 +9,18 @@ github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7Do github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/friendsofgo/errors v0.9.2 h1:X6NYxef4efCBdwI7BgS820zFaN7Cphrmb+Pljdzjtgk= +github.com/friendsofgo/errors v0.9.2/go.mod h1:yCvFW5AkDIL9qn7suHVLiI/gH228n7PC4Pn44IGoTOI= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= -github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -94,8 +97,8 @@ github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxz github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug= -github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= +github.com/pkg/profile v1.6.0 h1:hUDfIISABYI59DyeB3OTay/HxSRwTQ8rB/H83k6r5dM= +github.com/pkg/profile v1.6.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -104,8 +107,9 @@ github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OK github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -117,6 +121,14 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/volatiletech/inflect v0.0.1 h1:2a6FcMQyhmPZcLa+uet3VJ8gLn/9svWhJxJYwvE8KsU= +github.com/volatiletech/inflect v0.0.1/go.mod h1:IBti31tG6phkHitLlr5j7shC5SOo//x0AjDzaJU1PLA= +github.com/volatiletech/null/v8 v8.1.2 h1:kiTiX1PpwvuugKwfvUNX/SU/5A2KGZMXfGD0DUHdKEI= +github.com/volatiletech/null/v8 v8.1.2/go.mod h1:98DbwNoKEpRrYtGjWFctievIfm4n4MxG0A6EBUcoS5g= +github.com/volatiletech/randomize v0.0.1 h1:eE5yajattWqTB2/eN8df4dw+8jwAzBtbdo5sbWC4nMk= +github.com/volatiletech/randomize v0.0.1/go.mod h1:GN3U0QYqfZ9FOJ67bzax1cqZ5q2xuj2mXrXBjWaRTlY= +github.com/volatiletech/strmangle v0.0.1 h1:UKQoHmY6be/R3tSvD2nQYrH41k43OJkidwEiC74KIzk= +github.com/volatiletech/strmangle v0.0.1/go.mod h1:F6RA6IkB5vq0yTG4GQ0UsbbRcl3ni9P76i+JrTBKFFg= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -181,7 +193,6 @@ golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go index 3ff829c5..2fdb3588 100644 --- a/internal/jet/column_list.go +++ b/internal/jet/column_list.go @@ -4,6 +4,10 @@ package jet type ColumnList []ColumnExpression // SET creates column assigment for each column in column list. expression should be created by ROW function +// Link.UPDATE(). +// SET(Link.MutableColumns.SET(ROW(String("github.com"), Bool(false))). +// WHERE(Link.ID.EQ(Int(0))) +// func (cl ColumnList) SET(expression Expression) ColumnAssigment { return columnAssigmentImpl{ column: cl, @@ -11,7 +15,9 @@ func (cl ColumnList) SET(expression Expression) ColumnAssigment { } } -// Except will create new column list in which columns contained in excluded column names are removed +// Except will create new column list in which columns contained in list of excluded column names are removed +// Address.AllColumns.Except(Address.PostalCode, Address.Phone) +// func (cl ColumnList) Except(excludedColumns ...Column) ColumnList { excludedColumnList := UnwidColumnList(excludedColumns) excludedColumnNames := map[string]bool{} diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 0fe78df0..436b9d62 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -263,6 +263,25 @@ func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLB p.max.serialize(statement, out, FallTrough(options)...) } +type customExpression struct { + ExpressionInterfaceImpl + parts []Serializer +} + +func newCustomExpression(parts ...Serializer) Expression { + ret := customExpression{ + parts: parts, + } + ret.ExpressionInterfaceImpl.Parent = &ret + return &ret +} + +func (c *customExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + for _, expression := range c.parts { + expression.serialize(statement, out, options...) + } +} + type complexExpression struct { ExpressionInterfaceImpl expressions Expression diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index cfac71f8..6900457a 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -492,6 +492,11 @@ func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression { //----------------- Date/Time Functions and Operators ---------------// +// EXTRACT extracts time component from time expression +func EXTRACT(field string, from Expression) Expression { + return newCustomExpression(Token("EXTRACT("), Token(field), Token("FROM"), from, Token(")")) +} + // CURRENT_DATE returns current date func CURRENT_DATE() DateExpression { dateFunc := NewDateFunc("CURRENT_DATE") diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 866d60e9..93a1d3ba 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -96,3 +96,10 @@ func (s serializerImpl) serialize(statement StatementType, out *SQLBuilder, opti clause.Serialize(statement, out, FallTrough(options)...) } } + +// Token can be used to construct complex custom expressions +type Token string + +func (t Token) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString(string(t)) +} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index e3fb61b2..96ef2e6f 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -2,6 +2,7 @@ package jet import ( "bytes" + "database/sql/driver" "fmt" "github.com/go-jet/jet/v2/internal/3rdparty/pq" "github.com/go-jet/jet/v2/internal/utils" @@ -232,17 +233,26 @@ func argToString(value interface{}) string { case time.Time: return stringQuote(string(pq.FormatTimestamp(bindVal))) default: - if strBindValue, ok := bindVal.(toStringInterface); ok { + if strBindValue, ok := bindVal.(fmt.Stringer); ok { return stringQuote(strBindValue.String()) } + + if valuer, ok := bindVal.(driver.Valuer); ok { + val, err := valuer.Value() + + if err != nil { + // If valuer for some reason returns an error, we return error string representation. + // This is fine because argToString is called only from DebugSQL, and DebugSQL shouldn't be used in production. + return err.Error() + } + + return argToString(val) + } + panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) } } -type toStringInterface interface { - String() string -} - func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { case int: diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 183aaae8..11d8c95f 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -11,23 +11,23 @@ import ( type Statement interface { // Sql returns parametrized sql query with list of arguments. Sql() (query string, args []interface{}) - // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. + // DebugSql returns debug query where every parametrized placeholder is replaced with its argument string representation. // Do not use it in production. Use it only for debug purposes. DebugSql() (query string) - // Query executes statement over database connection/transaction db and stores row result in destination. + // Query executes statement over database connection/transaction db and stores row results in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. - Query(db qrm.DB, destination interface{}) error + Query(db qrm.Queryable, destination interface{}) error // QueryContext executes statement with a context over database connection/transaction db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. - QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error + QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error // Exec executes statement over db connection/transaction without returning any rows. - Exec(db qrm.DB) (sql.Result, error) + Exec(db qrm.Executable) (sql.Result, error) // ExecContext executes statement with context over db connection/transaction without returning any rows. - ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) + ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) // Rows executes statements over db connection/transaction and returns rows - Rows(ctx context.Context, db qrm.DB) (*Rows, error) + Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) } // Rows wraps sql.Rows type to add query result mapping for Scan method @@ -86,11 +86,11 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { return } -func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { +func (s *serializerStatementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error { return s.QueryContext(context.Background(), db, destination) } -func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { +func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error { query, args := s.Sql() callLogger(ctx, s) @@ -112,11 +112,11 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db return err } -func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { +func (s *serializerStatementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) { return s.ExecContext(context.Background(), db) } -func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { +func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) { query, args := s.Sql() callLogger(ctx, s) @@ -141,7 +141,7 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q return res, err } -func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { +func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) { query, args := s.Sql() callLogger(ctx, s) diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 37c1665a..7d231e09 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -2,6 +2,8 @@ package testutils import ( "bytes" + "context" + "database/sql" "encoding/json" "fmt" "github.com/go-jet/jet/v2/internal/jet" @@ -25,6 +27,18 @@ var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool { return t1.Unix() == t2.Unix() }) +// AssertExecAndRollback will execute and rollback statement in sql transaction +func AssertExecAndRollback(t *testing.T, stmt jet.Statement, db *sql.DB, rowsAffected ...int64) { + tx, err := db.Begin() + require.NoError(t, err) + defer func() { + err := tx.Rollback() + require.NoError(t, err) + }() + + AssertExec(t, stmt, tx, rowsAffected...) +} + // AssertExec assert statement execution for successful execution and number of rows affected func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) @@ -38,6 +52,18 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int } } +// ExecuteInTxAndRollback will execute function in sql transaction and then rollback transaction +func ExecuteInTxAndRollback(t *testing.T, db *sql.DB, f func(tx *sql.Tx)) { + tx, err := db.Begin() + require.NoError(t, err) + defer func() { + err := tx.Rollback() + require.NoError(t, err) + }() + + f(tx) +} + // AssertExecErr assert statement execution for failed execution with error string errorStr func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) { _, err := stmt.Exec(db) @@ -45,6 +71,13 @@ func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) require.Error(t, err, errorStr) } +// AssertExecContextErr assert statement execution for failed execution with error string errorStr +func AssertExecContextErr(ctx context.Context, t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) { + _, err := stmt.ExecContext(ctx, db) + + require.Error(t, err, errorStr) +} + func getFullPath(relativePath string) string { path, _ := os.Getwd() return filepath.Join(path, "../", relativePath) diff --git a/mysql/functions.go b/mysql/functions.go index b794ef70..2a8e2278 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -224,6 +224,12 @@ var REGEXP_LIKE = jet.REGEXP_LIKE //----------------- Date/Time Functions and Operators ------------// +// EXTRACT function retrieves subfields such as year or hour from date/time values +// EXTRACT(DAY, User.CreatedAt) +func EXTRACT(field unitType, from Expression) IntegerExpression { + return IntExp(jet.EXTRACT(string(field), from)) +} + // CURRENT_DATE returns current date var CURRENT_DATE = jet.CURRENT_DATE diff --git a/mysql/interval.go b/mysql/interval.go index c563855e..ea246329 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -39,10 +39,11 @@ const ( type Interval = jet.Interval // INTERVAL creates new temporal interval. -// In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type -// value parameter should be number. For example: INTERVAL(1, DAY) -// In a case of other unit types, value should be string with appropriate format. -// For example: INTERVAL("10:08:50", HOUR_SECOND) +// In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type +// value parameter has to be a number. +// INTERVAL(1, DAY) +// In a case of other unit types, value should be string with appropriate format. +// INTERVAL("10:08:50", HOUR_SECOND) func INTERVAL(value interface{}, unitType unitType) Interval { switch unitType { case MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR: diff --git a/mysql/literal.go b/mysql/literal.go index 0a66eb37..1c69c31b 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -56,41 +56,41 @@ var String = jet.String var UUID = jet.UUID // Date creates new date literal -var Date = func(year int, month time.Month, day int) DateExpression { +func Date(year int, month time.Month, day int) DateExpression { return CAST(jet.Date(year, month, day)).AS_DATE() } // DateT creates new date literal from time.Time -var DateT = func(t time.Time) DateExpression { +func DateT(t time.Time) DateExpression { return CAST(jet.DateT(t)).AS_DATE() } // Time creates new time literal -var Time = func(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { +func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { return CAST(jet.Time(hour, minute, second, nanoseconds...)).AS_TIME() } // TimeT creates new time literal from time.Time -var TimeT = func(t time.Time) TimeExpression { +func TimeT(t time.Time) TimeExpression { return CAST(jet.TimeT(t)).AS_TIME() } // DateTime creates new datetime literal -var DateTime = func(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) DateTimeExpression { +func DateTime(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) DateTimeExpression { return CAST(jet.Timestamp(year, month, day, hour, minute, second, nanoseconds...)).AS_DATETIME() } // DateTimeT creates new datetime literal from time.Time -var DateTimeT = func(t time.Time) DateTimeExpression { +func DateTimeT(t time.Time) DateTimeExpression { return CAST(jet.TimestampT(t)).AS_DATETIME() } // Timestamp creates new timestamp literal -var Timestamp = func(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression { +func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression { return TIMESTAMP(StringExp(jet.Timestamp(year, month, day, hour, minute, second, nanoseconds...))) } // TimestampT creates new timestamp literal from time.Time -var TimestampT = func(t time.Time) TimestampExpression { +func TimestampT(t time.Time) TimestampExpression { return TIMESTAMP(StringExp(jet.TimestampT(t))) } diff --git a/postgres/cast_test.go b/postgres/cast_test.go index e02336a1..c0586d31 100644 --- a/postgres/cast_test.go +++ b/postgres/cast_test.go @@ -5,7 +5,7 @@ import ( ) func TestExpressionCAST_AS(t *testing.T) { - assertSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") + assertSerialize(t, CAST(Int(11)).AS("text"), `$1::text`, int64(11)) } func TestExpressionCAST_AS_BOOL(t *testing.T) { diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 45ed7396..9aadbc92 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -4,16 +4,16 @@ import "testing" func TestString_REGEXP_LIKE_operator(t *testing.T) { assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") - assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") - assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") - assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1::text)", "JOHN") } func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") - assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") - assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") - assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1::text)", "JOHN") } func TestExists(t *testing.T) { diff --git a/postgres/expressions_test.go b/postgres/expressions_test.go index 77c3dee4..76403fbe 100644 --- a/postgres/expressions_test.go +++ b/postgres/expressions_test.go @@ -60,7 +60,7 @@ func TestRawHelperMethods(t *testing.T) { assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), "((table.colInt + $1) = $2)", 11.22, 3.14) assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), - "((table.colStr || $1) = $2)", "doe", "john doe") + "((table.colStr || $1) = $2::text)", "doe", "john doe") now := time.Now() assertSerialize(t, RawTime("table.colTime").EQ(TimeT(now)), diff --git a/postgres/functions.go b/postgres/functions.go index cd2c130a..7cc03353 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -1,6 +1,8 @@ package postgres -import "github.com/go-jet/jet/v2/internal/jet" +import ( + "github.com/go-jet/jet/v2/internal/jet" +) // This functions can be used, instead of its method counterparts, to have a better indentation of a complex condition // in the Go code and in the generated SQL. @@ -279,6 +281,26 @@ var TO_TIMESTAMP = jet.TO_TIMESTAMP //----------------- Date/Time Functions and Operators ------------// +// Additional time unit types for EXTRACT function +const ( + DOW unit = MILLENNIUM + 1 + iota + DOY + EPOCH + ISODOW + ISOYEAR + JULIAN + QUARTER + TIMEZONE + TIMEZONE_HOUR + TIMEZONE_MINUTE +) + +// EXTRACT function retrieves subfields such as year or hour from date/time values +// EXTRACT(DAY, User.CreatedAt) +func EXTRACT(field unit, from Expression) FloatExpression { + return FloatExp(jet.EXTRACT(unitToString(field), from)) +} + // CURRENT_DATE returns current date var CURRENT_DATE = jet.CURRENT_DATE diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 3ec333e0..25300c27 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -167,7 +167,7 @@ VALUES ('one', 'two'), ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE SET col_bool = TRUE::boolean, col_int = 1, - (col1, col_bool) = ROW(2, 'two') + (col1, col_bool) = ROW(2, 'two'::text) WHERE table1.col1 > 2 RETURNING table1.col1 AS "table1.col1", table1.col_bool AS "table1.col_bool"; @@ -193,7 +193,7 @@ VALUES ('one', 'two'), ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE SET col_bool = FALSE::boolean, col_int = 1, - (col1, col_bool) = ROW(2, 'two') + (col1, col_bool) = ROW(2, 'two'::text) WHERE table1.col1 > 2 RETURNING table1.col1 AS "table1.col1", table1.col_bool AS "table1.col_bool"; diff --git a/postgres/interval_expression.go b/postgres/interval_expression.go index 68d33bc1..d1c47887 100644 --- a/postgres/interval_expression.go +++ b/postgres/interval_expression.go @@ -10,10 +10,11 @@ import ( ) type quantityAndUnit = float64 +type unit = float64 // Interval unit types const ( - YEAR quantityAndUnit = 123456789 + iota + YEAR unit = 123456789 + iota MONTH WEEK DAY @@ -119,7 +120,7 @@ type intervalExpression struct { } // INTERVAL creates new interval expression from the list of quantity-unit pairs. -// For example: INTERVAL(1, DAY, 3, MINUTE) +// INTERVAL(1, DAY, 3, MINUTE) func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { quantityAndUnitLen := len(quantityAndUnit) if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 { @@ -208,6 +209,27 @@ func unitToString(unit quantityAndUnit) string { return "CENTURY" case MILLENNIUM: return "MILLENNIUM" + // additional field units for EXTRACT function + case DOW: + return "DOW" + case DOY: + return "DOY" + case EPOCH: + return "EPOCH" + case ISODOW: + return "ISODOW" + case ISOYEAR: + return "ISOYEAR" + case JULIAN: + return "JULIAN" + case QUARTER: + return "QUARTER" + case TIMEZONE: + return "TIMEZONE" + case TIMEZONE_HOUR: + return "TIMEZONE_HOUR" + case TIMEZONE_MINUTE: + return "TIMEZONE_MINUTE" default: panic("jet: invalid INTERVAL unit type") } diff --git a/postgres/literal.go b/postgres/literal.go index e46b874d..7b1bd197 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -61,14 +61,16 @@ var Float = jet.Float var Decimal = jet.Decimal // String creates new string literal expression -var String = jet.String +func String(value string) StringExpression { + return CAST(jet.String(value)).AS_TEXT() +} // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method var UUID = jet.UUID // Bytea creates new bytea literal expression -var Bytea = func(value interface{}) StringExpression { +func Bytea(value interface{}) StringExpression { switch value.(type) { case string, []byte: default: @@ -78,51 +80,51 @@ var Bytea = func(value interface{}) StringExpression { } // Date creates new date literal expression -var Date = func(year int, month time.Month, day int) DateExpression { +func Date(year int, month time.Month, day int) DateExpression { return CAST(jet.Date(year, month, day)).AS_DATE() } // DateT creates new date literal expression from time.Time object -var DateT = func(t time.Time) DateExpression { +func DateT(t time.Time) DateExpression { return CAST(jet.DateT(t)).AS_DATE() } // Time creates new time literal expression -var Time = func(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { +func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { return CAST(jet.Time(hour, minute, second, nanoseconds...)).AS_TIME() } // TimeT creates new time literal expression from time.Time object -var TimeT = func(t time.Time) TimeExpression { +func TimeT(t time.Time) TimeExpression { return CAST(jet.TimeT(t)).AS_TIME() } // Timez creates new time with time zone literal expression -var Timez = func(hour, minute, second int, milliseconds time.Duration, timezone string) TimezExpression { +func Timez(hour, minute, second int, milliseconds time.Duration, timezone string) TimezExpression { return CAST(jet.Timez(hour, minute, second, milliseconds, timezone)).AS_TIMEZ() } // TimezT creates new time with time zone literal expression from time.Time object -var TimezT = func(t time.Time) TimezExpression { +func TimezT(t time.Time) TimezExpression { return CAST(jet.TimezT(t)).AS_TIMEZ() } // Timestamp creates new timestamp literal expression -var Timestamp = func(year int, month time.Month, day, hour, minute, second int, milliseconds ...time.Duration) TimestampExpression { +func Timestamp(year int, month time.Month, day, hour, minute, second int, milliseconds ...time.Duration) TimestampExpression { return CAST(jet.Timestamp(year, month, day, hour, minute, second, milliseconds...)).AS_TIMESTAMP() } // TimestampT creates new timestamp literal expression from time.Time object -var TimestampT = func(t time.Time) TimestampExpression { +func TimestampT(t time.Time) TimestampExpression { return CAST(jet.TimestampT(t)).AS_TIMESTAMP() } // Timestampz creates new timestamp with time zone literal expression -var Timestampz = func(year int, month time.Month, day, hour, minute, second int, milliseconds time.Duration, timezone string) TimestampzExpression { +func Timestampz(year int, month time.Month, day, hour, minute, second int, milliseconds time.Duration, timezone string) TimestampzExpression { return CAST(jet.Timestampz(year, month, day, hour, minute, second, milliseconds, timezone)).AS_TIMESTAMPZ() } // TimestampzT creates new timestamp literal expression from time.Time object -var TimestampzT = func(t time.Time) TimestampzExpression { +func TimestampzT(t time.Time) TimestampzExpression { return CAST(jet.TimestampzT(t)).AS_TIMESTAMPZ() } diff --git a/postgres/literal_test.go b/postgres/literal_test.go index f95e4867..5c5160ed 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -59,7 +59,7 @@ func TestFloat(t *testing.T) { } func TestString(t *testing.T) { - assertSerialize(t, String("Some text"), `$1`, "Some text") + assertSerialize(t, String("Some text"), `$1::text`, "Some text") } func TestBytea(t *testing.T) { diff --git a/qrm/db.go b/qrm/db.go index 6b319eb3..1efefb11 100644 --- a/qrm/db.go +++ b/qrm/db.go @@ -13,3 +13,13 @@ type DB interface { Query(query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } + +// Queryable interface for sql QueryContext method +type Queryable interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +// Executable interface for sql ExecContext method +type Executable interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} diff --git a/qrm/qrm.go b/qrm/qrm.go index 50597cd9..1c559f62 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set") // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. -func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { +func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") @@ -88,7 +88,7 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac return nil } -func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { +func queryToSlice(ctx context.Context, db Queryable, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { if ctx == nil { ctx = context.Background() } diff --git a/tests/Makefile b/tests/Makefile index 632c3d21..1a847786 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -32,13 +32,20 @@ init-sqlite: # jet-gen will call generator on each of the test databases to generate sql builder and model files need to run the tests. jet-gen-all: install-jet-gen jet-gen-postgres jet-gen-mysql jet-gen-mariadb jet-gen-sqlite +ifeq ($(OS),Windows_NT) + target := jet.exe +else + target := jet +endif + install-jet-gen: - go build -o ${GOPATH}/bin/jet ../cmd/jet/ + go build -o ${GOPATH}/bin/${target} ../cmd/jet/ jet-gen-postgres: jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=northwind -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/ jet-gen-mysql: @@ -56,6 +63,12 @@ jet-gen-sqlite: jet -source=sqlite -dsn="./testdata/init/sqlite/sakila.db" -schema=dvds -path=./.gentestdata/sqlite/sakila jet -source=sqlite -dsn="./testdata/init/sqlite/test_sample.db" -schema=dvds -path=./.gentestdata/sqlite/test_sample +jet-gen-cockroach: + jet -dsn=postgres://jet:jet@127.0.0.1:26257/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=northwind -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/ # docker-compose-cleanup will stop and remove test containers, volumes, and images. cleanup: diff --git a/tests/dbconfig/dbconfig.go b/tests/dbconfig/dbconfig.go index bbf73f99..59ff402a 100644 --- a/tests/dbconfig/dbconfig.go +++ b/tests/dbconfig/dbconfig.go @@ -15,7 +15,24 @@ const ( ) // PostgresConnectString is PostgreSQL test database connection string -var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName) +var PostgresConnectString = pgConnectionString(PgHost, PgPort, PgUser, PgPassword, PgDBName) + +// Postgres test database connection parameters +const ( + CockroachHost = "localhost" + CockroachPort = 26257 + CockroachUser = "jet" + CockroachPassword = "jet" + CockroachDBName = "jetdb" +) + +// CockroachConnectString is Cockroach test database connection string +var CockroachConnectString = pgConnectionString(CockroachHost, CockroachPort, CockroachUser, CockroachPassword, CockroachDBName) + +func pgConnectionString(host string, port int, user, password, dbName string) string { + return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbName) +} // MySQL test database connection parameters const ( diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 2e913f13..9c562fbe 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -37,3 +37,16 @@ services: - '50903:3306' volumes: - ./testdata/init/mysql:/docker-entrypoint-initdb.d + + cockroach: + image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 + environment: + - COCKROACH_USER=jet + - COCKROACH_PASSWORD=jet + - COCKROACH_DATABASE=jetdb + ports: + - "26257:26257" + command: start-single-node --insecure +# volumes: +# - ./testdata/init/cockroach:/docker-entrypoint-initdb.d + diff --git a/tests/init/init.go b/tests/init/init.go index c1c842ab..633457e8 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -1,10 +1,12 @@ package main import ( + "context" "database/sql" "flag" "fmt" "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/tests/internal/utils/repo" "io/ioutil" @@ -12,46 +14,53 @@ import ( "os/exec" "strings" - "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" + _ "github.com/jackc/pgx/v4/stdlib" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) var testSuite string func init() { - flag.StringVar(&testSuite, "testsuite", "all", "Test suite name (postgres or mysql)") - + flag.StringVar(&testSuite, "testsuite", "all", "Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)") flag.Parse() } -func main() { - - testSuite = strings.ToLower(testSuite) - - if testSuite == "postgres" { - initPostgresDB() - return - } +// Database names +const ( + Postgres = "postgres" + MySql = "mysql" + MariaDB = "mariadb" + Sqlite = "sqlite" + Cockroach = "cockroach" +) - if testSuite == "mysql" || testSuite == "mariadb" { - initMySQLDB(testSuite == "mariadb") - return - } +func main() { - if testSuite == "sqlite" { + switch strings.ToLower(testSuite) { + case Postgres: + initPostgresDB(Postgres, dbconfig.PostgresConnectString) + case Cockroach: + initPostgresDB(Cockroach, dbconfig.CockroachConnectString) + case MySql: + initMySQLDB(false) + case MariaDB: + initMySQLDB(true) + case Sqlite: + initSQLiteDB() + case "all": + initPostgresDB(Cockroach, dbconfig.CockroachConnectString) + initPostgresDB(Postgres, dbconfig.PostgresConnectString) + initMySQLDB(false) + initMySQLDB(true) initSQLiteDB() - return + default: + panic("invalid testsuite flag. Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)") } - - initPostgresDB() - initMySQLDB(false) - initMySQLDB(true) - initSQLiteDB() } func initSQLiteDB() { @@ -109,8 +118,8 @@ func initMySQLDB(isMariaDB bool) { } } -func initPostgresDB() { - db, err := sql.Open("postgres", dbconfig.PostgresConnectString) +func initPostgresDB(dbType string, connectionString string) { + db, err := sql.Open("postgres", connectionString) if err != nil { panic("Failed to connect to test db: " + err.Error()) } @@ -120,26 +129,19 @@ func initPostgresDB() { }() schemaNames := []string{ + "northwind", "dvds", "test_sample", "chinook", "chinook2", - "northwind", } for _, schemaName := range schemaNames { + fmt.Println("\nInitializing", schemaName, "schema...") - execFile(db, "./testdata/init/postgres/"+schemaName+".sql") + execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName)) - err = postgres.Generate("./.gentestdata", postgres.DBConnection{ - Host: dbconfig.PgHost, - Port: dbconfig.PgPort, - User: dbconfig.PgUser, - Password: dbconfig.PgPassword, - DBName: dbconfig.PgDBName, - SchemaName: schemaName, - SslMode: "disable", - }) + err = postgres.GenerateDSN(connectionString, schemaName, "./.gentestdata") throw.OnError(err) } } @@ -148,10 +150,32 @@ func execFile(db *sql.DB, sqlFilePath string) { testSampleSql, err := ioutil.ReadFile(sqlFilePath) throw.OnError(err) - _, err = db.Exec(string(testSampleSql)) + err = execInTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(string(testSampleSql)) + return err + }) throw.OnError(err) } +func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ + Isolation: sql.LevelReadUncommitted, // to speed up initialization of test database + }) + + if err != nil { + return err + } + + err = f(tx) + + if err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + func printOnError(err error) { if err != nil { fmt.Println(err.Error()) diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 428a0e64..bf257511 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -20,7 +20,7 @@ import ( func TestAllTypes(t *testing.T) { - dest := []model.AllTypes{} + var dest []model.AllTypes err := AllTypes. SELECT(AllTypes.AllColumns). @@ -39,7 +39,7 @@ func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes - dest := []AllTypesView{} + var dest []AllTypesView err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) require.NoError(t, err) @@ -539,6 +539,8 @@ func TestTimeExpressions(t *testing.T) { AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)), + EXTRACT(DAY_HOUR, AllTypes.Time), + CURRENT_TIME(), CURRENT_TIME(3), ) @@ -574,6 +576,7 @@ SELECT CAST('20:34:58' AS TIME), all_types.time - INTERVAL all_types.small_int MINUTE, all_types.time - INTERVAL 3 MINUTE, (all_types.time + INTERVAL 20 MINUTE) - INTERVAL 11 HOUR, + EXTRACT(DAY_HOUR FROM all_types.time), CURRENT_TIME, CURRENT_TIME(3) FROM test_sample.all_types; @@ -936,6 +939,62 @@ func TestINTERVAL(t *testing.T) { require.NoError(t, err) } +func TestTimeEXTRACT(t *testing.T) { + stmt := SELECT( + EXTRACT(MICROSECOND, TimeT(time.Now())), + EXTRACT(SECOND, AllTypes.Time), + EXTRACT(MINUTE, AllTypes.Timestamp), + EXTRACT(HOUR, AllTypes.Timestamp), + EXTRACT(DAY, AllTypes.Date), + EXTRACT(WEEK, AllTypes.Timestamp), + EXTRACT(MONTH, AllTypes.Timestamp.ADD(INTERVAL(1, DAY))), + EXTRACT(QUARTER, AllTypes.Timestamp), + EXTRACT(YEAR, AllTypes.Timestamp).EQ(Int(1189654)), + EXTRACT(SECOND_MICROSECOND, AllTypes.Time), + EXTRACT(MINUTE_MICROSECOND, AllTypes.DateTime), + EXTRACT(MINUTE_SECOND, AllTypes.Timestamp), + EXTRACT(HOUR_MICROSECOND, AllTypes.Timestamp), + EXTRACT(HOUR_SECOND, AllTypes.Timestamp), + EXTRACT(HOUR_MINUTE, AllTypes.Timestamp), + EXTRACT(DAY_MICROSECOND, AllTypes.Timestamp), + EXTRACT(DAY_SECOND, AllTypes.Timestamp), + EXTRACT(DAY_MINUTE, AllTypes.Timestamp), + EXTRACT(DAY_HOUR, AllTypes.Timestamp), + EXTRACT(YEAR_MONTH, AllTypes.Timestamp), + ).FROM( + AllTypes, + ) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT EXTRACT(MICROSECOND FROM CAST(? AS TIME)), + EXTRACT(SECOND FROM all_types.time), + EXTRACT(MINUTE FROM all_types.timestamp), + EXTRACT(HOUR FROM all_types.timestamp), + EXTRACT(DAY FROM all_types.date), + EXTRACT(WEEK FROM all_types.timestamp), + EXTRACT(MONTH FROM all_types.timestamp + INTERVAL 1 DAY), + EXTRACT(QUARTER FROM all_types.timestamp), + EXTRACT(YEAR FROM all_types.timestamp) = ?, + EXTRACT(SECOND_MICROSECOND FROM all_types.time), + EXTRACT(MINUTE_MICROSECOND FROM all_types.date_time), + EXTRACT(MINUTE_SECOND FROM all_types.timestamp), + EXTRACT(HOUR_MICROSECOND FROM all_types.timestamp), + EXTRACT(HOUR_SECOND FROM all_types.timestamp), + EXTRACT(HOUR_MINUTE FROM all_types.timestamp), + EXTRACT(DAY_MICROSECOND FROM all_types.timestamp), + EXTRACT(DAY_SECOND FROM all_types.timestamp), + EXTRACT(DAY_MINUTE FROM all_types.timestamp), + EXTRACT(DAY_HOUR FROM all_types.timestamp), + EXTRACT(YEAR_MONTH FROM all_types.timestamp) +FROM test_sample.all_types; +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} + func TestAllTypesInsert(t *testing.T) { tx, err := db.Begin() require.NoError(t, err) diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index 709ce1a2..2c92367f 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -13,24 +13,20 @@ import ( ) func TestDeleteWithWhere(t *testing.T) { - initForDeleteTest(t) - - var expectedSQL = ` -DELETE FROM test_sample.link -WHERE link.name IN ('Gmail', 'Outlook'); -` deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) - testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") - testutils.AssertExec(t, deleteStmt, db, 2) + testutils.AssertDebugStatementSql(t, deleteStmt, ` +DELETE FROM test_sample.link +WHERE link.name IN ('Gmail', 'Outlook'); +`, "Gmail", "Outlook") + + testutils.AssertExecAndRollback(t, deleteStmt, db, 2) requireLogged(t, deleteStmt) } func TestDeleteWithWhereOrderByLimit(t *testing.T) { - initForDeleteTest(t) - var expectedSQL = ` DELETE FROM test_sample.link WHERE link.name IN ('Gmail', 'Outlook') @@ -44,13 +40,11 @@ LIMIT 1; LIMIT(1) testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1)) - testutils.AssertExec(t, deleteStmt, db, 1) + testutils.AssertExecAndRollback(t, deleteStmt, db, 1) requireLogged(t, deleteStmt) } func TestDeleteQueryContext(t *testing.T) { - initForDeleteTest(t) - deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) @@ -60,7 +54,7 @@ func TestDeleteQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := deleteStmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") @@ -68,8 +62,6 @@ func TestDeleteQueryContext(t *testing.T) { } func TestDeleteExecContext(t *testing.T) { - initForDeleteTest(t) - deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) @@ -84,19 +76,7 @@ func TestDeleteExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") } -func initForDeleteTest(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(Link.URL, Link.Name, Link.Description). - VALUES("www.gmail.com", "Gmail", "Email service developed by Google"). - VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft") - - testutils.AssertExec(t, stmt, db, 2) -} - func TestDeleteWithUsing(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() - stmt := table.Rental.DELETE(). USING( table.Rental. @@ -116,5 +96,5 @@ USING dvds.rental WHERE (staff.staff_id != ?) AND (rental.rental_id < ?); `) - testutils.AssertExec(t, stmt, tx) + testutils.AssertExecAndRollback(t, stmt, db) } diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index a414df32..e8f8d8f7 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -88,45 +88,70 @@ func TestCmdGenerator(t *testing.T) { } func TestIgnoreTablesViewsEnums(t *testing.T) { - cmd := exec.Command("jet", - "-source=MySQL", - "-dbname=dvds", - "-host="+dbconfig.MySqLHost, - "-port="+strconv.Itoa(dbconfig.MySQLPort), - "-user="+dbconfig.MySQLUser, - "-password="+dbconfig.MySQLPassword, - "-ignore-tables=actor,ADDRESS,Category, city ,country,staff,store,rental", - "-ignore-views=actor_info,CUSTomER_LIST, film_list", - "-ignore-enums=film_list_rating,film_rating", - "-path="+genTestDir3) + tests := []struct { + name string + args []string + }{ + { + name: "with dsn", + args: []string{ + "-dsn=mysql://" + dbconfig.MySQLConnectionString(sourceIsMariaDB(), "dvds"), + "-ignore-tables=actor,ADDRESS,Category, city ,country,staff,store,rental", + "-ignore-views=actor_info,CUSTomER_LIST, film_list", + "-ignore-enums=film_list_rating,film_rating", + "-path=" + genTestDir3, + }, + }, + { + name: "without dsn", + args: []string{ + "-source=MySQL", + "-dbname=dvds", + "-host=" + dbconfig.MySqLHost, + "-port=" + strconv.Itoa(dbconfig.MySQLPort), + "-user=" + dbconfig.MySQLUser, + "-password=" + dbconfig.MySQLPassword, + "-ignore-tables=actor,ADDRESS,Category, city ,country,staff,store,rental", + "-ignore-views=actor_info,CUSTomER_LIST, film_list", + "-ignore-enums=film_list_rating,film_rating", + "-path=" + genTestDir3, + }, + }, + } - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := exec.Command("jet", tt.args...) - err := cmd.Run() - require.NoError(t, err) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout - tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") - require.NoError(t, err) - testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "customer.go", "film.go", "film_actor.go", - "film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go") + err := cmd.Run() + require.NoError(t, err) - viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") - require.NoError(t, err) - testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go", - "sales_by_film_category.go", "sales_by_store.go", "staff_list.go") + tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "customer.go", "film.go", "film_actor.go", + "film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go") - enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") - require.NoError(t, err) - testutils.AssertFileNamesEqual(t, enumFiles, "nicer_but_slower_film_list_rating.go") + viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "sales_by_store.go", "staff_list.go") - modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") - require.NoError(t, err) + enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") + require.NoError(t, err) + testutils.AssertFileNamesEqual(t, enumFiles, "nicer_but_slower_film_list_rating.go") - testutils.AssertFileNamesEqual(t, modelFiles, - "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", - "payment.go", "nicer_but_slower_film_list_rating.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", - "sales_by_store.go", "staff_list.go") + modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "nicer_but_slower_film_list_rating.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "sales_by_store.go", "staff_list.go") + }) + } } func assertGeneratedFiles(t *testing.T) { @@ -236,6 +261,16 @@ func (a ActorTable) FromSchema(schemaName string) ActorTable { return newActorTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorTable with assigned table prefix +func (a ActorTable) WithPrefix(prefix string) ActorTable { + return newActorTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorTable with assigned table suffix +func (a ActorTable) WithSuffix(suffix string) ActorTable { + return newActorTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorTable(schemaName, tableName, alias string) ActorTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") @@ -322,6 +357,16 @@ func (a ActorInfoTable) FromSchema(schemaName string) ActorInfoTable { return newActorInfoTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorInfoTable with assigned table prefix +func (a ActorInfoTable) WithPrefix(prefix string) ActorInfoTable { + return newActorInfoTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorInfoTable with assigned table suffix +func (a ActorInfoTable) WithSuffix(suffix string) ActorInfoTable { + return newActorInfoTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorInfoTable(schemaName, tableName, alias string) ActorInfoTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 55fc706b..10887f5a 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -2,6 +2,7 @@ package mysql import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" @@ -13,51 +14,47 @@ import ( ) func TestInsertValues(t *testing.T) { - cleanUpLinkTable(t) - - var expectedSQL = ` -INSERT INTO test_sample.link (id, url, name, description) -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL); -` - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(101, "http://www.google.com", "Google", DEFAULT). VALUES(102, "http://www.yahoo.com", "Yahoo", nil) - testutils.AssertDebugStatementSql(t, insertQuery, expectedSQL, + testutils.AssertDebugStatementSql(t, insertQuery, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 101, "http://www.google.com", "Google", 102, "http://www.yahoo.com", "Yahoo", nil) - _, err := insertQuery.Exec(db) - require.NoError(t, err) - requireLogged(t, insertQuery) - - insertedLinks := []model.Link{} - - err = Link.SELECT(Link.AllColumns). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(db, &insertedLinks) - - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 3) - - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) - - testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ - ID: 101, - URL: "http://www.google.com", - Name: "Google", - }) - - testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ - ID: 102, - URL: "http://www.yahoo.com", - Name: "Yahoo", + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := insertQuery.Exec(tx) + require.NoError(t, err) + requireLogged(t, insertQuery) + + var insertedLinks []model.Link + + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.BETWEEN(Int(100), Int(199))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) }) } @@ -68,42 +65,34 @@ var postgreTutorial = model.Link{ } func TestInsertEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - - expectedSQL := ` -INSERT INTO test_sample.link -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); -` - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") - _, err := stmt.Exec(db) - require.NoError(t, err) - requireLogged(t, stmt) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + requireLogged(t, stmt) - insertedLinks := []model.Link{} + var insertedLinks []model.Link - err = Link.SELECT(Link.AllColumns). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(db, &insertedLinks) + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.BETWEEN(Int(100), Int(199))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 1) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + }) } func TestInsertModelObject(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); -` - linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -113,19 +102,19 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); INSERT(Link.URL, Link.Name). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); +`, + "http://www.duckduckgo.com", "Duck Duck go") - _, err := query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link -VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); -` - linkData := model.Link{ ID: 1000, URL: "http://www.duckduckgo.com", @@ -136,20 +125,18 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); INSERT(). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +`, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) - _, err := query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelsObject(t *testing.T) { - expectedSQL := ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); -` - tutorial := model.Link{ URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", @@ -169,24 +156,23 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), INSERT(Link.URL, Link.Name). MODELS([]model.Link{tutorial, google, yahoo}) - testutils.AssertDebugStatementSql(t, query, expectedSQL, + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", "http://www.yahoo.com", "Yahoo") - _, err := query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertUsingMutableColumns(t *testing.T) { - var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); -` - google := model.Link{ URL: "http://www.google.com", Name: "Google", @@ -203,31 +189,25 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), MODEL(google). MODELS([]model.Link{google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil, "http://www.yahoo.com", "Yahoo", nil) - _, err := stmt.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + }) } func TestInsertQuery(t *testing.T) { - _, err := Link.DELETE(). - WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))). - Exec(db) - require.NoError(t, err) - - var expectedSQL = ` -INSERT INTO test_sample.link (url, name) ( - SELECT link.url AS "link.url", - link.name AS "link.name" - FROM test_sample.link - WHERE link.id = 1 -); -` - query := Link. INSERT(Link.URL, Link.Name). QUERY( @@ -236,19 +216,28 @@ INSERT INTO test_sample.link (url, name) ( WHERE(Link.ID.EQ(Int(1))), ) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) ( + SELECT link.url AS "link.url", + link.name AS "link.name" + FROM test_sample.link + WHERE link.id = 1 +); +`, int64(1)) - _, err = query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) - youtubeLinks := []model.Link{} - err = Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Youtube"))). - Query(db, &youtubeLinks) + var youtubeLinks []model.Link + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Youtube"))). + Query(tx, &youtubeLinks) - require.NoError(t, err) - require.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) + }) } func TestInsertOnDuplicateKey(t *testing.T) { @@ -272,28 +261,29 @@ ON DUPLICATE KEY UPDATE id = (id + ?), randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", int64(11), "PostgreSQL Tutorial 2") - testutils.AssertExec(t, stmt, db, 3) - - newLinks := []model.Link{} - - err := SELECT(Link.AllColumns). - FROM(Link). - WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). - Query(db, &newLinks) - - require.NoError(t, err) - require.Len(t, newLinks, 1) - require.Equal(t, newLinks[0], model.Link{ - ID: randId + 11, - URL: "http://www.postgresqltutorial.com", - Name: "PostgreSQL Tutorial 2", - Description: nil, + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + + var newLinks []model.Link + + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). + Query(tx, &newLinks) + + require.NoError(t, err) + require.Len(t, newLinks, 1) + require.Equal(t, newLinks[0], model.Link{ + ID: randId + 11, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial 2", + Description: nil, + }) }) } func TestInsertWithQueryContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) @@ -302,15 +292,13 @@ func TestInsertWithQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := stmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") } func TestInsertWithExecContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) @@ -323,8 +311,3 @@ func TestInsertWithExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") } - -func cleanUpLinkTable(t *testing.T) { - _, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db) - require.NoError(t, err) -} diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index e04580d4..f6ce57d8 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -96,9 +96,3 @@ func skipForMariaDB(t *testing.T) { t.SkipNow() } } - -func beginTx(t *testing.T) *sql.Tx { - tx, err := db.Begin() - require.NoError(t, err) - return tx -} diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index ba628a1b..c03d4240 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -2,6 +2,7 @@ package mysql import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" @@ -13,8 +14,6 @@ import ( ) func TestUpdateValues(t *testing.T) { - setupLinkTableForUpdateTest(t) - var expectedSQL = ` UPDATE test_sample.link SET name = 'Bong', @@ -28,8 +27,26 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") - testutils.AssertExec(t, query, db) - requireLogged(t, query) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + + var links []model.Link + + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bong"))). + Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + }) + }) t.Run("new version", func(t *testing.T) { @@ -41,29 +58,29 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") - testutils.AssertExec(t, stmt, db) - requireLogged(t, stmt) - }) - - links := []model.Link{} - - err := Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Bong"))). - Query(db, &links) - - require.NoError(t, err) - require.Equal(t, len(links), 1) - testutils.AssertDeepEqual(t, links[0], model.Link{ - ID: 204, - URL: "http://bong.com", - Name: "Bong", + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + + var links []model.Link + + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bong"))). + Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + }) }) } func TestUpdateWithSubQueries(t *testing.T) { - setupLinkTableForUpdateTest(t) - expectedSQL := ` UPDATE test_sample.link SET name = ?, @@ -86,7 +103,7 @@ WHERE link.name = ?; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") - testutils.AssertExec(t, query, db) + testutils.AssertExecAndRollback(t, query, db) requireLogged(t, query) }) @@ -104,14 +121,12 @@ WHERE link.name = ?; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") - testutils.AssertExec(t, query, db) + testutils.AssertExecAndRollback(t, query, db) requireLogged(t, query) }) } func TestUpdateWithModelData(t *testing.T) { - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -123,24 +138,20 @@ func TestUpdateWithModelData(t *testing.T) { MODEL(link). WHERE(Link.ID.EQ(Int32(link.ID))) - expectedSQL := ` + testutils.AssertStatementSql(t, stmt, ` UPDATE test_sample.link SET id = ?, url = ?, name = ?, description = ? WHERE link.id = ?; -` - testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) +`, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) - testutils.AssertExec(t, stmt, db) + testutils.AssertExecAndRollback(t, stmt, db) requireLogged(t, stmt) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { - - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -154,23 +165,19 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { MODEL(link). WHERE(Link.ID.EQ(Int32(link.ID))) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET description = NULL, name = 'DuckDuckGo', url = 'http://www.duckduckgo.com' WHERE link.id = 201; -` - - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) +`, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) - testutils.AssertExec(t, stmt, db) + testutils.AssertExecAndRollback(t, stmt, db) requireLogged(t, stmt) } func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -192,7 +199,7 @@ WHERE link.id = 201; //fmt.Println(stmt.DebugSql()) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) - testutils.AssertExec(t, stmt, db) + testutils.AssertExecAndRollback(t, stmt, db) } func TestUpdateWithInvalidModelData(t *testing.T) { @@ -201,8 +208,6 @@ func TestUpdateWithInvalidModelData(t *testing.T) { require.Equal(t, r, "missing struct field for column : id") }() - setupLinkTableForUpdateTest(t) - link := struct { Ident int URL string @@ -215,17 +220,13 @@ func TestUpdateWithInvalidModelData(t *testing.T) { Name: "DuckDuckGo", } - stmt := Link. + _ = Link. UPDATE(Link.AllColumns). MODEL(link). WHERE(Link.ID.EQ(Int(int64(link.Ident)))) - - stmt.Sql() } func TestUpdateQueryContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -243,8 +244,6 @@ func TestUpdateQueryContext(t *testing.T) { } func TestUpdateExecContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -261,9 +260,6 @@ func TestUpdateExecContext(t *testing.T) { } func TestUpdateWithJoin(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() - statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). UPDATE(table.Staff.LastName). SET(String("New staff name")). @@ -276,21 +272,5 @@ SET last_name = ? WHERE staff.staff_id = ?; `, "New staff name", int64(1)) - _, err := statement.Exec(tx) - require.NoError(t, err) -} - -func setupLinkTableForUpdateTest(t *testing.T) { - - cleanUpLinkTable(t) - - _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). - VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES(201, "http://www.ask.com", "Ask", DEFAULT). - VALUES(202, "http://www.ask.com", "Ask", DEFAULT). - VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES(204, "http://www.bing.com", "Bing", DEFAULT). - Exec(db) - - require.NoError(t, err) + testutils.AssertExecAndRollback(t, statement, db) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 405ec9ec..2a1e0e2d 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "testing" "time" @@ -17,10 +18,10 @@ import ( ) func TestAllTypesSelect(t *testing.T) { - dest := []model.AllTypes{} + var dest []model.AllTypes err := AllTypes.SELECT( - AllTypes.AllColumns, + AllTypesAllColumns, ).LIMIT(2). Query(db, &dest) require.NoError(t, err) @@ -32,7 +33,7 @@ func TestAllTypesSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes - dest := []AllTypesView{} + var dest []AllTypesView err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) require.NoError(t, err) @@ -44,40 +45,123 @@ func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) { skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) - query := AllTypes.INSERT(AllTypes.AllColumns). + query := AllTypes.INSERT(AllTypesAllColumns). MODEL(allTypesRow0). MODEL(&allTypesRow1). RETURNING(AllTypes.AllColumns) - dest := []model.AllTypes{} - err := query.Query(db, &dest) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.AllTypes + err := query.Query(tx, &dest) + require.NoError(t, err) - require.Equal(t, len(dest), 2) - testutils.AssertDeepEqual(t, dest[0], allTypesRow0) - testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + if sourceIsCockroachDB() { + return + } + require.Equal(t, len(dest), 2) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + }) } +var AllTypesAllColumns = AllTypes.AllColumns.Except(IntegerColumn("rowid")) + func TestAllTypesInsertQuery(t *testing.T) { - query := AllTypes.INSERT(AllTypes.AllColumns). + query := AllTypes.INSERT(AllTypesAllColumns). QUERY( AllTypes. - SELECT(AllTypes.AllColumns). + SELECT(AllTypesAllColumns). LIMIT(2), ). - RETURNING(AllTypes.AllColumns) + RETURNING(AllTypesAllColumns) - dest := []model.AllTypes{} - err := query.Query(db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.AllTypes + err := query.Query(tx, &dest) + + require.NoError(t, err) + require.Equal(t, len(dest), 2) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + }) +} +func TestUUIDType(t *testing.T) { + id := uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + + query := AllTypes. + SELECT(AllTypes.UUID, AllTypes.UUIDPtr). + WHERE(AllTypes.UUID.EQ(UUID(id))) + + testutils.AssertDebugStatementSql(t, query, ` +SELECT all_types.uuid AS "all_types.uuid", + all_types.uuid_ptr AS "all_types.uuid_ptr" +FROM test_sample.all_types +WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; +`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + + result := model.AllTypes{} + + err := query.Query(db, &result) require.NoError(t, err) - require.Equal(t, len(dest), 2) - testutils.AssertDeepEqual(t, dest[0], allTypesRow0) - testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + requireLogged(t, query) +} + +func TestBytea(t *testing.T) { + byteArrHex := "\\x48656c6c6f20476f7068657221" + byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21") + + insertStmt := AllTypes.INSERT(AllTypes.Bytea, AllTypes.ByteaPtr). + VALUES(byteArrHex, byteArrBin). + RETURNING(AllTypes.Bytea, AllTypes.ByteaPtr) + + testutils.AssertStatementSql(t, insertStmt, ` +INSERT INTO test_sample.all_types (bytea, bytea_ptr) +VALUES ($1, $2) +RETURNING all_types.bytea AS "all_types.bytea", + all_types.bytea_ptr AS "all_types.bytea_ptr"; +`, byteArrHex, byteArrBin) + + var inserted model.AllTypes + err := insertStmt.Query(db, &inserted) + require.NoError(t, err) + + require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") + // It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver. + // pq driver always encodes parameter string if destination column is of type bytea. + // Probably pq driver error. + // require.Equal(t, string(inserted.Bytea), "Hello Gopher!") + + stmt := SELECT( + AllTypes.Bytea, + AllTypes.ByteaPtr, + ).FROM( + AllTypes, + ).WHERE( + AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)), + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT all_types.bytea AS "all_types.bytea", + all_types.bytea_ptr AS "all_types.bytea_ptr" +FROM test_sample.all_types +WHERE all_types.bytea_ptr = $1::bytea; +`, byteArrBin) + + var dest model.AllTypes + + err = stmt.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") + // Probably pq driver error. + // require.Equal(t, string(dest.Bytea), "Hello Gopher!") } func TestAllTypesFromSubQuery(t *testing.T) { - subQuery := SELECT(AllTypes.AllColumns). + subQuery := SELECT(AllTypesAllColumns). FROM(AllTypes). AsTable("allTypesSubQuery") @@ -214,7 +298,7 @@ FROM ( LIMIT 2; `) - dest := []model.AllTypes{} + var dest []model.AllTypes err := mainQuery.Query(db, &dest) require.NoError(t, err) @@ -298,7 +382,6 @@ LIMIT $11; } func TestExpressionCast(t *testing.T) { - skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text' query := AllTypes.SELECT( @@ -315,19 +398,28 @@ func TestExpressionCast(t *testing.T) { CAST(Int(234)).AS_TEXT(), CAST(String("1/8/1999")).AS_DATE(), CAST(String("04:05:06.789")).AS_TIME(), - CAST(String("04:05:06 PST")).AS_TIMEZ(), + CAST(String("04:05:06+01:00")).AS_TIMEZ(), CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), - CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(), CAST(String("04:05:06")).AS_INTERVAL(), - TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), - TO_CHAR(AllTypes.Integer, String("999")), - TO_CHAR(AllTypes.DoublePrecision, String("999D9")), - TO_CHAR(AllTypes.Numeric, String("999D99S")), + func() ProjectionList { + if sourceIsCockroachDB() { + return ProjectionList{NULL} + } + + // cockroach doesn't support currently + return ProjectionList{ + TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), + TO_CHAR(AllTypes.Integer, String("999")), + TO_CHAR(AllTypes.DoublePrecision, String("999D9")), + TO_CHAR(AllTypes.Numeric, String("999D99S")), - TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")), - TO_NUMBER(String("12,454"), String("99G999D9S")), - TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")), + TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")), + TO_NUMBER(String("12,454"), String("99G999D9S")), + TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")), + } + }(), COALESCE(AllTypes.IntegerPtr, AllTypes.SmallIntPtr, NULL, Int(11)), NULLIF(AllTypes.Text, String("(none)")), @@ -337,16 +429,15 @@ func TestExpressionCast(t *testing.T) { Raw("current_database()"), ) - //fmt.Println(query.DebugSql()) - - dest := []struct{}{} + var dest []struct{} err := query.Query(db, &dest) require.NoError(t, err) } func TestStringOperators(t *testing.T) { - skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' + skipForCockroachDB(t) // some string functions are still unimplemented + skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' query := AllTypes.SELECT( AllTypes.Text.EQ(AllTypes.Char), @@ -395,18 +486,18 @@ func TestStringOperators(t *testing.T) { CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), - CONVERT(String("bytea"), String("UTF8"), String("LATIN1")), + CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")), CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")), - CONVERT_FROM(String("text_in_utf8"), String("UTF8")), + CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")), CONVERT_TO(String("text_in_utf8"), String("UTF8")), - ENCODE(String("123\000\001"), String("base64")), + ENCODE(Bytea("123\000\001"), String("base64")), DECODE(String("MTIzAAE="), String("base64")), FORMAT(String("Hello %s, %1$s"), String("World")), INITCAP(String("hi THOMAS")), LEFT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)), - LENGTH(String("jose")), - LENGTH(String("jose"), String("UTF8")), + LENGTH(Bytea("jose")), + LENGTH(Bytea("jose"), String("UTF8")), LPAD(String("Hi"), Int(5)), LPAD(String("Hi"), Int(5), String("xy")), RPAD(String("Hi"), Int(5)), @@ -421,8 +512,6 @@ func TestStringOperators(t *testing.T) { TO_HEX(AllTypes.IntegerPtr), ) - //fmt.Println(query.DebugSql()) - dest := []struct{}{} err := query.Query(db, &dest) @@ -501,6 +590,8 @@ LIMIT $5; } func TestFloatOperators(t *testing.T) { + skipForCockroachDB(t) // some functions are still unimplemented + query := AllTypes.SELECT( AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"), AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"), @@ -604,6 +695,8 @@ LIMIT $38; } func TestIntegerOperators(t *testing.T) { + skipForCockroachDB(t) // some functions are still unimplemented + query := AllTypes.SELECT( AllTypes.BigInt, AllTypes.BigIntPtr, @@ -733,6 +826,8 @@ LIMIT $27; } func TestTimeExpression(t *testing.T) { + skipForCockroachDB(t) + query := AllTypes.SELECT( AllTypes.Time.EQ(AllTypes.Time), AllTypes.Time.EQ(Time(23, 6, 6, 1)), @@ -804,15 +899,17 @@ func TestTimeExpression(t *testing.T) { NOW(), ) - //fmt.Println(query.DebugSql()) + // fmt.Println(query.DebugSql()) - dest := []struct{}{} + var dest []struct{} err := query.Query(db, &dest) require.NoError(t, err) } func TestInterval(t *testing.T) { + skipForCockroachDB(t) + stmt := SELECT( INTERVAL(1, YEAR), INTERVAL(1, MONTH), @@ -866,6 +963,66 @@ func TestInterval(t *testing.T) { requireLogged(t, stmt) } +func TestTimeEXTRACT(t *testing.T) { + stmt := SELECT( + EXTRACT(CENTURY, AllTypes.Timestampz), + EXTRACT(DAY, AllTypes.Timestamp), + EXTRACT(DECADE, AllTypes.Date), + EXTRACT(DOW, AllTypes.TimestampzPtr), + EXTRACT(DOY, DateT(time.Now())), + EXTRACT(EPOCH, TimestampT(time.Now())), + EXTRACT(HOUR, AllTypes.Time.ADD(INTERVAL(1, HOUR))), + EXTRACT(ISODOW, AllTypes.Timestampz), + EXTRACT(ISOYEAR, AllTypes.Timestampz), + EXTRACT(JULIAN, AllTypes.Timestampz).EQ(Float(3456.123)), + EXTRACT(MICROSECOND, AllTypes.Timestampz), + EXTRACT(MILLENNIUM, AllTypes.Timestampz), + EXTRACT(MILLISECOND, AllTypes.Timez), + EXTRACT(MINUTE, INTERVAL(1, HOUR, 2, MINUTE)), + EXTRACT(MONTH, AllTypes.Timestampz), + EXTRACT(QUARTER, AllTypes.Timestampz), + EXTRACT(SECOND, AllTypes.Timestampz), + EXTRACT(TIMEZONE, AllTypes.Timestampz), + EXTRACT(TIMEZONE_HOUR, AllTypes.Timestampz), + EXTRACT(TIMEZONE_MINUTE, AllTypes.Timestampz), + EXTRACT(WEEK, AllTypes.Timestampz), + EXTRACT(YEAR, AllTypes.Timestampz), + ).FROM( + AllTypes, + ) + + // fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT EXTRACT(CENTURY FROM all_types.timestampz), + EXTRACT(DAY FROM all_types.timestamp), + EXTRACT(DECADE FROM all_types.date), + EXTRACT(DOW FROM all_types.timestampz_ptr), + EXTRACT(DOY FROM $1::date), + EXTRACT(EPOCH FROM $2::timestamp without time zone), + EXTRACT(HOUR FROM all_types.time + INTERVAL '1 HOUR'), + EXTRACT(ISODOW FROM all_types.timestampz), + EXTRACT(ISOYEAR FROM all_types.timestampz), + EXTRACT(JULIAN FROM all_types.timestampz) = $3, + EXTRACT(MICROSECOND FROM all_types.timestampz), + EXTRACT(MILLENNIUM FROM all_types.timestampz), + EXTRACT(MILLISECOND FROM all_types.timez), + EXTRACT(MINUTE FROM INTERVAL '1 HOUR 2 MINUTE'), + EXTRACT(MONTH FROM all_types.timestampz), + EXTRACT(QUARTER FROM all_types.timestampz), + EXTRACT(SECOND FROM all_types.timestampz), + EXTRACT(TIMEZONE FROM all_types.timestampz), + EXTRACT(TIMEZONE_HOUR FROM all_types.timestampz), + EXTRACT(TIMEZONE_MINUTE FROM all_types.timestampz), + EXTRACT(WEEK FROM all_types.timestampz), + EXTRACT(YEAR FROM all_types.timestampz) +FROM test_sample.all_types; +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { sql string @@ -1084,6 +1241,10 @@ LIMIT $6; dest.Timez = dest.Timez.UTC() dest.Timestampz = dest.Timestampz.UTC() + if sourceIsCockroachDB() { + return // rounding differences + } + testutils.AssertJSON(t, dest, ` { "Date": "2009-11-17T00:00:00Z", diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 2d9821cb..dba07ba1 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -6,6 +6,7 @@ import ( . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook2/table" "github.com/stretchr/testify/require" "testing" "time" @@ -188,34 +189,36 @@ func TestJoinEverything(t *testing.T) { manager := Employee.AS("Manager") - stmt := Artist. - LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)). - LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)). - LEFT_JOIN(Genre, Genre.GenreId.EQ(Track.GenreId)). - LEFT_JOIN(MediaType, MediaType.MediaTypeId.EQ(Track.MediaTypeId)). - LEFT_JOIN(PlaylistTrack, PlaylistTrack.TrackId.EQ(Track.TrackId)). - LEFT_JOIN(Playlist, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId)). - LEFT_JOIN(InvoiceLine, InvoiceLine.TrackId.EQ(Track.TrackId)). - LEFT_JOIN(Invoice, Invoice.InvoiceId.EQ(InvoiceLine.InvoiceId)). - LEFT_JOIN(Customer, Customer.CustomerId.EQ(Invoice.CustomerId)). - LEFT_JOIN(Employee, Employee.EmployeeId.EQ(Customer.SupportRepId)). - LEFT_JOIN(manager, manager.EmployeeId.EQ(Employee.ReportsTo)). - SELECT( - Artist.AllColumns, - Album.AllColumns, - Track.AllColumns, - Genre.AllColumns, - MediaType.AllColumns, - PlaylistTrack.AllColumns, - Playlist.AllColumns, - Invoice.AllColumns, - Customer.AllColumns, - Employee.AllColumns, - manager.AllColumns, - ). - ORDER_BY(Artist.ArtistId, Album.AlbumId, Track.TrackId, - Genre.GenreId, MediaType.MediaTypeId, Playlist.PlaylistId, - Invoice.InvoiceId, Customer.CustomerId) + stmt := SELECT( + Artist.AllColumns, + Album.AllColumns, + Track.AllColumns, + Genre.AllColumns, + MediaType.AllColumns, + PlaylistTrack.AllColumns, + Playlist.AllColumns, + Invoice.AllColumns, + Customer.AllColumns, + Employee.AllColumns, + manager.AllColumns, + ).FROM( + Artist. + LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)). + LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)). + LEFT_JOIN(Genre, Genre.GenreId.EQ(Track.GenreId)). + LEFT_JOIN(MediaType, MediaType.MediaTypeId.EQ(Track.MediaTypeId)). + LEFT_JOIN(PlaylistTrack, PlaylistTrack.TrackId.EQ(Track.TrackId)). + LEFT_JOIN(Playlist, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId)). + LEFT_JOIN(InvoiceLine, InvoiceLine.TrackId.EQ(Track.TrackId)). + LEFT_JOIN(Invoice, Invoice.InvoiceId.EQ(InvoiceLine.InvoiceId)). + LEFT_JOIN(Customer, Customer.CustomerId.EQ(Invoice.CustomerId)). + LEFT_JOIN(Employee, Employee.EmployeeId.EQ(Customer.SupportRepId)). + LEFT_JOIN(manager, manager.EmployeeId.EQ(Employee.ReportsTo)), + ).ORDER_BY( + Artist.ArtistId, Album.AlbumId, Track.TrackId, + Genre.GenreId, MediaType.MediaTypeId, Playlist.PlaylistId, + Invoice.InvoiceId, Customer.CustomerId, + ) var dest []struct { //list of all artist model.Artist @@ -398,11 +401,11 @@ FROM ( SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId", "subQuery1"."Artist.Name" AS "Artist.Name", "subQuery1".custom_column_1 AS "custom_column_1", - $1 AS "custom_column_2" + $1::text AS "custom_column_2" FROM ( SELECT "Artist"."ArtistId" AS "Artist.ArtistId", "Artist"."Name" AS "Artist.Name", - $2 AS "custom_column_1" + $2::text AS "custom_column_1" FROM chinook."Artist" ORDER BY "Artist"."ArtistId" ASC ) AS "subQuery1" @@ -721,11 +724,14 @@ ORDER BY "Album.AlbumId"; } func TestQueryWithContext(t *testing.T) { + if sourceIsCockroachDB() && !isPgxDriver() { + return // context cancellation doesn't work for pq driver + } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - dest := []model.Album{} + var dest []model.Album err := Album. CROSS_JOIN(Track). @@ -737,6 +743,9 @@ func TestQueryWithContext(t *testing.T) { } func TestExecWithContext(t *testing.T) { + if sourceIsCockroachDB() && !isPgxDriver() { + return // context cancellation doesn't work for pq driver + } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -807,7 +816,7 @@ ORDER BY "first10Artist"."Artist.ArtistId"; require.NoError(t, err) } -func Test_SchemaRename(t *testing.T) { +func TestMultiTenantDifferentSchema(t *testing.T) { Artist2 := Artist.FromSchema("chinook2") Album2 := Album.FromSchema("chinook2") @@ -828,10 +837,12 @@ func Test_SchemaRename(t *testing.T) { albumArtistID := Album2.ArtistId.From(first10Albums) - stmt := SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()). - FROM(first10Artist. - INNER_JOIN(first10Albums, artistID.EQ(albumArtistID))). - ORDER_BY(artistID) + stmt := SELECT( + first10Artist.AllColumns(), + first10Albums.AllColumns(), + ).FROM(first10Artist. + INNER_JOIN(first10Albums, artistID.EQ(albumArtistID)), + ).ORDER_BY(artistID) testutils.AssertDebugStatementSql(t, stmt, ` SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId", @@ -872,6 +883,182 @@ ORDER BY "first10Artist"."Artist.ArtistId"; require.Equal(t, dest[0].Album[0].Title, "Plays Metallica By Four Cellos") } +func TestMultiTenantSameSchemaDifferentTablePrefix(t *testing.T) { + + var selectAlbumsFrom = func(tenant string) SelectStatement { + Album := table.Album.WithPrefix(tenant) + + return SELECT( + Album.AllColumns, + ).FROM( + Album, + ).ORDER_BY( + Album.AlbumId.ASC(), + ).LIMIT(3) + } + + t.Run("tenant1", func(t *testing.T) { + stmt := selectAlbumsFrom("tenant1.") + + testutils.AssertStatementSql(t, stmt, ` +SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" +FROM chinook2."tenant1.Album" AS "Album" +ORDER BY "Album"."AlbumId" ASC +LIMIT $1; +`) + + var albums []model.Album + err := stmt.Query(db, &albums) + require.NoError(t, err) + + testutils.AssertJSON(t, albums, ` +[ + { + "AlbumId": 80, + "Title": "In Your Honor [Disc 2]", + "ArtistId": 84 + }, + { + "AlbumId": 81, + "Title": "One By One", + "ArtistId": 84 + }, + { + "AlbumId": 82, + "Title": "The Colour And The Shape", + "ArtistId": 84 + } +] +`) + }) + + t.Run("tenant2", func(t *testing.T) { + stmt := selectAlbumsFrom("tenant2.") + + testutils.AssertStatementSql(t, stmt, ` +SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" +FROM chinook2."tenant2.Album" AS "Album" +ORDER BY "Album"."AlbumId" ASC +LIMIT $1; +`) + + var albums []model.Album + err := stmt.Query(db, &albums) + require.NoError(t, err) + testutils.AssertJSON(t, albums, ` +[ + { + "AlbumId": 152, + "Title": "Master Of Puppets", + "ArtistId": 50 + }, + { + "AlbumId": 153, + "Title": "ReLoad", + "ArtistId": 50 + }, + { + "AlbumId": 154, + "Title": "Ride The Lightning", + "ArtistId": 50 + } +] +`) + }) +} + +func TestMultiTenantSameSchemaDifferentTableSuffix(t *testing.T) { + + var selectAlbumsFrom = func(tenant string) SelectStatement { + Album := table.Album.WithSuffix(tenant) + + return SELECT( + Album.AllColumns, + ).FROM( + Album, + ).ORDER_BY( + Album.AlbumId.ASC(), + ).LIMIT(3) + } + + t.Run("tenant1", func(t *testing.T) { + stmt := selectAlbumsFrom(".tenant1") + + testutils.AssertStatementSql(t, stmt, ` +SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" +FROM chinook2."Album.tenant1" AS "Album" +ORDER BY "Album"."AlbumId" ASC +LIMIT $1; +`) + + var albums []model.Album + err := stmt.Query(db, &albums) + require.NoError(t, err) + + testutils.AssertJSON(t, albums, ` +[ + { + "AlbumId": 80, + "Title": "In Your Honor [Disc 2]", + "ArtistId": 84 + }, + { + "AlbumId": 81, + "Title": "One By One", + "ArtistId": 84 + }, + { + "AlbumId": 82, + "Title": "The Colour And The Shape", + "ArtistId": 84 + } +] +`) + }) + + t.Run("tenant2", func(t *testing.T) { + stmt := selectAlbumsFrom(".tenant2") + + testutils.AssertStatementSql(t, stmt, ` +SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" +FROM chinook2."Album.tenant2" AS "Album" +ORDER BY "Album"."AlbumId" ASC +LIMIT $1; +`) + + var albums []model.Album + err := stmt.Query(db, &albums) + require.NoError(t, err) + testutils.AssertJSON(t, albums, ` +[ + { + "AlbumId": 152, + "Title": "Master Of Puppets", + "ArtistId": 50 + }, + { + "AlbumId": 153, + "Title": "ReLoad", + "ArtistId": 50 + }, + { + "AlbumId": 154, + "Title": "Ride The Lightning", + "ArtistId": 50 + } +] +`) + }) +} + var album1 = model.Album{ AlbumId: 1, Title: "For Those About To Rock We Salute You", @@ -891,6 +1078,8 @@ var album347 = model.Album{ } func TestAggregateFunc(t *testing.T) { + skipForCockroachDB(t) + stmt := SELECT( PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index abbb3449..ee8d3206 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" @@ -14,69 +15,49 @@ import ( ) func TestDeleteWithWhere(t *testing.T) { - initForDeleteTest(t) - - var expectedSQL = ` -DELETE FROM test_sample.link -WHERE link.name IN ('Gmail', 'Outlook'); -` deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) - testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") - - res, err := deleteStmt.ExecContext(context.Background(), db) + testutils.AssertDebugStatementSql(t, deleteStmt, ` +DELETE FROM test_sample.link +WHERE link.name IN ('Gmail'::text, 'Outlook'::text); +`, "Gmail", "Outlook") - require.NoError(t, err) - rows, err := res.RowsAffected() - require.NoError(t, err) - require.Equal(t, rows, int64(2)) + testutils.AssertExecAndRollback(t, deleteStmt, db, 2) requireQueryLogged(t, deleteStmt, int64(2)) } func TestDeleteWithWhereAndReturning(t *testing.T) { - initForDeleteTest(t) + deleteStmt := Link. + DELETE(). + WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))). + RETURNING(Link.AllColumns) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, deleteStmt, ` DELETE FROM test_sample.link -WHERE link.name IN ('Gmail', 'Outlook') +WHERE link.name IN ('Gmail'::text, 'Outlook'::text) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -` - deleteStmt := Link. - DELETE(). - WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))). - RETURNING(Link.AllColumns) - - testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") +`, "Gmail", "Outlook") - dest := []model.Link{} + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.Link - err := deleteStmt.Query(db, &dest) + err := deleteStmt.Query(tx, &dest) - require.NoError(t, err) + require.NoError(t, err) - require.Equal(t, len(dest), 2) - testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") - testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") - requireLogged(t, deleteStmt) -} - -func initForDeleteTest(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(Link.URL, Link.Name, Link.Description). - VALUES("www.gmail.com", "Gmail", "Email service developed by Google"). - VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft") - - AssertExec(t, stmt, 2) + require.Equal(t, len(dest), 2) + testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") + testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") + requireLogged(t, deleteStmt) + }) } func TestDeleteQueryContext(t *testing.T) { - initForDeleteTest(t) - deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) @@ -86,16 +67,16 @@ func TestDeleteQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} - err := deleteStmt.QueryContext(ctx, db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + dest := []model.Link{} + err := deleteStmt.QueryContext(ctx, tx, &dest) - require.Error(t, err, "context deadline exceeded") - requireLogged(t, deleteStmt) + require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) + }) } func TestDeleteExecContext(t *testing.T) { - initForDeleteTest(t) - list := []Expression{String("Gmail"), String("Outlook")} deleteStmt := Link. @@ -107,15 +88,16 @@ func TestDeleteExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - _, err := deleteStmt.ExecContext(ctx, db) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := deleteStmt.ExecContext(ctx, tx) - require.Error(t, err, "context deadline exceeded") - requireLogged(t, deleteStmt) + require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) + }) } func TestDeleteFrom(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() + skipForCockroachDB(t) // USING is not supported stmt := table.Rental.DELETE(). USING( @@ -158,16 +140,17 @@ RETURNING rental.rental_id AS "rental.rental_id", store.last_update AS "store.last_update"; `) - var dest []struct { - Rental model2.Rental - Store model2.Store - } + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + Rental model2.Rental + Store model2.Store + } - err := stmt.Query(tx, &dest) + err := stmt.Query(tx, &dest) - require.NoError(t, err) - require.Len(t, dest, 3) - testutils.AssertJSON(t, dest[0], ` + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` { "Rental": { "RentalID": 4, @@ -186,4 +169,5 @@ RETURNING rental.rental_id AS "rental.rental_id", } } `) + }) } diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 852745bd..328de179 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -93,56 +93,88 @@ func TestCmdGenerator(t *testing.T) { } func TestGeneratorIgnoreTables(t *testing.T) { - err := os.RemoveAll(genTestDir2) - require.NoError(t, err) + tests := []struct { + name string + args []string + }{ + { + name: "with dsn", + args: []string{ + "-dsn=" + fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", + dbconfig.PgUser, + dbconfig.PgPassword, + dbconfig.PgHost, + dbconfig.PgPort, + "jetdb", + ), + "-schema=dvds", + "-ignore-tables=actor,ADDRESS,country, Film , cITY,", + "-ignore-views=Actor_info, FILM_LIST ,staff_list", + "-ignore-enums=mpaa_rating", + "-path=" + genTestDir2, + }, + }, + { + name: "without dsn", + args: []string{ + "-source=PostgreSQL", + "-host=localhost", + "-port=" + strconv.Itoa(dbconfig.PgPort), + "-user=jet", + "-password=jet", + "-dbname=jetdb", + "-schema=dvds", + "-ignore-tables=actor,ADDRESS,country, Film , cITY,", + "-ignore-views=Actor_info, FILM_LIST ,staff_list", + "-ignore-enums=mpaa_rating", + "-path=" + genTestDir2, + }, + }, + } - cmd := exec.Command("jet", - "-source=PostgreSQL", - "-host=localhost", - "-port="+strconv.Itoa(dbconfig.PgPort), - "-user=jet", - "-password=jet", - "-dbname=jetdb", - "-schema=dvds", - "-ignore-tables=actor,ADDRESS,country, Film , cITY,", - "-ignore-views=Actor_info, FILM_LIST ,staff_list", - "-ignore-enums=mpaa_rating", - "-path="+genTestDir2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := os.RemoveAll(genTestDir2) + require.NoError(t, err) - fmt.Println(cmd.Args) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout + cmd := exec.Command("jet", tt.args...) - err = cmd.Run() - require.NoError(t, err) + fmt.Println(cmd.Args) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout - // Table SQL Builder files - tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") - require.NoError(t, err) + err = cmd.Run() + require.NoError(t, err) - testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "category.go", - "customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", - "payment.go", "rental.go", "staff.go", "store.go") + // Table SQL Builder files + tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") + require.NoError(t, err) - // View SQL Builder files - viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") - require.NoError(t, err) + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "category.go", + "customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go") - testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go", - "sales_by_film_category.go", "customer_list.go", "sales_by_store.go") + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") + require.NoError(t, err) - // Enums SQL Builder files - _, err = ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") - require.Error(t, err, "open ./.gentestdata2/jetdb/dvds/enum: no such file or directory") + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go") - modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") - require.NoError(t, err) + // Enums SQL Builder files + _, err = ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") + require.Error(t, err, "open ./.gentestdata2/jetdb/dvds/enum: no such file or directory") - testutils.AssertFileNamesEqual(t, modelFiles, "category.go", - "customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", - "payment.go", "rental.go", "staff.go", "store.go", - "nicer_but_slower_film_list.go", "sales_by_film_category.go", - "customer_list.go", "sales_by_store.go") + modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, "category.go", + "customer.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", + "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go") + }) + } } func TestGenerator(t *testing.T) { @@ -313,6 +345,16 @@ func (a ActorTable) FromSchema(schemaName string) *ActorTable { return newActorTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorTable with assigned table prefix +func (a ActorTable) WithPrefix(prefix string) *ActorTable { + return newActorTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorTable with assigned table suffix +func (a ActorTable) WithSuffix(suffix string) *ActorTable { + return newActorTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorTable(schemaName, tableName, alias string) *ActorTable { return &ActorTable{ actorTable: newActorTableImpl(schemaName, tableName, alias), @@ -412,6 +454,16 @@ func (a ActorInfoTable) FromSchema(schemaName string) *ActorInfoTable { return newActorInfoTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorInfoTable with assigned table prefix +func (a ActorInfoTable) WithPrefix(prefix string) *ActorInfoTable { + return newActorInfoTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorInfoTable with assigned table suffix +func (a ActorInfoTable) WithSuffix(suffix string) *ActorInfoTable { + return newActorInfoTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorInfoTable(schemaName, tableName, alias string) *ActorInfoTable { return &ActorInfoTable{ actorInfoTable: newActorInfoTableImpl(schemaName, tableName, alias), @@ -445,6 +497,8 @@ func newActorInfoTableImpl(schemaName, tableName, alias string) actorInfoTable { ` func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { + skipForCockroachDB(t) + enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") @@ -705,6 +759,16 @@ func (a AllTypesTable) FromSchema(schemaName string) *AllTypesTable { return newAllTypesTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new AllTypesTable with assigned table prefix +func (a AllTypesTable) WithPrefix(prefix string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new AllTypesTable with assigned table suffix +func (a AllTypesTable) WithSuffix(suffix string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newAllTypesTable(schemaName, tableName, alias string) *AllTypesTable { return &AllTypesTable{ allTypesTable: newAllTypesTableImpl(schemaName, tableName, alias), diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 8a50e025..e34405eb 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" @@ -13,9 +14,13 @@ import ( ) func TestInsertValues(t *testing.T) { - cleanUpLinkTable(t) + insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(101, "http://www.google.com", "Google", DEFAULT). + VALUES(102, "http://www.yahoo.com", "Yahoo", nil). + RETURNING(Link.AllColumns) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, insertQuery, ` INSERT INTO test_sample.link (id, url, name, description) VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), (101, 'http://www.google.com', 'Google', DEFAULT), @@ -24,76 +29,61 @@ RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -` - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). - VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES(101, "http://www.google.com", "Google", DEFAULT). - VALUES(102, "http://www.yahoo.com", "Yahoo", nil). - RETURNING(Link.AllColumns) - - testutils.AssertDebugStatementSql(t, insertQuery, expectedSQL, +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 101, "http://www.google.com", "Google", 102, "http://www.yahoo.com", "Yahoo", nil) - insertedLinks := []model.Link{} - - err := insertQuery.Query(db, &insertedLinks) - - require.NoError(t, err) - - require.Equal(t, len(insertedLinks), 3) - - testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ - ID: 100, - URL: "http://www.postgresqltutorial.com", - Name: "PostgreSQL Tutorial", - }) - - testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ - ID: 101, - URL: "http://www.google.com", - Name: "Google", - }) - - testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ - ID: 102, - URL: "http://www.yahoo.com", - Name: "Yahoo", + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var insertedLinks []model.Link + + err := insertQuery.Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ + ID: 100, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", + }) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) + + var allLinks []model.Link + + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.BETWEEN(Int(100), Int(199))). + ORDER_BY(Link.ID). + Query(tx, &allLinks) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, insertedLinks, allLinks) }) - - allLinks := []model.Link{} - - err = Link.SELECT(Link.AllColumns). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(db, &allLinks) - - require.NoError(t, err) - - testutils.AssertDeepEqual(t, insertedLinks, allLinks) } func TestInsertEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - - expectedSQL := ` -INSERT INTO test_sample.link -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); -` - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) } func TestInsertOnConflict(t *testing.T) { - t.Run("do nothing", func(t *testing.T) { employee := model.Employee{EmployeeID: rand.Int31()} @@ -108,11 +98,12 @@ VALUES ($1, $2, $3, $4, $5), ($6, $7, $8, $9, $10) ON CONFLICT (employee_id) DO NOTHING; `) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) }) t.Run("on constraint do nothing", func(t *testing.T) { + skipForCockroachDB(t) // does not support employee := model.Employee{EmployeeID: rand.Int31()} stmt := Employee.INSERT(Employee.AllColumns). @@ -126,12 +117,11 @@ VALUES ($1, $2, $3, $4, $5), ($6, $7, $8, $9, $10) ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) }) t.Run("do update", func(t *testing.T) { - cleanUpLinkTable(t) stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). @@ -148,18 +138,19 @@ VALUES ($1, $2, $3, DEFAULT), ($4, $5, $6, DEFAULT) ON CONFLICT (id) DO UPDATE SET id = excluded.id, - url = $7 + url = $7::text RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; `) - AssertExec(t, stmt, 2) + testutils.AssertExecAndRollback(t, stmt, db, 2) }) t.Run("on constraint do update", func(t *testing.T) { - cleanUpLinkTable(t) + skipForCockroachDB(t) // does not support + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). @@ -177,18 +168,18 @@ VALUES ($1, $2, $3, DEFAULT), ($4, $5, $6, DEFAULT) ON CONFLICT ON CONSTRAINT link_pkey DO UPDATE SET id = excluded.id, - url = $7 + url = $7::text RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; `) - AssertExec(t, stmt, 2) + testutils.AssertExecAndRollback(t, stmt, db, 2) }) t.Run("do update complex", func(t *testing.T) { - cleanUpLinkTable(t) + skipForCockroachDB(t) // does not support ROW stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). @@ -210,21 +201,15 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE SELECT MAX(link.id) + 1 FROM test_sample.link ), - (name, description) = ROW(excluded.name, 'new description') + (name, description) = ROW(excluded.name, 'new description'::text) WHERE link.description IS NOT NULL; `) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) }) } func TestInsertModelObject(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); -` - linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -234,18 +219,15 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); INSERT(Link.URL, Link.Name). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); +`, "http://www.duckduckgo.com", "Duck Duck go") - AssertExec(t, query, 1) + testutils.AssertExecAndRollback(t, query, db, 1) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link -VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); -` - linkData := model.Link{ ID: 1000, URL: "http://www.duckduckgo.com", @@ -256,19 +238,16 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); INSERT(). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +`, + int64(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) - AssertExec(t, query, 1) + testutils.AssertExecAndRollback(t, query, db, 1) } func TestInsertModelsObject(t *testing.T) { - expectedSQL := ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); -` - tutorial := model.Link{ URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", @@ -288,23 +267,20 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), INSERT(Link.URL, Link.Name). MODELS([]model.Link{tutorial, google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", "http://www.yahoo.com", "Yahoo") - AssertExec(t, stmt, 3) + testutils.AssertExecAndRollback(t, stmt, db, 3) } func TestInsertUsingMutableColumns(t *testing.T) { - var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); -` - google := model.Link{ URL: "http://www.google.com", Name: "Google", @@ -321,22 +297,32 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), MODEL(google). MODELS([]model.Link{google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil, "http://www.yahoo.com", "Yahoo", nil) - AssertExec(t, stmt, 4) + testutils.AssertExecAndRollback(t, stmt, db, 4) } func TestInsertQuery(t *testing.T) { - _, err := Link.DELETE(). - WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). - Exec(db) - require.NoError(t, err) + query := Link. + INSERT(Link.URL, Link.Name). + QUERY( + SELECT(Link.URL, Link.Name). + FROM(Link). + WHERE(Link.ID.EQ(Int(0))), + ). + RETURNING(Link.AllColumns) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, query, ` INSERT INTO test_sample.link (url, name) ( SELECT link.url AS "link.url", link.name AS "link.name" @@ -347,38 +333,26 @@ RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -` +`, int64(0)) - query := Link. - INSERT(Link.URL, Link.Name). - QUERY( - SELECT(Link.URL, Link.Name). - FROM(Link). - WHERE(Link.ID.EQ(Int(0))), - ). - RETURNING(Link.AllColumns) - - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(0)) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.Link - dest := []model.Link{} + err := query.Query(tx, &dest) + require.NoError(t, err) - err = query.Query(db, &dest) + var youtubeLinks []model.Link + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Youtube"))). + Query(tx, &youtubeLinks) - require.NoError(t, err) - - youtubeLinks := []model.Link{} - err = Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Youtube"))). - Query(db, &youtubeLinks) - - require.NoError(t, err) - require.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) + }) } func TestInsertWithQueryContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). RETURNING(Link.AllColumns) @@ -388,15 +362,15 @@ func TestInsertWithQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} - err := stmt.QueryContext(ctx, db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + dest := []model.Link{} + err := stmt.QueryContext(ctx, tx, &dest) - require.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") + }) } func TestInsertWithExecContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) @@ -405,7 +379,7 @@ func TestInsertWithExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - _, err := stmt.ExecContext(ctx, db) - - require.Error(t, err, "context deadline exceeded") + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExecContextErr(ctx, t, stmt, tx, "context deadline exceeded") + }) } diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index c0286290..4caed101 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -12,6 +12,8 @@ import ( ) func TestLockTable(t *testing.T) { + skipForCockroachDB(t) // doesn't support + expectedSQL := ` LOCK TABLE dvds.address IN` @@ -62,6 +64,8 @@ LOCK TABLE dvds.address IN` } func TestLockExecContext(t *testing.T) { + skipForCockroachDB(t) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) defer cancel() diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index aa05e0fa..08af67a0 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -25,6 +25,24 @@ import ( var db *sql.DB var testRoot string +var source string + +const CockroachDB = "COCKROACH_DB" + +func init() { + source = os.Getenv("PG_SOURCE") +} + +func sourceIsCockroachDB() bool { + return source == CockroachDB +} + +func skipForCockroachDB(t *testing.T) { + if sourceIsCockroachDB() { + t.SkipNow() + } +} + func TestMain(m *testing.M) { rand.Seed(time.Now().Unix()) defer profile.Start().Stop() @@ -35,8 +53,15 @@ func TestMain(m *testing.M) { fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) func() { + + connectionString := dbconfig.PostgresConnectString + + if sourceIsCockroachDB() { + connectionString = dbconfig.CockroachConnectString + } + var err error - db, err = sql.Open(driverName, dbconfig.PostgresConnectString) + db, err = sql.Open(driverName, connectionString) if err != nil { fmt.Println(err.Error()) panic("Failed to connect to test db") @@ -113,9 +138,3 @@ func isPgxDriver() bool { return false } - -func beginTx(t *testing.T) *sql.Tx { - tx, err := db.Begin() - require.NoError(t, err) - return tx -} diff --git a/tests/postgres/raw_statements_test.go b/tests/postgres/raw_statements_test.go index 4bbf90c5..e201c750 100644 --- a/tests/postgres/raw_statements_test.go +++ b/tests/postgres/raw_statements_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "testing" "time" @@ -85,12 +86,10 @@ func TestRawStatementSelectWithArguments(t *testing.T) { } func TestRawInsert(t *testing.T) { - cleanUpLinkTable(t) - stmt := RawStatement(` INSERT INTO test_sample.link (id, url, name, description) VALUES (@id1, @url1, @name1, DEFAULT), - (200, @url1, @name1, NULL), + (2000, @url1, @name1, NULL), (@id2, @url2, @name2, DEFAULT), (@id3, @url3, @name3, NULL) RETURNING link.id AS "link.id", @@ -98,45 +97,47 @@ RETURNING link.id AS "link.id", link.name AS "link.name", link.description AS "link.description"`, RawArgs{ - "@id1": 100, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial", - "@id2": 101, "@url2": "http://www.google.com", "@name2": "Google", - "@id3": 102, "@url3": "http://www.yahoo.com", "@name3": "Yahoo", + "@id1": 1000, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial", + "@id2": 1010, "@url2": "http://www.google.com", "@name2": "Google", + "@id3": 1020, "@url3": "http://www.yahoo.com", "@name3": "Yahoo", }) testutils.AssertStatementSql(t, stmt, ` INSERT INTO test_sample.link (id, url, name, description) VALUES ($1, $2, $3, DEFAULT), - (200, $2, $3, NULL), + (2000, $2, $3, NULL), ($4, $5, $6, DEFAULT), ($7, $8, $9, NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", - 101, "http://www.google.com", "Google", - 102, "http://www.yahoo.com", "Yahoo") +`, 1000, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + 1010, "http://www.google.com", "Google", + 1020, "http://www.yahoo.com", "Yahoo") testutils.AssertDebugStatementSql(t, stmt, ` INSERT INTO test_sample.link (id, url, name, description) -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (200, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL) +VALUES (1000, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (2000, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + (1010, 'http://www.google.com', 'Google', DEFAULT), + (1020, 'http://www.yahoo.com', 'Yahoo', NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; `) - var links []model2.Link - err := stmt.Query(db, &links) - require.NoError(t, err) - require.Len(t, links, 4) - require.Equal(t, links[0].ID, int32(100)) - require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com") - require.Equal(t, links[2].Name, "Google") - require.Nil(t, links[2].Description) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var links []model2.Link + err := stmt.Query(tx, &links) + require.NoError(t, err) + require.Len(t, links, 4) + require.Equal(t, links[0].ID, int64(1000)) + require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com") + require.Equal(t, links[2].Name, "Google") + require.Nil(t, links[2].Description) + }) } func TestRawStatementRows(t *testing.T) { diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index f1d99992..a13a30b7 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -1,9 +1,9 @@ package postgres import ( + "github.com/google/uuid" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/go-jet/jet/v2/internal/testutils" @@ -14,30 +14,6 @@ import ( "github.com/shopspring/decimal" ) -func TestUUIDType(t *testing.T) { - - id := uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") - - query := AllTypes. - SELECT(AllTypes.UUID, AllTypes.UUIDPtr). - WHERE(AllTypes.UUID.EQ(UUID(id))) - - testutils.AssertDebugStatementSql(t, query, ` -SELECT all_types.uuid AS "all_types.uuid", - all_types.uuid_ptr AS "all_types.uuid_ptr" -FROM test_sample.all_types -WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; -`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") - - result := model.AllTypes{} - - err := query.Query(db, &result) - require.NoError(t, err) - require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - requireLogged(t, query) -} - func TestExactDecimals(t *testing.T) { type floats struct { @@ -80,7 +56,7 @@ func TestExactDecimals(t *testing.T) { t.Run("should insert decimal", func(t *testing.T) { insertQuery := Floats.INSERT( - Floats.AllColumns, + Floats.MutableColumns, ).MODEL( floats{ Floats: model.Floats{ @@ -102,7 +78,7 @@ func TestExactDecimals(t *testing.T) { DecimalPtr: decimal.RequireFromString("3.3333333333333333333"), }, ).RETURNING( - Floats.AllColumns, + Floats.MutableColumns, ) testutils.AssertDebugStatementSql(t, insertQuery, ` @@ -199,7 +175,9 @@ func TestUUIDComplex(t *testing.T) { }) t.Run("single struct", func(t *testing.T) { - singleQuery := query.WHERE(Person.PersonID.EQ(String("b68dbff6-a87d-11e9-a7f2-98ded00c39c8"))) + uuid, err := uuid.Parse("b68dbff6-a87d-11e9-a7f2-98ded00c39c8") + require.NoError(t, err) + singleQuery := query.WHERE(Person.PersonID.EQ(UUID(uuid))) var dest struct { model.Person @@ -207,7 +185,7 @@ func TestUUIDComplex(t *testing.T) { model.PersonPhone } } - err := singleQuery.Query(db, &dest) + err = singleQuery.Query(db, &dest) require.NoError(t, err) testutils.AssertJSON(t, dest, ` @@ -304,7 +282,7 @@ SELECT person.person_id AS "person.person_id", FROM test_sample.person; `) - result := []model.Person{} + var result []model.Person err := query.Query(db, &result) @@ -333,7 +311,7 @@ FROM test_sample.person; `) } -func TestSelecSelfJoin1(t *testing.T) { +func TestSelectSelfJoin1(t *testing.T) { // clean up _, err := Employee.DELETE().WHERE(Employee.EmployeeID.GT(Int(100))).Exec(db) @@ -398,7 +376,7 @@ ORDER BY employee.employee_id; } func TestWierdNamesTable(t *testing.T) { - stmt := WeirdNamesTable.SELECT(WeirdNamesTable.AllColumns) + stmt := WeirdNamesTable.SELECT(WeirdNamesTable.MutableColumns) testutils.AssertDebugStatementSql(t, stmt, ` SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column_name1", @@ -420,7 +398,7 @@ SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column FROM test_sample."WEIRD NAMES TABLE"; `) - dest := []model.WeirdNamesTable{} + var dest []model.WeirdNamesTable err := stmt.Query(db, &dest) @@ -448,7 +426,7 @@ FROM test_sample."WEIRD NAMES TABLE"; } func TestReserwedWordEscape(t *testing.T) { - stmt := SELECT(User.AllColumns). + stmt := SELECT(User.MutableColumns). FROM(User) //fmt.Println(stmt.DebugSql()) @@ -480,6 +458,7 @@ FROM test_sample."User"; testutils.AssertJSON(t, dest, ` [ { + "ID": 0, "Column": "Column", "Check": "CHECK", "Ceil": "CEIL", @@ -497,54 +476,3 @@ FROM test_sample."User"; ] `) } - -func TestBytea(t *testing.T) { - byteArrHex := "\\x48656c6c6f20476f7068657221" - byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21") - - insertStmt := AllTypes.INSERT(AllTypes.Bytea, AllTypes.ByteaPtr). - VALUES(byteArrHex, byteArrBin). - RETURNING(AllTypes.Bytea, AllTypes.ByteaPtr) - - testutils.AssertStatementSql(t, insertStmt, ` -INSERT INTO test_sample.all_types (bytea, bytea_ptr) -VALUES ($1, $2) -RETURNING all_types.bytea AS "all_types.bytea", - all_types.bytea_ptr AS "all_types.bytea_ptr"; -`, byteArrHex, byteArrBin) - - var inserted model.AllTypes - err := insertStmt.Query(db, &inserted) - require.NoError(t, err) - - require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") - // It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver. - // pq driver always encodes parameter string if destination column is of type bytea. - // Probably pq driver error. - // require.Equal(t, string(inserted.Bytea), "Hello Gopher!") - - stmt := SELECT( - AllTypes.Bytea, - AllTypes.ByteaPtr, - ).FROM( - AllTypes, - ).WHERE( - AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)), - ) - - testutils.AssertStatementSql(t, stmt, ` -SELECT all_types.bytea AS "all_types.bytea", - all_types.bytea_ptr AS "all_types.bytea_ptr" -FROM test_sample.all_types -WHERE all_types.bytea_ptr = $1::bytea; -`, byteArrBin) - - var dest model.AllTypes - - err = stmt.Query(db, &dest) - require.NoError(t, err) - - require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") - // Probably pq driver error. - // require.Equal(t, string(dest.Bytea), "Hello Gopher!") -} diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 30787090..24b5949d 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "github.com/volatiletech/null/v8" "testing" "time" @@ -1034,6 +1035,44 @@ func TestScanToPrimitiveElementsSlice(t *testing.T) { require.Len(t, dest[1].Title, 20) } +// https://github.com/go-jet/jet/issues/127 +func TestValuerTypeDebugSQL(t *testing.T) { + type customer struct { + CustomerID null.Int32 `sql:"primary_key"` + StoreID null.Int16 + FirstName null.String + LastName string + Email null.String + AddressID int16 + Activebool null.Bool + CreateDate null.Time + LastUpdate null.Time + Active null.Int8 + } + + stmt := Customer.INSERT(). + MODEL( + customer{ + CustomerID: null.Int32From(1234), + StoreID: null.Int16From(0), + FirstName: null.StringFrom("Joe"), + LastName: "", + Email: null.StringFromPtr(nil), + AddressID: 1, + Activebool: null.BoolFrom(true), + CreateDate: null.TimeFrom(time.Date(2020, 2, 2, 10, 0, 0, 0, time.UTC)), + LastUpdate: null.TimeFromPtr(nil), + Active: null.Int8From(1), + }, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO dvds.customer +VALUES (1234, 0, 'Joe', '', NULL, 1, TRUE, '2020-02-02 10:00:00Z', NULL, 1); +`) + testutils.AssertExecAndRollback(t, stmt, db) +} + var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 304acb8c..dce9b87b 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,6 +1,8 @@ package postgres import ( + "context" + "github.com/go-jet/jet/v2/qrm" "testing" "time" @@ -24,9 +26,9 @@ FROM dvds.actor WHERE actor.actor_id = 2; ` - query := Actor. - SELECT(Actor.AllColumns). + query := SELECT(Actor.AllColumns). DISTINCT(). + FROM(Actor). WHERE(Actor.ActorID.EQ(Int(2))) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2)) @@ -44,7 +46,6 @@ WHERE actor.actor_id = 2; } testutils.AssertDeepEqual(t, actor, expectedActor) - requireLogged(t, query) } @@ -166,7 +167,7 @@ SELECT customer.customer_id AS "customer.customer_id", FROM dvds.customer ORDER BY customer.customer_id ASC; ` - customers := []model.Customer{} + var customers []model.Customer query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC()) @@ -416,8 +417,8 @@ FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) WHERE ( - (city.city = 'London') - OR (city.city = 'York') + (city.city = 'London'::text) + OR (city.city = 'York'::text) ) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -492,7 +493,7 @@ SELECT city.city_id AS "my_city.id", FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) -WHERE (city.city = 'London') OR (city.city = 'York') +WHERE (city.city = 'London'::text) OR (city.city = 'York'::text) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -550,7 +551,7 @@ SELECT city.city_id AS "city_id", FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) -WHERE (city.city = 'London') OR (city.city = 'York') +WHERE (city.city = 'London'::text) OR (city.city = 'York'::text) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -607,7 +608,7 @@ SELECT city.city_id AS "city.city_id", FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) -WHERE (city.city = 'London') OR (city.city = 'York') +WHERE (city.city = 'London'::text) OR (city.city = 'York'::text) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -685,9 +686,6 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { err := query.Query(db, &customers) require.NoError(t, err) - - //spew.Dump(customers) - require.Equal(t, len(customers), 599) } @@ -770,27 +768,35 @@ ORDER BY customer.customer_id ASC; testutils.AssertDebugStatementSql(t, query, expectedSQL) - allCustomersAndAddress := []struct { + var allCustomersAndAddress []struct { Address *model.Address Customer *model.Customer - }{} + } err := query.Query(db, &allCustomersAndAddress) require.NoError(t, err) require.Equal(t, len(allCustomersAndAddress), 603) - testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) - require.True(t, allCustomersAndAddress[0].Address != nil) + if sourceIsCockroachDB() { + nullsFirst := allCustomersAndAddress[0] + require.True(t, nullsFirst.Customer == nil) + require.True(t, nullsFirst.Address != nil) - lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1] + testutils.AssertDeepEqual(t, allCustomersAndAddress[4].Customer, &customer0) + require.True(t, allCustomersAndAddress[0].Address != nil) + } else { // postgres + testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) + require.True(t, allCustomersAndAddress[0].Address != nil) - require.True(t, lastCustomerAddress.Customer == nil) - require.True(t, lastCustomerAddress.Address != nil) + nullsLast := allCustomersAndAddress[len(allCustomersAndAddress)-1] + require.True(t, nullsLast.Customer == nil) + require.True(t, nullsLast.Address != nil) + } } -func TestSelectFullCrossJoin(t *testing.T) { +func TestSelectCrossJoin(t *testing.T) { expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", @@ -1128,6 +1134,7 @@ ORDER BY film.film_id ASC; } func TestSelectGroupByHaving(t *testing.T) { + expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", @@ -1197,6 +1204,9 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC; require.Equal(t, len(dest), 104) + if sourceIsCockroachDB() { + return // small precision difference in result + } //testutils.SaveJsonFile(dest, "postgres/testdata/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") } @@ -1395,9 +1405,6 @@ ORDER BY payment.payment_date ASC; err := query.Query(db, &payments) require.NoError(t, err) - - //spew.Dump(payments) - require.Equal(t, len(payments), 9) testutils.AssertDeepEqual(t, payments[0], model.Payment{ PaymentID: 17793, @@ -1531,7 +1538,7 @@ func TestAllSetOperators(t *testing.T) { func TestSelectWithCase(t *testing.T) { expectedQuery := ` -SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE' WHEN 2 THEN 'TWO' WHEN 3 THEN 'THREE' ELSE 'OTHER' END) AS "staff_id_num" +SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE'::text WHEN 2 THEN 'TWO'::text WHEN 3 THEN 'THREE'::text ELSE 'OTHER'::text END) AS "staff_id_num" FROM dvds.payment ORDER BY payment.payment_id ASC LIMIT 20; @@ -1611,6 +1618,10 @@ FOR` require.NoError(t, err) } + if sourceIsCockroachDB() { + return // SKIP LOCKED lock wait policy is not supported + } + for lockType, lockTypeStr := range getRowLockTestData() { query.FOR(lockType.SKIP_LOCKED()) @@ -1660,7 +1671,7 @@ FROM dvds.actor INNER JOIN dvds.language ON (language.language_id = film.language_id) INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) INNER JOIN dvds.category ON (category.category_id = film_category.category_id) -WHERE ((language.name = 'English') AND (category.name != 'Action')) AND (film.length > 180) +WHERE ((language.name = 'English'::text) AND (category.name != 'Action'::text)) AND (film.length > 180) ORDER BY actor.actor_id ASC, film.film_id ASC; ` @@ -1927,10 +1938,11 @@ func TestSimpleView(t *testing.T) { query := SELECT( view.ActorInfo.AllColumns, - ). - FROM(view.ActorInfo). - ORDER_BY(view.ActorInfo.ActorID). - LIMIT(10) + ).FROM( + view.ActorInfo, + ).ORDER_BY( + view.ActorInfo.ActorID, + ).LIMIT(10) type ActorInfo struct { ActorID int @@ -1944,6 +1956,10 @@ func TestSimpleView(t *testing.T) { err := query.Query(db, &dest) require.NoError(t, err) + if sourceIsCockroachDB() { + return // skip for cockroach db, FilmInfo is set to '' in ddl + } + testutils.AssertJSON(t, dest[1:2], ` [ { @@ -2117,7 +2133,7 @@ FROM dvds.film language.name AS "language.name", language.last_update AS "language.last_update" FROM dvds.language - WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + WHERE (language.name NOT IN ('spanish'::text)) AND (film.language_id = language.language_id) ) AS films WHERE film.film_id = 1 ORDER BY film.film_id @@ -2162,7 +2178,7 @@ FROM dvds.film, language.name AS "language.name", language.last_update AS "language.last_update" FROM dvds.language - WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + WHERE (language.name NOT IN ('spanish'::text)) AND (film.language_id = language.language_id) ) AS films WHERE film.film_id = 1 ORDER BY film.film_id @@ -2630,6 +2646,8 @@ func GET_FILM_COUNT(lenFrom, lenTo IntegerExpression) IntegerExpression { } func TestCustomFunctionCall(t *testing.T) { + skipForCockroachDB(t) + stmt := SELECT( GET_FILM_COUNT(Int(100), Int(120)).AS("film_count"), ) @@ -2662,3 +2680,84 @@ SELECT dvds.get_film_count(100, 120) AS "film_count"; require.NoError(t, err) require.Equal(t, dest.FilmCount, 165) } + +func TestScanUsingConn(t *testing.T) { + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + DISTINCT(). + WHERE(Actor.ActorID.EQ(Int(2))) + + var actor model.Actor + err = stmt.Query(conn, &actor) + require.NoError(t, err) + err = stmt.QueryContext(context.Background(), conn, &actor) + require.NoError(t, err) + testutils.AssertDeepEqual(t, actor, model.Actor{ + ActorID: 2, + FirstName: "Nick", + LastName: "Wahlberg", + LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), + }) + + _, err = stmt.Exec(conn) + require.NoError(t, err) + _, err = stmt.ExecContext(context.Background(), conn) + require.NoError(t, err) + + t.Run("ensure qrm.DB still works", func(t *testing.T) { + var qrmDB qrm.DB = db + + err = stmt.Query(qrmDB, &actor) + require.NoError(t, err) + err = stmt.QueryContext(context.Background(), qrmDB, &actor) + require.NoError(t, err) + + _, err = stmt.Exec(qrmDB) + require.NoError(t, err) + _, err = stmt.ExecContext(context.Background(), qrmDB) + require.NoError(t, err) + }) +} + +var customer0 = model.Customer{ + CustomerID: 1, + StoreID: 1, + FirstName: "Mary", + LastName: "Smith", + Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), + AddressID: 5, + Activebool: true, + CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: testutils.Int32Ptr(1), +} + +var customer1 = model.Customer{ + CustomerID: 2, + StoreID: 1, + FirstName: "Patricia", + LastName: "Johnson", + Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), + AddressID: 6, + Activebool: true, + CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: testutils.Int32Ptr(1), +} + +var lastCustomer = model.Customer{ + CustomerID: 599, + StoreID: 2, + FirstName: "Austin", + LastName: "Cintron", + Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), + AddressID: 605, + Activebool: true, + CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: testutils.Int32Ptr(1), +} diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 6cde276b..e0c7da2a 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" @@ -14,9 +15,7 @@ import ( ) func TestUpdateValues(t *testing.T) { - setupLinkTableForUpdateTest(t) - - t.Run("deprecated version", func(t *testing.T) { + t.Run("deprecated update", func(t *testing.T) { query := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -25,31 +24,34 @@ func TestUpdateValues(t *testing.T) { testutils.AssertDebugStatementSql(t, query, ` UPDATE test_sample.link SET (name, url) = ('Bong', 'http://bong.com') -WHERE link.name = 'Bing'; +WHERE link.name = 'Bing'::text; `, "Bong", "http://bong.com", "Bing") - testutils.AssertExec(t, query, db, 1) - requireLogged(t, query) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { - links := []model.Link{} + testutils.AssertExec(t, query, tx, 1) + requireLogged(t, query) - selQuery := Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.IN(String("Bong"))) + var links []model.Link - err := selQuery.Query(db, &links) + selQuery := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.IN(String("Bong"))) - require.NoError(t, err) - require.Equal(t, len(links), 1) - testutils.AssertDeepEqual(t, links[0], model.Link{ - ID: 204, - URL: "http://bong.com", - Name: "Bong", + err := selQuery.Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + requireLogged(t, selQuery) }) - requireLogged(t, selQuery) }) - t.Run("new version", func(t *testing.T) { + t.Run("new type safe update", func(t *testing.T) { stmt := Link.UPDATE(). SET( Link.Name.SET(String("DuckDuckGo")), @@ -59,18 +61,16 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link -SET name = 'DuckDuckGo', - url = 'www.duckduckgo.com' -WHERE link.name = 'Yahoo'; +SET name = 'DuckDuckGo'::text, + url = 'www.duckduckgo.com'::text +WHERE link.name = 'Yahoo'::text; `) - testutils.AssertExec(t, stmt, db, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) }) } func TestUpdateWithSubQueries(t *testing.T) { - setupLinkTableForUpdateTest(t) - t.Run("deprecated version", func(t *testing.T) { query := Link. UPDATE(Link.Name, Link.URL). @@ -82,20 +82,19 @@ func TestUpdateWithSubQueries(t *testing.T) { ). WHERE(Link.Name.EQ(String("Bing"))) - expectedSQL := ` + testutils.AssertDebugStatementSql(t, query, ` UPDATE test_sample.link SET (name, url) = (( - SELECT 'Bong' + SELECT 'Bong'::text ), ( SELECT link.url AS "link.url" FROM test_sample.link - WHERE link.name = 'Bing' + WHERE link.name = 'Bing'::text )) -WHERE link.name = 'Bing'; -` - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") +WHERE link.name = 'Bing'::text; +`, "Bong", "Bing", "Bing") - AssertExec(t, query, 1) + testutils.AssertExecAndRollback(t, query, db, 1) requireLogged(t, query) }) @@ -113,50 +112,48 @@ WHERE link.name = 'Bing'; testutils.AssertStatementSql(t, query, ` UPDATE test_sample.link -SET name = $1, +SET name = $1::text, url = ( SELECT link.url AS "link.url" FROM test_sample.link - WHERE link.name = $2 + WHERE link.name = $2::text ) -WHERE link.name = $3; +WHERE link.name = $3::text; `, "Bong", "Bing", "Bing") - _, err := query.Exec(db) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, query, db) requireLogged(t, query) }) } func TestUpdateAndReturning(t *testing.T) { - setupLinkTableForUpdateTest(t) + stmt := Link. + UPDATE(Link.Name, Link.URL). + SET("DuckDuckGo", "http://www.duckduckgo.com"). + WHERE(Link.Name.EQ(String("Ask"))). + RETURNING(Link.AllColumns) - expectedSQL := ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.name = 'Ask' +WHERE link.name = 'Ask'::text RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -` +`, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") - stmt := Link. - UPDATE(Link.Name, Link.URL). - SET("DuckDuckGo", "http://www.duckduckgo.com"). - WHERE(Link.Name.EQ(String("Ask"))). - RETURNING(Link.AllColumns) - - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + links := []model.Link{} - links := []model.Link{} + err := stmt.Query(tx, &links) - err := stmt.Query(db, &links) + require.NoError(t, err) + require.Equal(t, len(links), 2) + require.Equal(t, links[0].Name, "DuckDuckGo") + require.Equal(t, links[1].Name, "DuckDuckGo") + requireLogged(t, stmt) + }) - require.NoError(t, err) - require.Equal(t, len(links), 2) - require.Equal(t, links[0].Name, "DuckDuckGo") - require.Equal(t, links[1].Name, "DuckDuckGo") - requireLogged(t, stmt) } func TestUpdateWithSelect(t *testing.T) { @@ -170,7 +167,7 @@ func TestUpdateWithSelect(t *testing.T) { ). WHERE(Link.ID.EQ(Int(0))) - expectedSQL := ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET (id, url, name, description) = ( SELECT link.id AS "link.id", @@ -181,10 +178,9 @@ SET (id, url, name, description) = ( WHERE link.id = 0 ) WHERE link.id = 0; -` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) +`, int64(0), int64(0)) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) }) t.Run("new version", func(t *testing.T) { @@ -210,12 +206,11 @@ SET (url, name, description) = ( WHERE link.id = 0; `, int64(0), int64(0)) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) }) } func TestUpdateWithInvalidSelect(t *testing.T) { - t.Run("deprecated version", func(t *testing.T) { stmt := Link.UPDATE(Link.AllColumns). SET( @@ -236,7 +231,6 @@ SET (id, url, name, description) = ( WHERE link.id = 0; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") }) @@ -250,8 +244,6 @@ WHERE link.id = 0; } func TestUpdateWithModelData(t *testing.T) { - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -261,24 +253,20 @@ func TestUpdateWithModelData(t *testing.T) { stmt := Link. UPDATE(Link.AllColumns). MODEL(link). - WHERE(Link.ID.EQ(Int32(link.ID))) + WHERE(Link.ID.EQ(Int64(link.ID))) expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) -WHERE link.id = 201::integer; +WHERE link.id = 201::bigint; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) - _, err := stmt.Exec(db) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireQueryLogged(t, stmt, 1) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { - - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -290,27 +278,24 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { stmt := Link. UPDATE(updateColumnList). MODEL(link). - WHERE(Link.ID.EQ(Int32(link.ID))) + WHERE(Link.ID.EQ(Int64(link.ID))) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.id = 201::integer; -` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) +WHERE link.id = 201::bigint; +`, + nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) } func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() - require.Equal(t, r, "missing struct field for column : id") }() - setupLinkTableForUpdateTest(t) - link := struct { Ident int URL string @@ -323,24 +308,13 @@ func TestUpdateWithInvalidModelData(t *testing.T) { Name: "DuckDuckGo", } - stmt := Link. + _ = Link. UPDATE(Link.AllColumns). - MODEL(link). + MODEL(link). // panics WHERE(Link.ID.EQ(Int(int64(link.Ident)))) - - var expectedSQL = ` -UPDATE test_sample.link -SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL) -WHERE link.id = 201; -` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) - - testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") } func TestUpdateQueryContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -351,15 +325,15 @@ func TestUpdateQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} - err := updateStmt.QueryContext(ctx, db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + dest := []model.Link{} + err := updateStmt.QueryContext(ctx, tx, &dest) - require.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") + }) } func TestUpdateExecContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -370,15 +344,10 @@ func TestUpdateExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - _, err := updateStmt.ExecContext(ctx, db) - - require.Error(t, err, "context deadline exceeded") + testutils.AssertExecContextErr(ctx, t, updateStmt, db, "context deadline exceeded") } func TestUpdateFrom(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() - stmt := table.Rental.UPDATE(). SET( table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)), @@ -416,16 +385,17 @@ RETURNING rental.rental_id AS "rental.rental_id", store.address_id AS "store.address_id"; `) - var dest []struct { - Rental model2.Rental - Store model2.Store - } + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + Rental model2.Rental + Store model2.Store + } - err := stmt.Query(tx, &dest) + err := stmt.Query(tx, &dest) - require.NoError(t, err) - require.Len(t, dest, 3) - testutils.AssertJSON(t, dest[0], ` + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` { "Rental": { "RentalID": 4, @@ -444,24 +414,5 @@ RETURNING rental.rental_id AS "rental.rental_id", } } `) -} - -func setupLinkTableForUpdateTest(t *testing.T) { - - cleanUpLinkTable(t) - - _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). - VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES(201, "http://www.ask.com", "Ask", DEFAULT). - VALUES(202, "http://www.ask.com", "Ask", DEFAULT). - VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES(204, "http://www.bing.com", "Bing", DEFAULT). - Exec(db) - - require.NoError(t, err) -} - -func cleanUpLinkTable(t *testing.T) { - _, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db) - require.NoError(t, err) + }) } diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go deleted file mode 100644 index 847056f3..00000000 --- a/tests/postgres/util_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package postgres - -import ( - "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" - "github.com/stretchr/testify/require" - "testing" -) - -func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { - res, err := stmt.Exec(db) - - require.NoError(t, err) - rows, err := res.RowsAffected() - require.NoError(t, err) - require.Equal(t, rows, rowsAffected) -} - -var customer0 = model.Customer{ - CustomerID: 1, - StoreID: 1, - FirstName: "Mary", - LastName: "Smith", - Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), - AddressID: 5, - Activebool: true, - CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), -} - -var customer1 = model.Customer{ - CustomerID: 2, - StoreID: 1, - FirstName: "Patricia", - LastName: "Johnson", - Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), - AddressID: 6, - Activebool: true, - CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), -} - -var lastCustomer = model.Customer{ - CustomerID: 599, - StoreID: 2, - FirstName: "Austin", - LastName: "Cintron", - Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), - AddressID: 605, - Activebool: true, - CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), -} diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index c78ca8a4..21fca326 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -106,9 +106,11 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { removeDiscontinuedOrders.AS( OrderDetails.DELETE(). WHERE(OrderDetails.ProductID.IN( - SELECT(Products.ProductID). - FROM(Products). - WHERE(Products.Discontinued.EQ(Int(1)))), + SELECT( + Products.ProductID, + ).FROM( + Products, + ).WHERE(Products.Discontinued.EQ(Int(1)))), ).RETURNING(OrderDetails.ProductID), ), updateDiscontinuedPrice.AS( @@ -121,7 +123,13 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { ), logDiscontinuedProducts.AS( ProductLogs.INSERT(ProductLogs.AllColumns). - QUERY(SELECT(updateDiscontinuedPrice.AllColumns()).FROM(updateDiscontinuedPrice)). + QUERY( + SELECT( + updateDiscontinuedPrice.AllColumns(), + ).FROM( + updateDiscontinuedPrice, + ), + ). RETURNING( ProductLogs.ProductID, ProductLogs.ProductName, @@ -384,7 +392,7 @@ WITH cte1 AS ( SELECT territories.territory_id AS "territories.territory_id", territories.territory_description AS "territories.territory_description", territories.region_id AS "territories.region_id", - $1 AS "custom_column_1" + $1::text AS "custom_column_1" FROM northwind.territories ORDER BY territories.territory_id ASC ),cte2 AS ( @@ -392,7 +400,7 @@ WITH cte1 AS ( cte1."territories.territory_description" AS "territories.territory_description", cte1."territories.region_id" AS "territories.region_id", cte1.custom_column_1 AS "custom_column_1", - $2 AS "custom_column_2" + $2::text AS "custom_column_2" FROM cte1 ) SELECT cte2."territories.territory_id" AS "territories.territory_id", @@ -485,7 +493,7 @@ func TestRecursiveWithStatement(t *testing.T) { Employees, ).WHERE( Employees.EmployeeID.EQ(Int(2)), - ).UNION( + ).UNION_ALL( SELECT( Employees.AllColumns, ).FROM( @@ -790,13 +798,13 @@ WITH suppliers_fax AS ( suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name", suppliers_fax."suppliers.country" AS "suppliers.country" FROM suppliers_fax - WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia') + WHERE suppliers_fax."suppliers.country" NOT IN ('US'::text, 'Australia'::text) ) SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id", not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name", not_from_us_or_aus."suppliers.country" AS "suppliers.country" FROM not_from_us_or_aus -WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'; +WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'::text; `) var dest []model.Suppliers diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go index b8280fdc..96653870 100644 --- a/tests/sqlite/generator_test.go +++ b/tests/sqlite/generator_test.go @@ -1,16 +1,17 @@ package sqlite import ( - "github.com/go-jet/jet/v2/generator/sqlite" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" - "github.com/go-jet/jet/v2/tests/internal/utils/repo" - "github.com/stretchr/testify/require" "io/ioutil" "os" "os/exec" "reflect" "testing" + + "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" + "github.com/stretchr/testify/require" ) func TestGeneratedModel(t *testing.T) { @@ -183,6 +184,16 @@ func (a ActorTable) FromSchema(schemaName string) *ActorTable { return newActorTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new ActorTable with assigned table prefix +func (a ActorTable) WithPrefix(prefix string) *ActorTable { + return newActorTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new ActorTable with assigned table suffix +func (a ActorTable) WithSuffix(suffix string) *ActorTable { + return newActorTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newActorTable(schemaName, tableName, alias string) *ActorTable { return &ActorTable{ actorTable: newActorTableImpl(schemaName, tableName, alias), @@ -264,6 +275,16 @@ func (a FilmListTable) FromSchema(schemaName string) *FilmListTable { return newFilmListTable(schemaName, a.TableName(), a.Alias()) } +// WithPrefix creates new FilmListTable with assigned table prefix +func (a FilmListTable) WithPrefix(prefix string) *FilmListTable { + return newFilmListTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new FilmListTable with assigned table suffix +func (a FilmListTable) WithSuffix(suffix string) *FilmListTable { + return newFilmListTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + func newFilmListTable(schemaName, tableName, alias string) *FilmListTable { return &FilmListTable{ filmListTable: newFilmListTableImpl(schemaName, tableName, alias), diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go index f5939bb2..e1b3a547 100644 --- a/tests/sqlite/insert_test.go +++ b/tests/sqlite/insert_test.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "database/sql" "math/rand" "testing" @@ -15,9 +16,6 @@ import ( ) func TestInsertValues(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(101, "http://www.google.com", "Google", "Search engine"). @@ -32,31 +30,32 @@ VALUES (?, ?, ?, ?), 101, "http://www.google.com", "Google", "Search engine", 102, "http://www.yahoo.com", "Yahoo", nil) - _, err := insertQuery.Exec(tx) - require.NoError(t, err) - requireLogged(t, insertQuery) - - insertedLinks := []model.Link{} - - err = SELECT(Link.AllColumns). - FROM(Link). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(tx, &insertedLinks) - - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 3) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) - testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ - ID: 101, - URL: "http://www.google.com", - Name: "Google", - Description: testutils.StringPtr("Search engine"), - }) - testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ - ID: 102, - URL: "http://www.yahoo.com", - Name: "Yahoo", + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + testutils.AssertExec(t, insertQuery, tx) + requireLogged(t, insertQuery) + + var insertedLinks []model.Link + + err := SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + Description: testutils.StringPtr("Search engine"), + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) }) } @@ -67,41 +66,35 @@ var postgreTutorial = model.Link{ } func TestInsertEmptyColumnList(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - expectedSQL := ` -INSERT INTO link -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL); -` - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) - _, err := stmt.Exec(tx) - require.NoError(t, err) - requireLogged(t, stmt) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + requireLogged(t, stmt) - insertedLinks := []model.Link{} + var insertedLinks []model.Link - err = SELECT(Link.AllColumns). - FROM(Link). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(tx, &insertedLinks) + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 1) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + }) } func TestInsertModelObject(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -115,19 +108,13 @@ INSERT INTO link (url, name) VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); `, "http://www.duckduckgo.com", "Duck Duck go") - _, err := query.Exec(tx) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - var expectedSQL = ` -INSERT INTO link -VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); -` - linkData := model.Link{ ID: 1000, URL: "http://www.duckduckgo.com", @@ -138,23 +125,18 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); INSERT(). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +`, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) - _, err := query.Exec(tx) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelsObject(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - expectedSQL := ` -INSERT INTO link (url, name) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); -` - tutorial := model.Link{ URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", @@ -176,27 +158,20 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), yahoo, }) - testutils.AssertDebugStatementSql(t, query, expectedSQL, + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", "http://www.yahoo.com", "Yahoo") - _, err := query.Exec(tx) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, query, sampleDB) } func TestInsertUsingMutableColumns(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - var expectedSQL = ` -INSERT INTO link (url, name, description) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); -` - google := model.Link{ URL: "http://www.google.com", Name: "Google", @@ -213,20 +188,22 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), MODEL(google). MODELS([]model.Link{google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, "http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil, "http://www.yahoo.com", "Yahoo", nil) - _, err := stmt.Exec(tx) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, stmt, sampleDB) } func TestInsertQuery(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - var expectedSQL = ` INSERT INTO link (url, name) SELECT link.url AS "link.url", @@ -242,24 +219,22 @@ WHERE link.id = 24; ) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24)) - - _, err := query.Exec(tx) - require.NoError(t, err) - - youtubeLinks := []model.Link{} - err = Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Bing"))). - Query(tx, &youtubeLinks) - - require.NoError(t, err) - require.Equal(t, len(youtubeLinks), 2) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + + var youtubeLinks []model.Link + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bing"))). + Query(tx, &youtubeLinks) + + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) + }) } func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - stmt := Link.INSERT(). DEFAULT_VALUES(). RETURNING(Link.AllColumns) @@ -273,24 +248,23 @@ RETURNING link.id AS "link.id", link.description AS "link.description"; `) - var link model.Link - err := stmt.Query(tx, &link) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + var link model.Link + err := stmt.Query(tx, &link) + require.NoError(t, err) - require.EqualValues(t, link, model.Link{ - ID: 25, - URL: "www.", - Name: "_", - Description: nil, + require.EqualValues(t, link, model.Link{ + ID: 25, + URL: "www.", + Name: "_", + Description: nil, + }) }) } func TestInsertOnConflict(t *testing.T) { t.Run("do nothing", func(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - link := model.Link{ID: rand.Int31()} stmt := Link.INSERT(Link.AllColumns). @@ -304,14 +278,11 @@ VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON CONFLICT (id) DO NOTHING; `) - testutils.AssertExec(t, stmt, tx, 1) + testutils.AssertExecAndRollback(t, stmt, sampleDB, 1) requireLogged(t, stmt) }) t.Run("do update", func(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). @@ -336,14 +307,11 @@ RETURNING link.id AS "link.id", link.description AS "link.description"; `) - testutils.AssertExec(t, stmt, tx) + testutils.AssertExecAndRollback(t, stmt, sampleDB) requireLogged(t, stmt) }) t.Run("do update complex", func(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). ON_CONFLICT(Link.ID). @@ -370,7 +338,7 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE WHERE link.description IS NOT NULL; `) - testutils.AssertExec(t, stmt, tx) + testutils.AssertExecAndRollback(t, stmt, sampleDB) requireLogged(t, stmt) }) } @@ -384,7 +352,7 @@ func TestInsertContextDeadlineExceeded(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := stmt.QueryContext(ctx, sampleDB, &dest) require.Error(t, err, "context deadline exceeded") diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 4eb274eb..49758451 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -35,6 +35,7 @@ func TestMain(m *testing.M) { var err error db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) throw.OnError(err) + defer db.Close() _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) throw.OnError(err) @@ -42,8 +43,6 @@ func TestMain(m *testing.M) { sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath) throw.OnError(err) - defer db.Close() - ret := m.Run() if ret != 0 { diff --git a/tests/testdata b/tests/testdata index 895bf576..fdb0cc59 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 895bf5760d055c717df77c3b872af276f34d06f1 +Subproject commit fdb0cc598d2b534310d2b559ce9a2f75b5507c56