Skip to content

Commit

Permalink
Merge pull request #557 from mathijs-dumon/OracleInsertMany
Browse files Browse the repository at this point in the history
Fix for Oracle insert many
  • Loading branch information
ahmad-moussawi authored Oct 5, 2022
2 parents 449959a + 0db74cf commit d9d0441
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 41 deletions.
61 changes: 61 additions & 0 deletions QueryBuilder.Tests/Oracle/OracleInsertManyTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using SqlKata.Compilers;
using SqlKata.Tests.Infrastructure;
using Xunit;

namespace SqlKata.Tests.Oracle
{
public class OracleInsertManyTests : TestSupport
{
private const string TableName = "Table";
private readonly OracleCompiler compiler;

public OracleInsertManyTests()
{
compiler = Compilers.Get<OracleCompiler>(EngineCodes.Oracle);
}

[Fact]
public void InsertManyForOracle_ShouldRepeatColumnsAndAddSelectFromDual()
{
// Arrange:
var cols = new[] { "Name", "Price" };

var data = new[] {
new object[] { "A", 1000 },
new object[] { "B", 2000 },
new object[] { "C", 3000 },
};

var query = new Query(TableName)
.AsInsert(cols, data);


// Act:
var ctx = compiler.Compile(query);

// Assert:
Assert.Equal($@"INSERT ALL INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?) INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?) INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?) SELECT 1 FROM DUAL", ctx.RawSql);
}

[Fact]
public void InsertForOracle_SingleInsertShouldNotAddALLKeywordAndNotHaveSelectFromDual()
{
// Arrange:
var cols = new[] { "Name", "Price" };

var data = new[] {
new object[] { "A", 1000 }
};

var query = new Query(TableName)
.AsInsert(cols, data);


// Act:
var ctx = compiler.Compile(query);

// Assert:
Assert.Equal($@"INSERT INTO ""{TableName}"" (""Name"", ""Price"") VALUES (?, ?)", ctx.RawSql);
}
}
}
90 changes: 49 additions & 41 deletions QueryBuilder/Compilers/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public partial class Compiler
protected virtual string LastId { get; set; } = "";
protected virtual string EscapeCharacter { get; set; } = "\\";


protected virtual string SingleInsertStartClause { get; set; } = "INSERT INTO";
protected virtual string MultiInsertStartClause { get; set; } = "INSERT INTO";


protected Compiler()
{
_compileConditionMethodsProvider = new ConditionsCompilerProvider(this);
Expand Down Expand Up @@ -391,81 +396,84 @@ protected virtual SqlResult CompileInsertQuery(Query query)
};

if (!ctx.Query.HasComponent("from", EngineCode))
{
throw new InvalidOperationException("No table set to insert");
}

var fromClause = ctx.Query.GetOneComponent<AbstractFrom>("from", EngineCode);

if (fromClause is null)
{
throw new InvalidOperationException("Invalid table expression");
}

string table = null;

if (fromClause is FromClause fromClauseCast)
{
table = Wrap(fromClauseCast.Table);
}

if (fromClause is RawFromClause rawFromClause)
{
table = WrapIdentifiers(rawFromClause.Expression);
ctx.Bindings.AddRange(rawFromClause.Bindings);
}

if (table is null)
{
throw new InvalidOperationException("Invalid table expression");
}

var inserts = ctx.Query.GetComponents<AbstractInsertClause>("insert", EngineCode);
if (inserts[0] is InsertQueryClause insertQueryClause)
return CompileInsertQueryClause(ctx, table, insertQueryClause);
else
return CompileValueInsertClauses(ctx, table, inserts.Cast<InsertClause>());
}

if (inserts[0] is InsertClause insertClause)
{
var columns = string.Join(", ", WrapArray(insertClause.Columns));
var values = string.Join(", ", Parameterize(ctx, insertClause.Values));
protected virtual SqlResult CompileInsertQueryClause(
SqlResult ctx, string table, InsertQueryClause clause)
{
string columns = GetInsertColumnsList(clause.Columns);

ctx.RawSql = $"INSERT INTO {table} ({columns}) VALUES ({values})";
var subCtx = CompileSelectQuery(clause.Query);
ctx.Bindings.AddRange(subCtx.Bindings);

if (insertClause.ReturnId && !string.IsNullOrEmpty(LastId))
{
ctx.RawSql += ";" + LastId;
}
}
else
{
var clause = inserts[0] as InsertQueryClause;
ctx.RawSql = $"{SingleInsertStartClause} {table}{columns} {subCtx.RawSql}";

var columns = "";
return ctx;
}

if (clause.Columns.Any())
{
columns = $" ({string.Join(", ", WrapArray(clause.Columns))}) ";
}
protected virtual SqlResult CompileValueInsertClauses(
SqlResult ctx, string table, IEnumerable<InsertClause> insertClauses)
{
bool isMultiValueInsert = insertClauses.Skip(1).Any();

var subCtx = CompileSelectQuery(clause.Query);
ctx.Bindings.AddRange(subCtx.Bindings);
var insertInto = (isMultiValueInsert) ? MultiInsertStartClause : SingleInsertStartClause;

ctx.RawSql = $"INSERT INTO {table}{columns}{subCtx.RawSql}";
}
var firstInsert = insertClauses.First();
string columns = GetInsertColumnsList(firstInsert.Columns);
var values = string.Join(", ", Parameterize(ctx, firstInsert.Values));

if (inserts.Count > 1)
{
foreach (var insert in inserts.GetRange(1, inserts.Count - 1))
{
var clause = insert as InsertClause;
ctx.RawSql = $"{insertInto} {table}{columns} VALUES ({values})";

ctx.RawSql += ", (" + string.Join(", ", Parameterize(ctx, clause.Values)) + ")";
if (isMultiValueInsert)
return CompileRemainingInsertClauses(ctx, table, insertClauses);

}
}
if (firstInsert.ReturnId && !string.IsNullOrEmpty(LastId))
ctx.RawSql += ";" + LastId;

return ctx;
}

protected virtual SqlResult CompileRemainingInsertClauses(SqlResult ctx, string table, IEnumerable<InsertClause> inserts)
{
foreach (var insert in inserts.Skip(1))
{
string values = string.Join(", ", Parameterize(ctx, insert.Values));
ctx.RawSql += $", ({values})";
}
return ctx;
}

protected string GetInsertColumnsList(List<string> columnList)
{
var columns = "";
if (columnList.Any())
columns = $" ({string.Join(", ", WrapArray(columnList))})";

return columns;
}

protected virtual SqlResult CompileCteQuery(SqlResult ctx, Query query)
{
Expand Down
20 changes: 20 additions & 0 deletions QueryBuilder/Compilers/OracleCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text.RegularExpressions;

namespace SqlKata.Compilers
{
Expand All @@ -12,6 +13,7 @@ public OracleCompiler()
ColumnAsKeyword = "";
TableAsKeyword = "";
parameterPrefix = ":p";
MultiInsertStartClause = "INSERT ALL INTO";
}

public override string EngineCode { get; } = EngineCodes.Oracle;
Expand Down Expand Up @@ -152,5 +154,23 @@ protected override string CompileBasicDateCondition(SqlResult ctx, BasicDateCond
return sql;

}

protected override SqlResult CompileRemainingInsertClauses(
SqlResult ctx, string table, IEnumerable<InsertClause> inserts)
{
foreach (var insert in inserts.Skip(1))
{
string columns = GetInsertColumnsList(insert.Columns);
string values = string.Join(", ", Parameterize(ctx, insert.Values));

string intoFormat = " INTO {0}{1} VALUES ({2})";
var nextInsert = string.Format(intoFormat, table, columns, values);

ctx.RawSql += nextInsert;
}

ctx.RawSql += " SELECT 1 FROM DUAL";
return ctx;
}
}
}

0 comments on commit d9d0441

Please sign in to comment.