@@ -77,15 +77,45 @@ class ParsedSelect:
7777 order_by : list [tuple [str , bool ]] = field (default_factory = list ) # (col, is_desc)
7878
7979
80+ @dataclass
81+ class ParsedUpdate :
82+ assignments : dict [str , ParsedExpr ]
83+ where_clause : ParsedExpr | None = None
84+
85+
86+ @dataclass
87+ class ParsedInsert :
88+ columns : list [str ] | None
89+ values : list [list [ParsedExpr ]]
90+
91+
92+ @dataclass
93+ class ParsedUpsert :
94+ columns : list [str ] | None
95+ values : list [list [ParsedExpr ]]
96+ key_columns : int
97+
98+
99+ ParsedQuery = ParsedSelect | ParsedUpdate | ParsedInsert | ParsedUpsert
100+
101+
80102class SQLParser :
81- def parse (self , sql : str ) -> ParsedSelect :
103+ def parse (self , sql : str ) -> ParsedQuery :
82104 sqlglot = _ensure_sqlglot ()
83105 ast = sqlglot .parse_one (sql )
84106
85- if ast .key != "select" :
86- raise ValueError (f"Only SELECT statements are supported, got: { ast .key } " )
87-
88- return self ._parse_select (ast )
107+ if ast .key == "select" :
108+ return self ._parse_select (ast )
109+ if ast .key == "update" :
110+ return self ._parse_update (ast )
111+ if ast .key == "insert" :
112+ if ast .args .get ("conflict" ): # on conflict
113+ return self ._parse_upsert (ast )
114+ return self ._parse_insert (ast )
115+
116+ raise ValueError (
117+ f"Only SELECT, UPDATE, INSERT, and UPSERT statements are supported, got: { ast .key } "
118+ )
89119
90120 def _parse_select (self , node : exp .Select ) -> ParsedSelect :
91121 import sqlglot .expressions as exp
@@ -127,6 +157,97 @@ def _parse_select(self, node: exp.Select) -> ParsedSelect:
127157 order_by = order_by ,
128158 )
129159
160+ def _parse_update (self , node : exp .Expression ) -> ParsedUpdate :
161+ import sqlglot .expressions as exp
162+
163+ assignments : dict [str , ParsedExpr ] = {}
164+ for expr in node .expressions :
165+ if isinstance (expr , exp .EQ ):
166+ col_name = expr .this .name if isinstance (expr .this , exp .Column ) else str (expr .this )
167+ assignments [col_name ] = self ._parse_expr (expr .expression )
168+
169+ where_clause = None
170+ if node .args .get ("where" ):
171+ where_clause = self ._parse_expr (node .args ["where" ].this )
172+
173+ return ParsedUpdate (assignments = assignments , where_clause = where_clause )
174+
175+ def _parse_insert (self , node : exp .Expression ) -> ParsedInsert :
176+ import sqlglot .expressions as exp
177+
178+ columns : list [str ] | None = None
179+ if hasattr (node .this , "expressions" ) and node .this .expressions :
180+ columns = [
181+ col .name if isinstance (col , exp .Column ) else str (col )
182+ for col in node .this .expressions
183+ ]
184+
185+ values : list [list [ParsedExpr ]] = []
186+ values_clause = node .args .get ("expression" )
187+ if values_clause and hasattr (values_clause , "expressions" ):
188+ for row_tuple in values_clause .expressions :
189+ if hasattr (row_tuple , "expressions" ):
190+ row = [self ._parse_expr (val ) for val in row_tuple .expressions ]
191+ values .append (row )
192+
193+ if not values :
194+ raise ValueError ("INSERT statement must have VALUES" )
195+
196+ return ParsedInsert (columns = columns , values = values )
197+
198+ def _parse_upsert (self , node : exp .Expression ) -> ParsedUpsert :
199+ import sqlglot .expressions as exp
200+
201+ columns : list [str ] | None = None
202+ if hasattr (node .this , "expressions" ) and node .this .expressions :
203+ columns = [
204+ col .name if isinstance (col , exp .Column ) else str (col )
205+ for col in node .this .expressions
206+ ]
207+
208+ values : list [list [ParsedExpr ]] = []
209+ values_clause = node .args .get ("expression" )
210+ if values_clause and hasattr (values_clause , "expressions" ):
211+ for row_tuple in values_clause .expressions :
212+ if hasattr (row_tuple , "expressions" ):
213+ row = [self ._parse_expr (val ) for val in row_tuple .expressions ]
214+ values .append (row )
215+
216+ if not values :
217+ raise ValueError ("UPSERT statement must have VALUES" )
218+
219+ conflict = node .args .get ("conflict" )
220+ conflict_keys : list [str ] = []
221+ if conflict :
222+ action = conflict .args .get ("action" )
223+ if action and str (action ) == "DO NOTHING" :
224+ raise ValueError ("ON CONFLICT DO NOTHING is not supported, use DO UPDATE" )
225+
226+ # Get conflict key columns
227+ keys = conflict .args .get ("conflict_keys" , [])
228+ conflict_keys = [k .name if hasattr (k , "name" ) else str (k ) for k in keys ]
229+
230+ # MySQL-style ON DUPLICATE KEY doesn't have conflict_keys
231+ if conflict .args .get ("duplicate" ) and not conflict_keys :
232+ raise ValueError (
233+ "ON DUPLICATE KEY UPDATE requires explicit key columns. "
234+ "Use: ON CONFLICT (col1, col2) DO UPDATE"
235+ )
236+
237+ if not conflict_keys :
238+ raise ValueError ("UPSERT requires ON CONFLICT (key_columns) clause" )
239+
240+ # Validate that conflict keys match the first N columns
241+ if columns :
242+ for i , key in enumerate (conflict_keys ):
243+ if i >= len (columns ) or columns [i ] != key :
244+ raise ValueError (
245+ f"Conflict key '{ key } ' must match the first { len (conflict_keys )} columns. "
246+ f"Expected '{ columns [i ] if i < len (columns ) else 'N/A' } ' at position { i } "
247+ )
248+
249+ return ParsedUpsert (columns = columns , values = values , key_columns = len (conflict_keys ))
250+
130251 def _parse_expr (self , node : exp .Expression ) -> ParsedExpr :
131252 import sqlglot .expressions as exp
132253
@@ -240,7 +361,18 @@ def _parse_expr(self, node: exp.Expression) -> ParsedExpr:
240361
241362
242363class SQLCompiler :
243- def compile (self , parsed : ParsedSelect , table : Table ) -> Table :
364+ def compile (self , parsed : ParsedQuery , table : Table ) -> Table :
365+ if isinstance (parsed , ParsedSelect ):
366+ return self ._compile_select (parsed , table )
367+ if isinstance (parsed , ParsedUpdate ):
368+ return self ._compile_update (parsed , table )
369+ if isinstance (parsed , ParsedInsert ):
370+ return self ._compile_insert (parsed , table )
371+ if isinstance (parsed , ParsedUpsert ):
372+ return self ._compile_upsert (parsed , table )
373+ raise ValueError (f"Unsupported query type: { type (parsed ).__name__ } " )
374+
375+ def _compile_select (self , parsed : ParsedSelect , table : Table ) -> Table :
244376 select_args : list [str ] = []
245377 select_kwargs : dict [str , t .Any ] = {}
246378
@@ -274,6 +406,61 @@ def compile(self, parsed: ParsedSelect, table: Table) -> Table:
274406
275407 return query .execute ()
276408
409+ def _compile_update (self , parsed : ParsedUpdate , table : Table ) -> Table :
410+ update_kwargs : dict [str , t .Any ] = {}
411+ for col_name , expr in parsed .assignments .items ():
412+ compiled = self ._compile_expr (expr )
413+ update_kwargs [col_name ] = compiled
414+
415+ query = table .update (** update_kwargs )
416+
417+ if parsed .where_clause :
418+ where_expr = self ._compile_expr (parsed .where_clause )
419+ query = query .where (where_expr )
420+
421+ return query .execute ()
422+
423+ def _compile_insert (self , parsed : ParsedInsert , table : Table ) -> Table :
424+ compiled_rows : list [list [t .Any ]] = []
425+ for row in parsed .values :
426+ compiled_row = [self ._compile_expr (val ) for val in row ]
427+ compiled_rows .append (compiled_row )
428+
429+ if parsed .columns :
430+ # INSERT with column names: use kwargs style
431+ insert_kwargs : dict [str , list [t .Any ]] = {col : [] for col in parsed .columns }
432+ for row in compiled_rows :
433+ for i , col in enumerate (parsed .columns ):
434+ insert_kwargs [col ].append (row [i ])
435+ return table .insert (** insert_kwargs ).execute ()
436+
437+ if len (compiled_rows ) == 1 :
438+ return table .insert (* compiled_rows [0 ]).execute ()
439+
440+ num_cols = len (compiled_rows [0 ]) # transpose
441+ col_values = [[row [i ] for row in compiled_rows ] for i in range (num_cols )]
442+ return table .insert (* col_values ).execute ()
443+
444+ def _compile_upsert (self , parsed : ParsedUpsert , table : Table ) -> Table :
445+ compiled_rows : list [list [t .Any ]] = []
446+ for row in parsed .values :
447+ compiled_row = [self ._compile_expr (val ) for val in row ]
448+ compiled_rows .append (compiled_row )
449+
450+ if parsed .columns :
451+ upsert_kwargs : dict [str , list [t .Any ]] = {col : [] for col in parsed .columns }
452+ for row in compiled_rows :
453+ for i , col in enumerate (parsed .columns ):
454+ upsert_kwargs [col ].append (row [i ])
455+ return table .upsert (** upsert_kwargs , key_columns = parsed .key_columns ).execute ()
456+
457+ if len (compiled_rows ) == 1 :
458+ return table .upsert (* compiled_rows [0 ], key_columns = parsed .key_columns ).execute ()
459+
460+ num_cols = len (compiled_rows [0 ]) # transpose
461+ col_values = [[row [i ] for row in compiled_rows ] for i in range (num_cols )]
462+ return table .upsert (* col_values , key_columns = parsed .key_columns ).execute ()
463+
277464 def _compile_expr (self , expr : ParsedExpr ) -> t .Any :
278465 from rayforce .types .table import Column
279466
0 commit comments