Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse arguments of stored procedures to make them available in DOM #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion sqlparser/dom.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package sqlparser

import (
"fmt"
"gopkg.in/yaml.v3"
"io"
"strings"

"gopkg.in/yaml.v3"
)

type Unparsed struct {
Expand Down Expand Up @@ -59,12 +60,35 @@ func (p PosString) String() string {
return p.Value
}

type Parameter struct {
Start Pos
Stop Pos

VariableName string
Datatype Type

// Attributes only relevant for procedures:
DefaultValue Unparsed
IsReadonly bool
IsOutput bool
}

func (p Parameter) WithoutPos() (result Parameter) {
result = p
result.Start = Pos{}
result.Stop = Pos{}
result.DefaultValue.Start = Pos{}
result.DefaultValue.Stop = Pos{}
return
}

type Create struct {
CreateType string // "procedure", "function" or "type"
QuotedName PosString // proc/func/type name, including []
Body []Unparsed
DependsOn []PosString
Docstring []PosString // comment lines before the create statement. Note: this is also part of Body
Parameters []Parameter
}

func (c Create) DocstringAsString() string {
Expand Down Expand Up @@ -100,9 +124,14 @@ func (c Create) ParseYamlInDocstring(out any) error {
return yaml.Unmarshal([]byte(yamldoc), out)
}

// Type indicates the type of a parameter. It can either be a basic type in which case BaseType and Args are set;
// or a table type in which case TableTypeSchema and TableTypeName is set.
type Type struct {
BaseType string
Args []string

TableTypeSchema string
TableTypeName string
}

func (t Type) String() (result string) {
Expand Down Expand Up @@ -165,11 +194,16 @@ func (c Create) WithoutPos() Create {
for _, x := range c.Body {
body = append(body, x.WithoutPos())
}
var parameters []Parameter
for _, x := range c.Parameters {
parameters = append(parameters, x.WithoutPos())
}
return Create{
CreateType: c.CreateType,
QuotedName: c.QuotedName,
DependsOn: c.DependsOn,
Body: body,
Parameters: parameters,
}
}

Expand Down
161 changes: 124 additions & 37 deletions sqlparser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,6 @@ func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) {

}

// AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return
// Note: The 'go' and EOF tokens are *not* copied
func AdvanceAndCopy(s *Scanner, target *[]Unparsed) {
for {
tt := s.NextToken()
switch tt {
case EOFToken, BatchSeparatorToken:
// do not copy
return
case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken:
// copy, and loop around
CopyToken(s, target)
continue
default:
// copy, and return
CopyToken(s, target)
return
}
}
}

func CreateUnparsed(s *Scanner) Unparsed {
return Unparsed{
Type: s.TokenType(),
Expand All @@ -80,27 +59,29 @@ func (d *Document) unexpectedTokenError(s *Scanner) {
d.addError(s, "Unexpected: "+s.Token())
}

func (doc *Document) parseTypeExpression(s *Scanner) (t Type) {
func (doc *Document) parseTypeExpression(s *Scanner, allowTableTypes bool, target *[]Unparsed) (result Type) {
parseArgs := func() {
// parses *after* the initial (; consumes trailing )
for {
CopyToken(s, target)
switch {
case s.TokenType() == NumberToken:
t.Args = append(t.Args, s.Token())
result.Args = append(result.Args, s.Token())
case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max":
t.Args = append(t.Args, "max")
result.Args = append(result.Args, "max")
default:
doc.unexpectedTokenError(s)
doc.recoverToNextStatement(s)
return
}
s.NextNonWhitespaceCommentToken()
NextTokenCopyingWhitespace(s, target)
CopyToken(s, target)
switch {
case s.TokenType() == CommaToken:
s.NextNonWhitespaceCommentToken()
NextTokenCopyingWhitespace(s, target)
continue
case s.TokenType() == RightParenToken:
s.NextNonWhitespaceCommentToken()
NextTokenCopyingWhitespace(s, target)
return
default:
doc.unexpectedTokenError(s)
Expand All @@ -110,14 +91,37 @@ func (doc *Document) parseTypeExpression(s *Scanner) (t Type) {
}
}

if s.TokenType() != UnquotedIdentifierToken {
panic("assertion failed, bug in caller")
if s.TokenType() != UnquotedIdentifierToken && s.TokenType() != QuotedIdentifierToken {
doc.addError(s, "expected type, got: "+s.Token())
return
}
t.BaseType = s.Token()
s.NextNonWhitespaceCommentToken()
if s.TokenType() == LeftParenToken {
s.NextNonWhitespaceCommentToken()
parseArgs()
// We will assume that a table type will have a schema name; types in 'default schema' we just don'result support.
// So an identifier followed by a `.` indicates table type.
firstToken := s.Token()
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)

if s.TokenType() == DotToken {
if !allowTableTypes {
doc.addError(s, "expected basic type (no table types), got: .")
return
}
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)

// parse a table type
result.TableTypeSchema = firstToken
result.TableTypeName = s.Token()
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)
} else {
// parse a basic type
result.BaseType = firstToken
if s.TokenType() == LeftParenToken {
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)
parseArgs()
}
}
return
}
Expand Down Expand Up @@ -146,7 +150,10 @@ loop:
doc.addError(s, "sqlcode constants needs a type declared explicitly")
s.NextNonWhitespaceCommentToken()
case UnquotedIdentifierToken:
variableType = doc.parseTypeExpression(s)
// parseTypeExpression is also used in a context where we are copying Unparsed nodes into stored procedure body;
// to use it here too just use a dummy output
var dummy []Unparsed
variableType = doc.parseTypeExpression(s, false, &dummy)
}

if s.TokenType() != EqualToken {
Expand Down Expand Up @@ -366,6 +373,81 @@ func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString
}
}

func (d *Document) parseArgumentList(s *Scanner, target *[]Unparsed) (result []Parameter) {
if s.TokenType() != LeftParenToken {
panic("assertion failed: should only be called on the ( position")
}
// Copy the `(`
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)

for s.TokenType() != RightParenToken {
var parameter Parameter

// `@parameter`
if s.TokenType() != VariableIdentifierToken {
d.addError(s, "expected a parameter name starting with @, got: "+s.Token())
return
}

parameter.Start = s.Start()
parameter.VariableName = s.Token()
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)

// datatype. This can either be a table type or a basic type...
parameter.Datatype = d.parseTypeExpression(s, true, target)

// Do we have a default value?
if s.TokenType() == EqualToken {
// Default value. AFAICT this can only be a single literal, not a full expression
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)
switch s.TokenType() {
case NVarcharLiteralToken, VarcharLiteralToken, NumberToken:
parameter.DefaultValue = CreateUnparsed(s)
default:
d.addError(s, "expecting default value literal, got: "+s.Token())
return
}
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)
}

// Do we have an option? This can be *either* readonly or output, both would not be relevant on the same
if s.TokenType() == UnquotedIdentifierToken {
// readonly or output
switch s.TokenLower() {
case "readonly":
parameter.IsReadonly = true
case "output":
parameter.IsOutput = true
default:
d.addError(s, "parsing argument list, unexpected: "+s.Token())
return
}
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)
}

// At this point we should have a comma or a right paren...
switch s.TokenType() {
case CommaToken:
CopyToken(s, target)
NextTokenCopyingWhitespace(s, target)
// Trailing comma won't be an error in this parser; but SQL will complain later..
case RightParenToken:
// fall through to break out of loop
default:
d.addError(s, "parsing argument list, unexpected: "+s.Token())
return
}

result = append(result, parameter)
}
return
}

// parseCreate parses anything that starts with "create". Position is
// *on* the create token.
// At this stage in sqlcode parser development we're only interested
Expand Down Expand Up @@ -411,8 +493,13 @@ func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Creat
return
}

// We have matched "create <createType> [code].<quotedName>"; at this
// point we copy the rest until the batch ends; *but* track dependencies
// We have matched "create <createType> [code].<quotedName>". Try to parse
// parameters. Procedures do not need an argument list, so only do this if we see ()
if createType == "procedure" && s.tokenType == LeftParenToken {
result.Parameters = d.parseArgumentList(s, &result.Body)
}

// At this point we copy the rest until the batch ends; *but* track dependencies
// + some other details mentioned below

tailloop:
Expand Down
69 changes: 66 additions & 3 deletions sqlparser/parser_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package sqlparser

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParserSmokeTest(t *testing.T) {
Expand Down Expand Up @@ -271,7 +272,6 @@ create procedure [code].FirstProc as table (x int)
assert.Equal(t, emsg, doc.Errors[0].Message)
}


func TestGoWithoutNewline(t *testing.T) {
doc := ParseString("test.sql", `
create procedure [code].Foo() as begin
Expand Down Expand Up @@ -352,3 +352,66 @@ create procedure [code].Foo as begin end
err.Error())

}

func TestProcedureArgs(t *testing.T) {
doc := parseAndVerifyCreate(t, "test.sql", `
create procedure [code].Foo
(
@a bigint,

@b varchar(max) = N'asdfas

lkjlkjlkjasdf'

output,
@c [code].[something:asf asdf -- as
df/ MyTableType] readonly,
@d numeric(1,2),@e tinyint
) as begin end
`)
create := doc.Creates[0].WithoutPos()

assert.Equal(t, []Parameter{
{
VariableName: "@a",
Datatype: Type{BaseType: "bigint"},
},
{
VariableName: "@b",
Datatype: Type{BaseType: "varchar", Args: []string{"max"}},
DefaultValue: Unparsed{Type: NVarcharLiteralToken, RawValue: "N'asdfas\n\nlkjlkjlkjasdf'"},
IsOutput: true,
},
{
VariableName: "@c",
Datatype: Type{
TableTypeSchema: "[code]",
TableTypeName: "[something:asf asdf -- as\ndf/ MyTableType]",
},
IsReadonly: true,
},
{
VariableName: "@d",
Datatype: Type{
BaseType: "numeric",
Args: []string{"1", "2"},
},
},
{
VariableName: "@e",
Datatype: Type{
BaseType: "tinyint",
},
},
}, create.Parameters)

}

// parseAndVerifyCreate expects to parse a single `create` statement, and verifies that serializing
// it back produces the same string.
func parseAndVerifyCreate(t *testing.T, filename FileRef, createStatement string) Document {
doc := ParseString(filename, createStatement)
require.Equal(t, 1, len(doc.Creates))
require.Equal(t, strings.TrimLeft(createStatement, "\n "), doc.Creates[0].String())
return doc
}