Skip to content

Commit 0df1b3d

Browse files
feat(sql) Add SQL support for UPDATE, INSERT, UPSERT
1 parent 0eaebbf commit 0df1b3d

File tree

3 files changed

+520
-9
lines changed

3 files changed

+520
-9
lines changed

docs/docs/content/documentation/plugins/sql.md

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,124 @@ employees.sql("SELECT name, salary / 12 AS monthly_salary FROM self")
147147
employees.sql("SELECT name, salary + 5000 AS adjusted_salary FROM self")
148148
```
149149

150+
## UPDATE Statements
151+
152+
Modify existing rows using UPDATE with SET and optional WHERE:
153+
154+
```python
155+
# Update all rows
156+
employees.sql("UPDATE self SET salary = 50000.0")
157+
158+
# Update with WHERE clause
159+
employees.sql("UPDATE self SET salary = 100000.0 WHERE level = 'senior'")
160+
161+
# Update multiple columns
162+
employees.sql("UPDATE self SET salary = 75000.0, level = 'mid' WHERE id = 5")
163+
164+
# Update with expressions (using column values)
165+
employees.sql("UPDATE self SET salary = salary * 1.1 WHERE rating > 4.0")
166+
167+
# Update with complex WHERE
168+
employees.sql("UPDATE self SET salary = salary + 5000.0 WHERE (dept = 'eng' OR dept = 'sales') AND years > 3")
169+
```
170+
171+
**Note:** UPDATE returns the modified table. The original table is not mutated.
172+
173+
## INSERT Statements
174+
175+
Add new rows to a table using INSERT with VALUES:
176+
177+
```python
178+
>>> employees = Table({
179+
... "id": Vector([1, 2], ray_type=I64),
180+
... "name": Vector(["Alice", "Bob"], ray_type=Symbol),
181+
... "salary": Vector([50000.0, 60000.0], ray_type=F64),
182+
... })
183+
184+
# Insert single row with column names
185+
>>> result = employees.sql("INSERT INTO self (id, name, salary) VALUES (3, 'Charlie', 70000.0)")
186+
>>> print(len(result))
187+
3
188+
189+
# Insert multiple rows
190+
>>> result = employees.sql("""
191+
... INSERT INTO self (id, name, salary)
192+
... VALUES (4, 'Diana', 55000.0), (5, 'Eve', 65000.0)
193+
... """)
194+
>>> print(len(result))
195+
5
196+
197+
# Insert without column names (values must match table column order)
198+
>>> result = employees.sql("INSERT INTO self VALUES (6, 'Frank', 72000.0)")
199+
```
200+
201+
**Note:** INSERT returns a new table with the added rows. The original table is not mutated.
202+
203+
## UPSERT (INSERT ... ON CONFLICT)
204+
205+
Perform upsert operations (insert or update) using the `ON CONFLICT` clause:
206+
207+
```python
208+
>>> products = Table({
209+
... "id": Vector([1, 2], ray_type=I64),
210+
... "name": Vector(["Widget", "Gadget"], ray_type=Symbol),
211+
... "price": Vector([10.0, 20.0], ray_type=F64),
212+
... })
213+
214+
# Update existing row (id=1 exists)
215+
>>> result = products.sql("""
216+
... INSERT INTO self (id, name, price)
217+
... VALUES (1, 'Widget Pro', 15.0)
218+
... ON CONFLICT (id) DO UPDATE
219+
... """)
220+
>>> print(result["name"][0].value)
221+
Widget Pro
222+
223+
# Insert new row (id=3 doesn't exist)
224+
>>> result = products.sql("""
225+
... INSERT INTO self (id, name, price)
226+
... VALUES (3, 'Gizmo', 30.0)
227+
... ON CONFLICT (id) DO UPDATE
228+
... """)
229+
>>> print(len(result))
230+
3
231+
232+
# Upsert multiple rows at once
233+
>>> result = products.sql("""
234+
... INSERT INTO self (id, name, price)
235+
... VALUES (1, 'Widget Updated', 12.0), (4, 'Doohickey', 25.0)
236+
... ON CONFLICT (id) DO UPDATE
237+
... """)
238+
```
239+
240+
### Composite Keys
241+
242+
Use multiple columns as the conflict key by listing them in the `ON CONFLICT` clause:
243+
244+
```python
245+
>>> inventory = Table({
246+
... "region": Vector(["US", "EU"], ray_type=Symbol),
247+
... "sku": Vector(["A001", "A001"], ray_type=Symbol),
248+
... "quantity": Vector([100, 50], ray_type=I64),
249+
... })
250+
251+
# Use (region, sku) as composite key
252+
>>> result = inventory.sql("""
253+
... INSERT INTO self (region, sku, quantity)
254+
... VALUES ('US', 'A001', 150)
255+
... ON CONFLICT (region, sku) DO UPDATE
256+
... """)
257+
```
258+
259+
**Important:** The conflict key columns must match the first N columns in the INSERT column list (in order). This is required because Rayforce uses positional key columns.
260+
261+
**Note:** UPSERT returns a new table with the changes applied. The original table is not mutated. `ON CONFLICT DO NOTHING` is not supported.
262+
150263
## Limitations
151264

152265
The current SQL implementation supports common query patterns but has some limitations:
153266

154-
- Only `SELECT` statements are supported (no `INSERT`, `UPDATE`, `DELETE`)
267+
- `DELETE` statements are not supported
155268
- `JOIN` operations are not yet supported via SQL (use the native `.inner_join()`, `.left_join()` methods)
156269
- Subqueries are not supported
157270
- `HAVING` clause is not supported

rayforce/plugins/sql.py

Lines changed: 193 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,45 @@ class ParsedSelect:
7777
order_by: list[tuple[str, bool]] = field(default_factory=list) # (col, is_desc)
7878

7979

80+
@dataclass
81+
class ParsedUpdate:
82+
assignments: dict[str, ParsedExpr]
83+
where_clause: ParsedExpr | None = None
84+
85+
86+
@dataclass
87+
class ParsedInsert:
88+
columns: list[str] | None
89+
values: list[list[ParsedExpr]]
90+
91+
92+
@dataclass
93+
class ParsedUpsert:
94+
columns: list[str] | None
95+
values: list[list[ParsedExpr]]
96+
key_columns: int
97+
98+
99+
ParsedQuery = ParsedSelect | ParsedUpdate | ParsedInsert | ParsedUpsert
100+
101+
80102
class SQLParser:
81-
def parse(self, sql: str) -> ParsedSelect:
103+
def parse(self, sql: str) -> ParsedQuery:
82104
sqlglot = _ensure_sqlglot()
83105
ast = sqlglot.parse_one(sql)
84106

85-
if ast.key != "select":
86-
raise ValueError(f"Only SELECT statements are supported, got: {ast.key}")
87-
88-
return self._parse_select(ast)
107+
if ast.key == "select":
108+
return self._parse_select(ast)
109+
if ast.key == "update":
110+
return self._parse_update(ast)
111+
if ast.key == "insert":
112+
if ast.args.get("conflict"): # on conflict
113+
return self._parse_upsert(ast)
114+
return self._parse_insert(ast)
115+
116+
raise ValueError(
117+
f"Only SELECT, UPDATE, INSERT, and UPSERT statements are supported, got: {ast.key}"
118+
)
89119

90120
def _parse_select(self, node: exp.Select) -> ParsedSelect:
91121
import sqlglot.expressions as exp
@@ -127,6 +157,97 @@ def _parse_select(self, node: exp.Select) -> ParsedSelect:
127157
order_by=order_by,
128158
)
129159

160+
def _parse_update(self, node: exp.Expression) -> ParsedUpdate:
161+
import sqlglot.expressions as exp
162+
163+
assignments: dict[str, ParsedExpr] = {}
164+
for expr in node.expressions:
165+
if isinstance(expr, exp.EQ):
166+
col_name = expr.this.name if isinstance(expr.this, exp.Column) else str(expr.this)
167+
assignments[col_name] = self._parse_expr(expr.expression)
168+
169+
where_clause = None
170+
if node.args.get("where"):
171+
where_clause = self._parse_expr(node.args["where"].this)
172+
173+
return ParsedUpdate(assignments=assignments, where_clause=where_clause)
174+
175+
def _parse_insert(self, node: exp.Expression) -> ParsedInsert:
176+
import sqlglot.expressions as exp
177+
178+
columns: list[str] | None = None
179+
if hasattr(node.this, "expressions") and node.this.expressions:
180+
columns = [
181+
col.name if isinstance(col, exp.Column) else str(col)
182+
for col in node.this.expressions
183+
]
184+
185+
values: list[list[ParsedExpr]] = []
186+
values_clause = node.args.get("expression")
187+
if values_clause and hasattr(values_clause, "expressions"):
188+
for row_tuple in values_clause.expressions:
189+
if hasattr(row_tuple, "expressions"):
190+
row = [self._parse_expr(val) for val in row_tuple.expressions]
191+
values.append(row)
192+
193+
if not values:
194+
raise ValueError("INSERT statement must have VALUES")
195+
196+
return ParsedInsert(columns=columns, values=values)
197+
198+
def _parse_upsert(self, node: exp.Expression) -> ParsedUpsert:
199+
import sqlglot.expressions as exp
200+
201+
columns: list[str] | None = None
202+
if hasattr(node.this, "expressions") and node.this.expressions:
203+
columns = [
204+
col.name if isinstance(col, exp.Column) else str(col)
205+
for col in node.this.expressions
206+
]
207+
208+
values: list[list[ParsedExpr]] = []
209+
values_clause = node.args.get("expression")
210+
if values_clause and hasattr(values_clause, "expressions"):
211+
for row_tuple in values_clause.expressions:
212+
if hasattr(row_tuple, "expressions"):
213+
row = [self._parse_expr(val) for val in row_tuple.expressions]
214+
values.append(row)
215+
216+
if not values:
217+
raise ValueError("UPSERT statement must have VALUES")
218+
219+
conflict = node.args.get("conflict")
220+
conflict_keys: list[str] = []
221+
if conflict:
222+
action = conflict.args.get("action")
223+
if action and str(action) == "DO NOTHING":
224+
raise ValueError("ON CONFLICT DO NOTHING is not supported, use DO UPDATE")
225+
226+
# Get conflict key columns
227+
keys = conflict.args.get("conflict_keys", [])
228+
conflict_keys = [k.name if hasattr(k, "name") else str(k) for k in keys]
229+
230+
# MySQL-style ON DUPLICATE KEY doesn't have conflict_keys
231+
if conflict.args.get("duplicate") and not conflict_keys:
232+
raise ValueError(
233+
"ON DUPLICATE KEY UPDATE requires explicit key columns. "
234+
"Use: ON CONFLICT (col1, col2) DO UPDATE"
235+
)
236+
237+
if not conflict_keys:
238+
raise ValueError("UPSERT requires ON CONFLICT (key_columns) clause")
239+
240+
# Validate that conflict keys match the first N columns
241+
if columns:
242+
for i, key in enumerate(conflict_keys):
243+
if i >= len(columns) or columns[i] != key:
244+
raise ValueError(
245+
f"Conflict key '{key}' must match the first {len(conflict_keys)} columns. "
246+
f"Expected '{columns[i] if i < len(columns) else 'N/A'}' at position {i}"
247+
)
248+
249+
return ParsedUpsert(columns=columns, values=values, key_columns=len(conflict_keys))
250+
130251
def _parse_expr(self, node: exp.Expression) -> ParsedExpr:
131252
import sqlglot.expressions as exp
132253

@@ -240,7 +361,18 @@ def _parse_expr(self, node: exp.Expression) -> ParsedExpr:
240361

241362

242363
class SQLCompiler:
243-
def compile(self, parsed: ParsedSelect, table: Table) -> Table:
364+
def compile(self, parsed: ParsedQuery, table: Table) -> Table:
365+
if isinstance(parsed, ParsedSelect):
366+
return self._compile_select(parsed, table)
367+
if isinstance(parsed, ParsedUpdate):
368+
return self._compile_update(parsed, table)
369+
if isinstance(parsed, ParsedInsert):
370+
return self._compile_insert(parsed, table)
371+
if isinstance(parsed, ParsedUpsert):
372+
return self._compile_upsert(parsed, table)
373+
raise ValueError(f"Unsupported query type: {type(parsed).__name__}")
374+
375+
def _compile_select(self, parsed: ParsedSelect, table: Table) -> Table:
244376
select_args: list[str] = []
245377
select_kwargs: dict[str, t.Any] = {}
246378

@@ -274,6 +406,61 @@ def compile(self, parsed: ParsedSelect, table: Table) -> Table:
274406

275407
return query.execute()
276408

409+
def _compile_update(self, parsed: ParsedUpdate, table: Table) -> Table:
410+
update_kwargs: dict[str, t.Any] = {}
411+
for col_name, expr in parsed.assignments.items():
412+
compiled = self._compile_expr(expr)
413+
update_kwargs[col_name] = compiled
414+
415+
query = table.update(**update_kwargs)
416+
417+
if parsed.where_clause:
418+
where_expr = self._compile_expr(parsed.where_clause)
419+
query = query.where(where_expr)
420+
421+
return query.execute()
422+
423+
def _compile_insert(self, parsed: ParsedInsert, table: Table) -> Table:
424+
compiled_rows: list[list[t.Any]] = []
425+
for row in parsed.values:
426+
compiled_row = [self._compile_expr(val) for val in row]
427+
compiled_rows.append(compiled_row)
428+
429+
if parsed.columns:
430+
# INSERT with column names: use kwargs style
431+
insert_kwargs: dict[str, list[t.Any]] = {col: [] for col in parsed.columns}
432+
for row in compiled_rows:
433+
for i, col in enumerate(parsed.columns):
434+
insert_kwargs[col].append(row[i])
435+
return table.insert(**insert_kwargs).execute()
436+
437+
if len(compiled_rows) == 1:
438+
return table.insert(*compiled_rows[0]).execute()
439+
440+
num_cols = len(compiled_rows[0]) # transpose
441+
col_values = [[row[i] for row in compiled_rows] for i in range(num_cols)]
442+
return table.insert(*col_values).execute()
443+
444+
def _compile_upsert(self, parsed: ParsedUpsert, table: Table) -> Table:
445+
compiled_rows: list[list[t.Any]] = []
446+
for row in parsed.values:
447+
compiled_row = [self._compile_expr(val) for val in row]
448+
compiled_rows.append(compiled_row)
449+
450+
if parsed.columns:
451+
upsert_kwargs: dict[str, list[t.Any]] = {col: [] for col in parsed.columns}
452+
for row in compiled_rows:
453+
for i, col in enumerate(parsed.columns):
454+
upsert_kwargs[col].append(row[i])
455+
return table.upsert(**upsert_kwargs, key_columns=parsed.key_columns).execute()
456+
457+
if len(compiled_rows) == 1:
458+
return table.upsert(*compiled_rows[0], key_columns=parsed.key_columns).execute()
459+
460+
num_cols = len(compiled_rows[0]) # transpose
461+
col_values = [[row[i] for row in compiled_rows] for i in range(num_cols)]
462+
return table.upsert(*col_values, key_columns=parsed.key_columns).execute()
463+
277464
def _compile_expr(self, expr: ParsedExpr) -> t.Any:
278465
from rayforce.types.table import Column
279466

0 commit comments

Comments
 (0)