Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ class Tokenizer(tokens.Tokenizer):
"DATETIME2": TokenType.DATETIME2,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.DECLARE,
"EXEC": TokenType.COMMAND,
"EXEC": TokenType.EXECUTE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"GO": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
Expand All @@ -578,7 +578,7 @@ class Tokenizer(tokens.Tokenizer):
}
KEYWORDS.pop("/*+")

COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.END}
COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.END} - {TokenType.EXECUTE}

class Parser(parser.Parser):
SET_REQUIRES_ASSIGNMENT_DELIMITER = False
Expand Down Expand Up @@ -660,6 +660,7 @@ class Parser(parser.Parser):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.DECLARE: lambda self: self._parse_declare(),
TokenType.EXECUTE: lambda self: self._parse_execute(),
}

RANGE_PARSERS = {
Expand Down Expand Up @@ -703,6 +704,18 @@ class Parser(parser.Parser):
"ts": exp.Timestamp,
}

def _parse_execute(self) -> exp.Execute:
execute = self.expression(
exp.Execute,
this=self._parse_table(schema=True),
expressions=self._parse_csv(self._parse_expression),
)

if execute.name.lower() == "sp_executesql":
execute = self.expression(exp.ExecuteSql, **execute.args)

return execute

def _parse_datepart(self) -> exp.Extract:
this = self._parse_var(tokens=[TokenType.IDENTIFIER])
expression = self._match(TokenType.COMMA) and self._parse_bitwise()
Expand Down Expand Up @@ -866,19 +879,24 @@ def _parse_user_defined_function(
) -> t.Optional[exp.Expression]:
this = super()._parse_user_defined_function(kind=kind)

if (
kind == TokenType.FUNCTION
or isinstance(this, exp.UserDefinedFunction)
or self._match(TokenType.ALIAS, advance=False)
):
if kind == TokenType.FUNCTION or isinstance(this, exp.UserDefinedFunction):
return this

if not self._match(TokenType.WITH, advance=False):
expressions = self._parse_csv(self._parse_function_parameter)
else:
expressions = None
if kind == TokenType.PROCEDURE and this:
expressions = this.expressions
if not (
expressions or self._match_set((TokenType.ALIAS, TokenType.WITH), advance=False)
):
expressions = self._parse_csv(self._parse_function_parameter)

return self.expression(
exp.StoredProcedure,
this=this if isinstance(this, exp.Table) else this.this,
expressions=expressions,
wrapped=this.args.get("wrapped"),
)

return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
return self.expression(exp.UserDefinedFunction, this=this)

def _parse_into(self) -> t.Optional[exp.Into]:
into = super()._parse_into()
Expand Down Expand Up @@ -922,16 +940,13 @@ def _parse_create(self) -> exp.Create | exp.Command:

return create

def _parse_if(self) -> t.Optional[exp.Expression]:
index = self._index
def _parse_if(self) -> exp.IfBlock:
this = self._parse_condition()
true = self._parse_block()

if self._match_text_seq("OBJECT_ID"):
self._parse_wrapped_csv(self._parse_string)
if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP):
return self._parse_drop(exists=True)
self._retreat(index)
false = self._match(TokenType.ELSE) and self._parse_block()

return super()._parse_if()
return self.expression(exp.IfBlock, this=this, true=true, false=false)

def _parse_unique(self) -> exp.UniqueColumnConstraint:
if self._match_texts(("CLUSTERED", "NONCLUSTERED")):
Expand Down
33 changes: 32 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,6 @@ class Create(DDL):
"indexes": False,
"no_schema_binding": False,
"begin": False,
"end": False,
"clone": False,
"concurrently": False,
"clustered": False,
Expand Down Expand Up @@ -8552,6 +8551,38 @@ class Variadic(Expression):
pass


class StoredProcedure(Expression):
arg_types = {"this": True, "expressions": False, "wrapped": False}


class Block(Expression):
arg_types = {"expressions": True}


class IfBlock(Expression):
arg_types = {"this": True, "true": True, "false": False}


class WhileBlock(Expression):
arg_types = {"this": True, "body": True}


class EndStatement(Expression):
arg_types = {}


class Execute(Expression):
arg_types = {"this": True, "expressions": False}

@property
def name(self) -> str:
return self.this.name


class ExecuteSql(Execute):
pass


ALL_FUNCTIONS = subclasses(__name__, Func, {AggFunc, Anonymous, Func})
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}

Expand Down
39 changes: 37 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class Generator(metaclass=_Generator):
exp.DynamicProperty: lambda *_: "DYNAMIC",
exp.EmptyProperty: lambda *_: "EMPTY",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.EndStatement: lambda *_: "END",
exp.EnviromentProperty: lambda self, e: f"ENVIRONMENT ({self.expressions(e, flat=True)})",
exp.EphemeralColumnConstraint: lambda self,
e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}",
Expand Down Expand Up @@ -1218,11 +1219,10 @@ def create_sql(self, expression: exp.Create) -> str:
properties_sql = f" {properties_sql}"

begin = " BEGIN" if expression.args.get("begin") else ""
end = " END" if expression.args.get("end") else ""

expression_sql = self.sql(expression, "expression")
if expression_sql:
expression_sql = f"{begin}{self.sep()}{expression_sql}{end}"
expression_sql = f"{begin}{self.sep()}{expression_sql}"

if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
postalias_props_sql = ""
Expand Down Expand Up @@ -5596,3 +5596,38 @@ def chr_sql(self, expression: exp.Chr, name: str = "CHR") -> str:
charset = self.sql(expression, "charset")
using = f" USING {charset}" if charset else ""
return self.func(name, this + using)

def storedprocedure_sql(self, expression: exp.StoredProcedure) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression)
expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)
return f"{this}{expressions}" if expressions.strip() != "" else this

def ifblock_sql(self, expression: exp.IfBlock) -> str:
this = self.sql(expression, "this")
true = self.sql(expression, "true")
true = f" {true}" if true else " "
false = self.sql(expression, "false")
false = f"; ELSE BEGIN {false}" if false else ""
return f"IF {this} BEGIN{true}{false}"

def whileblock_sql(self, expression: exp.WhileBlock) -> str:
this = self.sql(expression, "this")
body = self.sql(expression, "body")
body = f" {body}" if body else " "
return f"WHILE {this} BEGIN{body}"

def block_sql(self, expression: exp.Block) -> str:
expressions = self.expressions(expression, sep="; ", flat=True)
return f"{expressions}" if expressions else ""

def execute_sql(self, expression: exp.Execute) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression)
expressions = f" {expressions}" if expressions else ""
return f"EXECUTE {this}{expressions}"

def executesql_sql(self, expression: exp.ExecuteSql) -> str:
return self.execute_sql(expression)
94 changes: 77 additions & 17 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,8 @@ def _parse_partitioned_by_bucket_or_truncate(self) -> t.Optional[exp.Expression]
"_prev",
"_prev_comments",
"_pipe_cte_counter",
"_chunks",
"_chunk_index",
)

# Autofilled
Expand Down Expand Up @@ -1738,6 +1740,10 @@ def reset(self):
self._prev_comments = None
self._pipe_cte_counter = 0

# State necessary for parsing imperative SQL
self._chunks: t.List[t.List[Token]] = []
self._chunk_index = 0

def parse(
self, raw_tokens: t.List[Token], sql: t.Optional[str] = None
) -> t.List[t.Optional[exp.Expression]]:
Expand Down Expand Up @@ -1792,6 +1798,38 @@ def parse_into(
errors=merge_errors(errors),
) from errors[-1]

def _parse_batch_statements(
self,
parse_method: t.Callable[[Parser], t.Optional[exp.Expression]],
sep_first_statement: bool = True,
) -> t.List[t.Optional[exp.Expression]]:
expressions = []

# Chunkification binds if/while statements with the first statement of the body
if sep_first_statement:
self._match(TokenType.BEGIN)
expressions.append(parse_method(self))

chunks_length = len(self._chunks)
while self._chunk_index < chunks_length:
self._advance_chunk()

if self._match(TokenType.ELSE, advance=False):
return expressions

if not self._next and self._match(TokenType.END):
expressions.append(exp.EndStatement())
continue

expressions.append(parse_method(self))

if self._index < len(self._tokens):
self.raise_error("Invalid expression / Unexpected token")

self.check_errors()

return expressions

def _parse(
self,
parse_method: t.Callable[[Parser], t.Optional[exp.Expression]],
Expand All @@ -1814,21 +1852,16 @@ def _parse(
else:
chunks[-1].append(token)

expressions = []

for tokens in chunks:
self._index = -1
self._tokens = tokens
self._advance()
self._chunks = chunks

expressions.append(parse_method(self))
expression = self._parse_batch_statements(
parse_method=parse_method, sep_first_statement=False
)

if self._index < len(self._tokens):
self.raise_error("Invalid expression / Unexpected token")
if expression and len(expression) > 1:
expression = [exp.Block(expressions=expression)]

self.check_errors()

return expressions
return expression

def check_errors(self) -> None:
"""Logs or raises any found errors, depending on the chosen error level setting."""
Expand Down Expand Up @@ -1935,6 +1968,12 @@ def _advance(self, times: int = 1) -> None:
self._prev = None
self._prev_comments = None

def _advance_chunk(self) -> None:
self._index = -1
self._tokens = self._chunks[self._chunk_index]
self._chunk_index += 1
self._advance()

def _retreat(self, index: int) -> None:
if index != self._index:
self._advance(index - self._index)
Expand Down Expand Up @@ -2056,6 +2095,24 @@ def _parse_ttl_action() -> t.Optional[exp.Expression]:
aggregates=aggregates,
)

def _parse_condition(self) -> t.Any:
return self._parse_wrapped(parse_method=self._parse_expression, optional=True)

def _parse_block(self) -> exp.Block:
return self.expression(
exp.Block,
expressions=self._parse_batch_statements(
parse_method=lambda self: self._parse_statement()
),
)

def _parse_whileblock(self) -> exp.WhileBlock:
return self.expression(
exp.WhileBlock,
this=self._parse_condition(),
body=self._parse_block(),
)

def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
Expand All @@ -2069,6 +2126,9 @@ def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._match_set(self.dialect.tokenizer_class.COMMANDS):
return self._parse_command()

if self._match_text_seq("WHILE"):
return self._parse_whileblock()

expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
return self._parse_query_modifiers(expression)
Expand Down Expand Up @@ -2164,7 +2224,6 @@ def _parse_create(self) -> exp.Create | exp.Command:
indexes = None
no_schema_binding = None
begin = None
end = None
clone = None

def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
Expand Down Expand Up @@ -2196,9 +2255,11 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
expression = self._parse_string()
extend_props(self._parse_properties())
else:
expression = self._parse_user_defined_function_expression()

end = self._match_text_seq("END")
expression = (
self._parse_user_defined_function_expression()
if create_token.token_type == TokenType.FUNCTION
else self._parse_block()
)

if return_:
expression = self.expression(exp.Return, this=expression)
Expand Down Expand Up @@ -2305,7 +2366,6 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
end=end,
clone=clone,
concurrently=concurrently,
clustered=clustered,
Expand Down
Loading