diff --git a/.circleci/config.yml b/.circleci/config.yml index 35e0a6d6..108ecfd8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -69,7 +69,7 @@ jobs: - run: name: Install MySQL CLI; command: | - sudo apt-get update && sudo apt-get install default-mysql-client + sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client - run: name: Create MySQL user and databases @@ -88,7 +88,8 @@ jobs: - run: mkdir -p $TEST_RESULTS - - run: MY_SQL_SOURCE=MySQL go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + # this will run all tests and exclude test files from code coverage report + - run: MY_SQL_SOURCE=MySQL go test -v ./... -covermode=atomic -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/sqlite/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - run: name: Upload code coverage @@ -138,7 +139,7 @@ jobs: - run: name: Install MySQL CLI; command: | - sudo apt-get update && sudo apt-get install default-mysql-client + sudo apt-get --allow-releaseinfo-change update && sudo apt-get install default-mysql-client - run: name: Init MariaDB database diff --git a/.gitignore b/.gitignore index f286e83d..153be121 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ # Test files gen .gentestdata -.tests/testdata/ \ No newline at end of file +.tests/testdata/ +.gen \ No newline at end of file diff --git a/README.md b/README.md index e79e86c4..89ac87fe 100644 --- a/README.md +++ b/README.md @@ -9,48 +9,41 @@ Jet is a complete solution for efficient and high performance database access, consisting of type-safe SQL builder with code generation and automatic query result data mapping. -Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases. +Jet currently supports `PostgreSQL`, `MySQL`, `MariaDB` and `SQLite`. Future releases will add support for additional databases. ![jet](https://github.com/go-jet/jet/wiki/image/jet.png) -Jet is the easiest and the fastest way to write complex type-safe SQL queries as a Go code and map database query result +Jet is the easiest, and the fastest way to write complex type-safe SQL queries as a Go code and map database query result into complex object composition. __It is not an ORM.__ ## Motivation https://medium.com/@go.jet/jet-5f3667efa0cc ## Contents - - [Features](#features) - - [Getting Started](#getting-started) - - [Prerequisites](#prerequisites) - - [Installation](#installation) - - [Quick Start](#quick-start) - - [Generate sql builder and model files](#generate-sql-builder-and-model-files) - - [Lets write some SQL queries in Go](#lets-write-some-sql-queries-in-go) - - [Execute query and store result](#execute-query-and-store-result) - - [Benefits](#benefits) - - [Dependencies](#dependencies) - - [Versioning](#versioning) - - [License](#license) +- [Features](#features) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Installation](#installation) + - [Quick Start](#quick-start) + - [Generate sql builder and model types](#generate-sql-builder-and-model-types) + - [Lets write some SQL queries in Go](#lets-write-some-sql-queries-in-go) + - [Execute query and store result](#execute-query-and-store-result) +- [Benefits](#benefits) +- [Dependencies](#dependencies) +- [Versioning](#versioning) +- [License](#license) ## Features - 1) Auto-generated type-safe SQL Builder - - PostgreSQL: - * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` - * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT, RETURNING)`, + 1) Auto-generated type-safe SQL Builder. Statements supported: + * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, LOCK_IN_SHARE_MODE, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` + * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT/ON_DUPLICATE_KEY_UPDATE, RETURNING)`, * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE, RETURNING)`, - * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, RETURNING)`, - * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)` - * [WITH](https://github.com/go-jet/jet/wiki/WITH) - - MySQL and MariaDB: - * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)` - * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, ON_DUPLICATE_KEY_UPDATE, query)`, - * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE)`, - * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT)`, - * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(READ, WRITE)` + * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT, RETURNING)`, + * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)`, `(READ, WRITE)` * [WITH](https://github.com/go-jet/jet/wiki/WITH) + 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store result of database queries. Can be combined to create desired query result destination. - 3) Query execution with result mapping to arbitrary destination structure. + 3) Query execution with result mapping to arbitrary destination. ## Getting Started @@ -67,43 +60,44 @@ Use the command bellow to add jet as a dependency into `go.mod` project: $ go get -u github.com/go-jet/jet/v2 ``` -Jet generator can be install in the following ways: +Jet generator can be installed in the following ways: 1) Install jet generator to GOPATH/bin folder: -```sh -cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet -``` -*Make sure GOPATH/bin folder is added to the PATH environment variable.* + ```sh + cd $GOPATH/src/ && GO111MODULE=off go get -u github.com/go-jet/jet/cmd/jet + ``` + *Make sure GOPATH/bin folder is added to the PATH environment variable.* -2) Install jet generator to specific folder: - -```sh -git clone https://github.com/go-jet/jet.git -cd jet && go build -o dir_path ./cmd/jet -``` -*Make sure `dir_path` folder is added to the PATH environment variable.* +2) Install jet generator into specific folder: + + ```sh + git clone https://github.com/go-jet/jet.git + cd jet && go build -o dir_path ./cmd/jet + ``` + *Make sure `dir_path` folder is added to the PATH environment variable.* 3) (Go1.16+) Install jet generator using go install: -```sh -go install github.com/go-jet/jet/v2/cmd/jet@latest -``` -*Jet generator is installed to the directory named by the GOBIN environment variable, -which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.* + ```sh + go install github.com/go-jet/jet/v2/cmd/jet@latest + ``` + *Jet generator is installed to the directory named by the GOBIN environment variable, + which defaults to $GOPATH/bin or $HOME/go/bin if the GOPATH environment variable is not set.* ### Quick Start -For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in [./tests/testdata/init/postgres/dvds.sql](./tests/testdata/init/postgres/dvds.sql). +For this quick start example we will use PostgreSQL sample _'dvd rental'_ database. Full database dump can be found in +[./tests/testdata/init/postgres/dvds.sql](https://github.com/go-jet/jet-test-data/blob/master/init/postgres/dvds.sql). Schema diagram of interest for example can be found [here](./examples/quick-start/diagram.png). -#### Generate SQL Builder and Model files -To generate jet SQL Builder and Data Model files from postgres database, we need to call `jet` generator with postgres -connection parameters and root destination folder path for generated files.\ -Assuming we are running local postgres database, with user `jetuser`, user password `jetpass`, database `jetdb` and +#### Generate SQL Builder and Model types +To generate jet SQL Builder and Data Model types from postgres database, we need to call `jet` generator with postgres +connection parameters and root destination folder path for generated files. +Assuming we are running local postgres database, with user `user`, user password `pass`, database `jetdb` and schema `dvds` we will use this command: ```sh -jet -source=PostgreSQL -host=localhost -port=5432 -user=jetuser -password=jetpass -dbname=jetdb -schema=dvds -path=./.gen +jet -dsn=postgresql://user:pass@localhost:5432/jetdb -schema=dvds -path=./.gen ``` ```sh -Connecting to postgres database: host=localhost port=5432 user=jetuser password=jetpass dbname=jetdb sslmode=disable +Connecting to postgres database: postgresql://user:pass@localhost:5432/jetdb Retrieving schema information... FOUND 15 table(s), 7 view(s), 1 enum(s) Cleaning up destination directory... @@ -115,14 +109,19 @@ Generating view model files... Generating enum model files... Done ``` -Procedure is similar for MySQL or MariaDB, except source should be replaced with `MySql` or `MariaDB` and schema name should -be omitted (both databases doesn't have schema support). +Procedure is similar for MySQL, MariaDB and SQLite. For instance: +```sh +jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./gen +jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./gen # source flag can be omitted if data source appears in dsn +jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./gen +jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./gen # sqlite database assumed for 'file' data sources +``` _*User has to have a permission to read information schema tables._ As command output suggest, Jet will: - connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema - delete everything in schema destination folder - `./gen/jetdb/dvds`, -- and finally generate SQL Builder and Model files for each schema table, view and enum. +- and finally generate SQL Builder and Model types for each schema table, view and enum. Generated files folder structure will look like this: @@ -147,14 +146,14 @@ Generated files folder structure will look like this: | | |-- mpaa_rating.go | | ... ``` -Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types can be combined to store +Types from `table`, `view` and `enum` are used to write type safe SQL in Go, and `model` types are combined to store results of the SQL queries. -#### Lets write some SQL queries in Go +#### Let's write some SQL queries in Go -First we need to import jet and generated files from previous step: +First we need to import postgres SQLBuilder and generated packages from the previous step: ```go import ( // dot import so go code would resemble as much as native SQL @@ -165,7 +164,7 @@ import ( "github.com/go-jet/jet/v2/examples/quick-start/gen/jetdb/dvds/model" ) ``` -Lets say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English' +Let's say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English' and _film category_ is not 'Action'. ```java stmt := SELECT( @@ -189,17 +188,17 @@ stmt := SELECT( Film.FilmID.ASC(), ) ``` -_Package(dot) import is used so that statement would resemble as much as possible as native SQL._ +_Package(dot) import is used, so the statements would resemble as much as possible as native SQL._ Note that every column has a type. String column `Language.Name` and `Category.Name` can be compared only with string columns and expressions. `Actor.ActorID`, `FilmActor.ActorID`, `Film.Length` are integer columns and can be compared only with integer columns and expressions. -__How to get parametrized SQL query from statement?__ +__How to get a parametrized SQL query from the statement?__ ```go query, args := stmt.Sql() ``` -query - parametrized query\ -args - parameters for the query +query - parametrized query +args - query parameters
Click to see `query` and `args` @@ -248,7 +247,7 @@ __How to get debug SQL from statement?__ ```go debugSql := stmt.DebugSql() ``` -debugSql - query string that can be copy pasted to sql editor and executed. __It's not intended to be used in production!!!__ +debugSql - this query string can be copy-pasted to sql editor and executed. __It is not intended to be used in production, only for the purpose of debugging!!!__
Click to see debug sql @@ -291,14 +290,17 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; #### Execute query and store result -Well formed SQL is just a first half of the job. Lets see how can we make some sense of result set returned executing +Well-formed SQL is just a first half of the job. Let's see how can we make some sense of result set returned executing above statement. Usually this is the most complex and tedious work, but with Jet it is the easiest. First we have to create desired structure to store query result. -This is done be combining autogenerated model types or it can be done -manually(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)) for more information). +This is done be combining autogenerated model types, or it can be done +by combining custom model types(see [wiki](https://github.com/go-jet/jet/wiki/Query-Result-Mapping-(QRM)#custom-model-types) for more information). + +It's possible to overwrite default jet generator behavior, and all the aspects of generated model and SQLBuilder types can be +tailor-made([wiki](https://github.com/go-jet/jet/wiki/Generator#generator-customization)). -Let's say this is our desired structure: +Let's say this is our desired structure made of autogenerated types: ```go var dest []struct { model.Actor @@ -315,7 +317,7 @@ var dest []struct { `Langauge` field is just a single model struct. `Film` can belong to multiple categories. _*There is no limitation of how big or nested destination can be._ -Now lets execute a above statement on open database connection (or transaction) db and store result into `dest`. +Now lets execute above statement on open database connection (or transaction) db and store result into `dest`. ```go err := stmt.Query(db, &dest) @@ -524,7 +526,7 @@ found at project [Wiki](https://github.com/go-jet/jet/wiki) page. ## Benefits What are the benefits of writing SQL in Go using Jet? -The biggest benefit is speed. Speed is improved in 3 major areas: +The biggest benefit is speed. Speed is being improved in 3 major areas: ##### Speed of development @@ -538,32 +540,34 @@ Jet will always perform better as developers can write complex query and retriev Thus handler time lost on latency between server and database can be constant. Handler execution will be proportional only to the query complexity and the number of rows returned from database. -With Jet it is even possible to join the whole database and store the whole structured result in one database call. +With Jet, it is even possible to join the whole database and store the whole structured result in one database call. This is exactly what is being done in one of the tests: [TestJoinEverything](/tests/postgres/chinook_db_test.go#L40). The whole test database is joined and query result(~10,000 rows) is stored in a structured variable in less than 0.7s. ##### How quickly bugs are found -The most expensive bugs are the one on the production and the least expensive are those found during development. +The most expensive bugs are the one discovered on the production, and the least expensive are those found during development. With automatically generated type safe SQL, not only queries are written faster but bugs are found sooner. Lets return to quick start example, and take closer look at a line: ```go AND(Film.Length.GT(Int(180))), ``` -Lets say someone changes column `length` to `duration` from `film` table. The next go build will fail at that line and +Let's say someone changes column `length` to `duration` from `film` table. The next go build will fail at that line, and the bug will be caught at compile time. -Lets say someone changes the type of `length` column to some non integer type. Build will also fail at the same line -because integer columns and expressions can be only compered to other integer columns and expressions. +Let's say someone changes the type of `length` column to some non integer type. Build will also fail at the same line +because integer columns and expressions can be only compared to other integer columns and expressions. -Build will also fail if someone removes `length` column from `film` table, because `Film` field will be omitted from SQL Builder and Model types, next time `jet` generator is run. +Build will also fail if someone removes `length` column from `film` table. `Film` field will be omitted from SQL Builder and Model types, +next time `jet` generator is run. Without Jet these bugs will have to be either caught by some test or by manual testing. ## Dependencies At the moment Jet dependence only of: -- `github.com/lib/pq` _(Used by jet generator to read information about database schema from `PostgreSQL`)_ -- `github.com/go-sql-driver/mysql` _(Used by jet generator to read information about database from `MySQL` and `MariaDB`)_ +- `github.com/lib/pq` _(Used by jet generator to read `PostgreSQL` database information)_ +- `github.com/go-sql-driver/mysql` _(Used by jet generator to read `MySQL` and `MariaDB` database information)_ +- `github.com/mattn/go-sqlite3` _(Used by jet generator to read `SQLite` database information)_ - `github.com/google/uuid` _(Used in data model files and for debug purposes)_ To run the tests, additional dependencies are required: diff --git a/cmd/jet/main.go b/cmd/jet/main.go index 136e58b1..fbbccaae 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -3,19 +3,21 @@ package main import ( "flag" "fmt" + sqlitegen "github.com/go-jet/jet/v2/generator/sqlite" + "os" + "strings" + mysqlgen "github.com/go-jet/jet/v2/generator/mysql" postgresgen "github.com/go-jet/jet/v2/generator/postgres" - "github.com/go-jet/jet/v2/mysql" - "github.com/go-jet/jet/v2/postgres" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" - "os" - "strings" + _ "github.com/mattn/go-sqlite3" ) var ( source string + dsn string host string port int user string @@ -29,8 +31,9 @@ var ( ) func init() { - flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL or MariaDB)") + flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL, MariaDB or SQLite)") + flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.IntVar(&port, "port", 0, "Database port") flag.StringVar(&user, "user", "", "Database user") @@ -47,11 +50,22 @@ func main() { flag.Usage = func() { _, _ = fmt.Fprint(os.Stdout, ` -Jet generator 2.5.0 +Jet generator 2.6.0 Usage: + -dsn string + Data source name. Unified format for connecting to database. + PostgreSQL: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING + Example: + postgresql://user:pass@localhost:5432/dbname + MySQL: https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html + Example: + mysql://jet:jet@tcp(localhost:3306)/dvds + SQLite: https://www.sqlite.org/c3ref/open.html#urifilenameexamples + Example: + file://path/to/database/file -source string - Database system name (PostgreSQL, MySQL or MariaDB) + Database system name (PostgreSQL, MySQL, MariaDB or SQLite) -host string Database host path (Example: localhost) -port int @@ -65,25 +79,48 @@ Usage: -params string Additional connection string parameters(optional) -schema string - Database schema name. (default "public") (ignored for MySQL and MariaDB) + Database schema name. (default "public") (ignored for MySQL, MariaDB and SQLite) -sslmode string - Whether or not to use SSL(optional) (default "disable") (ignored for MySQL and MariaDB) + Whether or not to use SSL(optional) (default "disable") (ignored for MySQL, MariaDB and SQLite) -path string Destination dir for files generated. + +Example commands: + + $ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=dvds -path=./gen + $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds -path=./gen + $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen + $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -schema=dvds -path=./gen `) } flag.Parse() - if source == "" || host == "" || port == 0 || user == "" || dbName == "" { - printErrorAndExit("\nERROR: required flag(s) missing") + if dsn == "" { + // validations for separated connection flags. + if source == "" || host == "" || port == 0 || user == "" || dbName == "" { + printErrorAndExit("ERROR: required flag(s) missing") + } + } else { + if source == "" { + // try to get source from schema + source = detectSchema(dsn) + } + + // validations when dsn != "" + if source == "" { + printErrorAndExit("ERROR: required -source flag missing.") + } } var err error switch strings.ToLower(strings.TrimSpace(source)) { - case strings.ToLower(postgres.Dialect.Name()), - strings.ToLower(postgres.Dialect.PackageName()): + case "postgresql", "postgres": + if dsn != "" { + err = postgresgen.GenerateDSN(dsn, schemaName, destDir) + break + } genData := postgresgen.DBConnection{ Host: host, Port: port, @@ -98,8 +135,11 @@ Usage: err = postgresgen.Generate(destDir, genData) - case strings.ToLower(mysql.Dialect.Name()), "mariadb": - + case "mysql", "mysqlx", "mariadb": + if dsn != "" { + err = mysqlgen.GenerateDSN(dsn, destDir) + break + } dbConn := mysqlgen.DBConnection{ Host: host, Port: port, @@ -110,9 +150,13 @@ Usage: } err = mysqlgen.Generate(destDir, dbConn) + case "sqlite": + if dsn == "" { + printErrorAndExit("ERROR: required -dsn flag missing.") + } + err = sqlitegen.GenerateDSN(dsn, destDir) default: - fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.") - os.Exit(-4) + printErrorAndExit("ERROR: unknown data source " + source + ". Only postgres, mysql, mariadb and sqlite are supported.") } if err != nil { @@ -122,7 +166,22 @@ Usage: } func printErrorAndExit(error string) { - fmt.Println(error) + fmt.Println("\n", error) flag.Usage() os.Exit(-2) } + +func detectSchema(dsn string) string { + match := strings.SplitN(dsn, "://", 2) + if len(match) < 2 { // not found + return "" + } + + protocol := match[0] + + if protocol == "file" { + return "sqlite" + } + + return match[0] +} diff --git a/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go b/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go index c802aa93..bdb613a1 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go @@ -20,26 +20,32 @@ const ( ) func (e *MpaaRating) Scan(value interface{}) error { - if v, ok := value.(string); !ok { - return errors.New("jet: Invalid data for MpaaRating enum") - } else { - switch string(v) { - case "G": - *e = MpaaRating_G - case "PG": - *e = MpaaRating_Pg - case "PG-13": - *e = MpaaRating_Pg13 - case "R": - *e = MpaaRating_R - case "NC-17": - *e = MpaaRating_Nc17 - default: - return errors.New("jet: Inavlid data " + string(v) + "for MpaaRating enum") - } - - return nil + var enumValue string + switch val := value.(type) { + case string: + enumValue = val + case []byte: + enumValue = string(val) + default: + return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte") } + + switch enumValue { + case "G": + *e = MpaaRating_G + case "PG": + *e = MpaaRating_Pg + case "PG-13": + *e = MpaaRating_Pg13 + case "R": + *e = MpaaRating_R + case "NC-17": + *e = MpaaRating_Nc17 + default: + return errors.New("jet: Invalid scan value '" + enumValue + "' for MpaaRating enum") + } + + return nil } func (e MpaaRating) String() string { diff --git a/generator/internal/metadata/column_meta_data.go b/generator/internal/metadata/column_meta_data.go deleted file mode 100644 index dceb7c03..00000000 --- a/generator/internal/metadata/column_meta_data.go +++ /dev/null @@ -1,168 +0,0 @@ -package metadata - -import ( - "database/sql" - "fmt" - "github.com/go-jet/jet/v2/internal/utils" - "strings" -) - -// ColumnMetaData struct -type ColumnMetaData struct { - Name string - IsNullable bool - DataType string - EnumName string - IsUnsigned bool - - SqlBuilderColumnType string - GoBaseType string - GoModelType string -} - -// NewColumnMetaData create new column meta data that describes one column in SQL database -func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData { - columnMetaData := ColumnMetaData{ - Name: name, - IsNullable: isNullable, - DataType: dataType, - EnumName: enumName, - IsUnsigned: isUnsigned, - } - - columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType() - columnMetaData.GoBaseType = columnMetaData.getGoBaseType() - columnMetaData.GoModelType = columnMetaData.getGoModelType() - - return columnMetaData -} - -// getSqlBuilderColumnType returns type of jet sql builder column -func (c ColumnMetaData) getSqlBuilderColumnType() string { - switch c.DataType { - case "boolean": - return "Bool" - case "smallint", "integer", "bigint", - "tinyint", "mediumint", "int", "year": //MySQL - return "Integer" - case "date": - return "Date" - case "timestamp without time zone", - "timestamp", "datetime": //MySQL: - return "Timestamp" - case "timestamp with time zone": - return "Timestampz" - case "time without time zone", - "time": //MySQL - return "Time" - case "time with time zone": - return "Timez" - case "interval": - return "Interval" - case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", - "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "binary", "varbinary", - "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL - return "String" - case "real", "numeric", "decimal", "double precision", "float", - "double": // MySQL - return "Float" - default: - fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.") - return "String" - } -} - -// getGoBaseType returns model type for column info. -func (c ColumnMetaData) getGoBaseType() string { - switch c.DataType { - case "USER-DEFINED", "enum": - return utils.ToGoIdentifier(c.EnumName) - case "boolean": - return "bool" - case "tinyint": - return "int8" - case "smallint", - "year": - return "int16" - case "integer", - "mediumint", "int": //MySQL - return "int32" - case "bigint": - return "int64" - case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone", - "timestamp", "datetime", "time": // MySQL - return "time.Time" - case "bytea", - "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL - return "[]byte" - case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb", - "xml", "point", "interval", "line", "ARRAY", - "char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL - return "string" - case "real": - return "float32" - case "numeric", "decimal", "double precision", "float", - "double": // MySQL - return "float64" - case "uuid": - return "uuid.UUID" - default: - fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.") - return "string" - } -} - -// GoModelType returns model type for column info with optional pointer if -// column can be NULL. -func (c ColumnMetaData) getGoModelType() string { - typeStr := c.GoBaseType - - if strings.Contains(typeStr, "int") && c.IsUnsigned { - typeStr = "u" + typeStr - } - - if c.IsNullable { - return "*" + typeStr - } - - return typeStr -} - -// GoModelTag returns model field tag for column -func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string { - tags := []string{} - - if isPrimaryKey { - tags = append(tags, "primary_key") - } - - if len(tags) > 0 { - return "`sql:\"" + strings.Join(tags, ",") + "\"`" - } - - return "" -} - -func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData { - - rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName) - utils.PanicOnError(err) - defer rows.Close() - - ret := []ColumnMetaData{} - - for rows.Next() { - var name, isNullable, dataType, enumName string - var isUnsigned bool - err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned) - utils.PanicOnError(err) - - ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned)) - } - - err = rows.Err() - utils.PanicOnError(err) - - return ret -} diff --git a/generator/internal/metadata/dialect_query_set.go b/generator/internal/metadata/dialect_query_set.go deleted file mode 100644 index 6c918257..00000000 --- a/generator/internal/metadata/dialect_query_set.go +++ /dev/null @@ -1,15 +0,0 @@ -package metadata - -import ( - "database/sql" -) - -// DialectQuerySet is set of methods necessary to retrieve dialect meta data information -type DialectQuerySet interface { - ListOfTablesQuery() string - PrimaryKeysQuery() string - ListOfColumnsQuery() string - ListOfEnumsQuery() string - - GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData -} diff --git a/generator/internal/metadata/enum_meta_data.go b/generator/internal/metadata/enum_meta_data.go deleted file mode 100644 index 8479c603..00000000 --- a/generator/internal/metadata/enum_meta_data.go +++ /dev/null @@ -1,12 +0,0 @@ -package metadata - -// EnumMetaData struct -type EnumMetaData struct { - EnumName string - Values []string -} - -// Name returns enum name -func (e EnumMetaData) Name() string { - return e.EnumName -} diff --git a/generator/internal/metadata/meta_data.go b/generator/internal/metadata/meta_data.go deleted file mode 100644 index 17d2f5c5..00000000 --- a/generator/internal/metadata/meta_data.go +++ /dev/null @@ -1,6 +0,0 @@ -package metadata - -// MetaData interface -type MetaData interface { - Name() string -} diff --git a/generator/internal/metadata/schema_meta_data.go b/generator/internal/metadata/schema_meta_data.go deleted file mode 100644 index bc855112..00000000 --- a/generator/internal/metadata/schema_meta_data.go +++ /dev/null @@ -1,61 +0,0 @@ -package metadata - -import ( - "database/sql" - "fmt" - "github.com/go-jet/jet/v2/internal/utils" -) - -// SchemaMetaData struct -type SchemaMetaData struct { - TablesMetaData []MetaData - ViewsMetaData []MetaData - EnumsMetaData []MetaData -} - -// IsEmpty returns true if schema info does not contain any table, views or enums metadata -func (s SchemaMetaData) IsEmpty() bool { - return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0 -} - -const ( - baseTable = "BASE TABLE" - view = "VIEW" -) - -// GetSchemaMetaData returns schema information from db connection. -func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) { - - schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable) - schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view) - schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName) - - fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),", - len(schemaInfo.EnumsMetaData), "enum(s)") - - return -} - -func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData { - - rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType) - utils.PanicOnError(err) - defer rows.Close() - - ret := []MetaData{} - for rows.Next() { - var tableName string - - err = rows.Scan(&tableName) - utils.PanicOnError(err) - - tableInfo := GetTableMetaData(db, querySet, schemaName, tableName) - - ret = append(ret, tableInfo) - } - - err = rows.Err() - utils.PanicOnError(err) - - return ret -} diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go deleted file mode 100644 index c106dd49..00000000 --- a/generator/internal/metadata/table_meta_data.go +++ /dev/null @@ -1,103 +0,0 @@ -package metadata - -import ( - "database/sql" - "github.com/go-jet/jet/v2/internal/utils" - "strings" -) - -// TableMetaData metadata struct -type TableMetaData struct { - SchemaName string - name string - PrimaryKeys map[string]bool - Columns []ColumnMetaData -} - -// Name returns table info name -func (t TableMetaData) Name() string { - return t.name -} - -// IsPrimaryKey returns if column is a part of primary key -func (t TableMetaData) IsPrimaryKey(column string) bool { - return t.PrimaryKeys[column] -} - -// MutableColumns returns list of mutable columns for table -func (t TableMetaData) MutableColumns() []ColumnMetaData { - ret := []ColumnMetaData{} - - for _, column := range t.Columns { - if t.IsPrimaryKey(column.Name) { - continue - } - - ret = append(ret, column) - } - - return ret -} - -// GetImports returns model imports for table. -func (t TableMetaData) GetImports() []string { - imports := map[string]string{} - - for _, column := range t.Columns { - columnType := column.GoBaseType - - switch columnType { - case "time.Time": - imports["time.Time"] = "time" - case "uuid.UUID": - imports["uuid.UUID"] = "github.com/google/uuid" - } - } - - ret := []string{} - - for _, packageImport := range imports { - ret = append(ret, packageImport) - } - - return ret -} - -// GoStructName returns go struct name for sql builder -func (t TableMetaData) GoStructName() string { - return utils.ToGoIdentifier(t.name) + "Table" -} - -// GoStructImplName returns go struct impl name for sql builder -func (t TableMetaData) GoStructImplName() string { - name := utils.ToGoIdentifier(t.name) + "Table" - return string(strings.ToLower(name)[0]) + name[1:] -} - -// GetTableMetaData returns table info metadata -func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { - tableInfo.SchemaName = schemaName - tableInfo.name = tableName - - tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) - tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) - return -} - -func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool { - - rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName) - utils.PanicOnError(err) - - primaryKeyMap := map[string]bool{} - - for rows.Next() { - primaryKey := "" - err := rows.Scan(&primaryKey) - utils.PanicOnError(err) - - primaryKeyMap[primaryKey] = true - } - - return primaryKeyMap -} diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go deleted file mode 100644 index 34e9ca12..00000000 --- a/generator/internal/template/generate.go +++ /dev/null @@ -1,107 +0,0 @@ -package template - -import ( - "bytes" - "fmt" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/utils" - "path/filepath" - "text/template" -) - -// GenerateFiles generates Go files from tables and enums metadata -func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) { - if schemaInfo.IsEmpty() { - return - } - - fmt.Println("Destination directory:", destDir) - fmt.Println("Cleaning up destination directory...") - err := utils.CleanUpGeneratedFiles(destDir) - utils.PanicOnError(err) - - tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect) - generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) - generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) - generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) - - generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect) - generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect) - generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect) - - fmt.Println("Done") -} - -func getTableSQLBuilderTemplate(dialect jet.Dialect) string { - if dialect.Name() == "PostgreSQL" { - return tablePostgreSQLBuilderTemplate - } - - return tableSQLBuilderTemplate -} - -func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { - if len(metaData) == 0 { - return - } - fmt.Printf("Generating %s sql builder files...\n", fileTypes) - generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect) -} - -func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { - if len(metaData) == 0 { - return - } - fmt.Printf("Generating %s model files...\n", fileTypes) - generateGoFiles(destDir, "model", modelTemplate, metaData, dialect) -} - -func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) { - modelDirPath := filepath.Join(dirPath, packageName) - - err := utils.EnsureDirPath(modelDirPath) - utils.PanicOnError(err) - - autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect) - utils.PanicOnError(err) - - for _, metaData := range metaDataList { - text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName}) - utils.PanicOnError(err) - - err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...)) - utils.PanicOnError(err) - } - - return -} - -// GenerateTemplate generates template with template text and template data. -func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) { - - t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ - "ToGoIdentifier": utils.ToGoIdentifier, - "ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier, - "dialect": func() jet.Dialect { - return dialect - }, - "param": func(name string) interface{} { - if len(params) > 0 { - return params[0][name] - } - return "" - }, - }).Parse(templateText) - - if err != nil { - return nil, err - } - - var buf bytes.Buffer - if err := t.Execute(&buf, templateData); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go deleted file mode 100644 index 186ade02..00000000 --- a/generator/internal/template/templates.go +++ /dev/null @@ -1,213 +0,0 @@ -package template - -var autoGenWarningTemplate = ` -// -// Code generated by go-jet DO NOT EDIT. -// -// WARNING: Changes to this file may cause incorrect behavior -// and will be lost if the code is regenerated -// - -` - -var tableSQLBuilderTemplate = ` -{{define "column-list" -}} - {{- range $i, $c := . }} - {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column - {{- end}} -{{- end}} - -package {{param "package"}} - -import ( - "github.com/go-jet/jet/v2/{{dialect.PackageName}}" -) - -var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "") - -type {{.GoStructName}} struct { - {{dialect.PackageName}}.Table - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}} -{{- end}} - - AllColumns {{dialect.PackageName}}.ColumnList - MutableColumns {{dialect.PackageName}}.ColumnList -} - -// AS creates new {{.GoStructName}} with assigned alias -func (a {{.GoStructName}}) AS(alias string) {{.GoStructName}} { - return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias) -} - -// Schema creates new {{.GoStructName}} with assigned schema name -func (a {{.GoStructName}}) FromSchema(schemaName string) {{.GoStructName}} { - return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias()) -} - -func new{{.GoStructName}}(schemaName, tableName, alias string) {{.GoStructName}} { - var ( - {{- range .Columns}} - {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") - {{- end}} - allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } - mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } - ) - - return {{.GoStructName}}{ - Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, -{{- end}} - - AllColumns: allColumns, - MutableColumns: mutableColumns, - } -} -` - -var tablePostgreSQLBuilderTemplate = ` -{{define "column-list" -}} - {{- range $i, $c := . }} - {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column - {{- end}} -{{- end}} - -package {{param "package"}} - -import ( - "github.com/go-jet/jet/v2/{{dialect.PackageName}}" -) - -var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "") - -type {{.GoStructImplName}} struct { - {{dialect.PackageName}}.Table - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}} -{{- end}} - - AllColumns {{dialect.PackageName}}.ColumnList - MutableColumns {{dialect.PackageName}}.ColumnList -} - -type {{.GoStructName}} struct { - {{.GoStructImplName}} - - EXCLUDED {{.GoStructImplName}} -} - -// AS creates new {{.GoStructName}} with assigned alias -func (a {{.GoStructName}}) AS(alias string) *{{.GoStructName}} { - return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias) -} - -// Schema creates new {{.GoStructName}} with assigned schema name -func (a {{.GoStructName}}) FromSchema(schemaName string) *{{.GoStructName}} { - return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias()) -} - -func new{{.GoStructName}}(schemaName, tableName, alias string) *{{.GoStructName}} { - return &{{.GoStructName}}{ - {{.GoStructImplName}}: new{{.GoStructName}}Impl(schemaName, tableName, alias), - EXCLUDED: new{{.GoStructName}}Impl("", "excluded", ""), - } -} - -func new{{.GoStructName}}Impl(schemaName, tableName, alias string) {{.GoStructImplName}} { - var ( - {{- range .Columns}} - {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") - {{- end}} - allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } - mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } - ) - - return {{.GoStructImplName}}{ - Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, -{{- end}} - - AllColumns: allColumns, - MutableColumns: mutableColumns, - } -} -` - -var tableModelTemplate = `package model - -{{ if .GetImports }} -import ( -{{- range .GetImports}} - "{{.}}" -{{- end}} -) -{{end}} - - -type {{ToGoIdentifier .Name}} struct { -{{- range .Columns}} - {{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` -{{- end}} -} - - -` -var enumSQLBuilderTemplate = `package enum - -import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" - -var {{ToGoIdentifier $.Name}} = &struct { -{{- range $index, $element := .Values}} - {{ToGoEnumValueIdentifier $.Name $element}} {{dialect.PackageName}}.StringExpression -{{- end}} -} { -{{- range $index, $element := .Values}} - {{ToGoEnumValueIdentifier $.Name $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"), -{{- end}} -} -` - -var enumModelTemplate = `package model - -import "errors" - -type {{ToGoIdentifier $.Name}} string - -const ( -{{- range $index, $element := .Values}} - {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} {{ToGoIdentifier $.Name}} = "{{$element}}" -{{- end}} -) - -func (e *{{ToGoIdentifier $.Name}}) Scan(value interface{}) error { - if v, ok := value.(string); !ok { - return errors.New("jet: Invalid data for {{ToGoIdentifier $.Name}} enum") - } else { - switch string(v) { -{{- range $index, $element := .Values}} - case "{{$element}}": - *e = {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} -{{- end}} - default: - return errors.New("jet: Inavlid data " + string(v) + "for {{ToGoIdentifier $.Name}} enum") - } - - return nil - } -} - -func (e {{ToGoIdentifier $.Name}}) String() string { - return string(e) -} - -` diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go new file mode 100644 index 00000000..74184b68 --- /dev/null +++ b/generator/metadata/column_meta_data.go @@ -0,0 +1,27 @@ +package metadata + +// Column struct +type Column struct { + Name string + IsPrimaryKey bool + IsNullable bool + DataType DataType +} + +// DataTypeKind is database type kind(base, enum, user-defined, array) +type DataTypeKind string + +// DataTypeKind possible values +const ( + BaseType DataTypeKind = "base" + EnumType DataTypeKind = "enum" + UserDefinedType DataTypeKind = "user-defined" + ArrayType DataTypeKind = "array" +) + +// DataType contains information about column data type +type DataType struct { + Name string + Kind DataTypeKind + IsUnsigned bool +} diff --git a/generator/metadata/dialect_query_set.go b/generator/metadata/dialect_query_set.go new file mode 100644 index 00000000..66a32a65 --- /dev/null +++ b/generator/metadata/dialect_query_set.go @@ -0,0 +1,36 @@ +package metadata + +import ( + "database/sql" + "fmt" +) + +// TableType is type of database table(view or base) +type TableType string + +// SQL table types +const ( + BaseTable TableType = "BASE TABLE" + ViewTable TableType = "VIEW" +) + +// DialectQuerySet is set of methods necessary to retrieve dialect meta data information +type DialectQuerySet interface { + GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) []Table + GetEnumsMetaData(db *sql.DB, schemaName string) []Enum +} + +// GetSchema retrieves Schema information from database +func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema { + ret := Schema{ + Name: schemaName, + TablesMetaData: querySet.GetTablesMetaData(db, schemaName, BaseTable), + ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, ViewTable), + EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName), + } + + fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),", + len(ret.EnumsMetaData), "enum(s)") + + return ret +} diff --git a/generator/metadata/enum_meta_data.go b/generator/metadata/enum_meta_data.go new file mode 100644 index 00000000..7aea3d6e --- /dev/null +++ b/generator/metadata/enum_meta_data.go @@ -0,0 +1,7 @@ +package metadata + +// Enum metadata struct +type Enum struct { + Name string `sql:"primary_key"` + Values []string +} diff --git a/generator/metadata/schema_meta_data.go b/generator/metadata/schema_meta_data.go new file mode 100644 index 00000000..c4c505a1 --- /dev/null +++ b/generator/metadata/schema_meta_data.go @@ -0,0 +1,14 @@ +package metadata + +// Schema struct +type Schema struct { + Name string + TablesMetaData []Table + ViewsMetaData []Table + EnumsMetaData []Enum +} + +// IsEmpty returns true if schema info does not contain any table, views or enums metadata +func (s Schema) IsEmpty() bool { + return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0 +} diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go new file mode 100644 index 00000000..6479dc20 --- /dev/null +++ b/generator/metadata/table_meta_data.go @@ -0,0 +1,22 @@ +package metadata + +// Table metadata struct +type Table struct { + Name string + Columns []Column +} + +// MutableColumns returns list of mutable columns for table +func (t Table) MutableColumns() []Column { + var ret []Column + + for _, column := range t.Columns { + if column.IsPrimaryKey { + continue + } + + ret = append(ret, column) + } + + return ret +} diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index 7f5d99a3..0fbd5f91 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -3,11 +3,14 @@ package mysql import ( "database/sql" "fmt" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/generator/internal/template" + "strings" + + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/mysql" - "path" + mysqldr "github.com/go-sql-driver/mysql" ) // DBConnection contains MySQL connection details @@ -22,34 +25,68 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) (err error) { +func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) { defer utils.ErrorCatch(&err) - db := openConnection(dbConn) + connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) + if dbConn.Params != "" { + connectionString += "?" + dbConn.Params + } + + db := openConnection(connectionString) defer utils.DBClose(db) - fmt.Println("Retrieving database information...") - // No schemas in MySQL - dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) + generate(db, dbConn.DBName, destDir, generatorTemplate...) - genPath := path.Join(destDir, dbConn.DBName) + return nil +} + +// GenerateDSN opens connection via DSN string and does everything what Generate does. +func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) { + defer utils.ErrorCatch(&err) - template.GenerateFiles(genPath, dbInfo, mysql.Dialect) + // Special case for go mysql driver. It does not understand schema, + // so we need to trim it before passing to generator + // https://github.com/go-sql-driver/mysql#dsn-data-source-name + idx := strings.Index(dsn, "://") + if idx != -1 { + dsn = dsn[idx+len("://"):] + } + + cfg, err := mysqldr.ParseDSN(dsn) + throw.OnError(err) + if cfg.DBName == "" { + panic("database name is required") + } + + db := openConnection(dsn) + defer utils.DBClose(db) + + generate(db, cfg.DBName, destDir, templates...) return nil } -func openConnection(dbConn DBConnection) *sql.DB { - var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) - if dbConn.Params != "" { - connectionString += "?" + dbConn.Params - } +func openConnection(connectionString string) *sql.DB { fmt.Println("Connecting to MySQL database: " + connectionString) db, err := sql.Open("mysql", connectionString) - utils.PanicOnError(err) + throw.OnError(err) err = db.Ping() - utils.PanicOnError(err) + throw.OnError(err) return db } + +func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) { + fmt.Println("Retrieving database information...") + // No schemas in MySQL + schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbName) + + genTemplate := template.Default(mysql.Dialect) + if len(templates) > 0 { + genTemplate = templates[0] + } + + template.ProcessSchema(destDir, schemaMetaData, genTemplate) +} diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 1b4e2b2c..a409eb76 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -1,81 +1,91 @@ package mysql import ( + "context" "database/sql" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" "strings" ) // mySqlQuerySet is dialect query set for MySQL type mySqlQuerySet struct{} -func (m *mySqlQuerySet) ListOfTablesQuery() string { - return ` -SELECT table_name +func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { + query := ` +SELECT table_name as "table.name" FROM INFORMATION_SCHEMA.tables WHERE table_schema = ? and table_type = ?; ` -} + var tables []metadata.Table -func (m *mySqlQuerySet) PrimaryKeysQuery() string { - return ` -SELECT k.column_name -FROM information_schema.table_constraints t -JOIN information_schema.key_column_usage k -USING(constraint_name,table_schema,table_name) -WHERE t.constraint_type='PRIMARY KEY' - AND t.table_schema= ? - AND t.table_name= ?; -` + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + throw.OnError(err) + + for i := range tables { + tables[i].Columns = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + } + + return tables } -func (m *mySqlQuerySet) ListOfColumnsQuery() string { - return ` -SELECT COLUMN_NAME, - IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE), - IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''), - COLUMN_TYPE LIKE '%unsigned%' -FROM information_schema.columns -WHERE table_schema = ? and table_name = ? +func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { + query := ` +WITH primaryKeys AS ( + SELECT k.column_name + FROM information_schema.table_constraints t + JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name) + WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' +) +SELECT COLUMN_NAME AS "column.Name", + IS_NULLABLE = "YES" AS "column.IsNullable", + (EXISTS(SELECT 1 FROM primaryKeys AS pk WHERE pk.column_name = columns.column_name)) AS "column.IsPrimaryKey", + IF (COLUMN_TYPE = 'tinyint(1)', + 'boolean', + IF (DATA_TYPE='enum', + CONCAT(TABLE_NAME, '_', COLUMN_NAME), + DATA_TYPE) + ) AS "dataType.Name", + IF (DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", + COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" +FROM information_schema.columns +WHERE table_schema = ? AND table_name = ? ORDER BY ordinal_position; ` + var columns []metadata.Column + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) + throw.OnError(err) + + return columns } -func (m *mySqlQuerySet) ListOfEnumsQuery() string { - return ` -SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) +func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { + query := ` +SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name", + SUBSTRING(c.COLUMN_TYPE,5) as "values" FROM information_schema.columns as c INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; ` -} - -func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { - - rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) - utils.PanicOnError(err) - defer rows.Close() + var queryResult []struct { + Name string + Values string + } - ret := []metadata.MetaData{} + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) + throw.OnError(err) - for rows.Next() { - var enumName string - var enumValues string - err = rows.Scan(&enumName, &enumValues) - utils.PanicOnError(err) + var ret []metadata.Enum - enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1) + for _, result := range queryResult { + enumValues := strings.Replace(result.Values[1:len(result.Values)-1], "'", "", -1) - ret = append(ret, metadata.EnumMetaData{ - EnumName: enumName, - Values: strings.Split(enumValues, ","), + ret = append(ret, metadata.Enum{ + Name: result.Name, + Values: strings.Split(enumValues, ","), }) } - err = rows.Err() - utils.PanicOnError(err) - return ret - } diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index 970fd2de..1c3e8859 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -3,12 +3,16 @@ package postgres import ( "database/sql" "fmt" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/generator/internal/template" - "github.com/go-jet/jet/v2/internal/utils" - "github.com/go-jet/jet/v2/postgres" + "net/url" "path" "strconv" + + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/postgres" + "github.com/jackc/pgconn" ) // DBConnection contains postgres connection details @@ -25,38 +29,53 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) (err error) { +func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Template) (err error) { + dsn := fmt.Sprintf("postgresql://%s:%s@%s:%s/%s?sslmode=%s", + url.PathEscape(dbConn.User), + url.PathEscape(dbConn.Password), + dbConn.Host, + strconv.Itoa(dbConn.Port), + url.PathEscape(dbConn.DBName), + dbConn.SslMode, + ) + + return GenerateDSN(dsn, dbConn.SchemaName, destDir, genTemplate...) +} + +// GenerateDSN generates jet files using dsn connection string +func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (err error) { defer utils.ErrorCatch(&err) - db, err := openConnection(dbConn) - utils.PanicOnError(err) + cfg, err := pgconn.ParseConfig(dsn) + throw.OnError(err) + if cfg.Database == "" { + panic("database name is required") + } + db := openConnection(dsn) defer utils.DBClose(db) fmt.Println("Retrieving schema information...") - schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) + generatorTemplate := template.Default(postgres.Dialect) + if len(templates) > 0 { + generatorTemplate = templates[0] + } - genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) - template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) + schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, schema) + dirPath := path.Join(destDir, cfg.Database) + + template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) return } -func openConnection(dbConn DBConnection) (*sql.DB, error) { - connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", - dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params) +func openConnection(dsn string) *sql.DB { + fmt.Println("Connecting to postgres database: " + dsn) - fmt.Println("Connecting to postgres database: " + connectionString) - - db, err := sql.Open("postgres", connectionString) - if err != nil { - return nil, err - } + db, err := sql.Open("postgres", dsn) + throw.OnError(err) err = db.Ping() + throw.OnError(err) - if err != nil { - return nil, err - } - - return db, nil + return db } diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 0fc8fdcc..e2fb9698 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -1,81 +1,83 @@ package postgres import ( + "context" "database/sql" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" ) // postgresQuerySet is dialect query set for PostgreSQL type postgresQuerySet struct{} -func (p *postgresQuerySet) ListOfTablesQuery() string { - return ` -SELECT table_name +func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { + query := ` +SELECT table_name as "table.name" FROM information_schema.tables -where table_schema = $1 and table_type = $2; +WHERE table_schema = $1 and table_type = $2; ` -} + var tables []metadata.Table -func (p *postgresQuerySet) PrimaryKeysQuery() string { - return ` -SELECT c.column_name -FROM information_schema.key_column_usage AS c -LEFT JOIN information_schema.table_constraints AS t -ON t.constraint_name = c.constraint_name -WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'; -` + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + throw.OnError(err) + + for i := range tables { + tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + } + + return tables } -func (p *postgresQuerySet) ListOfColumnsQuery() string { - return ` -SELECT column_name, is_nullable, data_type, udt_name, FALSE -FROM information_schema.columns +func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { + query := ` +WITH primaryKeys AS ( + SELECT column_name + FROM information_schema.key_column_usage AS c + LEFT JOIN information_schema.table_constraints AS t + ON t.constraint_name = c.constraint_name + WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' +) +SELECT column_name as "column.Name", + is_nullable = 'YES' as "column.isNullable", + (EXISTS(SELECT 1 from primaryKeys as pk where pk.column_name = columns.column_name)) as "column.IsPrimaryKey", + dataType.kind as "dataType.Kind", + (case dataType.Kind when 'base' then data_type else LTRIM(udt_name, '_') end) as "dataType.Name", + FALSE as "dataType.isUnsigned" +FROM information_schema.columns, + LATERAL (select (case data_type + when 'ARRAY' then 'array' + when 'USER-DEFINED' then + case (select typtype from pg_type where typname = columns.udt_name) + when 'e' then 'enum' + else 'user-defined' + end + else 'base' + end) as Kind) as dataType where table_schema = $1 and table_name = $2 -order by ordinal_position;` +order by ordinal_position; +` + var columns []metadata.Column + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) + throw.OnError(err) + + return columns } -func (p *postgresQuerySet) ListOfEnumsQuery() string { - return ` -SELECT t.typname, - e.enumlabel +func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { + query := ` +SELECT t.typname as "enum.name", + e.enumlabel as "values" FROM pg_catalog.pg_type t JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace WHERE n.nspname = $1 ORDER BY n.nspname, t.typname, e.enumsortorder;` -} -func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { - rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) - utils.PanicOnError(err) - defer rows.Close() + var result []metadata.Enum - enumsInfosMap := map[string][]string{} - for rows.Next() { - var enumName string - var enumValue string - err = rows.Scan(&enumName, &enumValue) - utils.PanicOnError(err) - - enumValues := enumsInfosMap[enumName] - - enumValues = append(enumValues, enumValue) - - enumsInfosMap[enumName] = enumValues - } - - err = rows.Err() - utils.PanicOnError(err) - - ret := []metadata.MetaData{} - - for enumName, enumValues := range enumsInfosMap { - ret = append(ret, metadata.EnumMetaData{ - EnumName: enumName, - Values: enumValues, - }) - } + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) + throw.OnError(err) - return ret + return result } diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go new file mode 100644 index 00000000..e1d5e4d1 --- /dev/null +++ b/generator/sqlite/query_set.go @@ -0,0 +1,80 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" + "strings" +) + +// sqliteQuerySet is dialect query set for SQLite +type sqliteQuerySet struct{} + +func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { + query := ` + SELECT name as "table.name" + FROM sqlite_master + WHERE type=? AND name != 'sqlite_sequence' + ORDER BY name; +` + sqlTableType := "table" + + if tableType == metadata.ViewTable { + sqlTableType = "view" + } + + var tables []metadata.Table + + err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) + throw.OnError(err) + + for i := range tables { + tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + } + + return tables +} + +func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { + query := fmt.Sprintf(`select * from pragma_table_info(?);`) + var columnInfos []struct { + Name string + Type string + NotNull int32 + Pk int32 + } + + err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + throw.OnError(err) + + var columns []metadata.Column + + for _, columnInfo := range columnInfos { + columnType := getColumnType(columnInfo.Type) + + columns = append(columns, metadata.Column{ + Name: columnInfo.Name, + IsPrimaryKey: columnInfo.Pk != 0, + IsNullable: columnInfo.NotNull != 1, + DataType: metadata.DataType{ + Name: columnType, + Kind: metadata.BaseType, + IsUnsigned: false, + }, + }) + } + + return columns +} + +// will convert VARCHAR(10) -> VARCHAR, etc... +func getColumnType(columnType string) string { + return strings.TrimSpace(strings.Split(columnType, "(")[0]) +} + +func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { + return nil +} diff --git a/generator/sqlite/sqlite_generator.go b/generator/sqlite/sqlite_generator.go new file mode 100644 index 00000000..78873941 --- /dev/null +++ b/generator/sqlite/sqlite_generator.go @@ -0,0 +1,32 @@ +package sqlite + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/sqlite" +) + +// GenerateDSN generates jet files using dsn connection string +func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) { + defer utils.ErrorCatch(&err) + + db, err := sql.Open("sqlite3", dsn) + throw.OnError(err) + defer utils.DBClose(db) + + fmt.Println("Retrieving schema information...") + + generatorTemplate := template.Default(sqlite.Dialect) + if len(templates) > 0 { + generatorTemplate = templates[0] + } + + schemaMetadata := metadata.GetSchema(db, &sqliteQuerySet{}, "") + + template.ProcessSchema(destDir, schemaMetadata, generatorTemplate) + return +} diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go new file mode 100644 index 00000000..e3020cea --- /dev/null +++ b/generator/template/file_templates.go @@ -0,0 +1,229 @@ +package template + +var autoGenWarningTemplate = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +` + +var tableSQLBuilderTemplate = ` +{{define "column-list" -}} + {{- range $i, $c := . }} +{{- $field := columnField $c}} + {{- if gt $i 0 }}, {{end}}{{$field.Name}}Column + {{- end}} +{{- end}} + +package {{package}} + +import ( + "github.com/go-jet/jet/v2/{{dialect.PackageName}}" +) + +var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "") + +type {{tableTemplate.TypeName}} struct { + {{dialect.PackageName}}.Table + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} +{{- end}} + + AllColumns {{dialect.PackageName}}.ColumnList + MutableColumns {{dialect.PackageName}}.ColumnList +} + +// AS creates new {{tableTemplate.TypeName}} with assigned alias +func (a {{tableTemplate.TypeName}}) AS(alias string) {{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new {{tableTemplate.TypeName}} with assigned schema name +func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) {{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias()) +} + +func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableTemplate.TypeName}} { + var ( +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}") +{{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } + ) + + return {{tableTemplate.TypeName}}{ + Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}: {{$field.Name}}Column, +{{- end}} + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +var tableSQLBuilderTemplateWithEXCLUDED = ` +{{define "column-list" -}} + {{- range $i, $c := . }} +{{- $field := columnField $c}} + {{- if gt $i 0 }}, {{end}}{{$field.Name}}Column + {{- end}} +{{- end}} + +package {{package}} + +import ( + "github.com/go-jet/jet/v2/{{dialect.PackageName}}" +) + +var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "") + +type {{structImplName}} struct { + {{dialect.PackageName}}.Table + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} +{{- end}} + + AllColumns {{dialect.PackageName}}.ColumnList + MutableColumns {{dialect.PackageName}}.ColumnList +} + +type {{tableTemplate.TypeName}} struct { + {{structImplName}} + + EXCLUDED {{structImplName}} +} + +// AS creates new {{tableTemplate.TypeName}} with assigned alias +func (a {{tableTemplate.TypeName}}) AS(alias string) *{{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new {{tableTemplate.TypeName}} with assigned schema name +func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) *{{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias()) +} + +func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) *{{tableTemplate.TypeName}} { + return &{{tableTemplate.TypeName}}{ + {{structImplName}}: new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias), + EXCLUDED: new{{tableTemplate.TypeName}}Impl("", "excluded", ""), + } +} + +func new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias string) {{structImplName}} { + var ( +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}") +{{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } + ) + + return {{structImplName}}{ + Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}: {{$field.Name}}Column, +{{- end}} + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +var tableModelFileTemplate = `package {{package}} + +{{ with modelImports }} +import ( +{{- range .}} + "{{.}}" +{{- end}} +) +{{end}} + +{{$modelTableTemplate := tableTemplate}} +type {{$modelTableTemplate.TypeName}} struct { +{{- range .Columns}} +{{- $field := structField .}} + {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` +{{- end}} +} + +` + +var enumSQLBuilderTemplate = `package {{package}} + +import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" + +var {{enumTemplate.InstanceName}} = &struct { +{{- range $index, $value := .Values}} + {{enumValueName $value}} {{dialect.PackageName}}.StringExpression +{{- end}} +} { +{{- range $index, $value := .Values}} + {{enumValueName $value}}: {{dialect.PackageName}}.NewEnumValue("{{$value}}"), +{{- end}} +} +` + +var enumModelTemplate = `package {{package}} +{{- $enumTemplate := enumTemplate}} + +import "errors" + +type {{$enumTemplate.TypeName}} string + +const ( +{{- range $_, $value := .Values}} + {{valueName $value}} {{$enumTemplate.TypeName}} = "{{$value}}" +{{- end}} +) + +func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error { + var enumValue string + switch val := value.(type) { + case string: + enumValue = val + case []byte: + enumValue = string(val) + default: + return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte") + } + + switch enumValue { +{{- range $_, $value := .Values}} + case "{{$value}}": + *e = {{valueName $value}} +{{- end}} + default: + return errors.New("jet: Invalid scan value '" + enumValue + "' for {{$enumTemplate.TypeName}} enum") + } + + return nil +} + +func (e {{$enumTemplate.TypeName}}) String() string { + return string(e) +} + +` diff --git a/generator/template/generator_template.go b/generator/template/generator_template.go new file mode 100644 index 00000000..38e8fbb5 --- /dev/null +++ b/generator/template/generator_template.go @@ -0,0 +1,60 @@ +package template + +import ( + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/jet" +) + +// Template is generator template used for file generation +type Template struct { + Dialect jet.Dialect + Schema func(schemaMetaData metadata.Schema) Schema +} + +// Default is default generator template implementation +func Default(dialect jet.Dialect) Template { + return Template{ + Dialect: dialect, + Schema: DefaultSchema, + } +} + +// UseSchema replaces current schema generate function with a new implementation and returns new generator template +func (t Template) UseSchema(schemaFunc func(schemaMetaData metadata.Schema) Schema) Template { + t.Schema = schemaFunc + return t +} + +// Schema is schema generator template used to generate schema(model and sql builder) files +type Schema struct { + Path string + Model Model + SQLBuilder SQLBuilder +} + +// UsePath replaces path and returns new schema template +func (s Schema) UsePath(path string) Schema { + s.Path = path + return s +} + +// UseModel returns new schema template with replaced template for model files generation +func (s Schema) UseModel(model Model) Schema { + s.Model = model + return s +} + +// UseSQLBuilder returns new schema with replaced template for sql builder files generation +func (s Schema) UseSQLBuilder(sqlBuilder SQLBuilder) Schema { + s.SQLBuilder = sqlBuilder + return s +} + +// DefaultSchema returns default schema template implementation +func DefaultSchema(schemaMetaData metadata.Schema) Schema { + return Schema{ + Path: schemaMetaData.Name, + Model: DefaultModel(), + SQLBuilder: DefaultSQLBuilder(), + } +} diff --git a/generator/template/model_template.go b/generator/template/model_template.go new file mode 100644 index 00000000..032afc87 --- /dev/null +++ b/generator/template/model_template.go @@ -0,0 +1,327 @@ +package template + +import ( + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/google/uuid" + "path" + "reflect" + "strings" + "time" +) + +// Model is template for model files generation +type Model struct { + Skip bool + Path string + Table func(table metadata.Table) TableModel + View func(table metadata.Table) ViewModel + Enum func(enum metadata.Enum) EnumModel +} + +// PackageName returns package name of model types +func (m Model) PackageName() string { + return path.Base(m.Path) +} + +// UsePath returns new Model template with replaced file path +func (m Model) UsePath(path string) Model { + m.Path = path + return m +} + +// UseTable returns new Model template with replaced template for table model files generation +func (m Model) UseTable(tableModelFunc func(table metadata.Table) TableModel) Model { + m.Table = tableModelFunc + return m +} + +// UseView returns new Model template with replaced template for view model files generation +func (m Model) UseView(tableModelFunc func(table metadata.Table) TableModel) Model { + m.View = tableModelFunc + return m +} + +// UseEnum returns new Model template with replaced template for enum model files generation +func (m Model) UseEnum(enumFunc func(enumMetaData metadata.Enum) EnumModel) Model { + m.Enum = enumFunc + return m +} + +// DefaultModel returns default Model template implementation +func DefaultModel() Model { + return Model{ + Skip: false, + Path: "/model", + Table: DefaultTableModel, + View: DefaultViewModel, + Enum: DefaultEnumModel, + } +} + +// TableModel is template for table model files generation +type TableModel struct { + Skip bool + FileName string + TypeName string + Field func(columnMetaData metadata.Column) TableModelField +} + +// ViewModel is template for view model files generation +type ViewModel = TableModel + +// DefaultViewModel is default view template implementation +var DefaultViewModel = DefaultTableModel + +// DefaultTableModel is default table template implementation +func DefaultTableModel(tableMetaData metadata.Table) TableModel { + return TableModel{ + FileName: utils.ToGoFileName(tableMetaData.Name), + TypeName: utils.ToGoIdentifier(tableMetaData.Name), + Field: DefaultTableModelField, + } +} + +// UseFileName returns new TableModel with new file name set +func (t TableModel) UseFileName(fileName string) TableModel { + t.FileName = fileName + return t +} + +// UseTypeName returns new TableModel with new type name set +func (t TableModel) UseTypeName(typeName string) TableModel { + t.TypeName = typeName + return t +} + +// UseField returns new TableModel with new TableModelField template function +func (t TableModel) UseField(structFieldFunc func(columnMetaData metadata.Column) TableModelField) TableModel { + t.Field = structFieldFunc + return t +} + +func getTableModelImports(modelType TableModel, tableMetaData metadata.Table) []string { + importPaths := map[string]bool{} + for _, columnMetaData := range tableMetaData.Columns { + field := modelType.Field(columnMetaData) + importPath := field.Type.ImportPath + + if importPath != "" { + importPaths[importPath] = true + } + } + + var ret []string + for importPath := range importPaths { + ret = append(ret, importPath) + } + + return ret +} + +// EnumModel is template for enum model files generation +type EnumModel struct { + Skip bool + FileName string + TypeName string + ValueName func(value string) string +} + +// UseFileName returns new EnumModel with new file name set +func (em EnumModel) UseFileName(fileName string) EnumModel { + em.FileName = fileName + return em +} + +// UseTypeName returns new EnumModel with new type name set +func (em EnumModel) UseTypeName(typeName string) EnumModel { + em.TypeName = typeName + return em +} + +// DefaultEnumModel returns default implementation for EnumModel +func DefaultEnumModel(enumMetaData metadata.Enum) EnumModel { + typeName := utils.ToGoIdentifier(enumMetaData.Name) + + return EnumModel{ + FileName: utils.ToGoFileName(enumMetaData.Name), + TypeName: typeName, + ValueName: func(value string) string { + return typeName + "_" + utils.ToGoIdentifier(value) + }, + } +} + +// TableModelField is template for table model field generation +type TableModelField struct { + Name string + Type Type + Tags []string +} + +// DefaultTableModelField returns default TableModelField implementation +func DefaultTableModelField(columnMetaData metadata.Column) TableModelField { + var tags []string + + if columnMetaData.IsPrimaryKey { + tags = append(tags, `sql:"primary_key"`) + } + + return TableModelField{ + Name: utils.ToGoIdentifier(columnMetaData.Name), + Type: getType(columnMetaData), + Tags: tags, + } +} + +// UseType returns new TypeModelField with a new field type set +func (f TableModelField) UseType(t Type) TableModelField { + f.Type = t + return f +} + +// UseName returns new TableModelField implementation with new field name set +func (f TableModelField) UseName(name string) TableModelField { + f.Name = name + return f +} + +// UseTags returns new TableModelField implementation with additional tags added. +func (f TableModelField) UseTags(tags ...string) TableModelField { + f.Tags = append(f.Tags, tags...) + return f +} + +// TagsString returns tags string representation +func (f TableModelField) TagsString() string { + if len(f.Tags) == 0 { + return "" + } + + return fmt.Sprintf("`%s`", strings.Join(f.Tags, " ")) +} + +// Type represents type of the struct field +type Type struct { + ImportPath string + Name string +} + +// NewType creates new type for dummy object +func NewType(dummyObject interface{}) Type { + return Type{ + ImportPath: getImportPath(dummyObject), + Name: getTypeName(dummyObject), + } +} + +func getTypeName(t interface{}) string { + typeStr := reflect.TypeOf(t).String() + typeStr = strings.Replace(typeStr, "[]uint8", "[]byte", -1) + + return typeStr +} + +func getImportPath(dummyData interface{}) string { + dataType := reflect.TypeOf(dummyData) + if dataType.Kind() == reflect.Ptr { + return dataType.Elem().PkgPath() + } + return dataType.PkgPath() +} + +func getType(columnMetadata metadata.Column) Type { + userDefinedType := getUserDefinedType(columnMetadata) + + if userDefinedType != "" { + if columnMetadata.IsNullable { + return Type{Name: "*" + userDefinedType} + } + return Type{Name: userDefinedType} + } + + return NewType(getGoType(columnMetadata)) +} + +func getUserDefinedType(column metadata.Column) string { + switch column.DataType.Kind { + case metadata.EnumType: + return utils.ToGoIdentifier(column.DataType.Name) + case metadata.UserDefinedType, metadata.ArrayType: + return "string" + } + + return "" +} + +func getGoType(column metadata.Column) interface{} { + defaultGoType := toGoType(column) + + if column.IsNullable { + return reflect.New(reflect.TypeOf(defaultGoType)).Interface() + } + + return defaultGoType +} + +// toGoType returns model type for column info. +func toGoType(column metadata.Column) interface{} { + switch strings.ToLower(column.DataType.Name) { + case "user-defined", "enum": + return "" + case "boolean", "bool": + return false + case "tinyint": + if column.DataType.IsUnsigned { + return uint8(0) + } + return int8(0) + case "smallint", "int2", + "year": + if column.DataType.IsUnsigned { + return uint16(0) + } + return int16(0) + case "integer", "int4", + "mediumint", "int": //MySQL + if column.DataType.IsUnsigned { + return uint32(0) + } + return int32(0) + case "bigint", "int8": + if column.DataType.IsUnsigned { + return uint64(0) + } + return int64(0) + case "date", + "timestamp without time zone", "timestamp", + "timestamp with time zone", "timestamptz", + "time without time zone", "time", + "time with time zone", "timetz", + "datetime": // MySQL + return time.Time{} + case "bytea", + "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL + return []byte("") + case "text", + "character", "bpchar", + "character varying", "varchar", "nvarchar", + "tsvector", "bit", "bit varying", "varbit", + "money", "json", "jsonb", + "xml", "point", "interval", "line", "array", + "char", "tinytext", "mediumtext", "longtext": // MySQL + return "" + case "real", "float4": + return float32(0.0) + case "numeric", "decimal", + "double precision", "float8", "float", + "double": // MySQL + return float64(0.0) + case "uuid": + return uuid.UUID{} + default: + fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.") + return "" + } +} diff --git a/generator/template/model_template_test.go b/generator/template/model_template_test.go new file mode 100644 index 00000000..a7bbe287 --- /dev/null +++ b/generator/template/model_template_test.go @@ -0,0 +1,45 @@ +package template + +import ( + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/stretchr/testify/require" + "testing" +) + +func Test_TableModelField(t *testing.T) { + require.Equal(t, DefaultTableModelField(metadata.Column{ + Name: "col_name", + IsPrimaryKey: true, + IsNullable: true, + DataType: metadata.DataType{ + Name: "smallint", + Kind: "base", + IsUnsigned: true, + }, + }), TableModelField{ + Name: "ColName", + Type: Type{ + ImportPath: "", + Name: "*uint16", + }, + Tags: []string{"sql:\"primary_key\""}, + }) + + require.Equal(t, DefaultTableModelField(metadata.Column{ + Name: "time_column_1", + IsPrimaryKey: false, + IsNullable: true, + DataType: metadata.DataType{ + Name: "timestamp with time zone", + Kind: "base", + IsUnsigned: false, + }, + }), TableModelField{ + Name: "TimeColumn1", + Type: Type{ + ImportPath: "time", + Name: "*time.Time", + }, + Tags: nil, + }) +} diff --git a/generator/template/process.go b/generator/template/process.go new file mode 100644 index 00000000..46a598de --- /dev/null +++ b/generator/template/process.go @@ -0,0 +1,269 @@ +package template + +import ( + "bytes" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/jet" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" + "path" + "strings" + "text/template" +) + +// ProcessSchema will process schema metadata and constructs go files using generator Template +func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) { + if schemaMetaData.IsEmpty() { + return + } + + schemaTemplate := generatorTemplate.Schema(schemaMetaData) + schemaPath := path.Join(dirPath, schemaTemplate.Path) + + fmt.Println("Destination directory:", schemaPath) + fmt.Println("Cleaning up destination directory...") + err := utils.CleanUpGeneratedFiles(schemaPath) + throw.OnError(err) + + processModel(schemaPath, schemaMetaData, schemaTemplate) + processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate) +} + +func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) { + modelTemplate := schemaTemplate.Model + + if modelTemplate.Skip { + fmt.Println("Skipping the generation of model types.") + return + } + + modelDirPath := path.Join(dirPath, modelTemplate.Path) + + err := utils.EnsureDirPath(modelDirPath) + throw.OnError(err) + + processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate) + processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate) + processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate) +} + +func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) { + sqlBuilderTemplate := schemaTemplate.SQLBuilder + + if sqlBuilderTemplate.Skip { + fmt.Println("Skipping the generation of SQL Builder types.") + return + } + + sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path) + + processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate) + processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate) + processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate) +} + +func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) { + if len(enumsMetaData) == 0 { + return + } + + fmt.Printf("Generating enum sql builder files\n") + + for _, enumMetaData := range enumsMetaData { + enumTemplate := sqlBuilder.Enum(enumMetaData) + + if enumTemplate.Skip { + continue + } + + enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path) + + err := utils.EnsureDirPath(enumSQLBuilderPath) + throw.OnError(err) + + text, err := generateTemplate( + autoGenWarningTemplate+enumSQLBuilderTemplate, + enumMetaData, + template.FuncMap{ + "package": func() string { + return enumTemplate.PackageName() + }, + "dialect": func() jet.Dialect { + return dialect + }, + "enumTemplate": func() EnumSQLBuilder { + return enumTemplate + }, + "enumValueName": func(enumValue string) string { + return enumTemplate.ValueName(enumValue) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text) + throw.OnError(err) + } +} + +func processTableSQLBuilder(fileTypes, dirPath string, + dialect jet.Dialect, + schemaMetaData metadata.Schema, + tablesMetaData []metadata.Table, + sqlBuilderTemplate SQLBuilder) { + + if len(tablesMetaData) == 0 { + return + } + + fmt.Printf("Generating %s sql builder files\n", fileTypes) + + for _, tableMetaData := range tablesMetaData { + + var tableSQLBuilderTemplate TableSQLBuilder + + if fileTypes == "view" { + tableSQLBuilderTemplate = sqlBuilderTemplate.View(tableMetaData) + } else { + tableSQLBuilderTemplate = sqlBuilderTemplate.Table(tableMetaData) + } + + if tableSQLBuilderTemplate.Skip { + continue + } + + tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilderTemplate.Path) + + err := utils.EnsureDirPath(tableSQLBuilderPath) + throw.OnError(err) + + text, err := generateTemplate( + autoGenWarningTemplate+getTableSQLBuilderTemplate(dialect), + tableMetaData, + template.FuncMap{ + "package": func() string { + return tableSQLBuilderTemplate.PackageName() + }, + "dialect": func() jet.Dialect { + return dialect + }, + "schemaName": func() string { + return schemaMetaData.Name + }, + "tableTemplate": func() TableSQLBuilder { + return tableSQLBuilderTemplate + }, + "structImplName": func() string { // postgres only + structName := tableSQLBuilderTemplate.TypeName + return string(strings.ToLower(structName)[0]) + structName[1:] + }, + "columnField": func(columnMetaData metadata.Column) TableSQLBuilderColumn { + return tableSQLBuilderTemplate.Column(columnMetaData) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilderTemplate.FileName, text) + throw.OnError(err) + } +} + +func getTableSQLBuilderTemplate(dialect jet.Dialect) string { + if dialect.Name() == "PostgreSQL" || dialect.Name() == "SQLite" { + return tableSQLBuilderTemplateWithEXCLUDED + } + + return tableSQLBuilderTemplate +} + +func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) { + if len(tablesMetaData) == 0 { + return + } + fmt.Printf("Generating %s model files...\n", fileTypes) + + for _, tableMetaData := range tablesMetaData { + var tableTemplate TableModel + + if fileTypes == "table" { + tableTemplate = modelTemplate.Table(tableMetaData) + } else { + tableTemplate = modelTemplate.View(tableMetaData) + } + + if tableTemplate.Skip { + continue + } + + text, err := generateTemplate( + autoGenWarningTemplate+tableModelFileTemplate, + tableMetaData, + template.FuncMap{ + "package": func() string { + return modelTemplate.PackageName() + }, + "modelImports": func() []string { + return getTableModelImports(tableTemplate, tableMetaData) + }, + "tableTemplate": func() TableModel { + return tableTemplate + }, + "structField": func(columnMetaData metadata.Column) TableModelField { + return tableTemplate.Field(columnMetaData) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text) + throw.OnError(err) + } +} + +func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) { + if len(enumsMetaData) == 0 { + return + } + fmt.Print("Generating enum model files...\n") + + for _, enumMetaData := range enumsMetaData { + enumTemplate := modelTemplate.Enum(enumMetaData) + + if enumTemplate.Skip { + continue + } + + text, err := generateTemplate( + autoGenWarningTemplate+enumModelTemplate, + enumMetaData, + template.FuncMap{ + "package": func() string { + return modelTemplate.PackageName() + }, + "enumTemplate": func() EnumModel { + return enumTemplate + }, + "valueName": func(value string) string { + return enumTemplate.ValueName(value) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text) + throw.OnError(err) + } +} + +func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) { + t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText) + + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := t.Execute(&buf, templateData); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go new file mode 100644 index 00000000..099c0e3d --- /dev/null +++ b/generator/template/sql_builder_template.go @@ -0,0 +1,226 @@ +package template + +import ( + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils" + "path" + "strings" + "unicode" +) + +// SQLBuilder is template for generating sql builder files +type SQLBuilder struct { + Skip bool + Path string + Table func(table metadata.Table) TableSQLBuilder + View func(view metadata.Table) TableSQLBuilder + Enum func(enum metadata.Enum) EnumSQLBuilder +} + +// DefaultSQLBuilder returns default SQLBuilder implementation +func DefaultSQLBuilder() SQLBuilder { + return SQLBuilder{ + Path: "", + Table: DefaultTableSQLBuilder, + View: DefaultViewSQLBuilder, + Enum: DefaultEnumSQLBuilder, + } +} + +// UsePath returns new SQLBuilder with new relative path set +func (sb SQLBuilder) UsePath(path string) SQLBuilder { + sb.Path = path + return sb +} + +// UseTable returns new SQLBuilder with new TableSQLBuilder template function set +func (sb SQLBuilder) UseTable(tableFunc func(table metadata.Table) TableSQLBuilder) SQLBuilder { + sb.Table = tableFunc + return sb +} + +// UseView returns new SQLBuilder with new ViewSQLBuilder template function set +func (sb SQLBuilder) UseView(viewFunc func(table metadata.Table) ViewSQLBuilder) SQLBuilder { + sb.View = viewFunc + return sb +} + +// UseEnum returns new SQLBuilder with new EnumSQLBuilder template function set +func (sb SQLBuilder) UseEnum(enumFunc func(enum metadata.Enum) EnumSQLBuilder) SQLBuilder { + sb.Enum = enumFunc + return sb +} + +// TableSQLBuilder is template for generating table SQLBuilder files +type TableSQLBuilder struct { + Skip bool + Path string + FileName string + InstanceName string + TypeName string + Column func(columnMetaData metadata.Column) TableSQLBuilderColumn +} + +// ViewSQLBuilder is template for generating view SQLBuilder files +type ViewSQLBuilder = TableSQLBuilder + +// DefaultTableSQLBuilder returns default implementation for TableSQLBuilder +func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder { + return TableSQLBuilder{ + Path: "/table", + FileName: utils.ToGoFileName(tableMetaData.Name), + InstanceName: utils.ToGoIdentifier(tableMetaData.Name), + TypeName: utils.ToGoIdentifier(tableMetaData.Name) + "Table", + Column: DefaultTableSQLBuilderColumn, + } +} + +// DefaultViewSQLBuilder returns default implementation for ViewSQLBuilder +func DefaultViewSQLBuilder(viewMetaData metadata.Table) ViewSQLBuilder { + tableSQLBuilder := DefaultTableSQLBuilder(viewMetaData) + tableSQLBuilder.Path = "/view" + return tableSQLBuilder +} + +// PackageName returns package name of table sql builder types +func (tb TableSQLBuilder) PackageName() string { + return path.Base(tb.Path) +} + +// UsePath returns new TableSQLBuilder with new relative path set +func (tb TableSQLBuilder) UsePath(path string) TableSQLBuilder { + tb.Path = path + return tb +} + +// UseFileName returns new TableSQLBuilder with new file name set +func (tb TableSQLBuilder) UseFileName(name string) TableSQLBuilder { + tb.FileName = name + return tb +} + +// UseInstanceName returns new TableSQLBuilder with new instance name set +func (tb TableSQLBuilder) UseInstanceName(name string) TableSQLBuilder { + tb.InstanceName = name + return tb +} + +// UseTypeName returns new TableSQLBuilder with new type name set +func (tb TableSQLBuilder) UseTypeName(name string) TableSQLBuilder { + tb.TypeName = name + return tb +} + +// UseColumn returns new TableSQLBuilder with new column template function set +func (tb TableSQLBuilder) UseColumn(columnsFunc func(column metadata.Column) TableSQLBuilderColumn) TableSQLBuilder { + tb.Column = columnsFunc + return tb +} + +// TableSQLBuilderColumn is template for table sql builder column +type TableSQLBuilderColumn struct { + Name string + Type string +} + +// DefaultTableSQLBuilderColumn returns default implementation of TableSQLBuilderColumn +func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilderColumn { + return TableSQLBuilderColumn{ + Name: utils.ToGoIdentifier(columnMetaData.Name), + Type: getSqlBuilderColumnType(columnMetaData), + } +} + +// getSqlBuilderColumnType returns type of jet sql builder column +func getSqlBuilderColumnType(columnMetaData metadata.Column) string { + if columnMetaData.DataType.Kind != metadata.BaseType { + return "String" + } + + switch strings.ToLower(columnMetaData.DataType.Name) { + case "boolean": + return "Bool" + case "smallint", "integer", "bigint", + "tinyint", "mediumint", "int", "year": //MySQL + return "Integer" + case "date": + return "Date" + case "timestamp without time zone", + "timestamp", "datetime": //MySQL: + return "Timestamp" + case "timestamp with time zone": + return "Timestampz" + case "time without time zone", + "time": //MySQL + return "Time" + case "time with time zone": + return "Timez" + case "interval": + return "Interval" + case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", + "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", + "char", "varchar", "nvarchar", "binary", "varbinary", + "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL + return "String" + case "real", "numeric", "decimal", "double precision", "float", + "double": // MySQL + return "Float" + default: + fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.") + return "String" + } +} + +// EnumSQLBuilder is template for generating enum SQLBuilder files +type EnumSQLBuilder struct { + Skip bool + Path string + FileName string + InstanceName string + ValueName func(enumValue string) string +} + +// DefaultEnumSQLBuilder returns default implementation of EnumSQLBuilder +func DefaultEnumSQLBuilder(enumMetaData metadata.Enum) EnumSQLBuilder { + return EnumSQLBuilder{ + Path: "/enum", + FileName: utils.ToGoFileName(enumMetaData.Name), + InstanceName: utils.ToGoIdentifier(enumMetaData.Name), + ValueName: func(enumValue string) string { + return defaultEnumValueName(enumMetaData.Name, enumValue) + }, + } +} + +// PackageName returns enum sql builder package name +func (e EnumSQLBuilder) PackageName() string { + return path.Base(e.Path) +} + +// UsePath returns new EnumSQLBuilder with new path set +func (e EnumSQLBuilder) UsePath(path string) EnumSQLBuilder { + e.Path = path + return e +} + +// UseFileName returns new EnumSQLBuilder with new file name set +func (e EnumSQLBuilder) UseFileName(name string) EnumSQLBuilder { + e.FileName = name + return e +} + +// UseInstanceName returns new EnumSQLBuilder with instance name set +func (e EnumSQLBuilder) UseInstanceName(name string) EnumSQLBuilder { + e.InstanceName = name + return e +} + +func defaultEnumValueName(enumName, enumValue string) string { + enumValueName := utils.ToGoIdentifier(enumValue) + if !unicode.IsLetter([]rune(enumValueName)[0]) { + return utils.ToGoIdentifier(enumName) + enumValueName + } + + return enumValueName +} diff --git a/generator/template/sql_builder_template_test.go b/generator/template/sql_builder_template_test.go new file mode 100644 index 00000000..b3719d7a --- /dev/null +++ b/generator/template/sql_builder_template_test.go @@ -0,0 +1,11 @@ +package template + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestToGoEnumValueIdentifier(t *testing.T) { + require.Equal(t, defaultEnumValueName("enum_name", "enum_value"), "EnumValue") + require.Equal(t, defaultEnumValueName("NumEnum", "100"), "NumEnum100") +} diff --git a/go.mod b/go.mod index 9dd3e02a..12ed6d7b 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,10 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/google/go-cmp v0.5.0 //tests github.com/google/uuid v1.1.1 + github.com/jackc/pgconn v1.8.1 github.com/jackc/pgx/v4 v4.11.0 //tests github.com/lib/pq v1.7.0 + github.com/mattn/go-sqlite3 v1.14.8 github.com/pkg/profile v1.5.0 //tests github.com/shopspring/decimal v1.2.0 // tests github.com/stretchr/testify v1.6.1 // tests diff --git a/go.sum b/go.sum index 47af4791..d0f5f98b 100644 --- a/go.sum +++ b/go.sum @@ -42,7 +42,6 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7 github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -219,6 +218,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= +github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -457,7 +458,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/3rdparty/snaker/snaker.go b/internal/3rdparty/snaker/snaker.go index aadd9286..32a19e6c 100644 --- a/internal/3rdparty/snaker/snaker.go +++ b/internal/3rdparty/snaker/snaker.go @@ -9,8 +9,12 @@ import ( ) // SnakeToCamel returns a string converted from snake case to uppercase -func SnakeToCamel(s string) string { - return snakeToCamel(s, true) +func SnakeToCamel(s string, firstLetterUppercase ...bool) string { + upperCase := true + if len(firstLetterUppercase) > 0 { + upperCase = firstLetterUppercase[0] + } + return snakeToCamel(s, upperCase) } func snakeToCamel(s string, upperCase bool) string { diff --git a/internal/jet/clause.go b/internal/jet/clause.go index a6a49d85..446a5451 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -217,12 +217,13 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, opti // ClauseSetStmtOperator struct type ClauseSetStmtOperator struct { - Operator string - All bool - Selects []SerializerStatement - OrderBy ClauseOrderBy - Limit ClauseLimit - Offset ClauseOffset + Operator string + All bool + Selects []SerializerStatement + OrderBy ClauseOrderBy + Limit ClauseLimit + Offset ClauseOffset + SkipSelectWrap bool } // Projections returns set of projections for ClauseSetStmtOperator @@ -242,6 +243,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB for i, selectStmt := range s.Selects { out.NewLine() if i > 0 { + if s.SkipSelectWrap { + out.NewLine() + } + out.WriteString(s.Operator) if s.All { @@ -254,7 +259,11 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB panic("jet: select statement of '" + s.Operator + "' is nil") } - selectStmt.serialize(statementType, out, FallTrough(options)...) + if s.SkipSelectWrap { + options = append(FallTrough(options), NoWrap) + } + + selectStmt.serialize(statementType, out, options...) } s.OrderBy.Serialize(statementType, out) @@ -360,10 +369,6 @@ type ClauseValuesQuery struct { // Serialize serializes clause into SQLBuilder func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - if len(v.Rows) == 0 && v.Query == nil { - panic("jet: VALUES or QUERY has to be specified for INSERT statement") - } - if len(v.Rows) > 0 && v.Query != nil { panic("jet: VALUES or QUERY has to be specified for INSERT statement") } @@ -405,7 +410,8 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, o // ClauseQuery struct type ClauseQuery struct { - Query SerializerStatement + Query SerializerStatement + SkipSelectWrap bool } // Serialize serializes clause into SQLBuilder @@ -414,7 +420,11 @@ func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, op return } - v.Query.serialize(statementType, out, FallTrough(options)...) + if v.SkipSelectWrap { + options = append(FallTrough(options), NoWrap) + } + + v.Query.serialize(statementType, out, options...) } // ClauseDelete struct @@ -561,3 +571,26 @@ type KeywordClause struct { func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { k.serialize(statementType, out, FallTrough(options)...) } + +// ClauseReturning type +type ClauseReturning struct { + ProjectionList []Projection +} + +// Serialize for ClauseReturning +func (r *ClauseReturning) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(r.ProjectionList) == 0 { + return + } + + out.NewLine() + out.WriteString("RETURNING") + out.IncreaseIdent() + out.WriteProjections(statementType, r.ProjectionList) + out.DecreaseIdent() +} + +// Projections for ClauseReturning +func (r ClauseReturning) Projections() ProjectionList { + return r.ProjectionList +} diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go index 8483c76b..3ff829c5 100644 --- a/internal/jet/column_list.go +++ b/internal/jet/column_list.go @@ -11,6 +11,28 @@ func (cl ColumnList) SET(expression Expression) ColumnAssigment { } } +// Except will create new column list in which columns contained in excluded column names are removed +func (cl ColumnList) Except(excludedColumns ...Column) ColumnList { + excludedColumnList := UnwidColumnList(excludedColumns) + excludedColumnNames := map[string]bool{} + + for _, excludedColumn := range excludedColumnList { + excludedColumnNames[excludedColumn.Name()] = true + } + + var ret ColumnList + + for _, column := range cl { + if excludedColumnNames[column.Name()] { + continue + } + + ret = append(ret, column) + } + + return ret +} + func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { newProjectionList := ProjectionList{} diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 606e7e1c..9a647e96 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -2,7 +2,7 @@ package jet // ROW is construct one table row from list of expressions. func ROW(expressions ...Expression) Expression { - return newFunc("ROW", expressions, nil) + return NewFunc("ROW", expressions, nil) } // ------------------ Mathematical functions ---------------// @@ -265,118 +265,118 @@ func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression { // LOWER returns string expression in lower case func LOWER(stringExpression StringExpression) StringExpression { - return newStringFunc("LOWER", stringExpression) + return NewStringFunc("LOWER", stringExpression) } // UPPER returns string expression in upper case func UPPER(stringExpression StringExpression) StringExpression { - return newStringFunc("UPPER", stringExpression) + return NewStringFunc("UPPER", stringExpression) } // BTRIM removes the longest string consisting only of characters // in characters (a space by default) from the start and end of string func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { if len(trimChars) > 0 { - return newStringFunc("BTRIM", stringExpression, trimChars[0]) + return NewStringFunc("BTRIM", stringExpression, trimChars[0]) } - return newStringFunc("BTRIM", stringExpression) + return NewStringFunc("BTRIM", stringExpression) } // LTRIM removes the longest string containing only characters // from characters (a space by default) from the start of string func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { if len(trimChars) > 0 { - return newStringFunc("LTRIM", str, trimChars[0]) + return NewStringFunc("LTRIM", str, trimChars[0]) } - return newStringFunc("LTRIM", str) + return NewStringFunc("LTRIM", str) } // RTRIM removes the longest string containing only characters // from characters (a space by default) from the end of string func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { if len(trimChars) > 0 { - return newStringFunc("RTRIM", str, trimChars[0]) + return NewStringFunc("RTRIM", str, trimChars[0]) } - return newStringFunc("RTRIM", str) + return NewStringFunc("RTRIM", str) } // CHR returns character with the given code. func CHR(integerExpression IntegerExpression) StringExpression { - return newStringFunc("CHR", integerExpression) + return NewStringFunc("CHR", integerExpression) } // CONCAT adds two or more expressions together func CONCAT(expressions ...Expression) StringExpression { - return newStringFunc("CONCAT", expressions...) + return NewStringFunc("CONCAT", expressions...) } // CONCAT_WS adds two or more expressions together with a separator. func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { - return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...) + return NewStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...) } // CONVERT converts string to dest_encoding. The original encoding is // specified by src_encoding. The string must be valid in this encoding. func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { - return newStringFunc("CONVERT", str, srcEncoding, destEncoding) + return NewStringFunc("CONVERT", str, srcEncoding, destEncoding) } // CONVERT_FROM converts string to the database encoding. The original // encoding is specified by src_encoding. The string must be valid in this encoding. func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { - return newStringFunc("CONVERT_FROM", str, srcEncoding) + return NewStringFunc("CONVERT_FROM", str, srcEncoding) } // CONVERT_TO converts string to dest_encoding. func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { - return newStringFunc("CONVERT_TO", str, toEncoding) + return NewStringFunc("CONVERT_TO", str, toEncoding) } // ENCODE encodes binary data into a textual representation. // Supported formats are: base64, hex, escape. escape converts zero bytes and // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. func ENCODE(data StringExpression, format StringExpression) StringExpression { - return newStringFunc("ENCODE", data, format) + return NewStringFunc("ENCODE", data, format) } // DECODE decodes binary data from textual representation in string. // Options for format are same as in encode. func DECODE(data StringExpression, format StringExpression) StringExpression { - return newStringFunc("DECODE", data, format) + return NewStringFunc("DECODE", data, format) } // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { args := []Expression{formatStr} args = append(args, formatArgs...) - return newStringFunc("FORMAT", args...) + return NewStringFunc("FORMAT", args...) } // INITCAP converts the first letter of each word to upper case // and the rest to lower case. Words are sequences of alphanumeric // characters separated by non-alphanumeric characters. func INITCAP(str StringExpression) StringExpression { - return newStringFunc("INITCAP", str) + return NewStringFunc("INITCAP", str) } // LEFT returns first n characters in the string. // When n is negative, return all but last |n| characters. func LEFT(str StringExpression, n IntegerExpression) StringExpression { - return newStringFunc("LEFT", str, n) + return NewStringFunc("LEFT", str, n) } // RIGHT returns last n characters in the string. // When n is negative, return all but first |n| characters. func RIGHT(str StringExpression, n IntegerExpression) StringExpression { - return newStringFunc("RIGHT", str, n) + return NewStringFunc("RIGHT", str, n) } // LENGTH returns number of characters in string with a given encoding func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { if len(encoding) > 0 { - return newStringFunc("LENGTH", str, encoding[0]) + return NewStringFunc("LENGTH", str, encoding[0]) } - return newStringFunc("LENGTH", str) + return NewStringFunc("LENGTH", str) } // LPAD fills up the string to length length by prepending the characters @@ -384,40 +384,40 @@ func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression // then it is truncated (on the right). func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { if len(text) > 0 { - return newStringFunc("LPAD", str, length, text[0]) + return NewStringFunc("LPAD", str, length, text[0]) } - return newStringFunc("LPAD", str, length) + return NewStringFunc("LPAD", str, length) } // RPAD fills up the string to length length by appending the characters // fill (a space by default). If the string is already longer than length then it is truncated. func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { if len(text) > 0 { - return newStringFunc("RPAD", str, length, text[0]) + return NewStringFunc("RPAD", str, length, text[0]) } - return newStringFunc("RPAD", str, length) + return NewStringFunc("RPAD", str, length) } // MD5 calculates the MD5 hash of string, returning the result in hexadecimal func MD5(stringExpression StringExpression) StringExpression { - return newStringFunc("MD5", stringExpression) + return NewStringFunc("MD5", stringExpression) } // REPEAT repeats string the specified number of times func REPEAT(str StringExpression, n IntegerExpression) StringExpression { - return newStringFunc("REPEAT", str, n) + return NewStringFunc("REPEAT", str, n) } // REPLACE replaces all occurrences in string of substring from with substring to func REPLACE(text, from, to StringExpression) StringExpression { - return newStringFunc("REPLACE", text, from, to) + return NewStringFunc("REPLACE", text, from, to) } // REVERSE returns reversed string. func REVERSE(stringExpression StringExpression) StringExpression { - return newStringFunc("REVERSE", stringExpression) + return NewStringFunc("REVERSE", stringExpression) } // STRPOS returns location of specified substring (same as position(substring in string), @@ -429,22 +429,22 @@ func STRPOS(str, substring StringExpression) IntegerExpression { // SUBSTR extracts substring func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { if len(count) > 0 { - return newStringFunc("SUBSTR", str, from, count[0]) + return NewStringFunc("SUBSTR", str, from, count[0]) } - return newStringFunc("SUBSTR", str, from) + return NewStringFunc("SUBSTR", str, from) } // TO_ASCII convert string to ASCII from another encoding func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression { if len(encoding) > 0 { - return newStringFunc("TO_ASCII", str, encoding[0]) + return NewStringFunc("TO_ASCII", str, encoding[0]) } - return newStringFunc("TO_ASCII", str) + return NewStringFunc("TO_ASCII", str) } // TO_HEX converts number to its equivalent hexadecimal representation func TO_HEX(number IntegerExpression) StringExpression { - return newStringFunc("TO_HEX", number) + return NewStringFunc("TO_HEX", number) } // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. @@ -460,12 +460,12 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType // TO_CHAR converts expression to string with format func TO_CHAR(expression Expression, format StringExpression) StringExpression { - return newStringFunc("TO_CHAR", expression, format) + return NewStringFunc("TO_CHAR", expression, format) } // TO_DATE converts string to date using format func TO_DATE(dateStr, format StringExpression) DateExpression { - return newDateFunc("TO_DATE", dateStr, format) + return NewDateFunc("TO_DATE", dateStr, format) } // TO_NUMBER converts string to numeric using format @@ -482,7 +482,7 @@ func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression { // CURRENT_DATE returns current date func CURRENT_DATE() DateExpression { - dateFunc := newDateFunc("CURRENT_DATE") + dateFunc := NewDateFunc("CURRENT_DATE") dateFunc.noBrackets = true return dateFunc } @@ -522,9 +522,9 @@ func LOCALTIME(precision ...int) TimeExpression { var timeFunc *timeFunc if len(precision) > 0 { - timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0])) + timeFunc = NewTimeFunc("LOCALTIME", FixedLiteral(precision[0])) } else { - timeFunc = newTimeFunc("LOCALTIME") + timeFunc = NewTimeFunc("LOCALTIME") } timeFunc.noBrackets = true @@ -558,26 +558,26 @@ func NOW() TimestampzExpression { func COALESCE(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return newFunc("COALESCE", allValues, nil) + return NewFunc("COALESCE", allValues, nil) } // NULLIF function returns a null value if value1 equals value2; otherwise it returns value1. func NULLIF(value1, value2 Expression) Expression { - return newFunc("NULLIF", []Expression{value1, value2}, nil) + return NewFunc("NULLIF", []Expression{value1, value2}, nil) } // GREATEST selects the largest value from a list of expressions func GREATEST(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return newFunc("GREATEST", allValues, nil) + return NewFunc("GREATEST", allValues, nil) } // LEAST selects the smallest value from a list of expressions func LEAST(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return newFunc("LEAST", allValues, nil) + return NewFunc("LEAST", allValues, nil) } //--------------------------------------------------------------------// @@ -590,7 +590,8 @@ type funcExpressionImpl struct { noBrackets bool } -func newFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { +// NewFunc creates new function with name and expressions parameters +func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ name: name, expressions: expressions, @@ -608,7 +609,7 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr // NewFloatWindowFunc creates new float function with name and expressions func newWindowFunc(name string, expressions ...Expression) windowExpression { - newFun := newFunc(name, expressions, nil) + newFun := NewFunc(name, expressions, nil) windowExpr := newWindowExpression(newFun) newFun.ExpressionInterfaceImpl.Parent = windowExpr @@ -645,7 +646,7 @@ type boolFunc struct { func newBoolFunc(name string, expressions ...Expression) BoolExpression { boolFunc := &boolFunc{} - boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc) boolFunc.boolInterfaceImpl.parent = boolFunc boolFunc.ExpressionInterfaceImpl.Parent = boolFunc @@ -656,7 +657,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression { func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { boolFunc := &boolFunc{} - boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc) boolFunc.boolInterfaceImpl.parent = intWindowFunc boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc @@ -673,7 +674,7 @@ type floatFunc struct { func NewFloatFunc(name string, expressions ...Expression) FloatExpression { floatFunc := &floatFunc{} - floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) floatFunc.floatInterfaceImpl.parent = floatFunc return floatFunc @@ -683,7 +684,7 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression { func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { floatFunc := &floatFunc{} - floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc) floatFunc.floatInterfaceImpl.parent = floatWindowFunc floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc @@ -699,7 +700,7 @@ type integerFunc struct { func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { floatFunc := &integerFunc{} - floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) floatFunc.integerInterfaceImpl.parent = floatFunc return floatFunc @@ -709,7 +710,7 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { integerFunc := &integerFunc{} - integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) + integerFunc.funcExpressionImpl = *NewFunc(name, expressions, integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc) integerFunc.integerInterfaceImpl.parent = intWindowFunc integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc @@ -722,10 +723,11 @@ type stringFunc struct { stringInterfaceImpl } -func newStringFunc(name string, expressions ...Expression) StringExpression { +// NewStringFunc creates new string function with name and expression parameters +func NewStringFunc(name string, expressions ...Expression) StringExpression { stringFunc := &stringFunc{} - stringFunc.funcExpressionImpl = *newFunc(name, expressions, stringFunc) + stringFunc.funcExpressionImpl = *NewFunc(name, expressions, stringFunc) stringFunc.stringInterfaceImpl.parent = stringFunc return stringFunc @@ -736,10 +738,11 @@ type dateFunc struct { dateInterfaceImpl } -func newDateFunc(name string, expressions ...Expression) *dateFunc { +// NewDateFunc creates new date function with name and expression parameters +func NewDateFunc(name string, expressions ...Expression) *dateFunc { dateFunc := &dateFunc{} - dateFunc.funcExpressionImpl = *newFunc(name, expressions, dateFunc) + dateFunc.funcExpressionImpl = *NewFunc(name, expressions, dateFunc) dateFunc.dateInterfaceImpl.parent = dateFunc return dateFunc @@ -750,10 +753,11 @@ type timeFunc struct { timeInterfaceImpl } -func newTimeFunc(name string, expressions ...Expression) *timeFunc { +// NewTimeFunc creates new time function with name and expression parameters +func NewTimeFunc(name string, expressions ...Expression) *timeFunc { timeFun := &timeFunc{} - timeFun.funcExpressionImpl = *newFunc(name, expressions, timeFun) + timeFun.funcExpressionImpl = *NewFunc(name, expressions, timeFun) timeFun.timeInterfaceImpl.parent = timeFun return timeFun @@ -767,7 +771,7 @@ type timezFunc struct { func newTimezFunc(name string, expressions ...Expression) *timezFunc { timezFun := &timezFunc{} - timezFun.funcExpressionImpl = *newFunc(name, expressions, timezFun) + timezFun.funcExpressionImpl = *NewFunc(name, expressions, timezFun) timezFun.timezInterfaceImpl.parent = timezFun return timezFun @@ -782,7 +786,7 @@ type timestampFunc struct { func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc { timestampFunc := ×tampFunc{} - timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) + timestampFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampFunc) timestampFunc.timestampInterfaceImpl.parent = timestampFunc return timestampFunc @@ -796,7 +800,7 @@ type timestampzFunc struct { func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { timestampzFunc := ×tampzFunc{} - timestampzFunc.funcExpressionImpl = *newFunc(name, expressions, timestampzFunc) + timestampzFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampzFunc) timestampzFunc.timestampzInterfaceImpl.parent = timestampzFunc return timestampzFunc @@ -804,5 +808,5 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { // Func can be used to call an custom or as of yet unsupported function in the database. func Func(name string, expressions ...Expression) Expression { - return newFunc(name, expressions, nil) + return NewFunc(name, expressions, nil) } diff --git a/internal/jet/interval.go b/internal/jet/interval.go index 5b371e15..debcb57a 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -19,7 +19,7 @@ func (i *IsIntervalImpl) isInterval() {} // NewInterval creates new interval from serializer func NewInterval(s Serializer) *IntervalImpl { newInterval := &IntervalImpl{ - interval: s, + Value: s, } return newInterval @@ -27,11 +27,11 @@ func NewInterval(s Serializer) *IntervalImpl { // IntervalImpl is implementation of Interval type type IntervalImpl struct { - interval Serializer + Value Serializer IsIntervalImpl } func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("INTERVAL") - i.interval.serialize(statement, out, FallTrough(options)...) + i.Value.serialize(statement, out, FallTrough(options)...) } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index d7cf47a9..450b0abc 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -375,9 +375,14 @@ type wrap struct { expressions []Expression } -func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("(") - serializeExpressionList(statement, n.expressions, ", ", out) + + if len(n.expressions) == 1 { + options = append(options, NoWrap, Ident) + } + serializeExpressionList(statementType, n.expressions, ", ", out, options...) + out.WriteString(")") } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index b8cf04a0..866d60e9 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -7,8 +7,10 @@ type SerializeOption int const ( NoWrap SerializeOption = iota SkipNewLine + Ident fallTroughOptions // fall trough options + ShortName ) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index da3650db..1d050459 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -195,10 +195,19 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti out.IncreaseIdent() } + if contains(options, Ident) { + out.IncreaseIdent() + } + for _, clause := range s.Clauses { clause.Serialize(statement, out, FallTrough(options)...) } + if contains(options, Ident) { + out.DecreaseIdent() + out.NewLine() + } + if !contains(options, NoWrap) { out.DecreaseIdent() out.NewLine() diff --git a/internal/jet/utils.go b/internal/jet/utils.go index b2fff487..eab44030 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -21,14 +21,19 @@ func SerializeClauseList(statement StatementType, clauses []Serializer, out *SQL } } -func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SQLBuilder) { - - for i, value := range expressions { +func serializeExpressionList( + statement StatementType, + expressions []Expression, + separator string, + out *SQLBuilder, + options ...SerializeOption) { + + for i, expression := range expressions { if i > 0 { out.WriteString(separator) } - value.serialize(statement, out) + expression.serialize(statement, out, options...) } } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index dd5e7906..c1419aa0 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -20,6 +20,11 @@ import ( "github.com/google/go-cmp/cmp" ) +// UnixTimeComparer will compare time equality while ignoring time zone +var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool { + return t1.Unix() == t2.Unix() +}) + // AssertExec assert statement execution for successful execution and number of rows affected func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) @@ -66,7 +71,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) err := ioutil.WriteFile(filePath, jsonText, 0644) - utils.PanicOnError(err) + throw.OnError(err) } // AssertJSONFile check if data json representation is the same as json at testRelativePath @@ -113,7 +118,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st _, args := query.Sql() if len(expectedArgs) > 0 { - AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") + AssertDeepEqual(t, args, expectedArgs) } debugSql := query.DebugSql() @@ -223,9 +228,9 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } // AssertDeepEqual checks if actual and expected objects are deeply equal. -func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { - if !assert.True(t, cmp.Equal(actual, expected), msg) { - printDiff(actual, expected) +func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) { + if !assert.True(t, cmp.Equal(actual, expected, option...)) { + printDiff(actual, expected, option...) t.FailNow() } } @@ -237,7 +242,8 @@ func assertQueryString(t *testing.T, actual, expected string) { } } -func printDiff(actual, expected interface{}) { +func printDiff(actual, expected interface{}, options ...cmp.Option) { + fmt.Println(cmp.Diff(actual, expected, options...)) fmt.Println("Actual: ") fmt.Println(actual) fmt.Println("Expected: ") diff --git a/internal/testutils/time_utils.go b/internal/testutils/time_utils.go index 2bf66530..b48129cc 100644 --- a/internal/testutils/time_utils.go +++ b/internal/testutils/time_utils.go @@ -1,7 +1,7 @@ package testutils import ( - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "strings" "time" ) @@ -10,7 +10,7 @@ import ( func Date(t string) *time.Time { newTime, err := time.Parse("2006-01-02", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTime } @@ -26,7 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time { newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") - utils.PanicOnError(err) + throw.OnError(err) return &newTime } @@ -35,7 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time { func TimeWithoutTimeZone(t string) *time.Time { newTime, err := time.Parse("15:04:05", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTime } @@ -44,7 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time { func TimeWithTimeZone(t string) *time.Time { newTimez, err := time.Parse("15:04:05 -0700", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTimez } @@ -60,7 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time { newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTime } diff --git a/internal/utils/min/min.go b/internal/utils/min/min.go new file mode 100644 index 00000000..0e92146e --- /dev/null +++ b/internal/utils/min/min.go @@ -0,0 +1,9 @@ +package min + +// Int returns minimum of two int values +func Int(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/utils/throw/throw.go b/internal/utils/throw/throw.go new file mode 100644 index 00000000..9595c8b1 --- /dev/null +++ b/internal/utils/throw/throw.go @@ -0,0 +1,8 @@ +package throw + +// OnError will panic if err is not nill +func OnError(err error) { + if err != nil { + panic(err) + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 55005b4d..6f6f1782 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -10,7 +10,6 @@ import ( "reflect" "strings" "time" - "unicode" ) // ToGoIdentifier converts database to Go identifier. @@ -18,16 +17,6 @@ func ToGoIdentifier(databaseIdentifier string) string { return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) } -// ToGoEnumValueIdentifier converts enum value name to Go identifier name. -func ToGoEnumValueIdentifier(enumName, enumValue string) string { - enumValueIdentifier := ToGoIdentifier(enumValue) - if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) { - return ToGoIdentifier(enumName) + enumValueIdentifier - } - - return enumValueIdentifier -} - // ToGoFileName converts database identifier to Go file name. func ToGoFileName(databaseIdentifier string) string { return strings.ToLower(replaceInvalidChars(databaseIdentifier)) @@ -35,7 +24,11 @@ func ToGoFileName(databaseIdentifier string) string { // SaveGoFile saves go file at folder dir, with name fileName and contents text. func SaveGoFile(dirPath, fileName string, text []byte) error { - newGoFilePath := filepath.Join(dirPath, fileName) + ".go" + newGoFilePath := filepath.Join(dirPath, fileName) + + if !strings.HasSuffix(newGoFilePath, ".go") { + newGoFilePath += ".go" + } file, err := os.Create(newGoFilePath) @@ -160,13 +153,6 @@ func MustBeInitializedPtr(val interface{}, errorStr string) { } } -// PanicOnError panics if err is not nil -func PanicOnError(err error) { - if err != nil { - panic(err) - } -} - // ErrorCatch is used in defer to recover from panics and to set err func ErrorCatch(err *error) { recovered := recover() diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index f2b4f84e..f374929e 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -25,11 +25,6 @@ func TestToGoIdentifier(t *testing.T) { require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") } -func TestToGoEnumValueIdentifier(t *testing.T) { - require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue") - require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100") -} - func TestErrorCatchErr(t *testing.T) { var err error diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index dbabc3f2..7b396d08 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -7,7 +7,6 @@ import ( ) func TestInvalidInsert(t *testing.T) { - assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") } diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 8ebab036..ffb8054f 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -43,7 +43,7 @@ type SelectStatement interface { DISTINCT() SelectStatement FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement - GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement + GROUP_BY(groupByClauses ...GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement WINDOW(name string) windowExpand ORDER_BY(orderByClauses ...OrderByClause) SelectStatement @@ -118,7 +118,7 @@ func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement { return s } -func (s *selectStatementImpl) GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement { +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { s.GroupBy.List = groupByClauses return s } diff --git a/mysql/types.go b/mysql/types.go index 8c6608f8..c82962fb 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -20,5 +20,8 @@ type PrintableStatement = jet.PrintableStatement // OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY. type OrderByClause = jet.OrderByClause +// GroupByClause interface to use as input for GROUP_BY +type GroupByClause = jet.GroupByClause + // SetLogger sets automatic statement logging var SetLogger = jet.SetLoggerFunc diff --git a/postgres/clause.go b/postgres/clause.go index 6174d4fd..3a23fd07 100644 --- a/postgres/clause.go +++ b/postgres/clause.go @@ -4,33 +4,10 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -type clauseReturning struct { - ProjectionList []jet.Projection -} - -func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(r.ProjectionList) == 0 { - return - } - - out.NewLine() - out.WriteString("RETURNING") - out.IncreaseIdent() - out.WriteProjections(statementType, r.ProjectionList) - out.DecreaseIdent() -} - -func (r clauseReturning) Projections() ProjectionList { - return r.ProjectionList -} - -// ========================================== // - type onConflict interface { ON_CONSTRAINT(name string) conflictTarget WHERE(indexPredicate BoolExpression) conflictTarget - DO_NOTHING() InsertStatement - DO_UPDATE(action conflictAction) InsertStatement + conflictTarget } type conflictTarget interface { diff --git a/postgres/delete_statement.go b/postgres/delete_statement.go index ca2816cb..2bfbd8c4 100644 --- a/postgres/delete_statement.go +++ b/postgres/delete_statement.go @@ -16,7 +16,7 @@ type deleteStatementImpl struct { Delete jet.ClauseStatementBegin Where jet.ClauseWhere - Returning clauseReturning + Returning jet.ClauseReturning } func newDeleteStatement(table WritableTable) DeleteStatement { diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 9b7b3d16..d98d8f3a 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -46,33 +46,33 @@ func TestExists(t *testing.T) { func TestIN(t *testing.T) { assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), - `($1 IN (( + `($1 IN ( SELECT table1.col1 AS "table1.col1" FROM db.table1 -)))`, float64(1.11)) +))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) IN (( + `(ROW($1, table1.col1) IN ( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -)))`, int64(12)) +))`, int64(12)) } func TestNOT_IN(t *testing.T) { assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), - `($1 NOT IN (( + `($1 NOT IN ( SELECT table1.col1 AS "table1.col1" FROM db.table1 -)))`, float64(1.11)) +))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) NOT IN (( + `(ROW($1, table1.col1) NOT IN ( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -)))`, int64(12)) +))`, int64(12)) } func TestReservedWordEscaped(t *testing.T) { diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go index a134a126..763e533f 100644 --- a/postgres/insert_statement.go +++ b/postgres/insert_statement.go @@ -22,7 +22,11 @@ type InsertStatement interface { func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, - &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning) + &newInsert.Insert, + &newInsert.ValuesQuery, + &newInsert.OnConflict, + &newInsert.Returning, + ) newInsert.Insert.Table = table newInsert.Insert.Columns = columns @@ -35,7 +39,7 @@ type insertStatementImpl struct { Insert jet.ClauseInsert ValuesQuery jet.ClauseValuesQuery - Returning clauseReturning + Returning jet.ClauseReturning OnConflict onConflictClause } diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 609d38ae..ad687b5f 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -8,7 +8,6 @@ import ( ) func TestInvalidInsert(t *testing.T) { - assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") } @@ -155,7 +154,7 @@ func TestInsert_ON_CONFLICT(t *testing.T) { ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( SET(table1ColBool.SET(Bool(true)), table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), ).WHERE(table1Col1.GT(Int(2))), ). RETURNING(table1Col1, table1ColBool) diff --git a/postgres/interval_expression.go b/postgres/interval_expression.go index b8468cf7..6f5ab586 100644 --- a/postgres/interval_expression.go +++ b/postgres/interval_expression.go @@ -116,7 +116,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { panic("jet: invalid number of quantity and unit fields") } - fields := []string{} + var fields []string for i := 0; i < len(quantityAndUnit); i += 2 { quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64) diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 516ae25b..8fb9cb6d 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -1,8 +1,9 @@ package postgres import ( - "github.com/go-jet/jet/v2/internal/jet" "math" + + "github.com/go-jet/jet/v2/internal/jet" ) // RowLock is interface for SELECT statement row lock types @@ -46,7 +47,7 @@ type SelectStatement interface { DISTINCT() SelectStatement FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement - GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement + GROUP_BY(groupByClauses ...GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement WINDOW(name string) windowExpand ORDER_BY(orderByClauses ...OrderByClause) SelectStatement @@ -121,7 +122,7 @@ func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement { return s } -func (s *selectStatementImpl) GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement { +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { s.GroupBy.List = groupByClauses return s } diff --git a/postgres/types.go b/postgres/types.go index 05354b73..6fed21b6 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -20,5 +20,8 @@ type PrintableStatement = jet.PrintableStatement // OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY. type OrderByClause = jet.OrderByClause +// GroupByClause interface to use as input for GROUP_BY +type GroupByClause = jet.GroupByClause + // SetLogger sets automatic statement logging var SetLogger = jet.SetLoggerFunc diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 594efa47..58c5ba40 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -22,7 +22,7 @@ type updateStatementImpl struct { Set clauseSet SetNew jet.SetClauseNew Where jet.ClauseWhere - Returning clauseReturning + Returning jet.ClauseReturning } func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go index 5a39094a..ab75cf62 100644 --- a/qrm/internal/null_types.go +++ b/qrm/internal/null_types.go @@ -1,263 +1,175 @@ package internal import ( + "database/sql" "database/sql/driver" "fmt" + "github.com/go-jet/jet/v2/internal/utils/min" + "reflect" "strconv" "time" ) -//===============================================================// - -// NullByteArray struct -type NullByteArray struct { - ByteArray []byte - Valid bool +// NullBool struct +type NullBool struct { + sql.NullBool } // Scan implements the Scanner interface. -func (nb *NullByteArray) Scan(value interface{}) error { +func (nb *NullBool) Scan(value interface{}) error { switch v := value.(type) { - case nil: - nb.Valid = false - return nil - case []byte: - nb.ByteArray = append(v[:0:0], v...) + case bool: + nb.Bool, nb.Valid = v, true + case int8, int16, int32, int64, int: + intVal := reflect.ValueOf(v).Int() + + if intVal != 0 && intVal != 1 { + return fmt.Errorf("can't assign %T(%d) to bool", value, value) + } + + nb.Bool = intVal == 1 + nb.Valid = true + case uint8, uint16, uint32, uint64, uint: + uintVal := reflect.ValueOf(v).Uint() + + if uintVal != 0 && uintVal != 1 { + return fmt.Errorf("can't assign %T(%d) to bool", value, value) + } + + nb.Bool = uintVal == 1 nb.Valid = true - return nil default: - return fmt.Errorf("can't scan []byte from %v", value) + return nb.NullBool.Scan(value) } -} -// Value implements the driver Valuer interface. -func (nb NullByteArray) Value() (driver.Value, error) { - if !nb.Valid { - return nil, nil - } - return nb.ByteArray, nil + return nil } -//===============================================================// - // NullTime struct type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL + sql.NullTime } // Scan implements the Scanner interface. -func (nt *NullTime) Scan(value interface{}) (err error) { - switch v := value.(type) { - case nil: - nt.Valid = false - return - case time.Time: - nt.Time, nt.Valid = v, true - return - case []byte: - nt.Time, nt.Valid = parseTime(string(v)) - return - case string: - nt.Time, nt.Valid = parseTime(v) - return - default: - return fmt.Errorf("can't scan time.Time from %v", value) - } -} +func (nt *NullTime) Scan(value interface{}) error { + err := nt.NullTime.Scan(value) -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil + if err == nil { + return nil } - return nt.Time, nil -} - -const formatTime = "2006-01-02 15:04:05.999999" -func parseTime(timeStr string) (t time.Time, valid bool) { + // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value. + // At this point we try to parse those values using some of the predefined formats + nt.Time, nt.Valid = tryParseAsTime(value) - var format string - - switch len(timeStr) { - case 8: - format = formatTime[11:19] - case 10, 19, 21, 22, 23, 24, 25, 26: - format = formatTime[:len(timeStr)] - default: - return t, false + if !nt.Valid { + return fmt.Errorf("can't scan time.Time from %q", value) } - t, err := time.Parse(format, timeStr) - return t, err == nil + return nil } -//===============================================================// - -// NullInt8 struct -type NullInt8 struct { - Int8 int8 - Valid bool +var formats = []string{ + "2006-01-02 15:04:05-07:00", // sqlite + "2006-01-02 15:04:05.999999", // go-sql-driver/mysql + "15:04:05-07", // pgx + "15:04:05.999999", // pgx } -// Scan implements the Scanner interface. -func (n *NullInt8) Scan(value interface{}) (err error) { +func tryParseAsTime(value interface{}) (time.Time, bool) { + + var timeStr string + switch v := value.(type) { - case nil: - n.Valid = false - return - case int64: - n.Int8, n.Valid = int8(v), true - return - case int8: - n.Int8, n.Valid = v, true - return + case string: + timeStr = v case []byte: - intV, err := strconv.ParseInt(string(v), 10, 8) - if err == nil { - n.Int8, n.Valid = int8(intV), true - } - return err + timeStr = string(v) + case int64: + return time.Unix(v, 0), true // sqlite default: - return fmt.Errorf("can't scan int8 from %v", value) + return time.Time{}, false } -} - -// Value implements the driver Valuer interface. -func (n NullInt8) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return n.Int8, nil -} -//===============================================================// + for _, format := range formats { + formatLen := min.Int(len(format), len(timeStr)) + t, err := time.Parse(format[:formatLen], timeStr) -// NullInt16 struct -type NullInt16 struct { - Int16 int16 - Valid bool -} - -// Scan implements the Scanner interface. -func (n *NullInt16) Scan(value interface{}) error { - - switch v := value.(type) { - case nil: - n.Valid = false - return nil - case int64: - n.Int16, n.Valid = int16(v), true - return nil - case int16: - n.Int16, n.Valid = v, true - return nil - case int8: - n.Int16, n.Valid = int16(v), true - return nil - case uint8: - n.Int16, n.Valid = int16(v), true - return nil - case []byte: - intV, err := strconv.ParseInt(string(v), 10, 16) - if err == nil { - n.Int16, n.Valid = int16(intV), true + if err != nil { + continue } - return nil - default: - return fmt.Errorf("can't scan int16 from %v", value) - } -} -// Value implements the driver Valuer interface. -func (n NullInt16) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil + return t, true } - return n.Int16, nil -} -//===============================================================// + return time.Time{}, false +} -// NullInt32 struct -type NullInt32 struct { - Int32 int32 - Valid bool +// NullUInt64 struct +type NullUInt64 struct { + UInt64 uint64 + Valid bool } // Scan implements the Scanner interface. -func (n *NullInt32) Scan(value interface{}) error { +func (n *NullUInt64) Scan(value interface{}) error { + var stringValue string switch v := value.(type) { case nil: n.Valid = false return nil case int64: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true + return nil + case uint64: + n.UInt64, n.Valid = v, true return nil case int32: - n.Int32, n.Valid = v, true + n.UInt64, n.Valid = uint64(v), true + return nil + case uint32: + n.UInt64, n.Valid = uint64(v), true return nil case int16: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil case uint16: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil case int8: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil case uint8: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil - case []byte: - intV, err := strconv.ParseInt(string(v), 10, 32) - if err == nil { - n.Int32, n.Valid = int32(intV), true - } + case int: + n.UInt64, n.Valid = uint64(v), true return nil + case uint: + n.UInt64, n.Valid = uint64(v), true + return nil + case []byte: + stringValue = string(v) + case string: + stringValue = v default: - return fmt.Errorf("can't scan int32 from %v", value) + return fmt.Errorf("can't scan uint64 from %v", value) } -} -// Value implements the driver Valuer interface. -func (n NullInt32) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil + uintV, err := strconv.ParseUint(stringValue, 10, 64) + if err != nil { + return err } - return n.Int32, nil -} + n.UInt64 = uintV + n.Valid = true -//===============================================================// - -// NullFloat32 struct -type NullFloat32 struct { - Float32 float32 - Valid bool -} - -// Scan implements the Scanner interface. -func (n *NullFloat32) Scan(value interface{}) error { - switch v := value.(type) { - case nil: - n.Valid = false - return nil - case float64: - n.Float32, n.Valid = float32(v), true - return nil - case float32: - n.Float32, n.Valid = v, true - return nil - default: - return fmt.Errorf("can't scan float32 from %v", value) - } + return nil } // Value implements the driver Valuer interface. -func (n NullFloat32) Value() (driver.Value, error) { +func (n NullUInt64) Value() (driver.Value, error) { if !n.Valid { return nil, nil } - return n.Float32, nil + return n.UInt64, nil } diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index 8f4addea..a15b104d 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -7,141 +7,85 @@ import ( "time" ) -func TestNullByteArray(t *testing.T) { - var array NullByteArray +func TestNullBool(t *testing.T) { + var nullBool NullBool - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) + require.NoError(t, nullBool.Scan(nil)) + require.Equal(t, nullBool.Valid, false) - require.NoError(t, array.Scan([]byte("bytea"))) - require.Equal(t, array.Valid, true) - require.Equal(t, string(array.ByteArray), string([]byte("bytea"))) + require.NoError(t, nullBool.Scan(int64(1))) + require.Equal(t, nullBool.Valid, true) + value, _ := nullBool.Value() + require.Equal(t, value, true) - require.Error(t, array.Scan(12), "can't scan []byte from 12") + require.NoError(t, nullBool.Scan(uint32(0))) + require.Equal(t, nullBool.Valid, true) + value, _ = nullBool.Value() + require.Equal(t, value, false) + + require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool") } func TestNullTime(t *testing.T) { - var array NullTime + var nullTime NullTime - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) + require.NoError(t, nullTime.Scan(nil)) + require.Equal(t, nullTime.Valid, false) time := time.Now() - require.NoError(t, array.Scan(time)) - require.Equal(t, array.Valid, true) - value, _ := array.Value() + require.NoError(t, nullTime.Scan(time)) + require.Equal(t, nullTime.Valid, true) + value, _ := nullTime.Value() require.Equal(t, value, time) - require.NoError(t, array.Scan([]byte("13:10:11"))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() + require.NoError(t, nullTime.Scan([]byte("13:10:11"))) + require.Equal(t, nullTime.Valid, true) + value, _ = nullTime.Value() require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - require.NoError(t, array.Scan("13:10:11")) - require.Equal(t, array.Valid, true) - value, _ = array.Value() + require.NoError(t, nullTime.Scan("13:10:11")) + require.Equal(t, nullTime.Valid, true) + value, _ = nullTime.Value() require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - require.Error(t, array.Scan(12), "can't scan time.Time from 12") -} - -func TestNullInt8(t *testing.T) { - var array NullInt8 - - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) - - require.NoError(t, array.Scan(int64(11))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, int8(11)) - - require.Error(t, array.Scan("text"), "can't scan int8 from text") -} - -func TestNullInt16(t *testing.T) { - var array NullInt16 - - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) - - require.NoError(t, array.Scan(int64(11))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, int16(11)) - - require.NoError(t, array.Scan(int16(20))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int16(20)) - - require.NoError(t, array.Scan(int8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int16(30)) - - require.NoError(t, array.Scan(uint8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int16(30)) - - require.Error(t, array.Scan("text"), "can't scan int16 from text") + require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12") } -func TestNullInt32(t *testing.T) { - var array NullInt32 - - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) +func TestNullUInt64(t *testing.T) { + var nullUInt64 NullUInt64 - require.NoError(t, array.Scan(int64(11))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, int32(11)) + require.NoError(t, nullUInt64.Scan(nil)) + require.Equal(t, nullUInt64.Valid, false) - require.NoError(t, array.Scan(int32(32))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(32)) + require.NoError(t, nullUInt64.Scan(int64(11))) + require.Equal(t, nullUInt64.Valid, true) + value, _ := nullUInt64.Value() + require.Equal(t, value, uint64(11)) - require.NoError(t, array.Scan(int16(20))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(20)) - - require.NoError(t, array.Scan(uint16(16))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(16)) - - require.NoError(t, array.Scan(int8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(30)) - - require.NoError(t, array.Scan(uint8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(30)) - - require.Error(t, array.Scan("text"), "can't scan int32 from text") -} + require.NoError(t, nullUInt64.Scan(int32(32))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(32)) -func TestNullFloat32(t *testing.T) { - var array NullFloat32 + require.NoError(t, nullUInt64.Scan(int16(20))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(20)) - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) + require.NoError(t, nullUInt64.Scan(uint16(16))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(16)) - require.NoError(t, array.Scan(float64(64))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, float32(64)) + require.NoError(t, nullUInt64.Scan(int8(30))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(30)) - require.NoError(t, array.Scan(float32(32))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, float32(32)) + require.NoError(t, nullUInt64.Scan(uint8(30))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(30)) - require.Error(t, array.Scan(12), "can't scan float32 from 12") + require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text") } diff --git a/qrm/qrm.go b/qrm/qrm.go index 51bbffa2..45024023 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -27,7 +27,10 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr if destinationPtrType.Elem().Kind() == reflect.Slice { _, err := queryToSlice(ctx, db, query, args, destPtr) - return err + if err != nil { + return fmt.Errorf("jet: %w", err) + } + return nil } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() @@ -35,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) if err != nil { - return err + return fmt.Errorf("jet: %w", err) } if rowsProcessed == 0 { @@ -214,7 +217,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value } rowElemPtr := scanContext.rowElemValuePtr(index) - if !rowElemPtr.IsNil() { + if rowElemPtr.IsValid() && !rowElemPtr.IsNil() { updated = true err = appendElemToSlice(slicePtrValue, rowElemPtr) if err != nil { @@ -275,10 +278,16 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re err = scanner.Scan(cellValue) if err != nil { - panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) + err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) + return } } else { - setReflectValue(reflect.ValueOf(cellValue), fieldValue) + err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) + + if err != nil { + err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) + return + } } } } diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 9d9e059b..dbc4b877 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -2,9 +2,7 @@ package qrm import ( "database/sql" - "database/sql/driver" "fmt" - "github.com/go-jet/jet/v2/internal/utils" "reflect" "strings" ) @@ -45,7 +43,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { } return &scanContext{ - row: createScanValue(columnTypes), + row: createScanSlice(len(columnTypes)), uniqueDestObjectsMap: make(map[string]int), groupKeyInfoCache: make(map[string]groupKeyInfo), @@ -55,6 +53,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { }, nil } +func createScanSlice(columnCount int) []interface{} { + scanSlice := make([]interface{}, columnCount) + scanPtrSlice := make([]interface{}, columnCount) + + for i := range scanPtrSlice { + scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value + } + + return scanPtrSlice +} + type typeInfo struct { fieldMappings []fieldMapping } @@ -209,22 +218,23 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { } func (s *scanContext) rowElem(index int) interface{} { + cellValue := reflect.ValueOf(s.row[index]) - valuer, ok := s.row[index].(driver.Valuer) - - utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer") - - value, err := valuer.Value() - - utils.PanicOnError(err) + if cellValue.IsValid() && !cellValue.IsNil() { + return cellValue.Elem().Interface() + } - return value + return nil } func (s *scanContext) rowElemValuePtr(index int) reflect.Value { rowElem := s.rowElem(index) rowElemValue := reflect.ValueOf(rowElem) + if !rowElemValue.IsValid() { + return reflect.Value{} + } + if rowElemValue.Kind() == reflect.Ptr { return rowElemValue } diff --git a/qrm/utill.go b/qrm/utill.go index f4857974..6926c423 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -7,7 +7,6 @@ import ( "github.com/go-jet/jet/v2/qrm/internal" "github.com/google/uuid" "reflect" - "strconv" "strings" "time" ) @@ -56,21 +55,30 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e sliceValue := slicePtrValue.Elem() sliceElemType := sliceValue.Type().Elem() - newElemValue := objPtrValue + var newSliceElemValue reflect.Value - if sliceElemType.Kind() != reflect.Ptr { - newElemValue = objPtrValue.Elem() - } + if objPtrValue.Type().AssignableTo(sliceElemType) { + newSliceElemValue = objPtrValue + } else if objPtrValue.Elem().Type().AssignableTo(sliceElemType) { + newSliceElemValue = objPtrValue.Elem() + } else { + newSliceElemValue = reflect.New(sliceElemType).Elem() - if newElemValue.Type().ConvertibleTo(sliceElemType) { - newElemValue = newElemValue.Convert(sliceElemType) - } + var err error + + if newSliceElemValue.Kind() == reflect.Ptr { + newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem())) + err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem()) + } else { + err = tryAssign(objPtrValue.Elem(), newSliceElemValue) + } - if !newElemValue.Type().AssignableTo(sliceElemType) { - panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice") + if err != nil { + return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err) + } } - sliceValue.Set(reflect.Append(sliceValue, newElemValue)) + sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue)) return nil } @@ -121,7 +129,6 @@ func toCommonIdentifier(name string) string { } func initializeValueIfNilPtr(value reflect.Value) { - if !value.IsValid() || !value.CanSet() { return } @@ -173,172 +180,160 @@ func isSimpleModelType(objType reflect.Type) bool { return objType == timeType || objType == uuidType || objType == byteArrayType } -func isIntegerType(value reflect.Type) bool { - switch value { - case int8Type, unit8Type, int16Type, uint16Type, - int32Type, uint32Type, int64Type, uint64Type: +func isIntegerType(objType reflect.Type) bool { + objType = indirectType(objType) + + switch objType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return true } return false } -func isNumber(valueType reflect.Type) bool { - return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type +func isFloatType(value reflect.Type) bool { + switch value.Kind() { + case reflect.Float32, reflect.Float64: + return true + } + + return false } -func tryAssign(source, destination reflect.Value) bool { +func tryAssign(source, destination reflect.Value) error { + + if source.Type() != destination.Type() && + !isFloatType(destination.Type()) && // to preserve precision during conversion + !(isIntegerType(source.Type()) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string + source.Type().ConvertibleTo(destination.Type()) { - switch { - case source.Type().ConvertibleTo(destination.Type()): source = source.Convert(destination.Type()) - case isIntegerType(source.Type()) && destination.Type() == boolType: - intValue := source.Int() + } - if intValue == 1 { - source = reflect.ValueOf(true) - } else if intValue == 0 { - source = reflect.ValueOf(false) + if source.Type().AssignableTo(destination.Type()) { + switch b := source.Interface().(type) { + case []byte: + destination.SetBytes(cloneBytes(b)) + default: + destination.Set(source) } - case source.Type() == stringType && isNumber(destination.Type()): - // if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number - // and then parsed number is converted into destination type - f, err := strconv.ParseFloat(source.String(), 64) + return nil + } + + sourceInterface := source.Interface() + + switch destination.Interface().(type) { + case bool: + var nullBool internal.NullBool + + err := nullBool.Scan(sourceInterface) + if err != nil { - return false + return err } - source = reflect.ValueOf(f) - if source.Type().ConvertibleTo(destination.Type()) { - source = source.Convert(destination.Type()) + destination.SetBool(nullBool.Bool) + + case float32, float64: + var nullFloat sql.NullFloat64 + + err := nullFloat.Scan(sourceInterface) + if err != nil { + return err } - } - if source.Type().AssignableTo(destination.Type()) { - destination.Set(source) - return true - } + if nullFloat.Valid { + destination.SetFloat(nullFloat.Float64) + } + case int, int8, int16, int32, int64: + var integer sql.NullInt64 - return false -} + err := integer.Scan(sourceInterface) + if err != nil { + return err + } -func setReflectValue(source, destination reflect.Value) { + if integer.Valid { + destination.SetInt(integer.Int64) + } - if tryAssign(source, destination) { - return - } + case uint, uint8, uint16, uint32, uint64: + var uInt internal.NullUInt64 - if destination.Kind() == reflect.Ptr { - if source.Kind() == reflect.Ptr { - if !source.IsNil() { - if destination.IsNil() { - initializeValueIfNilPtr(destination) - } - - if tryAssign(source.Elem(), destination.Elem()) { - return - } - } else { - return - } - } else { - if source.CanAddr() { - source = source.Addr() - } else { - sourceCopy := reflect.New(source.Type()) - sourceCopy.Elem().Set(source) + err := uInt.Scan(sourceInterface) - source = sourceCopy - } + if err != nil { + return err + } - if tryAssign(source, destination) { - return - } + if uInt.Valid { + destination.SetUint(uInt.UInt64) + } - if tryAssign(source.Elem(), destination.Elem()) { - return - } + case string: + var str sql.NullString + + err := str.Scan(sourceInterface) + if err != nil { + return err } - } else { - if source.Kind() == reflect.Ptr { - if source.IsNil() { - return - } - source = source.Elem() + + if str.Valid { + destination.SetString(str.String) + } + + case time.Time: + var nullTime internal.NullTime + + err := nullTime.Scan(sourceInterface) + if err != nil { + return err } - if tryAssign(source, destination) { - return + if nullTime.Valid { + destination.Set(reflect.ValueOf(nullTime.Time)) } + + default: + return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface()) } - panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) + return nil } -func createScanValue(columnTypes []*sql.ColumnType) []interface{} { - values := make([]interface{}, len(columnTypes)) +func setReflectValue(source, destination reflect.Value) error { - for i, sqlColumnType := range columnTypes { - columnType := newScanType(sqlColumnType) + if destination.Kind() == reflect.Ptr { + if destination.IsNil() { + initializeValueIfNilPtr(destination) + } - columnValue := reflect.New(columnType) + if source.Kind() == reflect.Ptr { + if source.IsNil() { + return nil // source is nil, destination should keep its zero value + } + source = source.Elem() + } - values[i] = columnValue.Interface() - } + if err := tryAssign(source, destination.Elem()); err != nil { + return err + } - return values -} + } else { + if source.Kind() == reflect.Ptr { + if source.IsNil() { + return nil // source is nil, destination should keep its zero value + } + source = source.Elem() + } -var boolType = reflect.TypeOf(true) -var int8Type = reflect.TypeOf(int8(1)) -var unit8Type = reflect.TypeOf(uint8(1)) -var int16Type = reflect.TypeOf(int16(1)) -var uint16Type = reflect.TypeOf(uint16(1)) -var int32Type = reflect.TypeOf(int32(1)) -var uint32Type = reflect.TypeOf(uint32(1)) -var int64Type = reflect.TypeOf(int64(1)) -var uint64Type = reflect.TypeOf(uint64(1)) -var float32Type = reflect.TypeOf(float32(1)) -var float64Type = reflect.TypeOf(float64(1)) -var stringType = reflect.TypeOf("") - -var nullBoolType = reflect.TypeOf(sql.NullBool{}) -var nullInt8Type = reflect.TypeOf(internal.NullInt8{}) -var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) -var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) -var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) -var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{}) -var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) -var nullStringType = reflect.TypeOf(sql.NullString{}) -var nullTimeType = reflect.TypeOf(internal.NullTime{}) -var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) - -func newScanType(columnType *sql.ColumnType) reflect.Type { - - switch columnType.DatabaseTypeName() { - case "TINYINT": - return nullInt8Type - case "INT2", "SMALLINT", "YEAR": - return nullInt16Type - case "INT4", "MEDIUMINT", "INT": - return nullInt32Type - case "INT8", "BIGINT": - return nullInt64Type - case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML": - return nullStringType - case "FLOAT4": - return nullFloat32Type - case "FLOAT8", "FLOAT", "DOUBLE": - return nullFloat64Type - case "BOOL": - return nullBoolType - case "BYTEA", "BINARY", "VARBINARY", "BLOB": - return nullByteArrayType - case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ": - return nullTimeType - default: - return nullStringType + if err := tryAssign(source, destination); err != nil { + return err + } } + + return nil } func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { @@ -385,3 +380,12 @@ func fieldToString(field *reflect.StructField) string { return " at '" + field.Name + " " + field.Type.String() + "'" } + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/qrm/utill_test.go b/qrm/utill_test.go index e23fa158..991a682d 100644 --- a/qrm/utill_test.go +++ b/qrm/utill_test.go @@ -58,25 +58,24 @@ func TestTryAssign(t *testing.T) { testValue := reflect.ValueOf(&destination).Elem() // convertible - require.True(t, tryAssign(reflect.ValueOf(convertible), testValue.FieldByName("Convertible"))) + require.NoError(t, tryAssign(reflect.ValueOf(convertible), testValue.FieldByName("Convertible"))) require.Equal(t, int64(16), destination.Convertible) // 1/0 to bool - require.True(t, tryAssign(reflect.ValueOf(intBool1), testValue.FieldByName("IntBool1"))) + require.NoError(t, tryAssign(reflect.ValueOf(intBool1), testValue.FieldByName("IntBool1"))) require.Equal(t, true, destination.IntBool1) - require.True(t, tryAssign(reflect.ValueOf(intBool0), testValue.FieldByName("IntBool0"))) + require.NoError(t, tryAssign(reflect.ValueOf(intBool0), testValue.FieldByName("IntBool0"))) require.Equal(t, false, destination.IntBool0) - require.False(t, tryAssign(reflect.ValueOf(intBool2), testValue.FieldByName("IntBool2"))) - require.Equal(t, false, destination.IntBool2) + require.EqualError(t, tryAssign(reflect.ValueOf(intBool2), testValue.FieldByName("IntBool2")), "can't assign int32(2) to bool") // string to float - require.True(t, tryAssign(reflect.ValueOf(floatStr), testValue.FieldByName("FloatStr"))) + require.NoError(t, tryAssign(reflect.ValueOf(floatStr), testValue.FieldByName("FloatStr"))) require.Equal(t, 1.11, destination.FloatStr) - require.False(t, tryAssign(reflect.ValueOf(floatErr), testValue.FieldByName("FloatErr"))) + require.EqualError(t, tryAssign(reflect.ValueOf(floatErr), testValue.FieldByName("FloatErr")), "converting driver.Value type string (\"1.abcd2\") to a float64: invalid syntax") require.Equal(t, 0.00, destination.FloatErr) // string to string - require.True(t, tryAssign(reflect.ValueOf(str), testValue.FieldByName("Str"))) + require.NoError(t, tryAssign(reflect.ValueOf(str), testValue.FieldByName("Str"))) require.Equal(t, str, destination.Str) } diff --git a/sqlite/cast.go b/sqlite/cast.go new file mode 100644 index 00000000..517fb95a --- /dev/null +++ b/sqlite/cast.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +type cast interface { + AS(castType string) Expression + AS_TEXT() StringExpression + AS_NUMERIC() FloatExpression + AS_INTEGER() IntegerExpression + AS_REAL() FloatExpression + AS_BLOB() StringExpression +} + +type castImpl struct { + jet.Cast +} + +// CAST function converts a expr (of any type) into latter specified datatype. +func CAST(expr Expression) cast { + castImpl := &castImpl{} + castImpl.Cast = jet.NewCastImpl(expr) + return castImpl +} + +// AS casts expressions to castType +func (c *castImpl) AS(castType string) Expression { + return c.Cast.AS(castType) +} + +// AS_TEXT cast expression to TEXT type +func (c *castImpl) AS_TEXT() StringExpression { + return StringExp(c.AS("TEXT")) +} + +// AS_NUMERIC cast expression to NUMERIC type +func (c *castImpl) AS_NUMERIC() FloatExpression { + return FloatExp(c.AS("NUMERIC")) +} + +// AS_INTEGER cast expression to INTEGER type +func (c *castImpl) AS_INTEGER() IntegerExpression { + return IntExp(c.AS("INTEGER")) +} + +// AS_REAL cast expression to REAL type +func (c *castImpl) AS_REAL() FloatExpression { + return FloatExp(c.AS("REAL")) +} + +// AS_BLOB cast expression to BLOB type +func (c *castImpl) AS_BLOB() StringExpression { + return StringExp(c.AS("BLOB")) +} diff --git a/sqlite/cast_test.go b/sqlite/cast_test.go new file mode 100644 index 00000000..c0ef9146 --- /dev/null +++ b/sqlite/cast_test.go @@ -0,0 +1,14 @@ +package sqlite + +import ( + "testing" +) + +func TestCAST(t *testing.T) { + assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) + assertSerialize(t, CAST(Int(22)).AS_TEXT(), `CAST(? AS TEXT)`) + assertSerialize(t, CAST(Int(22)).AS_NUMERIC(), `CAST(? AS NUMERIC)`) + assertSerialize(t, CAST(String("22")).AS_INTEGER(), `CAST(? AS INTEGER)`) + assertSerialize(t, CAST(String("22.2")).AS_REAL(), `CAST(? AS REAL)`) + assertSerialize(t, CAST(String("blob")).AS_BLOB(), `CAST(? AS BLOB)`) +} diff --git a/sqlite/columns.go b/sqlite/columns.go new file mode 100644 index 00000000..88ae4f6b --- /dev/null +++ b/sqlite/columns.go @@ -0,0 +1,58 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Column is common column interface for all types of columns. +type Column = jet.ColumnExpression + +// ColumnList function returns list of columns that be used as projection or column list for UPDATE and INSERT statement. +type ColumnList = jet.ColumnList + +// ColumnBool is interface for SQL boolean columns. +type ColumnBool = jet.ColumnBool + +// BoolColumn creates named bool column. +var BoolColumn = jet.BoolColumn + +// ColumnString is interface for SQL text, character, character varying +// bytea, uuid columns and enums types. +type ColumnString = jet.ColumnString + +// StringColumn creates named string column. +var StringColumn = jet.StringColumn + +// ColumnInteger is interface for SQL smallint, integer, bigint columns. +type ColumnInteger = jet.ColumnInteger + +// IntegerColumn creates named integer column. +var IntegerColumn = jet.IntegerColumn + +// ColumnFloat is interface for SQL real, numeric, decimal or double precision column. +type ColumnFloat = jet.ColumnFloat + +// FloatColumn creates named float column. +var FloatColumn = jet.FloatColumn + +// ColumnTime is interface for SQL time column. +type ColumnTime = jet.ColumnTime + +// TimeColumn creates named time column +var TimeColumn = jet.TimeColumn + +// ColumnDate is interface of SQL date columns. +type ColumnDate = jet.ColumnDate + +// DateColumn creates named date column. +var DateColumn = jet.DateColumn + +// ColumnDateTime is interface of SQL timestamp columns. +type ColumnDateTime = jet.ColumnTimestamp + +// DateTimeColumn creates named timestamp column +var DateTimeColumn = jet.TimestampColumn + +//ColumnTimestamp is interface of SQL timestamp columns. +type ColumnTimestamp = jet.ColumnTimestamp + +// TimestampColumn creates named timestamp column +var TimestampColumn = jet.TimestampColumn diff --git a/sqlite/delete_statement.go b/sqlite/delete_statement.go new file mode 100644 index 00000000..dee85c06 --- /dev/null +++ b/sqlite/delete_statement.go @@ -0,0 +1,61 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// DeleteStatement is interface for MySQL DELETE statement +type DeleteStatement interface { + Statement + + WHERE(expression BoolExpression) DeleteStatement + ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement + LIMIT(limit int64) DeleteStatement + RETURNING(projections ...jet.Projection) DeleteStatement +} + +type deleteStatementImpl struct { + jet.SerializerStatement + + Delete jet.ClauseStatementBegin + Where jet.ClauseWhere + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Returning jet.ClauseReturning +} + +func newDeleteStatement(table Table) DeleteStatement { + newDelete := &deleteStatementImpl{} + newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, + &newDelete.Delete, + &newDelete.Where, + &newDelete.OrderBy, + &newDelete.Limit, + &newDelete.Returning, + ) + + newDelete.Delete.Name = "DELETE FROM" + newDelete.Delete.Tables = append(newDelete.Delete.Tables, table) + newDelete.Where.Mandatory = true + newDelete.Limit.Count = -1 + + return newDelete +} + +func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { + d.Where.Condition = expression + return d +} + +func (d *deleteStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement { + d.OrderBy.List = orderByClauses + return d +} + +func (d *deleteStatementImpl) LIMIT(limit int64) DeleteStatement { + d.Limit.Count = limit + return d +} + +func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement { + d.Returning.ProjectionList = projections + return d +} diff --git a/sqlite/delete_statement_test.go b/sqlite/delete_statement_test.go new file mode 100644 index 00000000..6620c9f6 --- /dev/null +++ b/sqlite/delete_statement_test.go @@ -0,0 +1,26 @@ +package sqlite + +import ( + "testing" +) + +func TestDeleteUnconditionally(t *testing.T) { + assertStatementSqlErr(t, table1.DELETE(), `jet: WHERE clause not set`) + assertStatementSqlErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) +} + +func TestDeleteWithWhere(t *testing.T) { + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` +DELETE FROM db.table1 +WHERE table1.col1 = ?; +`, int64(1)) +} + +func TestDeleteWithWhereOrderByLimit(t *testing.T) { + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).ORDER_BY(table1Col1).LIMIT(1), ` +DELETE FROM db.table1 +WHERE table1.col1 = ? +ORDER BY table1.col1 +LIMIT ?; +`, int64(1), int64(1)) +} diff --git a/sqlite/dialect.go b/sqlite/dialect.go new file mode 100644 index 00000000..93e1d2f1 --- /dev/null +++ b/sqlite/dialect.go @@ -0,0 +1,225 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +// Dialect is implementation of SQL Builder for SQLite databases. +var Dialect = newDialect() + +func newDialect() jet.Dialect { + operatorSerializeOverrides := map[string]jet.SerializeOverride{} + operatorSerializeOverrides["IS DISTINCT FROM"] = sqlite_IS_DISTINCT_FROM + operatorSerializeOverrides["IS NOT DISTINCT FROM"] = sqlite_IS_NOT_DISTINCT_FROM + operatorSerializeOverrides["#"] = sqliteBitXOR + + mySQLDialectParams := jet.DialectParams{ + Name: "SQLite", + PackageName: "sqlite", + OperatorSerializeOverrides: operatorSerializeOverrides, + AliasQuoteChar: '"', + IdentifierQuoteChar: '`', + ArgumentPlaceholder: func(int) string { + return "?" + }, + ReservedWords: reservedWords2, + } + + return jet.NewDialect(mySQLDialectParams) +} + +func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(expressions) < 2 { + panic("jet: invalid number of expressions for operator XOR") + } + + // (~(a&b))&(a|b) + a := expressions[0] + b := expressions[1] + + out.WriteString("(~(") + jet.Serialize(a, statement, out, options...) + out.WriteByte('&') + jet.Serialize(b, statement, out, options...) + out.WriteString("))&(") + jet.Serialize(a, statement, out, options...) + out.WriteByte('|') + jet.Serialize(b, statement, out, options...) + out.WriteByte(')') + } +} + +func sqlite_IS_NOT_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(expressions) < 2 { + panic("jet: invalid number of expressions for operator") + } + + jet.Serialize(expressions[0], statement, out) + out.WriteString("IS") + jet.Serialize(expressions[1], statement, out) + } +} + +func sqlite_IS_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(expressions) < 2 { + panic("jet: invalid number of expressions for operator") + } + + jet.Serialize(expressions[0], statement, out) + out.WriteString("IS NOT") + jet.Serialize(expressions[1], statement, out) + } +} + +var reservedWords2 = []string{ + "ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DO", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FIRST", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GENERATED", + "GLOB", + "GROUP", + "GROUPS", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LAST", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MATERIALIZED", + "NATURAL", + "NO", + "NOT", + "NOTHING", + "NOTNULL", + "NULL", + "NULLS", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OTHERS", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RETURNING", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TIES", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT", +} diff --git a/sqlite/dialect_test.go b/sqlite/dialect_test.go new file mode 100644 index 00000000..e90357fd --- /dev/null +++ b/sqlite/dialect_test.go @@ -0,0 +1,59 @@ +package sqlite + +import ( + "testing" +) + +func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) { + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS NOT table2.col_bool)") + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS NOT ?)", false) +} + +func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS table2.col_bool)") + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS ?)", false) +} + +func TestBoolLiteral(t *testing.T) { + assertSerialize(t, Bool(true), "?", true) + assertSerialize(t, Bool(false), "?", false) +} + +func TestIntegerExpressionDIV(t *testing.T) { + assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int / table2.col_int)") + assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int / ?)", int64(11)) +} + +func TestIntExpressionPOW(t *testing.T) { + assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") + assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) +} + +func TestIntExpressionBIT_XOR(t *testing.T) { + assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "((~(table1.col_int & table2.col_int))&(table1.col_int | table2.col_int))") + assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "((~(table1.col_int & ?))&(table1.col_int | ?))", int64(11), int64(11)) +} + +func TestExists(t *testing.T) { + assertSerialize(t, EXISTS( + table2. + SELECT(Int(1)). + WHERE(table1Col1.EQ(table2Col3)), + ), + `(EXISTS ( + SELECT ? + FROM db.table2 + WHERE table1.col1 = table2.col3 +))`, int64(1)) +} + +func TestString_REGEXP_LIKE_operator(t *testing.T) { + assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") + +} + +func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") +} diff --git a/sqlite/expressions.go b/sqlite/expressions.go new file mode 100644 index 00000000..d1d47374 --- /dev/null +++ b/sqlite/expressions.go @@ -0,0 +1,97 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Expression is common interface for all expressions. +// Can be Bool, Int, Float, String, Date, Time or Timestamp expressions. +type Expression = jet.Expression + +// BoolExpression interface +type BoolExpression = jet.BoolExpression + +// StringExpression interface +type StringExpression = jet.StringExpression + +// NumericExpression is shared interface for integer or real expression +type NumericExpression = jet.NumericExpression + +// IntegerExpression interface +type IntegerExpression = jet.IntegerExpression + +// FloatExpression interface +type FloatExpression = jet.FloatExpression + +// TimeExpression interface +type TimeExpression = jet.TimeExpression + +// DateExpression interface +type DateExpression = jet.DateExpression + +// DateTimeExpression interface +type DateTimeExpression = jet.TimestampExpression + +// TimestampExpression interface +type TimestampExpression = jet.TimestampExpression + +// BoolExp is bool expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as bool expression. +// Does not add sql cast to generated sql builder output. +var BoolExp = jet.BoolExp + +// StringExp is string expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as string expression. +// Does not add sql cast to generated sql builder output. +var StringExp = jet.StringExp + +// IntExp is int expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as int expression. +// Does not add sql cast to generated sql builder output. +var IntExp = jet.IntExp + +// FloatExp is date expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as float expression. +// Does not add sql cast to generated sql builder output. +var FloatExp = jet.FloatExp + +// TimeExp is time expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as time expression. +// Does not add sql cast to generated sql builder output. +var TimeExp = jet.TimeExp + +// DateExp is date expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as date expression. +// Does not add sql cast to generated sql builder output. +var DateExp = jet.DateExp + +// DateTimeExp is timestamp expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as timestamp expression. +// Does not add sql cast to generated sql builder output. +var DateTimeExp = jet.TimestampExp + +// TimestampExp is timestamp expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as timestamp expression. +// Does not add sql cast to generated sql builder output. +var TimestampExp = jet.TimestampExp + +// RawArgs is type used to pass optional arguments to Raw method +type RawArgs = map[string]interface{} + +// Raw can be used for any unsupported functions, operators or expressions. +// For example: Raw("current_database()") +// Raw helper methods for each of the sqlite types +var ( + Raw = jet.Raw + + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimestamp = jet.RawTimestamp + RawDate = jet.RawDate +) + +// Func can be used to call an custom or as of yet unsupported function in the database. +var Func = jet.Func + +// NewEnumValue creates new named enum value +var NewEnumValue = jet.NewEnumValue diff --git a/sqlite/expressions_test.go b/sqlite/expressions_test.go new file mode 100644 index 00000000..2c2bbef3 --- /dev/null +++ b/sqlite/expressions_test.go @@ -0,0 +1,52 @@ +package sqlite + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestRaw(t *testing.T) { + assertSerialize(t, Raw("current_database()"), "(current_database())") + assertDebugSerialize(t, Raw("current_database()"), "(current_database())") + + assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(? + table.colInt + ?)", 11, 22) + assertDebugSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(11 + table.colInt + 22)") + + assertSerialize(t, + Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), + "(? + (? + table.colInt + ?))", + int64(700), 11, 22) + assertDebugSerialize(t, + Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), + "(700 + (11 + table.colInt + 22))") +} + +func TestRawDuplicateArguments(t *testing.T) { + assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}), + "(? + table.colInt + ?)", 11, 11) + + assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}), + "(? + table.colInt + ? + ? + ? + 11)", 11, 2000, 11, 2000) + + assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4", + RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}), + `(? + all_types.integer + ? + ? + ? + ? + ?)`, 11, 22, 11, 22, 33, 44) +} + +func TestRawInvalidArguments(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, "jet: named argument 'first_arg' does not appear in raw query", r) + }() + + assertSerialize(t, Raw("table.colInt + :second_arg", RawArgs{"first_arg": 11}), "(table.colInt + ?)", 22) +} + +func TestRawType(t *testing.T) { + assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), + "((table.colInt + ?) = ?)", 11.22, 3.14) + assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), + "((table.colStr || ?) = ?)", "doe", "john doe") +} diff --git a/sqlite/functions.go b/sqlite/functions.go new file mode 100644 index 00000000..2b70714b --- /dev/null +++ b/sqlite/functions.go @@ -0,0 +1,342 @@ +package sqlite + +import ( + "fmt" + "github.com/go-jet/jet/v2/internal/jet" + "time" +) + +// ROW is construct one table row from list of expressions. +func ROW(expressions ...Expression) Expression { + return jet.NewFunc("", expressions, nil) +} + +// ------------------ Mathematical functions ---------------// + +// ABSf calculates absolute value from float expression +var ABSf = jet.ABSf + +// ABSi calculates absolute value from int expression +var ABSi = jet.ABSi + +// POW calculates power of base with exponent +var POW = jet.POW + +// POWER calculates power of base with exponent +var POWER = jet.POWER + +// SQRT calculates square root of numeric expression +var SQRT = jet.SQRT + +// CBRT calculates cube root of numeric expression +func CBRT(number jet.NumericExpression) jet.FloatExpression { + return POWER(number, Float(1.0).DIV(Float(3.0))) +} + +// CEIL calculates ceil of float expression +var CEIL = jet.CEIL + +// FLOOR calculates floor of float expression +var FLOOR = jet.FLOOR + +// ROUND calculates round of a float expressions with optional precision +var ROUND = jet.ROUND + +// SIGN returns sign of float expression +var SIGN = jet.SIGN + +// TRUNC calculates trunc of float expression with precision +var TRUNC = TRUNCATE + +// TRUNCATE calculates trunc of float expression with precision +var TRUNCATE = func(floatExpression jet.FloatExpression, precision jet.IntegerExpression) jet.FloatExpression { + return jet.NewFloatFunc("TRUNCATE", floatExpression, precision) +} + +// LN calculates natural algorithm of float expression +var LN = jet.LN + +// LOG calculates logarithm of float expression +var LOG = jet.LOG + +// ----------------- Aggregate functions -------------------// + +// AVG is aggregate function used to calculate avg value from numeric expression +var AVG = jet.AVG + +// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. +//var BIT_AND = jet.BIT_AND + +// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none. +//var BIT_OR = jet.BIT_OR + +// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null. +var COUNT = jet.COUNT + +// MAX is aggregate function. Returns maximum value of expression across all input values +var MAX = jet.MAX + +// MAXi is aggregate function. Returns maximum value of int expression across all input values +var MAXi = jet.MAXi + +// MAXf is aggregate function. Returns maximum value of float expression across all input values +var MAXf = jet.MAXf + +// MIN is aggregate function. Returns minimum value of int expression across all input values +var MIN = jet.MIN + +// MINi is aggregate function. Returns minimum value of int expression across all input values +var MINi = jet.MINi + +// MINf is aggregate function. Returns minimum value of float expression across all input values +var MINf = jet.MINf + +// SUM is aggregate function. Returns sum of all expressions +var SUM = jet.SUM + +// SUMi is aggregate function. Returns sum of integer expression. +var SUMi = jet.SUMi + +// SUMf is aggregate function. Returns sum of float expression. +var SUMf = jet.SUMf + +// -------------------- Window functions -----------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +var ROW_NUMBER = jet.ROW_NUMBER + +// RANK of the current row with gaps; same as row_number of its first peer +var RANK = jet.RANK + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +var DENSE_RANK = jet.DENSE_RANK + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +var PERCENT_RANK = jet.PERCENT_RANK + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +var CUME_DIST = jet.CUME_DIST + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +var NTILE = jet.NTILE + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LAG = jet.LAG + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LEAD = jet.LEAD + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +var FIRST_VALUE = jet.FIRST_VALUE + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +var LAST_VALUE = jet.LAST_VALUE + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +var NTH_VALUE = jet.NTH_VALUE + +//--------------------- String functions ------------------// + +// BIT_LENGTH returns number of bits in string expression +//var BIT_LENGTH = jet.BIT_LENGTH +// +//// CHAR_LENGTH returns number of characters in string expression +//var CHAR_LENGTH = jet.CHAR_LENGTH +// +//// OCTET_LENGTH returns number of bytes in string expression +//var OCTET_LENGTH = jet.OCTET_LENGTH + +// LOWER returns string expression in lower case +var LOWER = jet.LOWER + +// UPPER returns string expression in upper case +var UPPER = jet.UPPER + +// LTRIM removes the longest string containing only characters +// from characters (a space by default) from the start of string +var LTRIM = jet.LTRIM + +// RTRIM removes the longest string containing only characters +// from characters (a space by default) from the end of string +var RTRIM = jet.RTRIM + +// CONCAT adds two or more expressions together +//var CONCAT = jet.CONCAT + +// CONCAT_WS adds two or more expressions together with a separator. +//var CONCAT_WS = jet.CONCAT_WS + +// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. +//var FORMAT = jet.FORMAT + +// LEFTSTR returns first n characters in the string. +// When n is negative, return all but last |n| characters. +//func LEFTSTR(str StringExpression, n IntegerExpression) StringExpression { +// return jet.NewStringFunc("LEFTSTR", str, n) +//} +// +//// RIGHT returns last n characters in the string. +//// When n is negative, return all but first |n| characters. +//func RIGHTSTR(str StringExpression, n IntegerExpression) StringExpression { +// return jet.NewStringFunc("RIGHTSTR", str, n) +//} + +// LENGTH returns number of characters in string with a given encoding +func LENGTH(str jet.StringExpression) jet.StringExpression { + return jet.LENGTH(str) +} + +// LPAD fills up the string to length length by prepending the characters +// fill (a space by default). If the string is already longer than length +// then it is truncated (on the right). +//func LPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression { +// return jet.LPAD(str, length, text) +//} + +// RPAD fills up the string to length length by appending the characters +// fill (a space by default). If the string is already longer than length then it is truncated. +//func RPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression { +// return jet.RPAD(str, length, text) +//} + +// MD5 calculates the MD5 hash of string, returning the result in hexadecimal +//var MD5 = jet.MD5 + +// REPEAT repeats string the specified number of times +//var REPEAT = jet.REPEAT + +// REPLACE replaces all occurrences in string of substring from with substring to +var REPLACE = jet.REPLACE + +// REVERSE returns reversed string. +var REVERSE = jet.REVERSE + +// SUBSTR extracts substring +var SUBSTR = jet.SUBSTR + +// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. +var REGEXP_LIKE = jet.REGEXP_LIKE + +//----------------- Date/Time Functions and Operators ------------// + +// CURRENT_DATE returns current date +var CURRENT_DATE = jet.CURRENT_DATE + +// CURRENT_TIME returns current time with time zone +func CURRENT_TIME() TimeExpression { + return TimeExp(jet.CURRENT_TIME()) +} + +// CURRENT_TIMESTAMP returns current timestamp with time zone +func CURRENT_TIMESTAMP() TimestampExpression { + return TimestampExp(jet.CURRENT_TIMESTAMP()) +} + +//// NOW returns current datetime +//func NOW() DateTimeExpression { +// //if len(fsp) > 0 { +// // return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0]))) +// //} +// //return jet.NewTimestampFunc("NOW") +// return DATETIME(jet.FixedLiteral("now")) +//} + +// time-value modifiers +var ( + YEARS = modifier("YEARS") + MONTHS = modifier("MONTHS") + DAYS = modifier("DAYS") + HOURS = modifier("HOURS") + MINUTES = modifier("MINUTES") + SECONDS = modifier("SECONDS") + + START_OF_YEAR = String("start of year") + START_OF_MONTH = String("start of month") + UNIXEPOCH = String("unixepoch") + LOCALTIME = String("localtime") + UTC = String("UTC") + + WEEKDAY = func(value int) Expression { + return String(fmt.Sprintf("WEEKDAY %d", value)) + } +) + +func modifier(modifierName string) func(value float64) Expression { + return func(value float64) Expression { + return String(fmt.Sprintf("%g %s", value, modifierName)) + } +} + +// DATE function creates new date from time-value and zero or more time modifiers +func DATE(timeValue interface{}, modifiers ...Expression) DateExpression { + exprList := getFuncExprList(timeValue, modifiers...) + + return jet.NewDateFunc("DATE", exprList...) +} + +// TIME function creates new time from time-value and zero or more time modifiers +func TIME(timeValue interface{}, modifiers ...Expression) TimeExpression { + exprList := getFuncExprList(timeValue, modifiers...) + + return jet.NewTimeFunc("TIME", exprList...) +} + +// DATETIME function creates new DateTime from time-value and zero or more time modifiers +func DATETIME(timeValue interface{}, modifiers ...Expression) DateTimeExpression { + exprList := getFuncExprList(timeValue, modifiers...) + + return jet.NewTimestampFunc("DATETIME", exprList...) +} + +// JULIANDAY returns the number of days since noon in Greenwich on November 24, 4714 B.C +func JULIANDAY(timeValue interface{}, modifiers ...Expression) FloatExpression { + exprList := getFuncExprList(timeValue, modifiers...) + return jet.NewFloatFunc("JULIANDAY", exprList...) +} + +// STRFTIME routine returns the date formatted according to the format string specified as the first argument. +func STRFTIME(format StringExpression, timeValue interface{}, modifiers ...Expression) StringExpression { + exprList := append([]Expression{format}, getFuncExprList(timeValue, modifiers...)...) + return jet.NewStringFunc("strftime", exprList...) +} + +func getFuncExprList(timeValue interface{}, modifiers ...Expression) []Expression { + return append([]Expression{getTimeValueExpression(timeValue)}, modifiers...) +} + +func getTimeValueExpression(timeValue interface{}) Expression { + switch t := timeValue.(type) { + case string: + return String(t) + case Expression: + return t + case time.Time, int64: + return jet.Literal(t) + } + + panic(fmt.Sprintf("jet: Invalid time value %T(%q)", timeValue, timeValue)) +} + +// TIMESTAMP return a datetime value based on the arguments: +func TIMESTAMP(str StringExpression) TimestampExpression { + return jet.NewTimestampFunc("TIMESTAMP", str) +} + +// UNIX_TIMESTAMP returns unix timestamp +func UNIX_TIMESTAMP(str StringExpression) TimestampExpression { + return jet.NewTimestampFunc("UNIX_TIMESTAMP", str) +} + +//----------- Comparison operators ---------------// + +// EXISTS checks for existence of the rows in subQuery +var EXISTS = jet.EXISTS + +// CASE create CASE operator with optional list of expressions +var CASE = jet.CASE diff --git a/sqlite/insert_statement.go b/sqlite/insert_statement.go new file mode 100644 index 00000000..3912cc32 --- /dev/null +++ b/sqlite/insert_statement.go @@ -0,0 +1,117 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// InsertStatement is interface for SQL INSERT statements +type InsertStatement interface { + Statement + + VALUES(value interface{}, values ...interface{}) InsertStatement + MODEL(data interface{}) InsertStatement + MODELS(data interface{}) InsertStatement + QUERY(selectStatement SelectStatement) InsertStatement + DEFAULT_VALUES() InsertStatement + + ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict + RETURNING(projections ...Projection) InsertStatement +} + +func newInsertStatement(table Table, columns []jet.Column) InsertStatement { + newInsert := &insertStatementImpl{ + DefaultValues: jet.ClauseOptional{Name: "DEFAULT VALUES", InNewLine: true}, + } + + newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, + &newInsert.Insert, + &newInsert.ValuesQuery, + &newInsert.OnDuplicateKey, + &newInsert.DefaultValues, + &newInsert.OnConflict, + &newInsert.Returning, + ) + + newInsert.Insert.Table = table + newInsert.Insert.Columns = columns + newInsert.ValuesQuery.SkipSelectWrap = true + + return newInsert +} + +type insertStatementImpl struct { + jet.SerializerStatement + + Insert jet.ClauseInsert + ValuesQuery jet.ClauseValuesQuery + OnDuplicateKey onDuplicateKeyUpdateClause + DefaultValues jet.ClauseOptional + OnConflict onConflictClause + Returning jet.ClauseReturning +} + +func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) + return is +} + +// MODEL will insert row of values, where value for each column is extracted from filed of structure data. +// If data is not struct or there is no field for every column selected, this method will panic. +func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data)) + return is +} + +func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...) + return is +} + +func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement { + is.OnDuplicateKey = assigments + return is +} + +func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + is.ValuesQuery.Query = selectStatement + return is +} + +func (is *insertStatementImpl) DEFAULT_VALUES() InsertStatement { + is.DefaultValues.Show = true + return is +} + +func (is *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement { + is.Returning.ProjectionList = projections + return is +} + +type onDuplicateKeyUpdateClause []jet.ColumnAssigment + +// Serialize for SetClause +func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(s) == 0 { + return + } + out.NewLine() + out.WriteString("ON DUPLICATE KEY UPDATE") + out.IncreaseIdent(24) + + for i, assigment := range s { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...) + } + + out.DecreaseIdent(24) +} + +func (is *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict { + is.OnConflict = onConflictClause{ + insertStatement: is, + indexExpressions: indexExpressions, + } + return &is.OnConflict +} diff --git a/sqlite/insert_statement_test.go b/sqlite/insert_statement_test.go new file mode 100644 index 00000000..5bb639e8 --- /dev/null +++ b/sqlite/insert_statement_test.go @@ -0,0 +1,150 @@ +package sqlite + +import ( + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestInvalidInsert(t *testing.T) { + assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") +} + +func TestInsertNilValue(t *testing.T) { + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` +INSERT INTO db.table1 (col1) +VALUES (?); +`, nil) +} + +func TestInsertSingleValue(t *testing.T) { + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` +INSERT INTO db.table1 (col1) +VALUES (?); +`, int(1)) +} + +func TestInsertWithColumnList(t *testing.T) { + columnList := ColumnList{table3ColInt} + + columnList = append(columnList, table3StrCol) + + assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` +INSERT INTO db.table3 (col_int, col2) +VALUES (?, ?); +`, 1, 3) +} + +func TestInsertDate(t *testing.T) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` +INSERT INTO db.table1 (col_timestamp) +VALUES (?); +`, date) +} + +func TestInsertMultipleValues(t *testing.T) { + assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` +INSERT INTO db.table1 (col1, col_float, col3) +VALUES (?, ?, ?); +`, 1, 2, 3) +} + +func TestInsertMultipleRows(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColFloat). + VALUES(1, 2). + VALUES(11, 22). + VALUES(111, 222) + + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?), + (?, ?); +`, 1, 2, 11, 22, 111, 222) +} + +func TestInsertValuesFromModel(t *testing.T) { + type Table1Model struct { + Col1 *int + ColFloat float64 + } + + one := 1 + + toInsert := Table1Model{ + Col1: &one, + ColFloat: 1.11, + } + + stmt := table1.INSERT(table1Col1, table1ColFloat). + MODEL(toInsert). + MODEL(&toInsert) + + expectedSQL := ` +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?); +` + + assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) +} + +func TestInsertValuesFromModelColumnMismatch(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, r, "missing struct field for column : col1") + }() + type Table1Model struct { + Col1Prim int + Col2 string + } + + newData := Table1Model{ + Col1Prim: 1, + Col2: "one", + } + + table1. + INSERT(table1Col1, table1ColFloat). + MODEL(newData) +} + +func TestInsertFromNonStructModel(t *testing.T) { + + defer func() { + r := recover() + require.Equal(t, r, "jet: data has to be a struct") + }() + + table2.INSERT(table2ColInt).MODEL([]int{}) +} + +func TestInsert_ON_CONFLICT(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColBool). + VALUES("one", "two"). + VALUES("1", "2"). + VALUES("theta", "beta"). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), + ). + RETURNING(table1Col1, table1ColBool) + + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_bool) +VALUES (?, ?), + (?, ?), + (?, ?) +ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE + SET col_bool = ?, + col_int = ?, + (col1, col_bool) = (?, ?) + WHERE table1.col1 > ? +RETURNING table1.col1 AS "table1.col1", + table1.col_bool AS "table1.col_bool"; +`) +} diff --git a/sqlite/literal.go b/sqlite/literal.go new file mode 100644 index 00000000..2df5dd74 --- /dev/null +++ b/sqlite/literal.go @@ -0,0 +1,70 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" + "time" +) + +// Keywords +var ( + STAR = jet.STAR + NULL = jet.NULL +) + +// Bool creates new bool literal expression +var Bool = jet.Bool + +// Int is constructor for 64 bit signed integer expressions literals. +var Int = jet.Int + +// Int8 is constructor for 8 bit signed integer expressions literals. +var Int8 = jet.Int8 + +// Int16 is constructor for 16 bit signed integer expressions literals. +var Int16 = jet.Int16 + +// Int32 is constructor for 32 bit signed integer expressions literals. +var Int32 = jet.Int32 + +// Int64 is constructor for 64 bit signed integer expressions literals. +var Int64 = jet.Int + +// Uint8 is constructor for 8 bit unsigned integer expressions literals. +var Uint8 = jet.Uint8 + +// Uint16 is constructor for 16 bit unsigned integer expressions literals. +var Uint16 = jet.Uint16 + +// Uint32 is constructor for 32 bit unsigned integer expressions literals. +var Uint32 = jet.Uint32 + +// Uint64 is constructor for 64 bit unsigned integer expressions literals. +var Uint64 = jet.Uint64 + +// Float creates new float literal expression from float64 value +var Float = jet.Float + +// Decimal creates new float literal expression from string value +var Decimal = jet.Decimal + +// String creates new string literal expression +var String = jet.String + +// UUID is a helper function to create string literal expression from uuid object +// value can be any uuid type with a String method +var UUID = jet.UUID + +// Date creates new date literal expression +func Date(year int, month time.Month, day int) DateExpression { + return DATE(jet.Date(year, month, day)) +} + +// Time creates new time literal expression +func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { + return TIME(jet.Time(hour, minute, second, nanoseconds...)) +} + +// DateTime creates new datetime(timestamp) literal expression +func DateTime(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) DateTimeExpression { + return DATETIME(jet.Timestamp(year, month, day, hour, minute, second, nanoseconds...)) +} diff --git a/sqlite/literal_test.go b/sqlite/literal_test.go new file mode 100644 index 00000000..8a409315 --- /dev/null +++ b/sqlite/literal_test.go @@ -0,0 +1,80 @@ +package sqlite + +import ( + "math" + "testing" + "time" +) + +func TestBool(t *testing.T) { + assertSerialize(t, Bool(false), `?`, false) +} + +func TestInt(t *testing.T) { + assertSerialize(t, Int(11), `?`, int64(11)) +} + +func TestInt8(t *testing.T) { + val := int8(math.MinInt8) + assertSerialize(t, Int8(val), `?`, val) +} + +func TestInt16(t *testing.T) { + val := int16(math.MinInt16) + assertSerialize(t, Int16(val), `?`, val) +} + +func TestInt32(t *testing.T) { + val := int32(math.MinInt32) + assertSerialize(t, Int32(val), `?`, val) +} + +func TestInt64(t *testing.T) { + val := int64(math.MinInt64) + assertSerialize(t, Int64(val), `?`, val) +} + +func TestUint8(t *testing.T) { + val := uint8(math.MaxUint8) + assertSerialize(t, Uint8(val), `?`, val) +} + +func TestUint16(t *testing.T) { + val := uint16(math.MaxUint16) + assertSerialize(t, Uint16(val), `?`, val) +} + +func TestUint32(t *testing.T) { + val := uint32(math.MaxUint32) + assertSerialize(t, Uint32(val), `?`, val) +} + +func TestUint64(t *testing.T) { + val := uint64(math.MaxUint64) + assertSerialize(t, Uint64(val), `?`, val) +} + +func TestFloat(t *testing.T) { + assertSerialize(t, Float(12.34), `?`, float64(12.34)) +} + +func TestString(t *testing.T) { + assertSerialize(t, String("Some text"), `?`, "Some text") +} + +var testTime = time.Now() + +func TestDate(t *testing.T) { + assertSerialize(t, Date(2014, time.January, 2), "DATE(?)", "2014-01-02") + assertSerialize(t, DATE(testTime), "DATE(?)", testTime) +} + +func TestTime(t *testing.T) { + assertSerialize(t, Time(10, 15, 30), `TIME(?)`, "10:15:30") + assertSerialize(t, TIME(testTime), "TIME(?)", testTime) +} + +func TestDateTime(t *testing.T) { + assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `DATETIME(?)`, "2010-03-30 10:15:30") + assertSerialize(t, DATETIME(testTime), `DATETIME(?)`, testTime) +} diff --git a/sqlite/on_conflict_clause.go b/sqlite/on_conflict_clause.go new file mode 100644 index 00000000..d131b9ea --- /dev/null +++ b/sqlite/on_conflict_clause.go @@ -0,0 +1,84 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +type onConflict interface { + WHERE(indexPredicate BoolExpression) conflictTarget + conflictTarget +} + +type conflictTarget interface { + DO_NOTHING() InsertStatement + DO_UPDATE(action conflictAction) InsertStatement +} + +type onConflictClause struct { + insertStatement InsertStatement + indexExpressions []jet.ColumnExpression + whereClause jet.ClauseWhere + do jet.Serializer +} + +func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget { + o.whereClause.Condition = indexPredicate + return o +} + +func (o *onConflictClause) DO_NOTHING() InsertStatement { + o.do = jet.Keyword("DO NOTHING") + return o.insertStatement +} + +func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement { + o.do = action + return o.insertStatement +} + +func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(o.indexExpressions) == 0 && o.do == nil { + return + } + + out.NewLine() + out.WriteString("ON CONFLICT") + if len(o.indexExpressions) > 0 { + out.WriteString("(") + jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + out.WriteString(")") + } + + o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName) + + out.IncreaseIdent(7) + jet.Serialize(o.do, statementType, out) + out.DecreaseIdent(7) +} + +type conflictAction interface { + jet.Serializer + WHERE(condition BoolExpression) conflictAction +} + +// SET creates conflict action for ON_CONFLICT clause +func SET(assigments ...ColumnAssigment) conflictAction { + conflictAction := updateConflictActionImpl{} + conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} + conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) + conflictAction.set = assigments + return &conflictAction +} + +type updateConflictActionImpl struct { + jet.Serializer + + doUpdate jet.KeywordClause + set jet.SetClauseNew + where jet.ClauseWhere +} + +func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { + u.where.Condition = condition + return u +} diff --git a/sqlite/operators.go b/sqlite/operators.go new file mode 100644 index 00000000..8ebecbf4 --- /dev/null +++ b/sqlite/operators.go @@ -0,0 +1,9 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// NOT returns negation of bool expression result +var NOT = jet.NOT + +// BIT_NOT inverts every bit in integer expression result +var BIT_NOT = jet.BIT_NOT diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go new file mode 100644 index 00000000..4406dcd3 --- /dev/null +++ b/sqlite/select_statement.go @@ -0,0 +1,186 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +// RowLock is interface for SELECT statement row lock types +type RowLock = jet.RowLock + +// Row lock types +var ( + UPDATE = jet.NewRowLock("UPDATE") + SHARE = jet.NewRowLock("SHARE") +) + +// Window function clauses +var ( + PARTITION_BY = jet.PARTITION_BY + ORDER_BY = jet.ORDER_BY + UNBOUNDED = jet.UNBOUNDED + CURRENT_ROW = jet.CURRENT_ROW +) + +// PRECEDING window frame clause +func PRECEDING(offset interface{}) jet.FrameExtent { + return jet.PRECEDING(toJetFrameOffset(offset)) +} + +// FOLLOWING window frame clause +func FOLLOWING(offset interface{}) jet.FrameExtent { + return jet.FOLLOWING(toJetFrameOffset(offset)) +} + +// Window is used to specify window reference from WINDOW clause +var Window = jet.WindowName + +// SelectStatement is interface for MySQL SELECT statement +type SelectStatement interface { + Statement + jet.HasProjections + Expression + + DISTINCT() SelectStatement + FROM(tables ...ReadableTable) SelectStatement + WHERE(expression BoolExpression) SelectStatement + GROUP_BY(groupByClauses ...GroupByClause) SelectStatement + HAVING(boolExpression BoolExpression) SelectStatement + WINDOW(name string) windowExpand + ORDER_BY(orderByClauses ...OrderByClause) SelectStatement + LIMIT(limit int64) SelectStatement + OFFSET(offset int64) SelectStatement + FOR(lock RowLock) SelectStatement + LOCK_IN_SHARE_MODE() SelectStatement + + UNION(rhs SelectStatement) setStatement + UNION_ALL(rhs SelectStatement) setStatement + + AsTable(alias string) SelectTable +} + +//SELECT creates new SelectStatement with list of projections +func SELECT(projection Projection, projections ...Projection) SelectStatement { + return newSelectStatement(nil, append([]Projection{projection}, projections...)) +} + +func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { + newSelect := &selectStatementImpl{} + newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, + &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) + + newSelect.Select.ProjectionList = projections + if table != nil { + newSelect.From.Tables = []jet.Serializer{table} + } + newSelect.Limit.Count = -1 + newSelect.Offset.Count = -1 + newSelect.ShareLock.Name = "LOCK IN SHARE MODE" + newSelect.ShareLock.InNewLine = true + + newSelect.setOperatorsImpl.parent = newSelect + + return newSelect +} + +type selectStatementImpl struct { + jet.ExpressionStatement + setOperatorsImpl + + Select jet.ClauseSelect + From jet.ClauseFrom + Where jet.ClauseWhere + GroupBy jet.ClauseGroupBy + Having jet.ClauseHaving + Window jet.ClauseWindow + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Offset jet.ClauseOffset + For jet.ClauseFor + ShareLock jet.ClauseOptional +} + +func (s *selectStatementImpl) DISTINCT() SelectStatement { + s.Select.Distinct = true + return s +} + +func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { + s.From.Tables = nil + for _, table := range tables { + s.From.Tables = append(s.From.Tables, table) + } + return s +} + +func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement { + s.Where.Condition = condition + return s +} + +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { + s.GroupBy.List = groupByClauses + return s +} + +func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatement { + s.Having.Condition = boolExpression + return s +} + +func (s *selectStatementImpl) WINDOW(name string) windowExpand { + s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{selectStatement: s} +} + +func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement { + s.OrderBy.List = orderByClauses + return s +} + +func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { + s.Limit.Count = limit + return s +} + +func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.Offset.Count = offset + return s +} + +func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { + s.For.Lock = lock + return s +} + +func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { + s.ShareLock.Show = true + return s +} + +func (s *selectStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +//----------------------------------------------------- + +type windowExpand struct { + selectStatement *selectStatementImpl +} + +func (w windowExpand) AS(window ...jet.Window) SelectStatement { + if len(window) == 0 { + return w.selectStatement + } + windowsDefinition := w.selectStatement.Window.Definitions + windowsDefinition[len(windowsDefinition)-1].Window = window[0] + return w.selectStatement +} + +func toJetFrameOffset(offset interface{}) jet.Serializer { + if offset == UNBOUNDED { + return jet.UNBOUNDED + } + + return jet.FixedLiteral(offset) +} diff --git a/sqlite/select_statement_test.go b/sqlite/select_statement_test.go new file mode 100644 index 00000000..0ba76f0f --- /dev/null +++ b/sqlite/select_statement_test.go @@ -0,0 +1,156 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "testing" +) + +func TestInvalidSelect(t *testing.T) { + assertStatementSqlErr(t, SELECT(nil), "jet: Projection is nil") +} + +func TestSelectColumnList(t *testing.T) { + columnList := ColumnList{table2ColInt, table2ColFloat, table3ColInt} + + assertStatementSql(t, SELECT(columnList).FROM(table2), ` +SELECT table2.col_int AS "table2.col_int", + table2.col_float AS "table2.col_float", + table3.col_int AS "table3.col_int" +FROM db.table2; +`) +} + +func TestSelectLiterals(t *testing.T) { + assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` +SELECT ?, + ?, + ? +FROM db.table1; +`, int64(1), 2.2, false) +} + +func TestSelectDistinct(t *testing.T) { + assertStatementSql(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` +SELECT DISTINCT table1.col_bool AS "table1.col_bool" +FROM db.table1; +`) +} + +func TestSelectFrom(t *testing.T) { + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1; +`) + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) + assertStatementSql(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) +} + +func TestSelectWhere(t *testing.T) { + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE ?; +`, true) + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE table1.col_int >= ?; +`, int64(10)) +} + +func TestSelectGroupBy(t *testing.T) { + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +GROUP BY table2.col_float; +`) +} + +func TestSelectHaving(t *testing.T) { + assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` +SELECT table3.col_int AS "table3.col_int" +FROM db.table3 +HAVING table1.col_bool = ?; +`, true) +} + +func TestSelectOrderBy(t *testing.T) { + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC; +`) + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC, table2.col_int ASC; +`) +} + +func TestSelectLimitOffset(t *testing.T) { + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT ?; +`, int64(10)) + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT ? +OFFSET ?; +`, int64(10), int64(2)) +} + +func TestSelectLock(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR SHARE NOWAIT; +`) +} + +func TestSelect_LOCK_IN_SHARE_MODE(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).LOCK_IN_SHARE_MODE(), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +LOCK IN SHARE MODE; +`) +} + +func TestSelect_NOT_EXISTS(t *testing.T) { + testutils.AssertStatementSql(t, + SELECT(table1ColInt). + FROM(table1). + WHERE( + NOT(EXISTS( + SELECT(table2ColInt). + FROM(table2). + WHERE( + table1ColInt.EQ(table2ColInt), + ), + ))), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE (NOT (EXISTS ( + SELECT table2.col_int AS "table2.col_int" + FROM db.table2 + WHERE table1.col_int = table2.col_int + ))); +`) +} diff --git a/sqlite/select_table.go b/sqlite/select_table.go new file mode 100644 index 00000000..4117e064 --- /dev/null +++ b/sqlite/select_table.go @@ -0,0 +1,24 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// SelectTable is interface for MySQL sub-queries +type SelectTable interface { + readableTable + jet.SelectTable +} + +type selectTableImpl struct { + jet.SelectTable + readableTableInterfaceImpl +} + +func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTable: jet.NewSelectTable(selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go new file mode 100644 index 00000000..18bcca56 --- /dev/null +++ b/sqlite/set_statement.go @@ -0,0 +1,99 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// UNION effectively appends the result of sub-queries(select statements) into single query. +// It eliminates duplicate rows from its result. +func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { + return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) +} + +// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. +// It does not eliminates duplicate rows from its result. +func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { + return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) +} + +type setStatement interface { + setOperators + + ORDER_BY(orderByClauses ...OrderByClause) setStatement + + LIMIT(limit int64) setStatement + OFFSET(offset int64) setStatement + + AsTable(alias string) SelectTable +} + +type setOperators interface { + jet.Statement + jet.HasProjections + jet.Expression + + UNION(rhs SelectStatement) setStatement + UNION_ALL(rhs SelectStatement) setStatement +} + +type setOperatorsImpl struct { + parent setOperators +} + +func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement { + return UNION(s.parent, rhs) +} + +func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement { + return UNION_ALL(s.parent, rhs) +} + +type setStatementImpl struct { + jet.ExpressionStatement + + setOperatorsImpl + + setOperator jet.ClauseSetStmtOperator +} + +func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement { + newSetStatement := &setStatementImpl{} + newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement, + &newSetStatement.setOperator) + + newSetStatement.setOperator.Operator = operator + newSetStatement.setOperator.All = all + newSetStatement.setOperator.Selects = selects + newSetStatement.setOperator.Limit.Count = -1 + newSetStatement.setOperator.Offset.Count = -1 + newSetStatement.setOperator.SkipSelectWrap = true + + newSetStatement.setOperatorsImpl.parent = newSetStatement + + return newSetStatement +} + +func (s *setStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) setStatement { + s.setOperator.OrderBy.List = orderByClauses + return s +} + +func (s *setStatementImpl) LIMIT(limit int64) setStatement { + s.setOperator.Limit.Count = limit + return s +} + +func (s *setStatementImpl) OFFSET(offset int64) setStatement { + s.setOperator.Offset.Count = offset + return s +} + +func (s *setStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +const ( + union = "UNION" +) + +func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement { + return append([]jet.SerializerStatement{lhs, rhs}, selects...) +} diff --git a/sqlite/set_statement_test.go b/sqlite/set_statement_test.go new file mode 100644 index 00000000..c822089b --- /dev/null +++ b/sqlite/set_statement_test.go @@ -0,0 +1,31 @@ +package sqlite + +import ( + "testing" +) + +func TestSelectSets(t *testing.T) { + select1 := SELECT(table1ColBool).FROM(table1) + select2 := SELECT(table2ColBool).FROM(table2) + + assertStatementSql(t, select1.UNION(select2), ` + +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 + +UNION + +SELECT table2.col_bool AS "table2.col_bool" +FROM db.table2; +`) + assertStatementSql(t, select1.UNION_ALL(select2), ` + +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 + +UNION ALL + +SELECT table2.col_bool AS "table2.col_bool" +FROM db.table2; +`) +} diff --git a/sqlite/statement.go b/sqlite/statement.go new file mode 100644 index 00000000..754ae41a --- /dev/null +++ b/sqlite/statement.go @@ -0,0 +1,8 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// RawStatement creates new sql statements from raw query and optional map of named arguments +func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { + return jet.RawStatement(Dialect, rawQuery, namedArguments...) +} diff --git a/sqlite/table.go b/sqlite/table.go new file mode 100644 index 00000000..6d70f7fe --- /dev/null +++ b/sqlite/table.go @@ -0,0 +1,122 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Table is interface for MySQL tables +type Table interface { + jet.SerializerTable + readableTable + + INSERT(columns ...jet.Column) InsertStatement + UPDATE(columns ...jet.Column) UpdateStatement + DELETE() DeleteStatement +} + +type readableTable interface { + // Generates a select query on the current tableName. + SELECT(projection Projection, projections ...Projection) SelectStatement + + // Creates a inner join tableName Expression using onCondition. + INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a left join tableName Expression using onCondition. + LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a right join tableName Expression using onCondition. + RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a full join tableName Expression using onCondition. + FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a cross join tableName Expression using onCondition. + CROSS_JOIN(table ReadableTable) joinSelectUpdateTable +} + +type joinSelectUpdateTable interface { + ReadableTable + UPDATE(columns ...jet.Column) UpdateStatement +} + +// ReadableTable interface +type ReadableTable interface { + readableTable + jet.Serializer +} + +type readableTableInterfaceImpl struct { + parent ReadableTable +} + +// Generates a select query on the current tableName. +func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { + return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) +} + +// Creates a inner join tableName Expression using onCondition. +func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) +} + +// Creates a left join tableName Expression using onCondition. +func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) +} + +// Creates a right join tableName Expression using onCondition. +func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.RightJoin, onCondition) +} + +func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.FullJoin, onCondition) +} + +func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.CrossJoin, nil) +} + +// NewTable creates new table with schema Name, table Name and list of columns +func NewTable(schemaName, name, alias string, columns ...jet.ColumnExpression) Table { + t := &tableImpl{ + SerializerTable: jet.NewTable(schemaName, name, alias, columns...), + } + + t.readableTableInterfaceImpl.parent = t + t.parent = t + + return t +} + +type tableImpl struct { + jet.SerializerTable + readableTableInterfaceImpl + parent Table +} + +func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement { + return newInsertStatement(t.parent, jet.UnwidColumnList(columns)) +} + +func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement { + return newUpdateStatement(t.parent, jet.UnwidColumnList(columns)) +} + +func (t *tableImpl) DELETE() DeleteStatement { + return newDeleteStatement(t.parent) +} + +type joinTable struct { + tableImpl + jet.JoinTable +} + +func newJoinTable(lhs jet.Serializer, rhs jet.Serializer, joinType jet.JoinType, onCondition BoolExpression) Table { + newJoinTable := &joinTable{ + JoinTable: jet.NewJoinTable(lhs, rhs, joinType, onCondition), + } + + newJoinTable.readableTableInterfaceImpl.parent = newJoinTable + newJoinTable.parent = newJoinTable + + return newJoinTable +} diff --git a/sqlite/table_test.go b/sqlite/table_test.go new file mode 100644 index 00000000..a68d5622 --- /dev/null +++ b/sqlite/table_test.go @@ -0,0 +1,101 @@ +package sqlite + +import ( + "testing" +) + +func TestJoinNilInputs(t *testing.T) { + assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + "jet: right hand side of join operation is nil table") + assertSerializeErr(t, table2.INNER_JOIN(table1, nil), + "jet: join condition is nil") +} + +func TestINNER_JOIN(t *testing.T) { + assertSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). + INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = table2.col_int) +INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(Int(1))). + INNER_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = ?) +INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestLEFT_JOIN(t *testing.T) { + assertSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) +LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). + LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = ?) +LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestRIGHT_JOIN(t *testing.T) { + assertSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) +RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). + RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = ?) +RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestFULL_JOIN(t *testing.T) { + assertSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). + FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = table2.col_int) +FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(Int(1))). + FULL_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = ?) +FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestCROSS_JOIN(t *testing.T) { + assertSerialize(t, table1. + CROSS_JOIN(table2), + `db.table1 +CROSS JOIN db.table2`) + assertSerialize(t, table1. + CROSS_JOIN(table2). + CROSS_JOIN(table3), + `db.table1 +CROSS JOIN db.table2 +CROSS JOIN db.table3`) +} diff --git a/sqlite/types.go b/sqlite/types.go new file mode 100644 index 00000000..755be1d8 --- /dev/null +++ b/sqlite/types.go @@ -0,0 +1,27 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) +type Statement = jet.Statement + +// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. +type Projection = jet.Projection + +// ProjectionList can be used to create conditional constructed projection list. +type ProjectionList = jet.ProjectionList + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment = jet.ColumnAssigment + +// PrintableStatement is a statement which sql query can be logged +type PrintableStatement = jet.PrintableStatement + +// OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY. +type OrderByClause = jet.OrderByClause + +// GroupByClause interface to use as input for GROUP_BY +type GroupByClause = jet.GroupByClause + +// SetLogger sets automatic statement logging +var SetLogger = jet.SetLoggerFunc diff --git a/sqlite/update_statement.go b/sqlite/update_statement.go new file mode 100644 index 00000000..53cf72d1 --- /dev/null +++ b/sqlite/update_statement.go @@ -0,0 +1,70 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// UpdateStatement is interface of SQL UPDATE statement +type UpdateStatement interface { + jet.Statement + + SET(value interface{}, values ...interface{}) UpdateStatement + MODEL(data interface{}) UpdateStatement + + WHERE(expression BoolExpression) UpdateStatement + RETURNING(projections ...jet.Projection) UpdateStatement +} + +type updateStatementImpl struct { + jet.SerializerStatement + + Update jet.ClauseUpdate + Set jet.SetClause + SetNew jet.SetClauseNew + Where jet.ClauseWhere + Returning jet.ClauseReturning +} + +func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { + update := &updateStatementImpl{} + update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, + &update.Update, + &update.Set, + &update.SetNew, + &update.Where, + &update.Returning) + + update.Update.Table = table + update.Set.Columns = columns + update.Where.Mandatory = true + + return update +} + +func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { + columnAssigment, isColumnAssigment := value.(ColumnAssigment) + + if isColumnAssigment { + u.SetNew = []ColumnAssigment{columnAssigment} + for _, value := range values { + u.SetNew = append(u.SetNew, value.(ColumnAssigment)) + } + } else { + u.Set.Values = jet.UnwindRowFromValues(value, values) + } + + return u +} + +func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { + u.Set.Values = jet.UnwindRowFromModel(u.Set.Columns, data) + return u +} + +func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { + u.Where.Condition = expression + return u +} + +func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { + u.Returning.ProjectionList = projections + return u +} diff --git a/sqlite/update_statement_test.go b/sqlite/update_statement_test.go new file mode 100644 index 00000000..5c468a3b --- /dev/null +++ b/sqlite/update_statement_test.go @@ -0,0 +1,82 @@ +package sqlite + +import ( + "fmt" + "strings" + "testing" +) + +func TestUpdateWithOneValue(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ? +WHERE table1.col_int >= ?; +` + stmt := table1.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatementSql(t, stmt, expectedSQL, 1, int64(33)) +} + +func TestUpdateWithValues(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ?, + col_float = ? +WHERE table1.col_int >= ?; +` + stmt := table1.UPDATE(table1ColInt, table1ColFloat). + SET(1, 22.2). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatementSql(t, stmt, expectedSQL, 1, 22.2, int64(33)) +} + +func TestUpdateOneColumnWithSelect(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_float = ( + SELECT table1.col_float AS "table1.col_float" + FROM db.table1 + ) +WHERE table1.col1 = ?; +` + stmt := table1. + UPDATE(table1ColFloat). + SET( + table1.SELECT(table1ColFloat), + ). + WHERE(table1Col1.EQ(Int(2))) + + assertStatementSql(t, stmt, expectedSQL, int64(2)) +} + +func TestUpdateReservedWorldColumn(t *testing.T) { + type table struct { + Load string + } + + loadColumn := StringColumn("Load") + assertStatementSql(t, + table1.UPDATE(loadColumn). + MODEL( + table{ + Load: "foo", + }, + ). + WHERE(loadColumn.EQ(String("bar"))), strings.Replace(` +UPDATE db.table1 +SET ''Load'' = ? +WHERE ''Load'' = ?; +`, "''", "`", -1), "foo", "bar") +} + +func TestInvalidInputs(t *testing.T) { + assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") + assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause") +} diff --git a/sqlite/utils_test.go b/sqlite/utils_test.go new file mode 100644 index 00000000..3f9b9f36 --- /dev/null +++ b/sqlite/utils_test.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" + "github.com/go-jet/jet/v2/internal/testutils" + "testing" +) + +var table1Col1 = IntegerColumn("col1") +var table1ColBool = BoolColumn("col_bool") +var table1ColInt = IntegerColumn("col_int") +var table1ColFloat = FloatColumn("col_float") +var table1ColString = StringColumn("col_string") +var table1Col3 = IntegerColumn("col3") +var table1ColTimestamp = TimestampColumn("col_timestamp") +var table1ColDate = DateColumn("col_date") +var table1ColTime = TimeColumn("col_time") + +var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1ColString, table1Col3, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTime) + +var table2Col3 = IntegerColumn("col3") +var table2Col4 = IntegerColumn("col4") +var table2ColInt = IntegerColumn("col_int") +var table2ColFloat = FloatColumn("col_float") +var table2ColStr = StringColumn("col_str") +var table2ColBool = BoolColumn("col_bool") +var table2ColTimestamp = TimestampColumn("col_timestamp") +var table2ColDate = DateColumn("col_date") + +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColDate, table2ColTimestamp) + +var table3Col1 = IntegerColumn("col1") +var table3ColInt = IntegerColumn("col_int") +var table3StrCol = StringColumn("col2") +var table3 = NewTable("db", "table3", "", table3Col1, table3ColInt, table3StrCol) + +func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { + testutils.AssertSerialize(t, Dialect, clause, query, args...) +} + +func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { + testutils.AssertDebugSerialize(t, Dialect, clause, query, args...) +} + +func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) { + testutils.AssertSerializeErr(t, Dialect, clause, errString) +} + +func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { + testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) +} + +var assertPanicErr = testutils.AssertPanicErr +var assertStatementSql = testutils.AssertStatementSql +var assertStatementSqlErr = testutils.AssertStatementSqlErr diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go new file mode 100644 index 00000000..7940dcd5 --- /dev/null +++ b/sqlite/with_statement.go @@ -0,0 +1,26 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// CommonTableExpression contains information about a CTE. +type CommonTableExpression struct { + readableTableInterfaceImpl + jet.CommonTableExpression +} + +// WITH function creates new WITH statement from list of common table expressions +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, cte...) +} + +// CTE creates new named CommonTableExpression +func CTE(name string) CommonTableExpression { + cte := CommonTableExpression{ + readableTableInterfaceImpl: readableTableInterfaceImpl{}, + CommonTableExpression: jet.CTE(name), + } + + cte.parent = &cte + + return cte +} diff --git a/tests/dbconfig/dbconfig.go b/tests/dbconfig/dbconfig.go index cf48420d..ef89c1b6 100644 --- a/tests/dbconfig/dbconfig.go +++ b/tests/dbconfig/dbconfig.go @@ -1,18 +1,21 @@ package dbconfig -import "fmt" +import ( + "fmt" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" +) // Postgres test database connection parameters const ( - Host = "localhost" - Port = 5432 - User = "jet" - Password = "jet" - DBName = "jetdb" + PgHost = "localhost" + PgPort = 5432 + PgUser = "jet" + PgPassword = "jet" + PgDBName = "jetdb" ) // PostgresConnectString is PostgreSQL test database connection string -var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) +var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName) // MySQL test database connection parameters const ( @@ -24,3 +27,10 @@ const ( // MySQLConnectionString is MySQL driver connection string to test database var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort) + +// sqllite +var ( + SakilaDBPath = repo.GetTestDataFilePath("/init/sqlite/sakila.db") + ChinookDBPath = repo.GetTestDataFilePath("/init/sqlite/chinook.db") + TestSampleDBPath = repo.GetTestDataFilePath("/init/sqlite/test_sample.db") +) diff --git a/tests/init/init.go b/tests/init/init.go index a2f6eb39..aa04fb5f 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -4,16 +4,21 @@ import ( "database/sql" "flag" "fmt" + "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" + "io/ioutil" + "os" + "os/exec" + "strings" + "github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/postgres" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" - "io/ioutil" - "os" - "os/exec" - "strings" + + _ "github.com/mattn/go-sqlite3" ) var testSuite string @@ -38,8 +43,23 @@ func main() { return } + if testSuite == "sqlite" { + initSQLiteDB() + return + } + initMySQLDB() initPostgresDB() + initSQLiteDB() +} + +func initSQLiteDB() { + err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila")) + throw.OnError(err) + err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook")) + throw.OnError(err) + err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample")) + throw.OnError(err) } func initMySQLDB() { @@ -62,7 +82,7 @@ func initMySQLDB() { cmd.Stdout = os.Stdout err := cmd.Run() - utils.PanicOnError(err) + throw.OnError(err) err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ Host: dbconfig.MySqLHost, @@ -72,7 +92,7 @@ func initMySQLDB() { DBName: dbName, }) - utils.PanicOnError(err) + throw.OnError(err) } } @@ -99,24 +119,24 @@ func initPostgresDB() { execFile(db, "./testdata/init/postgres/"+schemaName+".sql") err = postgres.Generate("./.gentestdata", postgres.DBConnection{ - Host: dbconfig.Host, - Port: 5432, - User: dbconfig.User, - Password: dbconfig.Password, - DBName: dbconfig.DBName, + Host: dbconfig.PgHost, + Port: dbconfig.PgPort, + User: dbconfig.PgUser, + Password: dbconfig.PgPassword, + DBName: dbconfig.PgDBName, SchemaName: schemaName, SslMode: "disable", }) - utils.PanicOnError(err) + throw.OnError(err) } } func execFile(db *sql.DB, sqlFilePath string) { testSampleSql, err := ioutil.ReadFile(sqlFilePath) - utils.PanicOnError(err) + throw.OnError(err) _, err = db.Exec(string(testSampleSql)) - utils.PanicOnError(err) + throw.OnError(err) } func printOnError(err error) { diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go new file mode 100644 index 00000000..6d08d22a --- /dev/null +++ b/tests/internal/utils/file/file.go @@ -0,0 +1,25 @@ +package file + +import ( + "github.com/stretchr/testify/require" + "io/ioutil" + "os" + "path" + "testing" +) + +// Exists expects file to exist on path constructed from pathElems and returns content of the file +func Exists(t *testing.T, pathElems ...string) (fileContent string) { + modelFilePath := path.Join(pathElems...) + file, err := ioutil.ReadFile(modelFilePath) + require.Nil(t, err) + require.NotEmpty(t, file) + return string(file) +} + +// NotExists expects file not to exist on path constructed from pathElems +func NotExists(t *testing.T, pathElems ...string) { + modelFilePath := path.Join(pathElems...) + _, err := ioutil.ReadFile(modelFilePath) + require.True(t, os.IsNotExist(err)) +} diff --git a/tests/internal/utils/repo/repo.go b/tests/internal/utils/repo/repo.go new file mode 100644 index 00000000..3d240390 --- /dev/null +++ b/tests/internal/utils/repo/repo.go @@ -0,0 +1,33 @@ +package repo + +import ( + "os/exec" + "path/filepath" + "strings" +) + +// GetRootDirPath will return this repo full dir path +func GetRootDirPath() string { + cmd := exec.Command("git", "rev-parse", "--show-toplevel") + byteArr, err := cmd.Output() + if err != nil { + panic(err) + } + + return strings.TrimSpace(string(byteArr)) +} + +// GetTestsDirPath will return tests folder full path +func GetTestsDirPath() string { + return filepath.Join(GetRootDirPath(), "tests") +} + +// GetTestsFilePath will return full file path of the file in the tests folder +func GetTestsFilePath(subPath string) string { + return filepath.Join(GetTestsDirPath(), subPath) +} + +// GetTestDataFilePath will return full file path of the file in the testdata folder +func GetTestDataFilePath(subPath string) string { + return filepath.Join(GetTestsDirPath(), "testdata", subPath) +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index d96c1d3b..2132d7a1 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -104,18 +104,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN (( + (all_types.small_int_ptr IN ( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.in_select", + )) AS "result.in_select", (CURRENT_USER()) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN (( + (all_types.small_int_ptr NOT IN ( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.not_in_select" + )) AS "result.not_in_select" FROM test_sample.all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -467,10 +467,10 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.REGEXP_LIKE(String("aba")), AllTypes.Text.REGEXP_LIKE(String("aba"), false), - String("ABA").REGEXP_LIKE(String("aba"), true), + //String("ABA").REGEXP_LIKE(String("aba"), true), AllTypes.Text.NOT_REGEXP_LIKE(String("aba")), AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false), - String("ABA").NOT_REGEXP_LIKE(String("aba"), true), + //String("ABA").NOT_REGEXP_LIKE(String("aba"), true), BIT_LENGTH(AllTypes.Text), CHAR_LENGTH(AllTypes.Char), @@ -962,7 +962,7 @@ func TestAllTypesInsert(t *testing.T) { tx, err := db.Begin() require.NoError(t, err) - stmt := AllTypes.INSERT(AllTypes.AllColumns). + stmt := AllTypes.INSERT(AllTypes.AllColumns.Except(AllTypes.TimestampPtr)). MODEL(toInsert) //fmt.Println(stmt.DebugSql()) @@ -970,7 +970,7 @@ func TestAllTypesInsert(t *testing.T) { testutils.AssertExec(t, stmt, tx, 1) var dest model.AllTypes - err = AllTypes.SELECT(AllTypes.AllColumns). + err = AllTypes.SELECT(AllTypes.AllColumns.Except(AllTypes.TimestampPtr)). WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))). Query(tx, &dest) diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go new file mode 100644 index 00000000..e915e0fc --- /dev/null +++ b/tests/mysql/generator_template_test.go @@ -0,0 +1,389 @@ +package mysql + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + mysql2 "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/3rdparty/snaker" + "github.com/go-jet/jet/v2/internal/utils" + postgres2 "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/dbconfig" + file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" + "github.com/stretchr/testify/require" + "path" + "testing" +) + +const tempTestDir = "./.tempTestDir" + +var defaultModelPath = path.Join(tempTestDir, "dvds/model") +var defaultActorModelFilePath = path.Join(tempTestDir, "dvds/model", "actor.go") +var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table") +var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view") +var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum") +var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go") + +var dbConnection = mysql2.DBConnection{ + Host: dbconfig.MySqLHost, + Port: dbconfig.MySQLPort, + User: dbconfig.MySQLUser, + Password: dbconfig.MySQLPassword, + DBName: "dvds", +} + +func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path") + }), + ) + + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "new/schema/path/model/actor.go") + file2.Exists(t, tempTestDir, "new/schema/path/table/actor.go") + file2.Exists(t, tempTestDir, "new/schema/path/view/actor_info.go") + file2.Exists(t, tempTestDir, "new/schema/path/enum/film_rating.go") +} + +func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.Model{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.NotExists(t, defaultActorModelFilePath) + file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.Exists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.Exists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.SQLBuilder{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.Exists(t, defaultActorModelFilePath) + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_Model_ChangePath(t *testing.T) { + const newModelPath = "/new/model/path" + + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "dvds", newModelPath, "actor.go") + file2.NotExists(t, defaultActorModelFilePath) +} + +func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { + const newModelPath = "/new/sql-builder/path" + + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "dvds", newModelPath, "table", "actor.go") + file2.Exists(t, tempTestDir, "dvds", newModelPath, "view", "actor_info.go") + file2.Exists(t, tempTestDir, "dvds", newModelPath, "enum", "film_rating.go") + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name). + UseTypeName(utils.ToGoIdentifier(table.Name) + "Table") + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "View") + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.DefaultEnumModel(enumMetaData). + UseFileName(enumMetaData.Name + "_enum"). + UseTypeName(utils.ToGoIdentifier(enumMetaData.Name) + "Enum") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultModelPath, "dvds_actor.go") + require.Contains(t, actor, "type ActorTable struct {") + + actorInfo := file2.Exists(t, defaultModelPath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoView struct {") + + mpaaRating := file2.Exists(t, defaultModelPath, "film_rating_enum.go") + require.Contains(t, mpaaRating, "type FilmRatingEnum string") +} + +func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.TableModel{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.EnumModel{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultModelPath, "actor.go") + file2.Exists(t, defaultModelPath, "actor_info.go") + file2.NotExists(t, defaultModelPath, "film_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseView(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumSQLBuilder { + return template.EnumSQLBuilder{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_table"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "TableSQLBuilder"). + UseInstanceName("T_" + utils.ToGoIdentifier(table.Name)) + }). + UseView(func(table metadata.Table) template.ViewSQLBuilder { + return template.DefaultViewSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "ViewSQLBuilder"). + UseInstanceName("V_" + utils.ToGoIdentifier(table.Name)) + }). + UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { + return template.DefaultEnumSQLBuilder(enum). + UseFileName(schemaMetaData.Name + "_" + enum.Name + "_enum"). + UseInstanceName(utils.ToGoIdentifier(enum.Name) + "EnumSQLBuilder") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "dvds_actor_table.go") + require.Contains(t, actor, "type ActorTableSQLBuilder struct {") + require.Contains(t, actor, "var T_Actor = newActorTableSQLBuilder(\"dvds\", \"actor\", \"\")") + actorInfo := file2.Exists(t, defaultViewSQLBuilderFilePath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoViewSQLBuilder struct {") + require.Contains(t, actorInfo, "var V_ActorInfo = newActorInfoViewSQLBuilder(\"dvds\", \"actor_info\", \"\")") + mpaaRating := file2.Exists(t, defaultEnumSQLBuilderFilePath, "dvds_film_rating_enum.go") + require.Contains(t, mpaaRating, "var FilmRatingEnumSQLBuilder = &struct {") +} + +func TestGeneratorTemplate_Model_AddTags(t *testing.T) { + + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + return defaultTableModelField.UseTags( + fmt.Sprintf(`json:"%s"`, snaker.SnakeToCamel(columnMetaData.Name, false)), + fmt.Sprintf(`xml:"%s"`, columnMetaData.Name), + ) + }) + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + if table.Name == "actor_info" && columnMetaData.Name == "actor_id" { + return defaultTableModelField.UseTags(`sql:"primary_key"`) + } + return defaultTableModelField + }) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorModelFilePath) + require.Contains(t, actor, "ActorID uint16 `sql:\"primary_key\" json:\"actorID\" xml:\"actor_id\"`") + require.Contains(t, actor, "FirstName string `json:\"firstName\" xml:\"first_name\"`") + + actorInfo := file2.Exists(t, defaultModelPath, "actor_info.go") + require.Contains(t, actorInfo, "ActorID uint16 `sql:\"primary_key\"`") +} + +func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + + switch defaultTableModelField.Type.Name { + case "*string": + defaultTableModelField.Type = template.NewType(sql.NullString{}) + case "*int32": + defaultTableModelField.Type = template.NewType(sql.NullInt32{}) + case "*int64": + defaultTableModelField.Type = template.NewType(sql.NullInt64{}) + case "*bool": + defaultTableModelField.Type = template.NewType(sql.NullBool{}) + case "*float64": + defaultTableModelField.Type = template.NewType(sql.NullFloat64{}) + case "*time.Time": + defaultTableModelField.Type = template.NewType(sql.NullTime{}) + } + return defaultTableModelField + }) + }), + ) + }), + ) + + require.Nil(t, err) + + data := file2.Exists(t, defaultModelPath, "film.go") + require.Contains(t, data, "\"database/sql\"") + require.Contains(t, data, "Description sql.NullString") + require.Contains(t, data, "ReleaseYear *int16") + require.Contains(t, data, "SpecialFeatures sql.NullString") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseColumn(func(column metadata.Column) template.TableSQLBuilderColumn { + defaultColumn := template.DefaultTableSQLBuilderColumn(column) + + if defaultColumn.Name == "ActorID" { + defaultColumn.Type = "String" + } + + return defaultColumn + }) + }), + ) + }), + ) + + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorSQLBuilderFilePath) + require.Contains(t, actor, "ActorID postgres.ColumnString") +} diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index c9dcc1a6..033f699d 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -1,14 +1,16 @@ package mysql import ( - "github.com/go-jet/jet/v2/generator/mysql" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" + "fmt" "io/ioutil" "os" "os/exec" "testing" + + "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" ) const genTestDirRoot = "./.gentestdata3" @@ -30,6 +32,21 @@ func TestGenerator(t *testing.T) { assertGeneratedFiles(t) } + for i := 0; i < 3; i++ { + dsn := fmt.Sprintf("%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s", + dbconfig.MySQLUser, + dbconfig.MySQLPassword, + dbconfig.MySqLHost, + dbconfig.MySQLPort, + "dvds", + ) + err := mysql.GenerateDSN(dsn, genTestDir3) + + require.NoError(t, err) + + assertGeneratedFiles(t) + } + err := os.RemoveAll(genTestDirRoot) require.NoError(t, err) } @@ -51,6 +68,25 @@ func TestCmdGenerator(t *testing.T) { err = os.RemoveAll(genTestDirRoot) require.NoError(t, err) + + // check that generation via DSN works + dsn := fmt.Sprintf("mysql://%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s", + dbconfig.MySQLUser, + dbconfig.MySQLPassword, + dbconfig.MySqLHost, + dbconfig.MySQLPort, + "dvds", + ) + cmd = exec.Command("jet", "-dsn="+dsn, "-path="+genTestDir3) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + err = os.RemoveAll(genTestDirRoot) + require.NoError(t, err) } func assertGeneratedFiles(t *testing.T) { diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 4f39d6c8..55fc706b 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -278,7 +278,7 @@ ON DUPLICATE KEY UPDATE id = (id + ?), err := SELECT(Link.AllColumns). FROM(Link). - WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))). + WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). Query(db, &newLinks) require.NoError(t, err) diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index 3b6aa759..dc289245 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -66,9 +66,7 @@ func TestUpdateWithSubQueries(t *testing.T) { expectedSQL := ` UPDATE test_sample.link -SET name = ( - SELECT ? - ), +SET name = ?, url = ( SELECT link2.url AS "link2.url" FROM test_sample.link2 @@ -80,7 +78,7 @@ WHERE link.name = ?; query := Link. UPDATE(Link.Name, Link.URL). SET( - SELECT(String("Bong")), + String("Bong"), SELECT(Link2.URL). FROM(Link2). WHERE(Link2.Name.EQ(String("Youtube"))), @@ -96,7 +94,7 @@ WHERE link.name = ?; query := Link. UPDATE(). SET( - Link.Name.SET(StringExp(SELECT(String("Bong")))), + Link.Name.SET(String("Bong")), Link.URL.SET(StringExp( SELECT(Link2.URL). FROM(Link2). @@ -123,7 +121,7 @@ func TestUpdateWithModelData(t *testing.T) { stmt := Link. UPDATE(Link.AllColumns). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) expectedSQL := ` UPDATE test_sample.link @@ -133,7 +131,7 @@ SET id = ?, description = ? WHERE link.id = ?; ` - testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertExec(t, stmt, db) requireLogged(t, stmt) @@ -154,7 +152,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { stmt := Link. UPDATE(updateColumnList). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) var expectedSQL = ` UPDATE test_sample.link @@ -163,9 +161,8 @@ SET description = NULL, url = 'http://www.duckduckgo.com' WHERE link.id = 201; ` - //fmt.Println(stmt.DebugSql()) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) testutils.AssertExec(t, stmt, db) requireLogged(t, stmt) @@ -183,7 +180,7 @@ func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { stmt := Link. UPDATE(Link.MutableColumns). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) var expectedSQL = ` UPDATE test_sample.link @@ -194,7 +191,7 @@ WHERE link.id = 201; ` //fmt.Println(stmt.DebugSql()) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertExec(t, stmt, db) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 29986daf..82ac82bf 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -17,11 +17,12 @@ import ( ) func TestAllTypesSelect(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string - dest := []model.AllTypes{} - err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) + err := AllTypes.SELECT( + AllTypes.AllColumns, + ).LIMIT(2). + Query(db, &dest) require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) @@ -29,8 +30,6 @@ func TestAllTypesSelect(t *testing.T) { } func TestAllTypesViewSelect(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string - type AllTypesView model.AllTypes dest := []AllTypesView{} @@ -43,7 +42,7 @@ func TestAllTypesViewSelect(t *testing.T) { } func TestAllTypesInsertModel(t *testing.T) { - skipForPgxDriver(t) // pgx driver does not handle well time with time zone + skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) query := AllTypes.INSERT(AllTypes.AllColumns). MODEL(allTypesRow0). @@ -60,8 +59,6 @@ func TestAllTypesInsertModel(t *testing.T) { } func TestAllTypesInsertQuery(t *testing.T) { - skipForPgxDriver(t) // pgx driver does not handle well time with time zone - query := AllTypes.INSERT(AllTypes.AllColumns). QUERY( AllTypes. @@ -80,8 +77,6 @@ func TestAllTypesInsertQuery(t *testing.T) { } func TestAllTypesFromSubQuery(t *testing.T) { - skipForPgxDriver(t) - subQuery := SELECT(AllTypes.AllColumns). FROM(AllTypes). AsTable("allTypesSubQuery") @@ -246,18 +241,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN ($1, $2)) AS "result.in", - (all_types.small_int_ptr IN (( + (all_types.small_int_ptr IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.in_select", + )) AS "result.in_select", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN (( + (all_types.small_int_ptr NOT IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.not_in_select" + )) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; `, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) @@ -302,10 +297,10 @@ LIMIT $11; func TestExpressionCast(t *testing.T) { - skipForPgxDriver(t) // for some reason, pgx driver, 150:char(12) returns as int value + skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text' query := AllTypes.SELECT( - CAST(Int(150)).AS_CHAR(12).AS("char12"), + CAST(Int(151)).AS_CHAR(12).AS("char12"), CAST(String("TRUE")).AS_BOOL(), CAST(String("111")).AS_SMALLINT(), CAST(String("111")).AS_INTEGER(), @@ -349,7 +344,7 @@ func TestExpressionCast(t *testing.T) { } func TestStringOperators(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns text column as int value + skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' query := AllTypes.SELECT( AllTypes.Text.EQ(AllTypes.Char), @@ -866,8 +861,6 @@ func TestInterval(t *testing.T) { } func TestSubQueryColumnReference(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string value - type expected struct { sql string args []interface{} @@ -1044,8 +1037,6 @@ FROM` } func TestTimeLiterals(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string - loc, err := time.LoadLocation("Europe/Berlin") require.NoError(t, err) @@ -1060,8 +1051,6 @@ func TestTimeLiterals(t *testing.T) { ).FROM(AllTypes). LIMIT(1) - //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT $1::date AS "date", $2::time without time zone AS "time", @@ -1073,25 +1062,29 @@ LIMIT $6; `) var dest struct { - Date time.Time - Time time.Time - Timez time.Time - Timestamp time.Time - //Timestampz time.Time + Date time.Time + Time time.Time + Timez time.Time + Timestamp time.Time + Timestampz time.Time } err = query.Query(db, &dest) require.NoError(t, err) - //testutils.PrintJson(dest) + // pq driver will return time with time zone in local timezone, + // while pgx driver will return time in UTC time zone + dest.Timez = dest.Timez.UTC() + dest.Timestampz = dest.Timestampz.UTC() testutils.AssertJSON(t, dest, ` { "Date": "2009-11-17T00:00:00Z", "Time": "0000-01-01T20:34:58.651387Z", - "Timez": "0000-01-01T20:34:58.651387+01:00", - "Timestamp": "2009-11-17T20:34:58.651387Z" + "Timez": "0000-01-01T19:34:58.651387Z", + "Timestamp": "2009-11-17T20:34:58.651387Z", + "Timestampz": "2009-11-17T19:34:58.651387Z" } `) requireLogged(t, query) diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go new file mode 100644 index 00000000..85dd01e4 --- /dev/null +++ b/tests/postgres/generator_template_test.go @@ -0,0 +1,387 @@ +package postgres + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/postgres" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/3rdparty/snaker" + "github.com/go-jet/jet/v2/internal/utils" + postgres2 "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/dbconfig" + file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" + "github.com/stretchr/testify/require" + "path" + "testing" +) + +const tempTestDir = "./.tempTestDir" + +var defaultModelPath = path.Join(tempTestDir, "jetdb/dvds/model") +var defaultActorModelFilePath = path.Join(tempTestDir, "jetdb/dvds/model", "actor.go") +var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table") +var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/view") +var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/enum") +var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table", "actor.go") + +var dbConnection = postgres.DBConnection{ + Host: dbconfig.PgHost, + Port: 5432, + User: dbconfig.PgUser, + Password: dbconfig.PgPassword, + DBName: dbconfig.PgDBName, + SchemaName: "dvds", + SslMode: "disable", +} + +func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path") + }), + ) + + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/model/actor.go") + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/table/actor.go") + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/view/actor_info.go") + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/enum/mpaa_rating.go") +} + +func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.Model{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.NotExists(t, defaultActorModelFilePath) +} + +func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.SQLBuilder{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_Model_ChangePath(t *testing.T) { + const newModelPath = "/new/model/path" + + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "actor.go") + file2.NotExists(t, defaultActorModelFilePath) +} + +func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { + const newModelPath = "/new/sql-builder/path" + + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "table", "actor.go") + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "view", "actor_info.go") + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "enum", "mpaa_rating.go") + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name). + UseTypeName(utils.ToGoIdentifier(table.Name) + "Table") + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "View") + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.DefaultEnumModel(enumMetaData). + UseFileName(enumMetaData.Name + "_enum"). + UseTypeName(utils.ToGoIdentifier(enumMetaData.Name) + "Enum") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultModelPath, "dvds_actor.go") + require.Contains(t, actor, "type ActorTable struct {") + + actorInfo := file2.Exists(t, defaultModelPath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoView struct {") + + mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go") + require.Contains(t, mpaaRating, "type MpaaRatingEnum string") +} + +func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.TableModel{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.EnumModel{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultModelPath, "actor.go") + file2.Exists(t, defaultModelPath, "actor_info.go") + file2.NotExists(t, defaultModelPath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseView(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumSQLBuilder { + return template.EnumSQLBuilder{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_table"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "TableSQLBuilder"). + UseInstanceName("T_" + utils.ToGoIdentifier(table.Name)) + }). + UseView(func(table metadata.Table) template.ViewSQLBuilder { + return template.DefaultViewSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "ViewSQLBuilder"). + UseInstanceName("V_" + utils.ToGoIdentifier(table.Name)) + }). + UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { + return template.DefaultEnumSQLBuilder(enum). + UseFileName(schemaMetaData.Name + "_" + enum.Name + "_enum"). + UseInstanceName(utils.ToGoIdentifier(enum.Name) + "EnumSQLBuilder") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "dvds_actor_table.go") + require.Contains(t, actor, "type ActorTableSQLBuilder struct {") + require.Contains(t, actor, "var T_Actor = newActorTableSQLBuilder(\"dvds\", \"actor\", \"\")") + actorInfo := file2.Exists(t, defaultViewSQLBuilderFilePath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoViewSQLBuilder struct {") + require.Contains(t, actorInfo, "var V_ActorInfo = newActorInfoViewSQLBuilder(\"dvds\", \"actor_info\", \"\")") + mpaaRating := file2.Exists(t, defaultEnumSQLBuilderFilePath, "dvds_mpaa_rating_enum.go") + require.Contains(t, mpaaRating, "var MpaaRatingEnumSQLBuilder = &struct {") +} + +func TestGeneratorTemplate_Model_AddTags(t *testing.T) { + + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + return defaultTableModelField.UseTags( + fmt.Sprintf(`json:"%s"`, snaker.SnakeToCamel(columnMetaData.Name, false)), + fmt.Sprintf(`xml:"%s"`, columnMetaData.Name), + ) + }) + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + if table.Name == "actor_info" && columnMetaData.Name == "actor_id" { + return defaultTableModelField.UseTags(`sql:"primary_key"`) + } + return defaultTableModelField + }) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorModelFilePath) + require.Contains(t, actor, "ActorID int32 `sql:\"primary_key\" json:\"actorID\" xml:\"actor_id\"`") + require.Contains(t, actor, "FirstName string `json:\"firstName\" xml:\"first_name\"`") + + actorInfo := file2.Exists(t, defaultModelPath, "actor_info.go") + require.Contains(t, actorInfo, "ActorID *int32 `sql:\"primary_key\"`") +} + +func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + + switch defaultTableModelField.Type.Name { + case "*string": + defaultTableModelField.Type = template.NewType(sql.NullString{}) + case "*int32": + defaultTableModelField.Type = template.NewType(sql.NullInt32{}) + case "*int64": + defaultTableModelField.Type = template.NewType(sql.NullInt64{}) + case "*bool": + defaultTableModelField.Type = template.NewType(sql.NullBool{}) + case "*float64": + defaultTableModelField.Type = template.NewType(sql.NullFloat64{}) + case "*time.Time": + defaultTableModelField.Type = template.NewType(sql.NullTime{}) + } + return defaultTableModelField + }) + }), + ) + }), + ) + + require.Nil(t, err) + + data := file2.Exists(t, defaultModelPath, "film.go") + require.Contains(t, data, "\"database/sql\"") + require.Contains(t, data, "Description sql.NullString") + require.Contains(t, data, "ReleaseYear sql.NullInt32") + require.Contains(t, data, "SpecialFeatures sql.NullString") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseColumn(func(column metadata.Column) template.TableSQLBuilderColumn { + defaultColumn := template.DefaultTableSQLBuilderColumn(column) + + if defaultColumn.Name == "ActorID" { + defaultColumn.Type = "String" + } + + return defaultColumn + }) + }), + ) + }), + ) + + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorSQLBuilderFilePath) + require.Contains(t, actor, "ActorID postgres.ColumnString") +} diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 0571157c..77c8aee1 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -1,16 +1,19 @@ package postgres import ( - "github.com/go-jet/jet/v2/generator/postgres" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" + "fmt" "io/ioutil" "os" "os/exec" + "path/filepath" "reflect" "testing" + "github.com/go-jet/jet/v2/generator/postgres" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" ) @@ -61,20 +64,40 @@ func TestCmdGenerator(t *testing.T) { err = os.RemoveAll(genTestDir2) require.NoError(t, err) + + // Check that connection via DSN works + dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", + dbconfig.PgUser, + dbconfig.PgPassword, + dbconfig.PgHost, + dbconfig.PgPort, + "jetdb", + ) + cmd = exec.Command("jet", "-dsn="+dsn, "-schema=dvds", "-path="+genTestDir2) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + assertGeneratedFiles(t) + + err = os.RemoveAll(genTestDir2) + require.NoError(t, err) } func TestGenerator(t *testing.T) { for i := 0; i < 3; i++ { err := postgres.Generate(genTestDir2, postgres.DBConnection{ - Host: dbconfig.Host, - Port: dbconfig.Port, - User: dbconfig.User, - Password: dbconfig.Password, + Host: dbconfig.PgHost, + Port: dbconfig.PgPort, + User: dbconfig.PgUser, + Password: dbconfig.PgPassword, SslMode: "disable", Params: "", - DBName: dbconfig.DBName, + DBName: dbconfig.PgDBName, SchemaName: "dvds", }) @@ -83,10 +106,42 @@ func TestGenerator(t *testing.T) { assertGeneratedFiles(t) } + for i := 0; i < 3; i++ { + dsn := fmt.Sprintf("postgresql://%[1]s:%[2]s@%[3]s:%[4]d/%[5]s?sslmode=disable", + dbconfig.PgUser, + dbconfig.PgPassword, + dbconfig.PgHost, + dbconfig.PgPort, + dbconfig.PgDBName, + ) + err := postgres.GenerateDSN(dsn, "dvds", genTestDir2) + + require.NoError(t, err) + + assertGeneratedFiles(t) + } + err := os.RemoveAll(genTestDir2) require.NoError(t, err) } +func TestGeneratorSpecialCharacters(t *testing.T) { + t.SkipNow() + err := postgres.Generate(genTestDir2, postgres.DBConnection{ + Host: dbconfig.PgHost, + Port: dbconfig.PgPort, + User: "!@#$%^&* () {}[];+-", + Password: "!@#$%^&* () {}[];+-", + SslMode: "disable", + Params: "", + + DBName: "!@#$%^&* () {}[];+-", + SchemaName: "!@#$%^&* () {}[];+-", + }) + + require.NoError(t, err) +} + func assertGeneratedFiles(t *testing.T) { // Table SQL Builder files tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") @@ -331,16 +386,16 @@ func newActorInfoTableImpl(schemaName, tableName, alias string) actorInfoTable { ` func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { - enumDir := testRoot + ".gentestdata/jetdb/test_sample/enum/" - modelDir := testRoot + ".gentestdata/jetdb/test_sample/model/" - tableDir := testRoot + ".gentestdata/jetdb/test_sample/table/" + enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") + modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") + tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") enumFiles, err := ioutil.ReadDir(enumDir) require.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mood.go", "level.go") - testutils.AssertFileContent(t, enumDir+"mood.go", moodEnumContent) - testutils.AssertFileContent(t, enumDir+"level.go", levelEnumContent) + testutils.AssertFileContent(t, enumDir+"/mood.go", moodEnumContent) + testutils.AssertFileContent(t, enumDir+"/level.go", levelEnumContent) modelFiles, err := ioutil.ReadDir(modelDir) require.NoError(t, err) @@ -348,7 +403,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go") - testutils.AssertFileContent(t, modelDir+"all_types.go", allTypesModelContent) + testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) tableFiles, err := ioutil.ReadDir(tableDir) require.NoError(t, err) @@ -356,7 +411,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go") - testutils.AssertFileContent(t, tableDir+"all_types.go", allTypesTableContent) + testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) } var moodEnumContent = ` diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 9c9875c2..8a50e025 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -140,8 +140,7 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; Link.ID.SET(Link.EXCLUDED.ID), Link.URL.SET(String("http://www.postgresqltutorial2.com")), ), - ). - RETURNING(Link.AllColumns) + ).RETURNING(Link.AllColumns) testutils.AssertStatementSql(t, stmt, ` INSERT INTO test_sample.link (id, url, name, description) diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index ef9337c7..4e8aade8 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -4,10 +4,9 @@ import ( "context" "database/sql" "fmt" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" "math/rand" "os" - "os/exec" - "strings" "testing" "time" @@ -31,7 +30,9 @@ func TestMain(m *testing.M) { setTestRoot() - for _, driverName := range []string{"postgres", "pgx"} { + for _, driverName := range []string{"pgx", "postgres"} { + fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) + func() { var err error db, err = sql.Open(driverName, dbconfig.PostgresConnectString) @@ -51,13 +52,7 @@ func TestMain(m *testing.M) { } func setTestRoot() { - cmd := exec.Command("git", "rev-parse", "--show-toplevel") - byteArr, err := cmd.Output() - if err != nil { - panic(err) - } - - testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" + testRoot = repo.GetTestsDirPath() } var loggedSQL string @@ -79,8 +74,16 @@ func requireLogged(t *testing.T, statement postgres.Statement) { } func skipForPgxDriver(t *testing.T) { + if isPgxDriver() { + t.SkipNow() + } +} + +func isPgxDriver() bool { switch db.Driver().(type) { case *stdlib.Driver: - t.SkipNow() + return true } + + return false } diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 4a80ba07..ce3cc46b 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -78,16 +78,31 @@ func TestScanToValidDestination(t *testing.T) { require.NoError(t, err) }) - t.Run("pointer to slice of strings", func(t *testing.T) { - err := oneInventoryQuery.Query(db, &[]int32{}) + t.Run("pointer to slice of integers", func(t *testing.T) { + var dest []int32 + err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) + require.Equal(t, dest[0], int32(1)) }) - t.Run("pointer to slice of strings", func(t *testing.T) { - err := oneInventoryQuery.Query(db, &[]*int32{}) + t.Run("pointer to slice integer pointers", func(t *testing.T) { + var dest []*int32 + err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) + require.Equal(t, dest[0], testutils.Int32Ptr(1)) + }) + + t.Run("NULL to integer", func(t *testing.T) { + var dest struct { + Int64 int64 + UInt64 uint64 + } + err := SELECT(NULL.AS("int64"), NULL.AS("uint64")).Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.Int64, int64(0)) + require.Equal(t, dest.UInt64, uint64(0)) }) } @@ -189,7 +204,9 @@ func TestScanToStruct(t *testing.T) { dest := Inventory{} - testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: Scan: unable to scan type int32 into UUID, at 'InventoryID uuid.UUID' of type postgres.Inventory`) + err := query.Query(db, &dest) + require.Error(t, err) + require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID") }) t.Run("type mismatch base type", func(t *testing.T) { @@ -200,7 +217,9 @@ func TestScanToStruct(t *testing.T) { dest := []Inventory{} - testutils.AssertQueryPanicErr(t, query.OFFSET(10), db, &dest, `jet: can't set int16 to bool`) + err := query.OFFSET(10).Query(db, &dest) + require.Error(t, err) + require.EqualError(t, err, "jet: can't assign int64('\\x02') to 'FilmID bool': can't assign int64(2) to bool") }) } @@ -451,8 +470,9 @@ func TestScanToSlice(t *testing.T) { t.Run("slice type mismatch", func(t *testing.T) { var dest []bool - testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: can't append int32 to []bool slice`) - //require.Error(t, err, `jet: can't append int32 to []bool slice `) + err := query.Query(db, &dest) + require.Error(t, err) + require.EqualError(t, err, `jet: can't append int64 to []bool slice: can't assign int64(2) to bool`) }) }) @@ -764,16 +784,8 @@ func TestRowsScan(t *testing.T) { requireLogged(t, stmt) } -func TestScanNumericToNumber(t *testing.T) { +func TestScanNumericToFloat(t *testing.T) { type Number struct { - Int8 int8 - UInt8 uint8 - Int16 int16 - UInt16 uint16 - Int32 int32 - UInt32 uint32 - Int64 int64 - UInt64 uint64 Float32 float32 Float64 float64 } @@ -781,14 +793,6 @@ func TestScanNumericToNumber(t *testing.T) { numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC() stmt := SELECT( - numeric.AS("number.int8"), - numeric.AS("number.uint8"), - numeric.AS("number.int16"), - numeric.AS("number.uint16"), - numeric.AS("number.int32"), - numeric.AS("number.uint32"), - numeric.AS("number.int64"), - numeric.AS("number.uint64"), numeric.AS("number.float32"), numeric.AS("number.float64"), ) @@ -796,19 +800,73 @@ func TestScanNumericToNumber(t *testing.T) { var number Number err := stmt.Query(db, &number) require.NoError(t, err) - - require.Equal(t, number.Int8, int8(-46)) // overflow - require.Equal(t, number.UInt8, uint8(210)) // overflow - require.Equal(t, number.Int16, int16(722)) // overflow - require.Equal(t, number.UInt16, uint16(722)) // overflow - require.Equal(t, number.Int32, int32(1234567890)) - require.Equal(t, number.UInt32, uint32(1234567890)) - require.Equal(t, number.Int64, int64(1234567890)) - require.Equal(t, number.UInt64, uint64(1234567890)) require.Equal(t, number.Float32, float32(1.234568e+09)) require.Equal(t, number.Float64, float64(1.234567890111e+09)) } +func TestScanNumericToIntegerError(t *testing.T) { + + var dest struct { + Integer int32 + } + + err := SELECT( + CAST(Decimal("1234567890.111")).AS_NUMERIC().AS("integer"), + ).Query(db, &dest) + + require.Error(t, err) + + if isPgxDriver() { + require.Contains(t, err.Error(), `jet: can't assign string("1234567890.111") to 'Integer int32': converting driver.Value type string ("1234567890.111") to a int64: invalid syntax`) + } else { + require.Contains(t, err.Error(), `jet: can't assign []uint8("1234567890.111") to 'Integer int32': converting driver.Value type []uint8 ("1234567890.111") to a int64: invalid syntax`) + } + +} + +// QueryContext panic when the scanned value is nil and the destination is a slice of primitive +// https://github.com/go-jet/jet/issues/91 +func TestScanToPrimitiveElementsSlice(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + // add actor without associated film (so that destination Title array is NULL). + _, err = Actor.INSERT(). + MODEL( + model.Actor{ + ActorID: 201, + FirstName: "Brigitte", + LastName: "Bardot", + LastUpdate: time.Time{}, + }, + ).Exec(tx) + require.NoError(t, err) + + stmt := SELECT( + Actor.ActorID.AS("actor_id"), + Film.Title.AS("title"), + ).FROM( + Actor. + LEFT_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)). + LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)), + ).WHERE( + Actor.ActorID.GT(Int(199)), + ).ORDER_BY(Actor.ActorID.DESC()) + + var dest []struct { + ActorID int `sql:"primary_key"` + Title []string + } + + err = stmt.Query(tx, &dest) + require.NoError(t, err) + require.Equal(t, dest[0].ActorID, 201) + require.Equal(t, dest[0].Title, []string(nil)) + require.Equal(t, dest[1].ActorID, 200) + require.Len(t, dest[1].Title, 20) +} + var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 59fc44a8..96359296 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -75,11 +75,12 @@ LIMIT 30; query := SELECT( Payment.AllColumns, Customer.AllColumns, - ). - FROM(Payment. - INNER_JOIN(Customer, Payment.CustomerID.EQ(Customer.CustomerID))). - ORDER_BY(Payment.PaymentID.ASC()). - LIMIT(30) + ).FROM( + Payment. + INNER_JOIN(Customer, Payment.CustomerID.EQ(Customer.CustomerID)), + ).ORDER_BY( + Payment.PaymentID.ASC(), + ).LIMIT(30) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(30)) @@ -1992,3 +1993,66 @@ LIMIT 1; require.Equal(t, dest, dest2) }) } + +func TestSelectColumnListWithExcludedColumns(t *testing.T) { + + t.Run("one column", func(t *testing.T) { + stmt := SELECT( + Address.AllColumns.Except(Address.LastUpdate), + ).FROM( + Address, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT address.address_id AS "address.address_id", + address.address AS "address.address", + address.address2 AS "address.address2", + address.district AS "address.district", + address.city_id AS "address.city_id", + address.postal_code AS "address.postal_code", + address.phone AS "address.phone" +FROM dvds.address; +`) + var dest []model.Address + require.NoError(t, stmt.Query(db, &dest)) + require.Len(t, dest, 603) + }) + + t.Run("multiple columns", func(t *testing.T) { + expectedSQL := ` +SELECT address.address_id AS "address.address_id", + address.address AS "address.address", + address.address2 AS "address.address2", + address.district AS "address.district", + address.city_id AS "address.city_id" +FROM dvds.address; +` + // list of columns + stmt := SELECT( + Address.AllColumns.Except(Address.PostalCode, Address.Phone, Address.LastUpdate), + ).FROM( + Address, + ) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL) + + // column list + excludedColumns := ColumnList{Address.PostalCode, Address.Phone, Address.LastUpdate, Film.Title} // Film.Title is ignored + stmt = SELECT( + Address.AllColumns.Except(excludedColumns), + ).FROM(Address) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL) + + // column list with just column names + excludedColumns = ColumnList{StringColumn("postal_code"), StringColumn("phone"), TimestampColumn("last_update")} + stmt = SELECT( + Address.AllColumns.Except(excludedColumns), + ).FROM(Address) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL) + + var dest []model.Address + require.NoError(t, stmt.Query(db, &dest)) + require.Len(t, dest, 603) + }) +} diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 043bf78f..5ec44a11 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -259,14 +259,14 @@ func TestUpdateWithModelData(t *testing.T) { stmt := Link. UPDATE(Link.AllColumns). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) WHERE link.id = 201; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) AssertExec(t, stmt, 1) } @@ -286,14 +286,14 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { stmt := Link. UPDATE(updateColumnList). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) var expectedSQL = ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') WHERE link.id = 201; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) AssertExec(t, stmt, 1) } diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 8eadf212..8a16fd4c 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -22,20 +22,23 @@ func TestWithRegionalSales(t *testing.T) { SELECT( Orders.ShipRegion, SUM(OrderDetails.Quantity).AS(regionalSalesTotalSales.Name()), - ). - FROM(Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID))). - GROUP_BY(Orders.ShipRegion), + ).FROM( + Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID)), + ).GROUP_BY(Orders.ShipRegion), ), topRegion.AS( - SELECT(regionalSalesShipRegion). - FROM(regionalSales). - WHERE(regionalSalesTotalSales.GT( + SELECT( + regionalSalesShipRegion, + ).FROM( + regionalSales, + ).WHERE( + regionalSalesTotalSales.GT( IntExp( SELECT(SUM(regionalSalesTotalSales)). FROM(regionalSales), ).DIV(Int(50)), ), - ), + ), ), )( SELECT( @@ -43,13 +46,17 @@ func TestWithRegionalSales(t *testing.T) { OrderDetails.ProductID, COUNT(STAR).AS("product_units"), SUM(OrderDetails.Quantity).AS("product_sales"), - ). - FROM(Orders.INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID))). - WHERE(Orders.ShipRegion.IN( - topRegion.SELECT(topRegionShipRegion)), - ). - GROUP_BY(Orders.ShipRegion, OrderDetails.ProductID). - ORDER_BY(SUM(OrderDetails.Quantity).DESC()), + ).FROM( + Orders. + INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)), + ).WHERE( + Orders.ShipRegion.IN(topRegion.SELECT(topRegionShipRegion)), + ).GROUP_BY( + Orders.ShipRegion, + OrderDetails.ProductID, + ).ORDER_BY( + SUM(OrderDetails.Quantity).DESC(), + ), ) //fmt.Println(stmt.DebugSql()) @@ -75,10 +82,10 @@ SELECT orders.ship_region AS "orders.ship_region", SUM(order_details.quantity) AS "product_sales" FROM northwind.orders INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) -WHERE orders.ship_region IN (( +WHERE orders.ship_region IN ( SELECT top_region."orders.ship_region" AS "orders.ship_region" FROM top_region - )) + ) GROUP BY orders.ship_region, order_details.product_id ORDER BY SUM(order_details.quantity) DESC; `) @@ -141,19 +148,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( DELETE FROM northwind.order_details - WHERE order_details.product_id IN (( + WHERE order_details.product_id IN ( SELECT products.product_id AS "products.product_id" FROM northwind.products WHERE products.discontinued = $1 - )) + ) RETURNING order_details.product_id AS "order_details.product_id" ),update_discontinued_price AS ( UPDATE northwind.products SET unit_price = $2 - WHERE products.product_id IN (( + WHERE products.product_id IN ( SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" FROM remove_discontinued_orders - )) + ) RETURNING products.product_id AS "products.product_id", products.product_name AS "products.product_name", products.supplier_id AS "products.supplier_id", diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go new file mode 100644 index 00000000..523d959b --- /dev/null +++ b/tests/sqlite/alltypes_test.go @@ -0,0 +1,911 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/view" + "github.com/go-jet/jet/v2/tests/testdata/results/common" + "github.com/google/uuid" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" +) + +func TestAllTypes(t *testing.T) { + + dest := []model.AllTypes{} + + err := SELECT(AllTypes.AllColumns). + FROM(AllTypes). + Query(sampleDB, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, allTypesJSON) +} + +var allTypesJSON = ` +[ + { + "Boolean": false, + "BooleanPtr": true, + "TinyInt": -3, + "TinyIntPtr": 3, + "SmallInt": 14, + "SmallIntPtr": 14, + "MediumInt": -150, + "MediumIntPtr": 150, + "Integer": -1600, + "IntegerPtr": 1600, + "BigInt": 5000, + "BigIntPtr": 50000, + "Decimal": 1.11, + "DecimalPtr": 1.01, + "Numeric": 2.22, + "NumericPtr": 2.02, + "Float": 3.33, + "FloatPtr": 3.03, + "Double": 4.44, + "DoublePtr": 4.04, + "Real": 5.55, + "RealPtr": 5.05, + "Time": "0000-01-01T10:11:12.33Z", + "TimePtr": "0000-01-01T10:11:12.123456Z", + "Date": "2008-07-04T00:00:00Z", + "DatePtr": "2008-07-04T00:00:00Z", + "DateTime": "2011-12-18T13:17:17Z", + "DateTimePtr": "2011-12-18T13:17:17Z", + "Timestamp": "2007-12-31T23:00:01Z", + "TimestampPtr": "2007-12-31T23:00:01Z", + "Char": "char1", + "CharPtr": "char-ptr", + "VarChar": "varchar", + "VarCharPtr": "varchar-ptr", + "Text": "text", + "TextPtr": "text-ptr", + "Blob": "YmxvYjE=", + "BlobPtr": "YmxvYi1wdHI=" + }, + { + "Boolean": false, + "BooleanPtr": null, + "TinyInt": -3, + "TinyIntPtr": null, + "SmallInt": 14, + "SmallIntPtr": null, + "MediumInt": -150, + "MediumIntPtr": null, + "Integer": -1600, + "IntegerPtr": null, + "BigInt": 5000, + "BigIntPtr": null, + "Decimal": 1.11, + "DecimalPtr": null, + "Numeric": 2.22, + "NumericPtr": null, + "Float": 3.33, + "FloatPtr": null, + "Double": 4.44, + "DoublePtr": null, + "Real": 5.55, + "RealPtr": null, + "Time": "0000-01-01T10:11:12.33Z", + "TimePtr": null, + "Date": "2008-07-04T00:00:00Z", + "DatePtr": null, + "DateTime": "2011-12-18T13:17:17Z", + "DateTimePtr": null, + "Timestamp": "2007-12-31T23:00:01Z", + "TimestampPtr": null, + "Char": "char2", + "CharPtr": null, + "VarChar": "varchar", + "VarCharPtr": null, + "Text": "text", + "TextPtr": null, + "Blob": "YmxvYjI=", + "BlobPtr": null + } +] +` + +func TestAllTypesViewSelect(t *testing.T) { + var dest []model.AllTypesView + + stmt := SELECT(view.AllTypesView.AllColumns). + FROM(view.AllTypesView) + + err := stmt.Query(sampleDB, &dest) + + require.NoError(t, err) + require.Equal(t, len(dest), 2) + + testutils.AssertJSON(t, dest, allTypesJSON) +} + +func TestAllTypesInsert(t *testing.T) { + tx := beginSampleDBTx(t) + + stmt := AllTypes.INSERT(AllTypes.AllColumns). + MODEL(toInsert). + RETURNING(AllTypes.AllColumns) + + var inserted model.AllTypes + err := stmt.Query(tx, &inserted) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, toInsert, inserted, testutils.UnixTimeComparer) + + var dest model.AllTypes + err = AllTypes.SELECT(AllTypes.AllColumns). + WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))). + Query(tx, &dest) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest, toInsert, testutils.UnixTimeComparer) + + err = tx.Rollback() + require.NoError(t, err) +} + +var toInsert = model.AllTypes{ + Boolean: false, + BooleanPtr: testutils.BoolPtr(true), + TinyInt: 1, + SmallInt: 3, + MediumInt: 5, + Integer: 7, + BigInt: 9, + TinyIntPtr: testutils.Int8Ptr(11), + SmallIntPtr: testutils.Int16Ptr(33), + MediumIntPtr: testutils.Int32Ptr(55), + IntegerPtr: testutils.Int32Ptr(77), + BigIntPtr: testutils.Int64Ptr(99), + Decimal: 11.22, + DecimalPtr: testutils.Float64Ptr(33.44), + Numeric: 55.66, + NumericPtr: testutils.Float64Ptr(77.88), + Float: 99.00, + FloatPtr: testutils.Float64Ptr(11.22), + Double: 33.44, + DoublePtr: testutils.Float64Ptr(55.66), + Real: 77.88, + RealPtr: testutils.Float32Ptr(99.00), + Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC), + TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), + Date: time.Now(), + DatePtr: testutils.TimePtr(time.Now()), + DateTime: time.Now(), + DateTimePtr: testutils.TimePtr(time.Now()), + Timestamp: time.Now(), + TimestampPtr: testutils.TimePtr(time.Now()), + Char: "abcd", + CharPtr: testutils.StringPtr("absd"), + VarChar: "abcd", + VarCharPtr: testutils.StringPtr("absd"), + Blob: []byte("large file"), + BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + Text: "some text", + TextPtr: testutils.StringPtr("text"), +} + +func TestUUID(t *testing.T) { + query := SELECT( + //Raw("uuid()").AS("uuid"), + String("dc8daae3-b83b-11e9-8eb4-98ded00c39c6").AS("str_uuid"), + ) + + var dest struct { + UUID uuid.UUID + StrUUID *uuid.UUID + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + require.Equal(t, dest.StrUUID.String(), "dc8daae3-b83b-11e9-8eb4-98ded00c39c6") + requireLogged(t, query) +} + +func TestExpressionOperators(t *testing.T) { + query := SELECT( + AllTypes.Integer.IS_NULL().AS("result.is_null"), + AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), + AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), + AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), + + Raw("length(121232459)").AS("result.raw"), + Raw(":first + COALESCE(all_types.small_int_ptr, 0) + :second", RawArgs{":first": 78, ":second": 56}). + AS("result.raw_arg"), + Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}). + AS("result.raw_arg2"), + + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), + AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), + ).FROM( + AllTypes, + ).LIMIT(2) + + testutils.AssertStatementSql(t, query, strings.Replace(` +SELECT all_types.integer IS NULL AS "result.is_null", + all_types.date_ptr IS NOT NULL AS "result.is_not_null", + (all_types.small_int_ptr IN (?, ?)) AS "result.in", + (all_types.small_int_ptr IN ( + SELECT all_types.integer AS "all_types.integer" + FROM all_types + )) AS "result.in_select", + (length(121232459)) AS "result.raw", + (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", + (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", + (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", + (all_types.small_int_ptr NOT IN ( + SELECT all_types.integer AS "all_types.integer" + FROM all_types + )) AS "result.not_in_select" +FROM all_types +LIMIT ?; +`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) + + var dest []struct { + common.ExpressionTestResult `alias:"result.*"` + } + + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + + require.Equal(t, *dest[0].IsNull, false) + require.Equal(t, *dest[0].IsNotNull, true) + require.Equal(t, *dest[0].In, false) + require.Equal(t, *dest[0].InSelect, false) + require.Equal(t, *dest[0].Raw, "9") + require.Equal(t, *dest[0].RawArg, int32(148)) + require.Equal(t, *dest[0].RawArg2, int32(-1479)) + require.Nil(t, dest[0].NotIn) + require.Equal(t, *dest[0].NotInSelect, true) +} + +func TestBoolOperators(t *testing.T) { + query := AllTypes.SELECT( + AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"), + AllTypes.Boolean.EQ(Bool(true)).AS("EQ2"), + AllTypes.Boolean.NOT_EQ(AllTypes.BooleanPtr).AS("NEq1"), + AllTypes.Boolean.NOT_EQ(Bool(false)).AS("NEq2"), + AllTypes.Boolean.IS_DISTINCT_FROM(AllTypes.BooleanPtr).AS("distinct1"), + AllTypes.Boolean.IS_DISTINCT_FROM(Bool(true)).AS("distinct2"), + AllTypes.Boolean.IS_NOT_DISTINCT_FROM(AllTypes.BooleanPtr).AS("not_distinct_1"), + AllTypes.Boolean.IS_NOT_DISTINCT_FROM(Bool(true)).AS("NOTDISTINCT2"), + AllTypes.Boolean.IS_TRUE().AS("ISTRUE"), + AllTypes.Boolean.IS_NOT_TRUE().AS("isnottrue"), + AllTypes.Boolean.IS_FALSE().AS("is_False"), + AllTypes.Boolean.IS_NOT_FALSE().AS("is not false"), + AllTypes.Boolean.IS_NULL().AS("is unknown"), + AllTypes.Boolean.IS_NOT_NULL().AS("is_not_unknown"), + + AllTypes.Boolean.AND(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)).AS("complex1"), + AllTypes.Boolean.OR(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)).AS("complex2"), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", + (all_types.boolean = ?) AS "EQ2", + (all_types.boolean != all_types.boolean_ptr) AS "NEq1", + (all_types.boolean != ?) AS "NEq2", + (all_types.boolean IS NOT all_types.boolean_ptr) AS "distinct1", + (all_types.boolean IS NOT ?) AS "distinct2", + (all_types.boolean IS all_types.boolean_ptr) AS "not_distinct_1", + (all_types.boolean IS ?) AS "NOTDISTINCT2", + all_types.boolean IS TRUE AS "ISTRUE", + all_types.boolean IS NOT TRUE AS "isnottrue", + all_types.boolean IS FALSE AS "is_False", + all_types.boolean IS NOT FALSE AS "is not false", + all_types.boolean IS NULL AS "is unknown", + all_types.boolean IS NOT NULL AS "is_not_unknown", + ((all_types.boolean AND all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex1", + ((all_types.boolean OR all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex2" +FROM all_types; +`, true, false, true, true) + + var dest []struct { + Eq1 *bool + Eq2 *bool + NEq1 *bool + NEq2 *bool + Distinct1 *bool + Distinct2 *bool + NotDistinct1 *bool + NotDistinct2 *bool + IsTrue *bool + IsNotTrue *bool + IsFalse *bool + IsNotFalse *bool + IsUnknown *bool + IsNotUnknown *bool + + Complex1 *bool + Complex2 *bool + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + + testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") +} + +func TestFloatOperators(t *testing.T) { + + query := AllTypes.SELECT( + AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"), + AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"), + AllTypes.Real.EQ(Float(12.12)).AS("eq3"), + AllTypes.Numeric.IS_DISTINCT_FROM(AllTypes.Numeric).AS("distinct1"), + AllTypes.Decimal.IS_DISTINCT_FROM(Float(12)).AS("distinct2"), + AllTypes.Real.IS_DISTINCT_FROM(Float(12.12)).AS("distinct3"), + AllTypes.Numeric.IS_NOT_DISTINCT_FROM(AllTypes.Numeric).AS("not_distinct1"), + AllTypes.Decimal.IS_NOT_DISTINCT_FROM(Float(12)).AS("not_distinct2"), + AllTypes.Real.IS_NOT_DISTINCT_FROM(Float(12.12)).AS("not_distinct3"), + AllTypes.Numeric.LT(Float(124)).AS("lt1"), + AllTypes.Numeric.LT(Float(34.56)).AS("lt2"), + AllTypes.Numeric.GT(Float(124)).AS("gt1"), + AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), + + AllTypes.Decimal.ADD(AllTypes.Decimal).AS("add1"), + AllTypes.Decimal.ADD(Float(11.22)).AS("add2"), + AllTypes.Decimal.SUB(AllTypes.DecimalPtr).AS("sub1"), + AllTypes.Decimal.SUB(Float(11.22)).AS("sub2"), + AllTypes.Decimal.MUL(AllTypes.DecimalPtr).AS("mul1"), + AllTypes.Decimal.MUL(Float(11.22)).AS("mul2"), + AllTypes.Decimal.DIV(AllTypes.DecimalPtr).AS("div1"), + AllTypes.Decimal.DIV(Float(11.22)).AS("div2"), + AllTypes.Decimal.MOD(AllTypes.DecimalPtr).AS("mod1"), + AllTypes.Decimal.MOD(Float(11.22)).AS("mod2"), + + // sqlite driver has to enable SQLITE_ENABLE_MATH_FUNCTIONS before commented math functions can be used + + //AllTypes.Decimal.POW(AllTypes.DecimalPtr).AS("pow1"), + //AllTypes.Decimal.POW(Float(2.1)).AS("pow2"), + + ABSf(AllTypes.Decimal).AS("abs"), + //POWER(AllTypes.Decimal, Float(2.1)).AS("power"), + //SQRT(AllTypes.Decimal).AS("sqrt"), + //CBRT(AllTypes.Decimal).AS("cbrt"), + + //CEIL(AllTypes.Real).AS("ceil"), + //FLOOR(AllTypes.Real).AS("floor"), + ROUND(AllTypes.Decimal).AS("round1"), + ROUND(AllTypes.Decimal, Int(2)).AS("round2"), + //TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), + SIGN(AllTypes.Real).AS("sign"), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query, ` +SELECT (all_types.numeric = all_types.numeric) AS "eq1", + (all_types.decimal = ?) AS "eq2", + (all_types.real = ?) AS "eq3", + (all_types.numeric IS NOT all_types.numeric) AS "distinct1", + (all_types.decimal IS NOT ?) AS "distinct2", + (all_types.real IS NOT ?) AS "distinct3", + (all_types.numeric IS all_types.numeric) AS "not_distinct1", + (all_types.decimal IS ?) AS "not_distinct2", + (all_types.real IS ?) AS "not_distinct3", + (all_types.numeric < ?) AS "lt1", + (all_types.numeric < ?) AS "lt2", + (all_types.numeric > ?) AS "gt1", + (all_types.numeric > ?) AS "gt2", + (all_types.decimal + all_types.decimal) AS "add1", + (all_types.decimal + ?) AS "add2", + (all_types.decimal - all_types.decimal_ptr) AS "sub1", + (all_types.decimal - ?) AS "sub2", + (all_types.decimal * all_types.decimal_ptr) AS "mul1", + (all_types.decimal * ?) AS "mul2", + (all_types.decimal / all_types.decimal_ptr) AS "div1", + (all_types.decimal / ?) AS "div2", + (all_types.decimal % all_types.decimal_ptr) AS "mod1", + (all_types.decimal % ?) AS "mod2", + ABS(all_types.decimal) AS "abs", + ROUND(all_types.decimal) AS "round1", + ROUND(all_types.decimal, ?) AS "round2", + SIGN(all_types.real) AS "sign" +FROM all_types +LIMIT ?; +`) + + var dest struct { + common.FloatExpressionTestResult `alias:"."` + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + require.Equal(t, *dest.Eq1, true) + require.Equal(t, *dest.Distinct1, false) + require.Equal(t, *dest.Lt1, true) + require.Equal(t, *dest.Add1, 2.22) + require.Equal(t, *dest.Mod2, float64(1)) + require.Equal(t, *dest.Round1, float64(1)) + require.Equal(t, *dest.Round2, float64(1.11)) + require.Equal(t, *dest.Sign, float64(1)) + + //testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") +} + +func TestIntegerOperators(t *testing.T) { + query := AllTypes.SELECT( + AllTypes.BigInt, + AllTypes.BigIntPtr, + AllTypes.SmallInt, + AllTypes.SmallIntPtr, + + AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), + AllTypes.BigInt.EQ(Int(12)).AS("eq2"), + + AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), + AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), + + AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), + AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), + + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), + + AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), + AllTypes.BigInt.LT(Int(65)).AS("lt2"), + + AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), + AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), + + AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), + AllTypes.BigInt.GT(Int(65)).AS("gt2"), + + AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), + AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + + AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), + AllTypes.BigInt.ADD(Int(11)).AS("add2"), + + AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), + AllTypes.BigInt.SUB(Int(11)).AS("sub2"), + + AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), + AllTypes.BigInt.MUL(Int(11)).AS("mul2"), + + AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), + AllTypes.BigInt.DIV(Int(11)).AS("div2"), + + AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), + AllTypes.BigInt.MOD(Int(11)).AS("mod2"), + + //AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), + //AllTypes.SmallInt.POW(Int(6)).AS("pow2"), + + AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), + AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), + + AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"), + AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"), + + AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), + AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), + + BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), + BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"), + + AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), + AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), + + AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), + AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"), + + ABSi(AllTypes.BigInt).AS("abs"), + //SQRT(ABSi(AllTypes.BigInt)).AS("sqrt"), + //CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), + ).LIMIT(2) + + var dest []struct { + common.AllTypesIntegerExpResult `alias:"."` + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + + require.Equal(t, *dest[0].Eq1, true) + require.Equal(t, *dest[0].Distinct2, true) + require.Equal(t, *dest[0].Lt2, false) + require.Equal(t, *dest[0].Add1, int64(10000)) + require.Equal(t, *dest[0].Mul1, int64(25000000)) + require.Equal(t, *dest[0].Div2, int64(454)) + require.Equal(t, *dest[0].BitAnd1, int64(14)) + require.Equal(t, *dest[0].BitXor2, int64(5)) + require.Equal(t, *dest[0].BitShiftLeft1, int64(1792)) + require.Equal(t, *dest[0].BitShiftRight2, int64(7)) + +} + +func TestStringOperators(t *testing.T) { + + query := SELECT( + AllTypes.Text.EQ(AllTypes.Char), + AllTypes.Text.EQ(String("Text")), + AllTypes.Text.NOT_EQ(AllTypes.VarCharPtr), + AllTypes.Text.NOT_EQ(String("Text")), + AllTypes.Text.GT(AllTypes.Text), + AllTypes.Text.GT(String("Text")), + AllTypes.Text.GT_EQ(AllTypes.TextPtr), + AllTypes.Text.GT_EQ(String("Text")), + AllTypes.Text.LT(AllTypes.Char), + AllTypes.Text.LT(String("Text")), + AllTypes.Text.LT_EQ(AllTypes.VarCharPtr), + AllTypes.Text.LT_EQ(String("Text")), + AllTypes.Text.CONCAT(String("text2")), + AllTypes.Text.CONCAT(Int(11)), + AllTypes.Text.LIKE(String("abc")), + AllTypes.Text.NOT_LIKE(String("_b_")), + //AllTypes.Text.REGEXP_LIKE(String("aba")), + //AllTypes.Text.REGEXP_LIKE(String("aba"), false), + //String("ABA").REGEXP_LIKE(String("aba"), true), + //AllTypes.Text.NOT_REGEXP_LIKE(String("aba")), + //AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false), + //String("ABA").NOT_REGEXP_LIKE(String("aba"), true), + + //BIT_LENGTH(AllTypes.Text), + //CHAR_LENGTH(AllTypes.Char), + //OCTET_LENGTH(AllTypes.Text), + LOWER(AllTypes.VarCharPtr), + UPPER(AllTypes.Char), + LTRIM(AllTypes.VarCharPtr), + RTRIM(AllTypes.VarCharPtr), + //CONCAT(String("string1"), Int(1), Float(11.12)), + //CONCAT_WS(String("string1"), Int(1), Float(11.12)), + //FORMAT(String("Hello %s, %1$s"), String("World")), + //LEFTSTR(String("abcde"), Int(2)), + //RIGHTSTR(String("abcde"), Int(2)), + LENGTH(String("jose")), + //LPAD(String("Hi"), Int(5), String("xy")), + //RPAD(String("Hi"), Int(5), String("xy")), + //MD5(AllTypes.VarCharPtr), + //REPEAT(AllTypes.Text, Int(33)), + REPLACE(AllTypes.Char, String("BA"), String("AB")), + //REVERSE(AllTypes.VarCharPtr), + SUBSTR(AllTypes.CharPtr, Int(3)), + SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), + ).FROM(AllTypes) + + dest := []struct{}{} + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) +} + +func TestReservedWord(t *testing.T) { + stmt := SELECT(ReservedWords.AllColumns). + FROM(ReservedWords) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT ''ReservedWords''.''column'' AS "ReservedWords.column", + ''ReservedWords''.use AS "ReservedWords.use", + ''ReservedWords''.ceil AS "ReservedWords.ceil", + ''ReservedWords''.''commit'' AS "ReservedWords.commit", + ''ReservedWords''.''create'' AS "ReservedWords.create", + ''ReservedWords''.''default'' AS "ReservedWords.default", + ''ReservedWords''.''desc'' AS "ReservedWords.desc", + ''ReservedWords''.empty AS "ReservedWords.empty", + ''ReservedWords''.float AS "ReservedWords.float", + ''ReservedWords''.''join'' AS "ReservedWords.join", + ''ReservedWords''.''like'' AS "ReservedWords.like", + ''ReservedWords''.max AS "ReservedWords.max", + ''ReservedWords''.rank AS "ReservedWords.rank" +FROM ''ReservedWords''; +`, "''", "`", -1)) + + var dest model.ReservedWords + err := stmt.Query(sampleDB, &dest) + require.NoError(t, err) + require.Equal(t, dest, model.ReservedWords{ + Column: "Column", + Use: "CHECK", + Ceil: "CEIL", + Commit: "COMMIT", + Create: "CREATE", + Default: "DEFAULT", + Desc: "DESC", + Empty: "EMPTY", + Float: "FLOAT", + Join: "JOIN", + Like: "LIKE", + Max: "MAX", + Rank: "RANK", + }) +} + +func TestExactDecimals(t *testing.T) { + + type exactDecimals struct { + model.ExactDecimals + Decimal decimal.Decimal + DecimalPtr decimal.Decimal + } + + t.Run("should query decimal", func(t *testing.T) { + query := SELECT( + ExactDecimals.AllColumns, + ).FROM( + ExactDecimals, + ).WHERE(ExactDecimals.Decimal.EQ(String("1.11111111111111111111"))) + + var result exactDecimals + + err := query.Query(sampleDB, &result) + require.NoError(t, err) + + require.Equal(t, "1.11111111111111111111", result.Decimal.String()) + require.Equal(t, "0", result.DecimalPtr.String()) // NULL + + require.Equal(t, "1.11111111111111111111", result.ExactDecimals.Decimal) // precision loss + require.Equal(t, (*string)(nil), result.ExactDecimals.DecimalPtr) + require.Equal(t, "2.22222222222222222222", result.ExactDecimals.Numeric) + require.Equal(t, (*string)(nil), result.ExactDecimals.NumericPtr) // NULL + }) + + t.Run("should insert decimal", func(t *testing.T) { + + insertQuery := ExactDecimals.INSERT( + ExactDecimals.AllColumns, + ).MODEL( + exactDecimals{ + ExactDecimals: model.ExactDecimals{ + // overwritten by wrapped(exactDecimals) scope + Decimal: "0.1", + DecimalPtr: nil, + + // not overwritten + Numeric: "6.7", + NumericPtr: testutils.StringPtr("7.7"), + }, + Decimal: decimal.RequireFromString("91.23"), + DecimalPtr: decimal.RequireFromString("45.67"), + }, + ).RETURNING(ExactDecimals.AllColumns) + + testutils.AssertDebugStatementSql(t, insertQuery, strings.Replace(` +INSERT INTO exact_decimals (decimal, decimal_ptr, numeric, numeric_ptr) +VALUES ('91.23', '45.67', '6.7', '7.7') +RETURNING exact_decimals.decimal AS "exact_decimals.decimal", + exact_decimals.decimal_ptr AS "exact_decimals.decimal_ptr", + exact_decimals.numeric AS "exact_decimals.numeric", + exact_decimals.numeric_ptr AS "exact_decimals.numeric_ptr"; +`, "''", "`", -1)) + + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var result exactDecimals + + err := insertQuery.Query(tx, &result) + require.NoError(t, err) + + require.Equal(t, "91.23", result.Decimal.String()) + require.Equal(t, "45.67", result.DecimalPtr.String()) + + require.Equal(t, "6.7", result.ExactDecimals.Numeric) + require.Equal(t, "7.7", *result.ExactDecimals.NumericPtr) + require.Equal(t, "91.23", result.ExactDecimals.Decimal) + require.Equal(t, "45.67", *result.ExactDecimals.DecimalPtr) + }) +} + +var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) + +func TestDateExpressions(t *testing.T) { + + query := AllTypes.SELECT( + //Date(2009, 11, 17, 2, MONTH, 1, DAY), + + //DateT(timeT, START_OF_THE_MONTH), + AllTypes.Date.AS("date"), + DATE("2009-11-17").AS("date1"), + DATE("2013-10-07 08:23:19.120", DAYS(1)).AS("date2"), + DATE(AllTypes.Date, START_OF_YEAR, DAYS(2)).AS("date3"), + DATE(timeT, START_OF_MONTH).AS("date3"), + DATE("now", WEEKDAY(1)).AS("date4"), + DATE(timeT.Unix(), UNIXEPOCH).AS("date5"), + DATE(time.Now(), UTC).AS("date6"), + DATE(time.Now().UTC(), LOCALTIME).AS("date7"), + + AllTypes.Date.EQ(AllTypes.Date), + AllTypes.Date.EQ(Date(2019, 6, 6)), + + AllTypes.DatePtr.NOT_EQ(AllTypes.Date), + AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)), + + AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date).AS("distinct1"), + AllTypes.Date.IS_DISTINCT_FROM(Date(2008, 7, 4)).AS("distinct2"), + + AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date), + AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)), + + AllTypes.Date.LT(AllTypes.Date), + AllTypes.Date.LT(Date(2019, 4, 6)), + + AllTypes.Date.LT_EQ(AllTypes.Date), + AllTypes.Date.LT_EQ(Date(2019, 5, 5)), + + AllTypes.Date.GT(AllTypes.Date), + AllTypes.Date.GT(Date(2019, 1, 4)), + + AllTypes.Date.GT_EQ(AllTypes.Date), + AllTypes.Date.GT_EQ(Date(2019, 2, 3)), + + //AllTypes.Date.ADD(INTERVAL2(2, HOUR)), + //AllTypes.Date.ADD(INTERVAL2(1, DAY, 7, MONTH)), + //AllTypes.Date.ADD(INTERVALd(25 * time.Hour + 100 * time.Millisecond)), + //AllTypes.Date.ADD(INTERVALd(-25 * time.Hour - 100 * time.Millisecond)), + // + //AllTypes.Date.SUB(INTERVAL(20, MINUTE)), + //AllTypes.Date.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + //AllTypes.Date.SUB(INTERVALd(3*time.Minute)), + + CURRENT_DATE().AS("current_date"), + ) + + var dest struct { + Date string + Date1 time.Time + Date2 string + Date3 time.Time + Date4 string + Date5 time.Time + Date6 string + Date7 time.Time + Distinct1 bool + Distinct2 bool + CurrentDate time.Time + } + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + + require.Equal(t, dest.Date, "2008-07-04T00:00:00Z") + require.Equal(t, dest.Date1.Unix(), int64(1258416000)) +} + +func TestTimeExpressions(t *testing.T) { + + query := AllTypes.SELECT( + TIME(AllTypes.Time).AS("time1"), + TIME(timeT).AS("time2"), + TIME("04:23:19.120-04:00", HOURS(1), MINUTES(2), SECONDS(1.234)).AS("time3"), + TIME(timeT.Unix(), UNIXEPOCH).AS("time4"), + TIME(time.Now(), UTC).AS("time5"), + TIME(time.Now().UTC(), LOCALTIME).AS("time6"), + + Time(timeT.Clock()), + + AllTypes.Time.EQ(AllTypes.Time), + AllTypes.Time.EQ(Time(23, 6, 6)), + AllTypes.Time.EQ(Time(22, 6, 6, 11*time.Millisecond)), + AllTypes.Time.EQ(Time(21, 6, 6, 11111*time.Microsecond)), + + AllTypes.TimePtr.NOT_EQ(AllTypes.Time), + AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)), + + AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time), + AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)), + + AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time), + AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)), + + AllTypes.Time.LT(AllTypes.Time), + AllTypes.Time.LT(Time(17, 46, 6)), + + AllTypes.Time.LT_EQ(AllTypes.Time), + AllTypes.Time.LT_EQ(Time(16, 56, 56)), + + AllTypes.Time.GT(AllTypes.Time), + AllTypes.Time.GT(Time(15, 16, 46)), + + AllTypes.Time.GT_EQ(AllTypes.Time), + AllTypes.Time.GT_EQ(Time(14, 26, 36)), + + //AllTypes.Time.ADD(INTERVAL(10, MINUTE)), + //AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)), + //AllTypes.Time.ADD(INTERVALd(3*time.Hour)), + // + //AllTypes.Time.SUB(INTERVAL(20, MINUTE)), + //AllTypes.Time.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + //AllTypes.Time.SUB(INTERVALd(3*time.Minute)), + // + //AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)), + + CURRENT_TIME(), + ) + + var dest struct { + Time1 string + Time2 time.Time + Time3 string + Time4 time.Time + Time5 string + Time6 time.Time + } + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + + require.Equal(t, dest.Time1, "10:11:12") + require.Equal(t, dest.Time2.UTC().String(), "0000-01-01 20:34:58 +0000 UTC") + require.Equal(t, dest.Time3, "09:25:20") +} + +func TestDateTimeExpressions(t *testing.T) { + + var dateTime = DateTime(2019, 6, 6, 10, 2, 46) + + query := SELECT( + DATETIME("now").AS("now"), + DATETIME("2013-10-07T08:23:19.120Z", YEARS(2), MONTHS(1), DAYS(1)).AS("datetime1"), + DATETIME(AllTypes.DateTime, MONTHS(1), DAYS(1)).AS("datetime2"), + DATETIME(timeT.Unix(), UNIXEPOCH).AS("datetime3"), + DATETIME(time.Now(), UTC).AS("datetime4"), + DATETIME(timeT.UTC(), LOCALTIME).AS("datetime5"), + + JULIANDAY(timeT, DAYS(1)).AS("JulianDay"), + STRFTIME(String("%H:%M"), timeT, SECONDS(1.22)).AS("strftime"), + + AllTypes.DateTime.EQ(AllTypes.DateTime), + AllTypes.DateTime.EQ(dateTime), + + AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime), + AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)), + + AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime), + AllTypes.DateTime.IS_DISTINCT_FROM(dateTime), + + AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime), + AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime), + + AllTypes.DateTime.LT(AllTypes.DateTime), + AllTypes.DateTime.LT(dateTime), + + AllTypes.DateTime.LT_EQ(AllTypes.DateTime), + AllTypes.DateTime.LT_EQ(dateTime), + + AllTypes.DateTime.GT(AllTypes.DateTime), + AllTypes.DateTime.GT(dateTime), + + AllTypes.DateTime.GT_EQ(AllTypes.DateTime), + AllTypes.DateTime.GT_EQ(dateTime), + + //AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + //AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)), + //AllTypes.DateTime.ADD(INTERVALd(2*time.Hour)), + // + //AllTypes.DateTime.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + //AllTypes.DateTime.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + //AllTypes.DateTime.SUB(INTERVALd(3*time.Hour)), + + CURRENT_TIMESTAMP(), + ).FROM(AllTypes) + + var dest struct { + Now time.Time + DateTime1 time.Time + DateTime2 time.Time + DateTime3 time.Time + DateTime4 time.Time + DateTime5 time.Time + JulianDay float64 + StrfTime string + } + + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + require.True(t, dest.Now.After(time.Now().Add(-1*time.Minute))) + require.Equal(t, dest.DateTime1.String(), "2015-11-08 08:23:19 +0000 UTC") + require.Equal(t, dest.DateTime2.String(), "2012-01-19 13:17:17 +0000 UTC") + require.Equal(t, dest.DateTime3.String(), "2009-11-17 20:34:58 +0000 UTC") + require.True(t, dest.DateTime4.After(time.Now().Add(-1*time.Minute))) + require.Equal(t, dest.JulianDay, 2.4551543576232754e+06) + require.Equal(t, dest.StrfTime, "20:34") +} diff --git a/tests/sqlite/cast_test.go b/tests/sqlite/cast_test.go new file mode 100644 index 00000000..a20a60c8 --- /dev/null +++ b/tests/sqlite/cast_test.go @@ -0,0 +1,41 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/stretchr/testify/require" + "testing" +) + +func TestCast(t *testing.T) { + query := SELECT( + CAST(String("test")).AS("CHARACTER").AS("result.AS1"), + CAST(Float(11.33)).AS_TEXT().AS("result.text"), + CAST(String("33.44")).AS_REAL().AS("result.real"), + CAST(String("33")).AS_INTEGER().AS("result.integer"), + CAST(String("Blob blob")).AS_BLOB().AS("result.blob"), + ) + + type Result struct { + As1 string + Text string + Real float64 + Integer int64 + Blob []byte + } + + var dest Result + + err := query.Query(db, &dest) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, dest, Result{ + As1: "test", + Text: "11.33", + Real: 33.44, + Integer: 33, + Blob: []byte("Blob blob"), + }) + + requireLogged(t, query) +} diff --git a/tests/sqlite/delete_test.go b/tests/sqlite/delete_test.go new file mode 100644 index 00000000..7045772b --- /dev/null +++ b/tests/sqlite/delete_test.go @@ -0,0 +1,83 @@ +package sqlite + +import ( + "context" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/stretchr/testify/require" +) + +func TestDelete_WHERE_RETURNING(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +DELETE FROM link +WHERE link.name IN ('Bing', 'Yahoo') +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +` + deleteStmt := Link.DELETE(). + WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))). + RETURNING(Link.AllColumns) + + testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Bing", "Yahoo") + var dest []model.Link + err := deleteStmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 2) + requireLogged(t, deleteStmt) +} + +func TestDeleteWithWhereOrderByLimit(t *testing.T) { + t.SkipNow() // Until https://github.com/mattn/go-sqlite3/pull/802 is fixed + tx := beginSampleDBTx(t) + defer tx.Rollback() + + sampleDB.Stats() + + var expectedSQL = ` +DELETE FROM link +WHERE link.name IN ('Bing', 'Yahoo') +ORDER BY link.name +LIMIT 1; +` + deleteStmt := Link.DELETE(). + WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))). + ORDER_BY(Link.Name). + LIMIT(1) + + testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Bing", "Yahoo", int64(1)) + testutils.AssertExec(t, deleteStmt, tx, 1) + requireLogged(t, deleteStmt) +} + +func TestDeleteContextDeadlineExceeded(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + deleteStmt := Link. + DELETE(). + WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) + + dest := []model.Link{} + err := deleteStmt.QueryContext(ctx, tx, &dest) + require.Error(t, err, "context deadline exceeded") + + _, err = deleteStmt.ExecContext(ctx, tx) + require.Error(t, err, "context deadline exceeded") + + requireLogged(t, deleteStmt) +} diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go new file mode 100644 index 00000000..ac7ab5d3 --- /dev/null +++ b/tests/sqlite/generator_test.go @@ -0,0 +1,298 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" + "github.com/stretchr/testify/require" + "io/ioutil" + "os" + "os/exec" + "reflect" + "testing" +) + +func TestGeneratedModel(t *testing.T) { + actor := model.Actor{} + + require.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") + actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") + require.True(t, ok) + require.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") + require.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") + require.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") + require.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") + + filmActor := model.FilmActor{} + + require.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int32") + filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") + require.True(t, ok) + require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") + + require.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int32") + actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") + require.True(t, ok) + require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") + + staff := model.Staff{} + + require.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") + require.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8") +} + +var testDatabaseFilePath = repo.GetTestDataFilePath("/init/sqlite/sakila.db") +var genDestDir = repo.GetTestsFilePath("/sqlite/.gen") + +func TestGenerator(t *testing.T) { + for i := 0; i < 3; i++ { + err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir) + require.NoError(t, err) + + assertGeneratedFiles(t) + } + + err := os.RemoveAll(genDestDir) + require.NoError(t, err) +} + +func TestCmdGenerator(t *testing.T) { + cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err := cmd.Run() + require.NoError(t, err) + + assertGeneratedFiles(t) + + err = os.RemoveAll(genDestDir) + require.NoError(t, err) +} + +func assertGeneratedFiles(t *testing.T) { + // Table SQL Builder files + tableSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/table") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go") + + testutils.AssertFileContent(t, genDestDir+"/table/actor.go", actorSQLBuilderFile) + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/view") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, genDestDir+"/view/film_list.go", filmListSQLBuilderFile) + + // Model files + modelFiles, err := ioutil.ReadDir(genDestDir + "/model") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", + "film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, genDestDir+"/model/address.go", addressModelFile) +} + +const actorSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/sqlite" +) + +var Actor = newActorTable("", "actor", "") + +type actorTable struct { + sqlite.Table + + //Columns + ActorID sqlite.ColumnInteger + FirstName sqlite.ColumnString + LastName sqlite.ColumnString + LastUpdate sqlite.ColumnTimestamp + + AllColumns sqlite.ColumnList + MutableColumns sqlite.ColumnList +} + +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + +// AS creates new ActorTable with assigned alias +func (a ActorTable) AS(alias string) *ActorTable { + return newActorTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new ActorTable with assigned schema name +func (a ActorTable) FromSchema(schemaName string) *ActorTable { + return newActorTable(schemaName, a.TableName(), a.Alias()) +} + +func newActorTable(schemaName, tableName, alias string) *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl(schemaName, tableName, alias), + EXCLUDED: newActorTableImpl("", "excluded", ""), + } +} + +func newActorTableImpl(schemaName, tableName, alias string) actorTable { + var ( + ActorIDColumn = sqlite.IntegerColumn("actor_id") + FirstNameColumn = sqlite.StringColumn("first_name") + LastNameColumn = sqlite.StringColumn("last_name") + LastUpdateColumn = sqlite.TimestampColumn("last_update") + allColumns = sqlite.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = sqlite.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} + ) + + return actorTable{ + Table: sqlite.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + ActorID: ActorIDColumn, + FirstName: FirstNameColumn, + LastName: LastNameColumn, + LastUpdate: LastUpdateColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +const filmListSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package view + +import ( + "github.com/go-jet/jet/v2/sqlite" +) + +var FilmList = newFilmListTable("", "film_list", "") + +type filmListTable struct { + sqlite.Table + + //Columns + Fid sqlite.ColumnInteger + Title sqlite.ColumnString + Description sqlite.ColumnString + Category sqlite.ColumnString + Price sqlite.ColumnFloat + Length sqlite.ColumnInteger + Rating sqlite.ColumnString + Actors sqlite.ColumnString + + AllColumns sqlite.ColumnList + MutableColumns sqlite.ColumnList +} + +type FilmListTable struct { + filmListTable + + EXCLUDED filmListTable +} + +// AS creates new FilmListTable with assigned alias +func (a FilmListTable) AS(alias string) *FilmListTable { + return newFilmListTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new FilmListTable with assigned schema name +func (a FilmListTable) FromSchema(schemaName string) *FilmListTable { + return newFilmListTable(schemaName, a.TableName(), a.Alias()) +} + +func newFilmListTable(schemaName, tableName, alias string) *FilmListTable { + return &FilmListTable{ + filmListTable: newFilmListTableImpl(schemaName, tableName, alias), + EXCLUDED: newFilmListTableImpl("", "excluded", ""), + } +} + +func newFilmListTableImpl(schemaName, tableName, alias string) filmListTable { + var ( + FidColumn = sqlite.IntegerColumn("FID") + TitleColumn = sqlite.StringColumn("title") + DescriptionColumn = sqlite.StringColumn("description") + CategoryColumn = sqlite.StringColumn("category") + PriceColumn = sqlite.FloatColumn("price") + LengthColumn = sqlite.IntegerColumn("length") + RatingColumn = sqlite.StringColumn("rating") + ActorsColumn = sqlite.StringColumn("actors") + allColumns = sqlite.ColumnList{FidColumn, TitleColumn, DescriptionColumn, CategoryColumn, PriceColumn, LengthColumn, RatingColumn, ActorsColumn} + mutableColumns = sqlite.ColumnList{FidColumn, TitleColumn, DescriptionColumn, CategoryColumn, PriceColumn, LengthColumn, RatingColumn, ActorsColumn} + ) + + return filmListTable{ + Table: sqlite.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + Fid: FidColumn, + Title: TitleColumn, + Description: DescriptionColumn, + Category: CategoryColumn, + Price: PriceColumn, + Length: LengthColumn, + Rating: RatingColumn, + Actors: ActorsColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +const addressModelFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +import ( + "time" +) + +type Address struct { + AddressID int32 ` + "`sql:\"primary_key\"`" + ` + Address string + Address2 *string + District string + CityID int32 + PostalCode *string + Phone string + LastUpdate time.Time +} +` diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go new file mode 100644 index 00000000..f5939bb2 --- /dev/null +++ b/tests/sqlite/insert_test.go @@ -0,0 +1,393 @@ +package sqlite + +import ( + "context" + "math/rand" + + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/stretchr/testify/require" +) + +func TestInsertValues(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + VALUES(101, "http://www.google.com", "Google", "Search engine"). + VALUES(102, "http://www.yahoo.com", "Yahoo", nil) + + testutils.AssertStatementSql(t, insertQuery, ` +INSERT INTO link (id, url, name, description) +VALUES (?, ?, ?, ?), + (?, ?, ?, ?), + (?, ?, ?, ?); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, + 101, "http://www.google.com", "Google", "Search engine", + 102, "http://www.yahoo.com", "Yahoo", nil) + + _, err := insertQuery.Exec(tx) + require.NoError(t, err) + requireLogged(t, insertQuery) + + insertedLinks := []model.Link{} + + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + Description: testutils.StringPtr("Search engine"), + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) +} + +var postgreTutorial = model.Link{ + ID: 100, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", +} + +func TestInsertEmptyColumnList(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + expectedSQL := ` +INSERT INTO link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL); +` + + stmt := Link.INSERT(). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) + + _, err := stmt.Exec(tx) + require.NoError(t, err) + requireLogged(t, stmt) + + insertedLinks := []model.Link{} + + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) +} + +func TestInsertModelObject(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + linkData := model.Link{ + URL: "http://www.duckduckgo.com", + Name: "Duck Duck go", + } + + query := Link.INSERT(Link.URL, Link.Name). + MODEL(linkData) + + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); +`, "http://www.duckduckgo.com", "Duck Duck go") + + _, err := query.Exec(tx) + require.NoError(t, err) +} + +func TestInsertModelObjectEmptyColumnList(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +INSERT INTO link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +` + + linkData := model.Link{ + ID: 1000, + URL: "http://www.duckduckgo.com", + Name: "Duck Duck go", + } + + query := Link. + INSERT(). + MODEL(linkData) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + + _, err := query.Exec(tx) + require.NoError(t, err) +} + +func TestInsertModelsObject(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + expectedSQL := ` +INSERT INTO link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +` + + tutorial := model.Link{ + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", + } + google := model.Link{ + URL: "http://www.google.com", + Name: "Google", + } + yahoo := model.Link{ + URL: "http://www.yahoo.com", + Name: "Yahoo", + } + + query := Link. + INSERT(Link.URL, Link.Name). + MODELS([]model.Link{ + tutorial, + google, + yahoo, + }) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, + "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + "http://www.google.com", "Google", + "http://www.yahoo.com", "Yahoo") + + _, err := query.Exec(tx) + require.NoError(t, err) +} + +func TestInsertUsingMutableColumns(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +INSERT INTO link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +` + + google := model.Link{ + URL: "http://www.google.com", + Name: "Google", + } + + yahoo := model.Link{ + URL: "http://www.yahoo.com", + Name: "Yahoo", + } + + stmt := Link. + INSERT(Link.MutableColumns). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + MODEL(google). + MODELS([]model.Link{google, yahoo}) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, + "http://www.google.com", "Google", nil, + "http://www.google.com", "Google", nil, + "http://www.yahoo.com", "Yahoo", nil) + + _, err := stmt.Exec(tx) + require.NoError(t, err) +} + +func TestInsertQuery(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +INSERT INTO link (url, name) +SELECT link.url AS "link.url", + link.name AS "link.name" +FROM link +WHERE link.id = 24; +` + query := Link.INSERT(Link.URL, Link.Name). + QUERY( + SELECT(Link.URL, Link.Name). + FROM(Link). + WHERE(Link.ID.EQ(Int(24))), + ) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24)) + + _, err := query.Exec(tx) + require.NoError(t, err) + + youtubeLinks := []model.Link{} + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bing"))). + Query(tx, &youtubeLinks) + + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) +} + +func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + stmt := Link.INSERT(). + DEFAULT_VALUES(). + RETURNING(Link.AllColumns) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link +DEFAULT VALUES +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + var link model.Link + err := stmt.Query(tx, &link) + require.NoError(t, err) + + require.EqualValues(t, link, model.Link{ + ID: 25, + URL: "www.", + Name: "_", + Description: nil, + }) +} + +func TestInsertOnConflict(t *testing.T) { + + t.Run("do nothing", func(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ID: rand.Int31()} + + stmt := Link.INSERT(Link.AllColumns). + MODEL(link). + MODEL(link). + ON_CONFLICT(Link.ID).DO_NOTHING() + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO link (id, url, name, description) +VALUES (?, ?, ?, ?), + (?, ?, ?, ?) +ON CONFLICT (id) DO NOTHING; +`) + testutils.AssertExec(t, stmt, tx, 1) + requireLogged(t, stmt) + }) + + t.Run("do update", func(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + ON_CONFLICT(Link.ID). + DO_UPDATE( + SET( + Link.ID.SET(Link.EXCLUDED.ID), + Link.URL.SET(String("http://www.postgresqltutorial2.com")), + ), + ).RETURNING(Link.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO link (id, url, name, description) +VALUES (?, ?, ?, ?), + (?, ?, ?, ?) +ON CONFLICT (id) DO UPDATE + SET id = excluded.id, + url = ? +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + }) + + t.Run("do update complex", func(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + ON_CONFLICT(Link.ID). + WHERE(Link.ID.MUL(Int(2)).GT(Int(10))). + DO_UPDATE( + SET( + Link.ID.SET( + IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))). + FROM(Link)), + ), + ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String(""))), + ).WHERE(Link.Description.IS_NOT_NULL()), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link (id, url, name, description) +VALUES (21, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL) +ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE + SET id = ( + SELECT MAX(link.id) + 1 + FROM link + ), + (name, description) = (excluded.name, '') + WHERE link.description IS NOT NULL; +`) + + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + }) +} + +func TestInsertContextDeadlineExceeded(t *testing.T) { + stmt := Link.INSERT(). + VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) + + dest := []model.Link{} + err := stmt.QueryContext(ctx, sampleDB, &dest) + require.Error(t, err, "context deadline exceeded") + + _, err = stmt.ExecContext(ctx, db) + require.Error(t, err, "context deadline exceeded") +} diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go new file mode 100644 index 00000000..710f7ad5 --- /dev/null +++ b/tests/sqlite/main_test.go @@ -0,0 +1,90 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" + "math/rand" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/pkg/profile" + + _ "github.com/mattn/go-sqlite3" +) + +var db *sql.DB +var sampleDB *sql.DB +var testRoot string + +func TestMain(m *testing.M) { + rand.Seed(time.Now().Unix()) + defer profile.Start().Stop() + + setTestRoot() + + var err error + db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) + throw.OnError(err) + + _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) + throw.OnError(err) + + sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath) + throw.OnError(err) + + defer db.Close() + + ret := m.Run() + + if ret != 0 { + os.Exit(ret) + } +} + +func setTestRoot() { + cmd := exec.Command("git", "rev-parse", "--show-toplevel") + byteArr, err := cmd.Output() + if err != nil { + panic(err) + } + + testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" +} + +var loggedSQL string +var loggedSQLArgs []interface{} +var loggedDebugSQL string + +func init() { + sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) { + loggedSQL, loggedSQLArgs = statement.Sql() + loggedDebugSQL = statement.DebugSql() + }) +} + +func requireLogged(t *testing.T, statement sqlite.Statement) { + query, args := statement.Sql() + require.Equal(t, loggedSQL, query) + require.Equal(t, loggedSQLArgs, args) + require.Equal(t, loggedDebugSQL, statement.DebugSql()) +} + +func beginSampleDBTx(t *testing.T) *sql.Tx { + tx, err := sampleDB.Begin() + require.NoError(t, err) + return tx +} + +func beginDBTx(t *testing.T) *sql.Tx { + tx, err := db.Begin() + require.NoError(t, err) + return tx +} diff --git a/tests/sqlite/raw_statement_test.go b/tests/sqlite/raw_statement_test.go new file mode 100644 index 00000000..974dfda1 --- /dev/null +++ b/tests/sqlite/raw_statement_test.go @@ -0,0 +1,121 @@ +package sqlite + +import ( + "context" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/stretchr/testify/require" +) + +func TestRawStatementSelect(t *testing.T) { + stmt := RawStatement(` + SELECT actor.first_name AS "actor.first_name" + FROM actor + WHERE actor.actor_id = 2`) + + testutils.AssertStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM actor + WHERE actor.actor_id = 2; +`) + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM actor + WHERE actor.actor_id = 2; +`) + var actor model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + require.Equal(t, actor.FirstName, "NICK") +} + +func TestRawStatementSelectWithArguments(t *testing.T) { + stmt := RawStatement(` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + WHERE actor.actor_id IN (#actorID1, #actorID2, #actorID3) AND ((#actorID1 / #actorID2) <> (#actorID2 * #actorID3)) + ORDER BY actor.actor_id`, + RawArgs{ + "#actorID1": int64(1), + "#actorID2": int64(2), + "#actorID3": int64(3), + }, + ) + + testutils.AssertStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + WHERE actor.actor_id IN (?, ?, ?) AND ((? / ?) <> (? * ?)) + ORDER BY actor.actor_id; +`, int64(1), int64(2), int64(3), int64(1), int64(2), int64(2), int64(3)) + + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + WHERE actor.actor_id IN (1, 2, 3) AND ((1 / 2) <> (2 * 3)) + ORDER BY actor.actor_id; +`) + + var actor []model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, actor[1], model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", + LastUpdate: *testutils.TimestampWithoutTimeZone("2019-04-11 18:11:48", 2), + }) +} + +func TestRawStatementRows(t *testing.T) { + stmt := RawStatement(` + SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + ORDER BY actor.actor_id`) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var actor model.Actor + err := rows.Scan(&actor) + require.NoError(t, err) + + require.NotEqual(t, actor.ActorID, int16(0)) + require.NotEqual(t, actor.FirstName, "") + require.NotEqual(t, actor.LastName, "") + require.NotEqual(t, actor.LastUpdate, time.Time{}) + + if actor.ActorID == 54 { + require.Equal(t, actor.ActorID, int32(54)) + require.Equal(t, actor.FirstName, "PENELOPE") + require.Equal(t, actor.LastName, "PINKETT") + require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2019-04-11T18:11:48Z") + } + } + + err = rows.Close() + require.NoError(t, err) + + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go new file mode 100644 index 00000000..95527f1e --- /dev/null +++ b/tests/sqlite/select_test.go @@ -0,0 +1,749 @@ +package sqlite + +import ( + "context" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table" + "strings" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/view" + + "github.com/stretchr/testify/require" +) + +func TestSelect_ScanToStruct(t *testing.T) { + query := Actor. + SELECT(Actor.AllColumns). + DISTINCT(). + WHERE(Actor.ActorID.EQ(Int(2))) + + testutils.AssertStatementSql(t, query, ` +SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM actor +WHERE actor.actor_id = ?; +`, int64(2)) + + actor := model.Actor{} + err := query.Query(db, &actor) + + require.NoError(t, err) + + testutils.AssertDeepEqual(t, actor, actor2) + requireLogged(t, query) +} + +var actor2 = model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", + LastUpdate: *testutils.TimestampWithoutTimeZone("2019-04-11 18:11:48", 2), +} + +func TestSelect_ScanToSlice(t *testing.T) { + query := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID) + + testutils.AssertStatementSql(t, query, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM actor +ORDER BY actor.actor_id; +`) + dest := []model.Actor{} + + err := query.Query(db, &dest) + + require.NoError(t, err) + + require.Equal(t, len(dest), 200) + testutils.AssertDeepEqual(t, dest[1], actor2) + + //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json") + requireLogged(t, query) +} + +func TestSelectGroupByHaving(t *testing.T) { + expectedSQL := ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.active AS "customer.active", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + SUM(payment.amount) AS "amount.sum", + AVG(payment.amount) AS "amount.avg", + MAX(payment.payment_date) AS "amount.max_date", + MAX(payment.amount) AS "amount.max", + MIN(payment.payment_date) AS "amount.min_date", + MIN(payment.amount) AS "amount.min", + COUNT(payment.amount) AS "amount.count" +FROM payment + INNER JOIN customer ON (customer.customer_id = payment.customer_id) +GROUP BY payment.customer_id +HAVING SUM(payment.amount) > 125.6 +ORDER BY payment.customer_id, SUM(payment.amount) ASC; +` + query := Payment. + INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)). + SELECT( + Customer.AllColumns, + + SUMf(Payment.Amount).AS("amount.sum"), + AVG(Payment.Amount).AS("amount.avg"), + MAX(Payment.PaymentDate).AS("amount.max_date"), + MAXf(Payment.Amount).AS("amount.max"), + MIN(Payment.PaymentDate).AS("amount.min_date"), + MINf(Payment.Amount).AS("amount.min"), + COUNT(Payment.Amount).AS("amount.count"), + ). + GROUP_BY(Payment.CustomerID). + HAVING( + SUMf(Payment.Amount).GT(Float(125.6)), + ). + ORDER_BY( + Payment.CustomerID, SUMf(Payment.Amount).ASC(), + ) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, float64(125.6)) + + var dest []struct { + model.Customer + + Amount struct { + Sum float64 + Avg float64 + Max float64 + Min float64 + Count int64 + } `alias:"amount"` + } + + err := query.Query(db, &dest) + + require.NoError(t, err) + require.Equal(t, len(dest), 174) + //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/customer_payment_sum.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/customer_payment_sum.json") + requireLogged(t, query) +} + +func TestSubQuery(t *testing.T) { + + rRatingFilms := + SELECT( + Film.FilmID, + Film.Title, + Film.Rating, + ).FROM( + Film, + ).WHERE(Film.Rating.EQ(String("R"))). + AsTable("rFilms") + + rFilmID := Film.FilmID.From(rRatingFilms) + + main := + SELECT( + Actor.AllColumns, + FilmActor.AllColumns, + rRatingFilms.AllColumns(), + ).FROM( + rRatingFilms. + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)), + ).ORDER_BY( + rFilmID, + Actor.ActorID, + ) + + var dest []struct { + model.Film + Actors []model.Actor + } + + err := main.Query(db, &dest) + require.NoError(t, err) + + //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/r_rating_films.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/r_rating_films.json") +} + +func TestSelectAndUnionInProjection(t *testing.T) { + query := UNION( + SELECT( + Payment.PaymentID, + ).FROM(Payment), + + SELECT( + STAR, + ).FROM( + SELECT(Payment.PaymentID). + FROM(Payment).LIMIT(1).OFFSET(2).AsTable("p"), + ), + ).LIMIT(1).OFFSET(10) + + testutils.AssertDebugStatementSql(t, query, ` + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment + +UNION + +SELECT * +FROM ( + SELECT payment.payment_id AS "payment.payment_id" + FROM payment + LIMIT 1 + OFFSET 2 + ) AS p +LIMIT 1 +OFFSET 10; +`, int64(1), int64(2), int64(1), int64(10)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestSelectUNION(t *testing.T) { + expectedSQL := ` + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.payment_id > ? + +UNION + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.amount < ? +LIMIT ?; +` + query := UNION( + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.PaymentID.GT(Int(11))), + + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(2000.0))), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(11), 2000.0, int64(1)) + + query2 := + SELECT( + Payment.PaymentID, + ).FROM( + Payment, + ).WHERE( + Payment.PaymentID.GT(Int(11)), + ).UNION( + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(2000.0))), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query2, expectedSQL, int64(11), 2000.0, int64(1)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestSelectUNION_ALL(t *testing.T) { + expectedSQL := ` + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.payment_id > ? + +UNION ALL + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.amount < ? +LIMIT ?; +` + query := UNION_ALL( + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.PaymentID.GT(Int(11))), + + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(2000.0))), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(11), 2000.0, int64(1)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestJoinQueryStruct(t *testing.T) { + + expectedSQL := ` +SELECT film_actor.actor_id AS "film_actor.actor_id", + film_actor.film_id AS "film_actor.film_id", + film_actor.last_update AS "film_actor.last_update", + film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update", + inventory.inventory_id AS "inventory.inventory_id", + inventory.film_id AS "inventory.film_id", + inventory.store_id AS "inventory.store_id", + inventory.last_update AS "inventory.last_update", + rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM language + INNER JOIN film ON (film.language_id = language.language_id) + INNER JOIN film_actor ON (film_actor.film_id = film.film_id) + INNER JOIN actor ON (actor.actor_id = film_actor.actor_id) + LEFT JOIN inventory ON (inventory.film_id = film.film_id) + LEFT JOIN rental ON (rental.inventory_id = inventory.inventory_id) +ORDER BY language.language_id ASC, film.film_id ASC, actor.actor_id ASC, inventory.inventory_id ASC, rental.rental_id ASC +LIMIT ?; +` + for i := 0; i < 2; i++ { + query := + SELECT( + FilmActor.AllColumns, + Film.AllColumns, + Language.AllColumns, + Actor.AllColumns, + Inventory.AllColumns, + Rental.AllColumns, + ). + FROM( + Language. + INNER_JOIN(Film, Film.LanguageID.EQ(Language.LanguageID)). + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)). + LEFT_JOIN(Inventory, Inventory.FilmID.EQ(Film.FilmID)). + LEFT_JOIN(Rental, Rental.InventoryID.EQ(Inventory.InventoryID)), + ).ORDER_BY( + Language.LanguageID.ASC(), + Film.FilmID.ASC(), + Actor.ActorID.ASC(), + Inventory.InventoryID.ASC(), + Rental.RentalID.ASC(), + ). + LIMIT(1000) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(1000)) + + var dest []struct { + model.Language + + Films []struct { + model.Film + + Actors []struct { + model.Actor + } + + Inventories *[]struct { + model.Inventory + + Rentals *[]model.Rental + } + } + } + + err := query.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/lang_film_actor_inventory_rental.json") + } +} + +func TestExpressionWrappers(t *testing.T) { + query := SELECT( + BoolExp(Raw("true")), + IntExp(Raw("11")), + FloatExp(Raw("11.22")), + StringExp(Raw("'stringer'")), + TimeExp(Raw("'raw'")), + TimestampExp(Raw("'raw'")), + DateTimeExp(Raw("'raw'")), + DateExp(Raw("'date'")), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT true, + 11, + 11.22, + 'stringer', + 'raw', + 'raw', + 'raw', + 'date'; +`) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestWindowFunction(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (PARTITION BY payment.customer_id), + MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC), + MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING), + MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + ROW_NUMBER() OVER (ORDER BY payment.payment_date), + RANK() OVER (ORDER BY payment.payment_date), + DENSE_RANK() OVER (ORDER BY payment.payment_date), + CUME_DIST() OVER (ORDER BY payment.payment_date), + NTILE(11) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date) +FROM payment +WHERE payment.payment_id < ? +GROUP BY payment.amount, payment.customer_id, payment.payment_date; +` + query := + SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)), + MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())), + MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))), + MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)), + RANK().OVER(ORDER_BY(Payment.PaymentDate)), + DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)), + CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)), + NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)), + ).FROM( + Payment, + ).GROUP_BY( + Payment.Amount, + Payment.CustomerID, + Payment.PaymentDate, + ).WHERE(Payment.PaymentID.LT(Int(10))) + + testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestWindowClause(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (w1), + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +FROM payment +WHERE payment.payment_id < ? +WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) +ORDER BY payment.customer_id; +` + query := SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(Window("w1")), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + ).FROM( + Payment, + ).WHERE( + Payment.PaymentID.LT(Int(10)), + ). + WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY( + Payment.CustomerID, + ) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) + + dest := []struct{}{} + err := query.Query(db, &dest) + + require.NoError(t, err) +} + +func TestSimpleView(t *testing.T) { + query := + SELECT( + view.CustomerList.AllColumns, + ).FROM( + view.CustomerList, + ).ORDER_BY( + view.CustomerList.ID, + ).LIMIT(10) + + var dest []model.CustomerList + + err := query.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 10) + require.Equal(t, dest[2], model.CustomerList{ + ID: testutils.Int32Ptr(3), + Name: testutils.StringPtr("LINDA WILLIAMS"), + Address: testutils.StringPtr("692 Joliet Street"), + ZipCode: testutils.StringPtr("83579"), + Phone: testutils.StringPtr(" "), + City: testutils.StringPtr("Athenai"), + Country: testutils.StringPtr("Greece"), + Notes: testutils.StringPtr("active"), + Sid: testutils.Int32Ptr(1), + }) +} + +func TestJoinViewWithTable(t *testing.T) { + query := + SELECT( + view.CustomerList.AllColumns, + Rental.AllColumns, + ).FROM( + view.CustomerList. + INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)), + ).ORDER_BY( + view.CustomerList.ID, + ).WHERE( + view.CustomerList.ID.LT_EQ(Int(2)), + ) + + var dest []struct { + model.CustomerList `sql:"primary_key=ID"` + Rentals []model.Rental + } + + err := query.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 2) + require.Equal(t, len(dest[0].Rentals), 32) + require.Equal(t, len(dest[1].Rentals), 27) +} + +func TestConditionalProjectionList(t *testing.T) { + projectionList := ProjectionList{} + + columnsToSelect := []string{"customer_id", "create_date"} + + for _, columnName := range columnsToSelect { + switch columnName { + case Customer.CustomerID.Name(): + projectionList = append(projectionList, Customer.CustomerID) + case Customer.Email.Name(): + projectionList = append(projectionList, Customer.Email) + case Customer.CreateDate.Name(): + projectionList = append(projectionList, Customer.CreateDate) + } + } + + stmt := SELECT(projectionList). + FROM(Customer). + LIMIT(3) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT customer.customer_id AS "customer.customer_id", + customer.create_date AS "customer.create_date" +FROM customer +LIMIT 3; +`) + var dest []model.Customer + err := stmt.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 3) +} + +func TestUseAttachedDatabase(t *testing.T) { + Artists := table.Artists.FromSchema("chinook") + Albums := table.Albums.FromSchema("chinook") + + stmt := + SELECT( + Artists.AllColumns, + Albums.AllColumns, + ).FROM( + Albums. + INNER_JOIN(Artists, Artists.ArtistId.EQ(Albums.ArtistId)), + ).ORDER_BY( + Artists.ArtistId, + ).LIMIT(10) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT artists.''ArtistId'' AS "artists.ArtistId", + artists.''Name'' AS "artists.Name", + albums.''AlbumId'' AS "albums.AlbumId", + albums.''Title'' AS "albums.Title", + albums.''ArtistId'' AS "albums.ArtistId" +FROM chinook.albums + INNER JOIN chinook.artists ON (artists.''ArtistId'' = albums.''ArtistId'') +ORDER BY artists.''ArtistId'' +LIMIT 10; +`, "''", "`", -1)) + + var dest []struct { + model2.Artists + Albums []model2.Albums + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 7) +} + +func TestRowsScan(t *testing.T) { + stmt := + SELECT( + Inventory.AllColumns, + ).FROM( + Inventory, + ).ORDER_BY( + Inventory.InventoryID.ASC(), + ) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var inventory model.Inventory + err = rows.Scan(&inventory) + require.NoError(t, err) + + require.NotEqual(t, inventory.InventoryID, uint32(0)) + require.NotEqual(t, inventory.FilmID, uint16(0)) + require.NotEqual(t, inventory.StoreID, uint16(0)) + require.NotEqual(t, inventory.LastUpdate, time.Time{}) + + if inventory.InventoryID == 2103 { + require.Equal(t, inventory.FilmID, int32(456)) + require.Equal(t, inventory.StoreID, int32(2)) + require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2019-04-11T18:11:48Z") + } + } + + err = rows.Close() + require.NoError(t, err) + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} + +func TestScanNumericToNumber(t *testing.T) { + type Number struct { + Int8 int8 + UInt8 uint8 + Int16 int16 + UInt16 uint16 + Int32 int32 + UInt32 uint32 + Int64 int64 + UInt64 uint64 + Float32 float32 + Float64 float64 + } + + numeric := CAST(String("1234567890.111")).AS_REAL() + + stmt := SELECT( + numeric.AS("number.int8"), + numeric.AS("number.uint8"), + numeric.AS("number.int16"), + numeric.AS("number.uint16"), + numeric.AS("number.int32"), + numeric.AS("number.uint32"), + numeric.AS("number.int64"), + numeric.AS("number.uint64"), + numeric.AS("number.float32"), + numeric.AS("number.float64"), + ) + + var number Number + err := stmt.Query(db, &number) + require.NoError(t, err) + + require.Equal(t, number.Int8, int8(-46)) // overflow + require.Equal(t, number.UInt8, uint8(210)) // overflow + require.Equal(t, number.Int16, int16(722)) // overflow + require.Equal(t, number.UInt16, uint16(722)) // overflow + require.Equal(t, number.Int32, int32(1234567890)) + require.Equal(t, number.UInt32, uint32(1234567890)) + require.Equal(t, number.Int64, int64(1234567890)) + require.Equal(t, number.UInt64, uint64(1234567890)) + require.Equal(t, number.Float32, float32(1.234568e+09)) + require.Equal(t, number.Float64, float64(1.234567890111e+09)) +} diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go new file mode 100644 index 00000000..61135a8f --- /dev/null +++ b/tests/sqlite/update_test.go @@ -0,0 +1,290 @@ +package sqlite + +import ( + "context" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/stretchr/testify/require" +) + +func TestUpdateValues(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +UPDATE link +SET name = 'Bong', + url = 'http://bong.com' +WHERE link.name = 'Bing'; +` + t.Run("old version", func(t *testing.T) { + query := Link.UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(String("http://bong.com")), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + }) + + links := []model.Link{} + + err := SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.Name.EQ(String("Bong"))). + Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 24, + URL: "http://bong.com", + Name: "Bong", + }) +} + +func TestUpdateWithSubQueries(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + expectedSQL := ` +UPDATE link +SET name = ?, + url = ( + SELECT link.url AS "link.url" + FROM link + WHERE link.name = ? + ) +WHERE link.name = ?; +` + t.Run("old version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + String("Bong"), + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Ask"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Ask", "Bing") + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + }) + + t.Run("new version", func(t *testing.T) { + query := Link. + UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(StringExp( + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Ask"))), + )), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Ask", "Bing") + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + }) +} + +func TestUpdateWithModelDataAndReturning(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ + ID: 20, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link.UPDATE(Link.AllColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int32(link.ID))). + RETURNING( + Link.AllColumns, + String("str").AS("dest.literal"), + NOT(Bool(false)).AS("dest.unary_operator"), + Link.ID.ADD(Int(11)).AS("dest.binary_operator"), + CAST(Link.ID).AS_TEXT().AS("dest.cast_operator"), + Link.Name.LIKE(String("Bing")).AS("dest.like_operator"), + Link.Description.IS_NULL().AS("dest.is_null"), + CASE(Link.Name). + WHEN(String("Yahoo")).THEN(String("search")). + WHEN(String("GMail")).THEN(String("mail")). + ELSE(String("unknown")).AS("dest.case_operator"), + ) + + expectedSQL := ` +UPDATE link +SET id = ?, + url = ?, + name = ?, + description = ? +WHERE link.id = ? +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description", + ? AS "dest.literal", + (NOT ?) AS "dest.unary_operator", + (link.id + ?) AS "dest.binary_operator", + CAST(link.id AS TEXT) AS "dest.cast_operator", + (link.name LIKE ?) AS "dest.like_operator", + link.description IS NULL AS "dest.is_null", + (CASE link.name WHEN ? THEN ? WHEN ? THEN ? ELSE ? END) AS "dest.case_operator"; +` + testutils.AssertStatementSql(t, stmt, expectedSQL, int32(20), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(20), + "str", false, int64(11), "Bing", "Yahoo", "search", "GMail", "mail", "unknown") + + type Dest struct { + model.Link + Literal string + UnaryOperator bool + BinaryOperator int64 + CastOperator string + LikeOperator bool + IsNull bool + CaseOperator string + } + + var dest Dest + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.EqualValues(t, dest, Dest{ + Link: link, + Literal: "str", + UnaryOperator: true, + BinaryOperator: 31, + CastOperator: "20", + LikeOperator: false, + IsNull: true, + CaseOperator: "unknown", + }) + requireLogged(t, stmt) +} + +func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ + ID: 20, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + updateColumnList := ColumnList{Link.Description, Link.Name, Link.URL} + + stmt := Link.UPDATE(updateColumnList). + MODEL(link). + WHERE(Link.ID.EQ(Int32(link.ID))) + + var expectedSQL = ` +UPDATE link +SET description = NULL, + name = 'DuckDuckGo', + url = 'http://www.duckduckgo.com' +WHERE link.id = 20; +` + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(20)) + + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) +} + +func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ + ID: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link.UPDATE(Link.MutableColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int32(link.ID))) + + var expectedSQL = ` +UPDATE link +SET url = 'http://www.duckduckgo.com', + name = 'DuckDuckGo', + description = NULL +WHERE link.id = 201; +` + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) + testutils.AssertExec(t, stmt, tx) +} + +func TestUpdateWithInvalidModelData(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, r, "missing struct field for column : id") + }() + + link := struct { + Ident int + URL string + Name string + Description *string + Rel *string + }{ + Ident: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link.UPDATE(Link.AllColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int(int64(link.Ident)))) + + stmt.Sql() +} + +func TestUpdateContextDeadlineExceeded(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + updateStmt := Link.UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) + + dest := []model.Link{} + err := updateStmt.QueryContext(ctx, tx, &dest) + require.Error(t, err, "context deadline exceeded") + + _, err = updateStmt.ExecContext(ctx, tx) + require.Error(t, err, "context deadline exceeded") +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go new file mode 100644 index 00000000..f2b623ab --- /dev/null +++ b/tests/sqlite/with_test.go @@ -0,0 +1,234 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestWITH_And_SELECT(t *testing.T) { + salesRep := CTE("sales_rep") + salesRepStaffID := Staff.StaffID.From(salesRep) + salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep) + customerSalesRep := CTE("customer_sales_rep") + + stmt := WITH( + salesRep.AS( + SELECT( + Staff.StaffID, + Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()), + ).FROM(Staff), + ), + customerSalesRep.AS( + SELECT( + Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"), + salesRepFullName, + ).FROM( + salesRep. + INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)). + INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)), + ), + ), + )( + SELECT(customerSalesRep.AllColumns()). + FROM(customerSalesRep), + ) + + testutils.AssertStatementSql(t, stmt, strings.Replace(` +WITH sales_rep AS ( + SELECT staff.staff_id AS "staff.staff_id", + (staff.first_name || staff.last_name) AS "sales_rep_full_name" + FROM staff +),customer_sales_rep AS ( + SELECT (customer.first_name || customer.last_name) AS "customer_name", + sales_rep.sales_rep_full_name AS "sales_rep_full_name" + FROM sales_rep + INNER JOIN store ON (store.manager_staff_id = sales_rep.''staff.staff_id'') + INNER JOIN customer ON (customer.store_id = store.store_id) +) +SELECT customer_sales_rep.customer_name AS "customer_name", + customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name" +FROM customer_sales_rep; +`, "''", "`", -1)) + + var dest []struct { + CustomerName string + SalesRepFullName string + } + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, len(dest), 599) +} + +func TestWITH_And_INSERT(t *testing.T) { + paymentsToInsert := CTE("payments_to_insert") + + stmt := WITH( + paymentsToInsert.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.INSERT(Payment.AllColumns). + QUERY( + SELECT( + paymentsToInsert.AllColumns(), + ).FROM( + paymentsToInsert, + ).WHERE(Bool(true)), //https://stackoverflow.com/questions/66230093/error-while-doing-upsert-in-sqlite-3-34-error-near-do-syntax-error + ).ON_CONFLICT().DO_UPDATE( + SET( + Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))), + ), + ), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +WITH payments_to_insert AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM payment + WHERE payment.amount < 0.5 +) +INSERT INTO payment (payment_id, customer_id, staff_id, rental_id, amount, payment_date, last_update) +SELECT payments_to_insert.''payment.payment_id'' AS "payment.payment_id", + payments_to_insert.''payment.customer_id'' AS "payment.customer_id", + payments_to_insert.''payment.staff_id'' AS "payment.staff_id", + payments_to_insert.''payment.rental_id'' AS "payment.rental_id", + payments_to_insert.''payment.amount'' AS "payment.amount", + payments_to_insert.''payment.payment_date'' AS "payment.payment_date", + payments_to_insert.''payment.last_update'' AS "payment.last_update" +FROM payments_to_insert +WHERE TRUE +ON CONFLICT DO UPDATE + SET payment_id = (payment.payment_id + 100000); +`, "''", "`", -1)) + + tx := beginDBTx(t) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx, 24) +} + +func TestWITH_SELECT_UPDATE(t *testing.T) { + paymentsToUpdate := CTE("payments_to_update") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate) + + stmt := WITH( + paymentsToUpdate.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.UPDATE(). + SET(Payment.Amount.SET(Float(0.0))). + WHERE(Payment.PaymentID.IN( + SELECT(paymentsToDeleteID). + FROM(paymentsToUpdate), + ), + ), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +WITH payments_to_update AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM payment + WHERE payment.amount < 0.5 +) +UPDATE payment +SET amount = 0 +WHERE payment.payment_id IN ( + SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id" + FROM payments_to_update + ); +`, "''", "`", -1)) + + tx := beginDBTx(t) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx) +} + +func TestWITH_And_DELETE(t *testing.T) { + paymentsToDelete := CTE("payments_to_delete") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete) + + stmt := WITH( + paymentsToDelete.AS( + SELECT( + Payment.AllColumns, + ).FROM( + Payment, + ).WHERE( + Payment.Amount.LT(Float(0.5)), + ), + ), + )( + Payment.DELETE(). + WHERE( + Payment.PaymentID.IN( + SELECT( + paymentsToDeleteID, + ).FROM( + paymentsToDelete, + ), + ), + ), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +WITH payments_to_delete AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM payment + WHERE payment.amount < 0.5 +) +DELETE FROM payment +WHERE payment.payment_id IN ( + SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" + FROM payments_to_delete + ); +`, "''", "`", -1)) + + tx := beginDBTx(t) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx, 24) +} + +func TestOperatorIN(t *testing.T) { + stmt := SELECT(Payment.PaymentID.IN(SELECT(Int(11)), Int(22))). + FROM(Payment) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT payment.payment_id IN (( + SELECT 11 + ), 22) +FROM payment; +`) + + var dest []struct{} + err := stmt.Query(db, &dest) + require.NoError(t, err) +} diff --git a/tests/testdata b/tests/testdata index a6c1975a..946bc1e5 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit a6c1975a167645f913496131ae81d4cabc070046 +Subproject commit 946bc1e5d3e162154eade8b79ff915e4c4986efd