From 0db74cff5dfbc0a5ac0d926f40327b40bf94704f Mon Sep 17 00:00:00 2001 From: Mathijs Dumon Date: Thu, 24 Feb 2022 15:07:45 +0100 Subject: [PATCH] Fix for Oracle insert many --- .../Oracle/OracleInsertManyTests.cs | 61 +++++++++++++ QueryBuilder/Compilers/Compiler.cs | 90 ++++++++++--------- QueryBuilder/Compilers/OracleCompiler.cs | 20 +++++ 3 files changed, 130 insertions(+), 41 deletions(-) create mode 100644 QueryBuilder.Tests/Oracle/OracleInsertManyTests.cs diff --git a/QueryBuilder.Tests/Oracle/OracleInsertManyTests.cs b/QueryBuilder.Tests/Oracle/OracleInsertManyTests.cs new file mode 100644 index 00000000..f25cf2ba --- /dev/null +++ b/QueryBuilder.Tests/Oracle/OracleInsertManyTests.cs @@ -0,0 +1,61 @@ +using SqlKata.Compilers; +using SqlKata.Tests.Infrastructure; +using Xunit; + +namespace SqlKata.Tests.Oracle +{ + public class OracleInsertManyTests : TestSupport + { + private const string TableName = "Table"; + private readonly OracleCompiler compiler; + + public OracleInsertManyTests() + { + compiler = Compilers.Get(EngineCodes.Oracle); + } + + [Fact] + public void InsertManyForOracle_ShouldRepeatColumnsAndAddSelectFromDual() + { + // Arrange: + var cols = new[] { "Name", "Price" }; + + var data = new[] { + new object[] { "A", 1000 }, + new object[] { "B", 2000 }, + new object[] { "C", 3000 }, + }; + + var query = new Query(TableName) + .AsInsert(cols, data); + + + // Act: + var ctx = compiler.Compile(query); + + // Assert: + Assert.Equal($@"INSERT ALL INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?) INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?) INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?) SELECT 1 FROM DUAL", ctx.RawSql); + } + + [Fact] + public void InsertForOracle_SingleInsertShouldNotAddALLKeywordAndNotHaveSelectFromDual() + { + // Arrange: + var cols = new[] { "Name", "Price" }; + + var data = new[] { + new object[] { "A", 1000 } + }; + + var query = new Query(TableName) + .AsInsert(cols, data); + + + // Act: + var ctx = compiler.Compile(query); + + // Assert: + Assert.Equal($@"INSERT INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?)", ctx.RawSql); + } + } +} diff --git a/QueryBuilder/Compilers/Compiler.cs b/QueryBuilder/Compilers/Compiler.cs index a71b0743..aec2405f 100644 --- a/QueryBuilder/Compilers/Compiler.cs +++ b/QueryBuilder/Compilers/Compiler.cs @@ -17,6 +17,11 @@ public partial class Compiler protected virtual string LastId { get; set; } = ""; protected virtual string EscapeCharacter { get; set; } = "\\"; + + protected virtual string SingleInsertStartClause { get; set; } = "INSERT INTO"; + protected virtual string MultiInsertStartClause { get; set; } = "INSERT INTO"; + + protected Compiler() { _compileConditionMethodsProvider = new ConditionsCompilerProvider(this); @@ -361,24 +366,15 @@ protected virtual SqlResult CompileInsertQuery(Query query) }; if (!ctx.Query.HasComponent("from", EngineCode)) - { throw new InvalidOperationException("No table set to insert"); - } var fromClause = ctx.Query.GetOneComponent("from", EngineCode); - if (fromClause is null) - { throw new InvalidOperationException("Invalid table expression"); - } string table = null; - if (fromClause is FromClause fromClauseCast) - { table = Wrap(fromClauseCast.Table); - } - if (fromClause is RawFromClause rawFromClause) { table = WrapIdentifiers(rawFromClause.Expression); @@ -386,56 +382,68 @@ protected virtual SqlResult CompileInsertQuery(Query query) } if (table is null) - { throw new InvalidOperationException("Invalid table expression"); - } var inserts = ctx.Query.GetComponents("insert", EngineCode); + if (inserts[0] is InsertQueryClause insertQueryClause) + return CompileInsertQueryClause(ctx, table, insertQueryClause); + else + return CompileValueInsertClauses(ctx, table, inserts.Cast()); + } - if (inserts[0] is InsertClause insertClause) - { - var columns = string.Join(", ", WrapArray(insertClause.Columns)); - var values = string.Join(", ", Parameterize(ctx, insertClause.Values)); + protected virtual SqlResult CompileInsertQueryClause( + SqlResult ctx, string table, InsertQueryClause clause) + { + string columns = GetInsertColumnsList(clause.Columns); - ctx.RawSql = $"INSERT INTO {table} ({columns}) VALUES ({values})"; + var subCtx = CompileSelectQuery(clause.Query); + ctx.Bindings.AddRange(subCtx.Bindings); - if (insertClause.ReturnId && !string.IsNullOrEmpty(LastId)) - { - ctx.RawSql += ";" + LastId; - } - } - else - { - var clause = inserts[0] as InsertQueryClause; + ctx.RawSql = $"{SingleInsertStartClause} {table}{columns} {subCtx.RawSql}"; - var columns = ""; + return ctx; + } - if (clause.Columns.Any()) - { - columns = $" ({string.Join(", ", WrapArray(clause.Columns))}) "; - } + protected virtual SqlResult CompileValueInsertClauses( + SqlResult ctx, string table, IEnumerable insertClauses) + { + bool isMultiValueInsert = insertClauses.Skip(1).Any(); - var subCtx = CompileSelectQuery(clause.Query); - ctx.Bindings.AddRange(subCtx.Bindings); + var insertInto = (isMultiValueInsert) ? MultiInsertStartClause : SingleInsertStartClause; - ctx.RawSql = $"INSERT INTO {table}{columns}{subCtx.RawSql}"; - } + var firstInsert = insertClauses.First(); + string columns = GetInsertColumnsList(firstInsert.Columns); + var values = string.Join(", ", Parameterize(ctx, firstInsert.Values)); - if (inserts.Count > 1) - { - foreach (var insert in inserts.GetRange(1, inserts.Count - 1)) - { - var clause = insert as InsertClause; + ctx.RawSql = $"{insertInto} {table}{columns} VALUES ({values})"; - ctx.RawSql += ", (" + string.Join(", ", Parameterize(ctx, clause.Values)) + ")"; + if (isMultiValueInsert) + return CompileRemainingInsertClauses(ctx, table, insertClauses); - } - } + if (firstInsert.ReturnId && !string.IsNullOrEmpty(LastId)) + ctx.RawSql += ";" + LastId; + return ctx; + } + protected virtual SqlResult CompileRemainingInsertClauses(SqlResult ctx, string table, IEnumerable inserts) + { + foreach (var insert in inserts.Skip(1)) + { + string values = string.Join(", ", Parameterize(ctx, insert.Values)); + ctx.RawSql += $", ({values})"; + } return ctx; } + protected string GetInsertColumnsList(List columnList) + { + var columns = ""; + if (columnList.Any()) + columns = $" ({string.Join(", ", WrapArray(columnList))})"; + + return columns; + } protected virtual SqlResult CompileCteQuery(SqlResult ctx, Query query) { diff --git a/QueryBuilder/Compilers/OracleCompiler.cs b/QueryBuilder/Compilers/OracleCompiler.cs index a48a13e7..5bebb983 100644 --- a/QueryBuilder/Compilers/OracleCompiler.cs +++ b/QueryBuilder/Compilers/OracleCompiler.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Text.RegularExpressions; namespace SqlKata.Compilers { @@ -12,6 +13,7 @@ public OracleCompiler() ColumnAsKeyword = ""; TableAsKeyword = ""; parameterPrefix = ":p"; + MultiInsertStartClause = "INSERT ALL INTO"; } public override string EngineCode { get; } = EngineCodes.Oracle; @@ -152,5 +154,23 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond return sql; } + + protected override SqlResult CompileRemainingInsertClauses( + SqlResult ctx, string table, IEnumerable inserts) + { + foreach (var insert in inserts.Skip(1)) + { + string columns = GetInsertColumnsList(insert.Columns); + string values = string.Join(", ", Parameterize(ctx, insert.Values)); + + string intoFormat = " INTO {0}{1} VALUES ({2})"; + var nextInsert = string.Format(intoFormat, table, columns, values); + + ctx.RawSql += nextInsert; + } + + ctx.RawSql += " SELECT 1 FROM DUAL"; + return ctx; + } } }