Skip to content

Commit ca5abd1

Browse files
feat(table) - Add pivot() method
1 parent ddbac63 commit ca5abd1

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

rayforce/types/table.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,15 @@ def window_join1(
562562
) -> WindowJoin1:
563563
return WindowJoin1(t.cast("_TableProtocol", self), on, join_with, interval, **aggregations)
564564

565+
def pivot(
566+
self,
567+
index: str | list[str],
568+
columns: str,
569+
values: str,
570+
aggfunc: t.Literal["sum", "count", "avg", "min", "max"] = "min",
571+
) -> Table:
572+
return PivotQuery(t.cast("_TableProtocol", self), index, columns, values, aggfunc).execute()
573+
565574

566575
class Table(
567576
TableInitMixin,
@@ -1009,6 +1018,61 @@ def execute(self) -> Table:
10091018
return Table(new_table)
10101019

10111020

1021+
class PivotQuery:
1022+
AGGFUNC_MAP: t.ClassVar[dict[str, Operation]] = {
1023+
"sum": Operation.SUM,
1024+
"count": Operation.COUNT,
1025+
"avg": Operation.AVG,
1026+
"min": Operation.MIN,
1027+
"max": Operation.MAX,
1028+
}
1029+
1030+
def __init__(
1031+
self,
1032+
table: _TableProtocol,
1033+
index: str | list[str],
1034+
columns: str,
1035+
values: str,
1036+
aggfunc: str = "min",
1037+
) -> None:
1038+
if aggfunc not in self.AGGFUNC_MAP:
1039+
raise errors.RayforceValueError(
1040+
f"Invalid aggfunc '{aggfunc}'. Must be one of: {list(self.AGGFUNC_MAP.keys())}"
1041+
)
1042+
self.table = t.cast("Table", table)
1043+
self.index = [index] if isinstance(index, str) else list(index)
1044+
self.columns = columns
1045+
self.values = values
1046+
self.aggfunc = aggfunc
1047+
1048+
def execute(self) -> Table:
1049+
distinct = self.table.select(_col=Column(self.columns).distinct()).execute()
1050+
unique_values = [v.value if hasattr(v, "value") else v for v in distinct["_col"]]
1051+
if not unique_values:
1052+
raise errors.RayforceValueError(f"No values in pivot column '{self.columns}'")
1053+
1054+
tables: list[Table] = []
1055+
for val in unique_values:
1056+
filtered = (
1057+
self.table.select(*self.index, self.values)
1058+
.where(Column(self.columns) == val)
1059+
.execute()
1060+
)
1061+
tables.append(
1062+
filtered.select(
1063+
**{str(val): Expression(self.AGGFUNC_MAP[self.aggfunc], Column(self.values))}
1064+
)
1065+
.by(*self.index)
1066+
.execute()
1067+
)
1068+
1069+
result = tables[0]
1070+
for tbl in tables[1:]:
1071+
result = result.left_join(tbl, on=self.index).execute()
1072+
1073+
return result
1074+
1075+
10121076
class TableColumnInterval:
10131077
def __init__(
10141078
self,
@@ -1053,6 +1117,7 @@ def compile(self) -> r.RayObject:
10531117
"InnerJoin",
10541118
"InsertQuery",
10551119
"LeftJoin",
1120+
"PivotQuery",
10561121
"Table",
10571122
"TableColumnInterval",
10581123
"UpdateQuery",

tests/types/table/test_pivot.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from rayforce import I64, Symbol, Table, Vector
2+
3+
4+
def _get_pivot_values(result, index_col: str) -> dict:
5+
index_values = [v.to_python() for v in result[index_col]]
6+
columns = [str(c) for c in result.columns() if str(c) != index_col]
7+
result_dict = {}
8+
for i, idx_val in enumerate(index_values):
9+
result_dict[idx_val] = {col: result[col][i].to_python() for col in columns}
10+
return result_dict
11+
12+
13+
def test_pivot_simple():
14+
table = Table(
15+
{
16+
"symbol": Vector(items=["AAPL", "AAPL", "GOOG", "GOOG"], ray_type=Symbol),
17+
"metric": Vector(items=["price", "volume", "price", "volume"], ray_type=Symbol),
18+
"value": Vector(items=[150, 1000, 2800, 500], ray_type=I64),
19+
}
20+
)
21+
22+
result = table.pivot(index="symbol", columns="metric", values="value")
23+
24+
columns = [str(c) for c in result.columns()]
25+
assert "symbol" in columns
26+
assert "price" in columns
27+
assert "volume" in columns
28+
assert len(result) == 2
29+
30+
result_dict = _get_pivot_values(result, "symbol")
31+
assert result_dict["AAPL"]["price"] == 150
32+
assert result_dict["AAPL"]["volume"] == 1000
33+
assert result_dict["GOOG"]["price"] == 2800
34+
assert result_dict["GOOG"]["volume"] == 500
35+
36+
37+
def test_pivot_with_multiple_index_columns():
38+
table = Table(
39+
{
40+
"date": Vector(
41+
items=["2024-01-01", "2024-01-01", "2024-01-02", "2024-01-02"], ray_type=Symbol
42+
),
43+
"symbol": Vector(items=["AAPL", "AAPL", "AAPL", "AAPL"], ray_type=Symbol),
44+
"metric": Vector(items=["open", "close", "open", "close"], ray_type=Symbol),
45+
"value": Vector(items=[150, 152, 153, 155], ray_type=I64),
46+
}
47+
)
48+
49+
result = table.pivot(index=["date", "symbol"], columns="metric", values="value")
50+
51+
columns = [str(c) for c in result.columns()]
52+
assert "date" in columns
53+
assert "symbol" in columns
54+
assert "open" in columns
55+
assert "close" in columns
56+
assert len(result) == 2
57+
58+
result_dict = _get_pivot_values(result, "date")
59+
assert result_dict["2024-01-01"]["open"] == 150
60+
assert result_dict["2024-01-01"]["close"] == 152
61+
assert result_dict["2024-01-02"]["open"] == 153
62+
assert result_dict["2024-01-02"]["close"] == 155
63+
64+
65+
def test_pivot_with_sum_aggfunc():
66+
table = Table(
67+
{
68+
"category": Vector(items=["A", "A", "A", "B", "B"], ray_type=Symbol),
69+
"type": Vector(items=["x", "x", "y", "x", "y"], ray_type=Symbol),
70+
"value": Vector(items=[10, 20, 30, 40, 50], ray_type=I64),
71+
}
72+
)
73+
74+
result = table.pivot(index="category", columns="type", values="value", aggfunc="sum")
75+
76+
columns = [str(c) for c in result.columns()]
77+
assert "x" in columns
78+
assert "y" in columns
79+
assert len(result) == 2
80+
81+
# A has x: 10+20=30, y: 30; B has x: 40, y: 50
82+
result_dict = _get_pivot_values(result, "category")
83+
assert result_dict["A"]["x"] == 30
84+
assert result_dict["A"]["y"] == 30
85+
assert result_dict["B"]["x"] == 40
86+
assert result_dict["B"]["y"] == 50
87+
88+
89+
def test_pivot_with_count_aggfunc():
90+
table = Table(
91+
{
92+
"category": Vector(items=["A", "A", "A", "B", "B"], ray_type=Symbol),
93+
"type": Vector(items=["x", "x", "y", "x", "y"], ray_type=Symbol),
94+
"value": Vector(items=[10, 20, 30, 40, 50], ray_type=I64),
95+
}
96+
)
97+
98+
result = table.pivot(index="category", columns="type", values="value", aggfunc="count")
99+
100+
assert len(result) == 2
101+
102+
# A has x: 2, y: 1; B has x: 1, y: 1
103+
result_dict = _get_pivot_values(result, "category")
104+
assert result_dict["A"]["x"] == 2
105+
assert result_dict["A"]["y"] == 1
106+
assert result_dict["B"]["x"] == 1
107+
assert result_dict["B"]["y"] == 1
108+
109+
110+
def test_pivot_with_avg_aggfunc():
111+
table = Table(
112+
{
113+
"category": Vector(items=["A", "A", "B"], ray_type=Symbol),
114+
"metric": Vector(items=["x", "x", "x"], ray_type=Symbol),
115+
"value": Vector(items=[10, 20, 30], ray_type=I64),
116+
}
117+
)
118+
119+
result = table.pivot(index="category", columns="metric", values="value", aggfunc="avg")
120+
121+
assert len(result) == 2
122+
123+
# A has x: (10+20)/2=15, B has x: 30
124+
result_dict = _get_pivot_values(result, "category")
125+
assert result_dict["A"]["x"] == 15
126+
assert result_dict["B"]["x"] == 30
127+
128+
129+
def test_pivot_with_min_aggfunc():
130+
table = Table(
131+
{
132+
"category": Vector(items=["A", "A", "A", "B", "B"], ray_type=Symbol),
133+
"type": Vector(items=["x", "x", "y", "x", "y"], ray_type=Symbol),
134+
"value": Vector(items=[10, 20, 30, 40, 50], ray_type=I64),
135+
}
136+
)
137+
138+
result = table.pivot(index="category", columns="type", values="value", aggfunc="min")
139+
140+
# A has x: min(10,20)=10, y: 30; B has x: 40, y: 50
141+
result_dict = _get_pivot_values(result, "category")
142+
assert result_dict["A"]["x"] == 10
143+
assert result_dict["A"]["y"] == 30
144+
assert result_dict["B"]["x"] == 40
145+
assert result_dict["B"]["y"] == 50
146+
147+
148+
def test_pivot_with_max_aggfunc():
149+
table = Table(
150+
{
151+
"category": Vector(items=["A", "A", "A", "B", "B"], ray_type=Symbol),
152+
"type": Vector(items=["x", "x", "y", "x", "y"], ray_type=Symbol),
153+
"value": Vector(items=[10, 20, 30, 40, 50], ray_type=I64),
154+
}
155+
)
156+
157+
result = table.pivot(index="category", columns="type", values="value", aggfunc="max")
158+
159+
# A has x: max(10,20)=20, y: 30; B has x: 40, y: 50
160+
result_dict = _get_pivot_values(result, "category")
161+
assert result_dict["A"]["x"] == 20
162+
assert result_dict["A"]["y"] == 30
163+
assert result_dict["B"]["x"] == 40
164+
assert result_dict["B"]["y"] == 50
165+
166+
167+
def test_pivot_single_value_per_cell():
168+
table = Table(
169+
{
170+
"row": Vector(items=["r1", "r1", "r2", "r2"], ray_type=Symbol),
171+
"col": Vector(items=["c1", "c2", "c1", "c2"], ray_type=Symbol),
172+
"val": Vector(items=[1, 2, 3, 4], ray_type=I64),
173+
}
174+
)
175+
176+
result = table.pivot(index="row", columns="col", values="val")
177+
178+
assert len(result) == 2
179+
columns = [str(c) for c in result.columns()]
180+
assert "c1" in columns
181+
assert "c2" in columns
182+
183+
result_dict = _get_pivot_values(result, "row")
184+
assert result_dict["r1"]["c1"] == 1
185+
assert result_dict["r1"]["c2"] == 2
186+
assert result_dict["r2"]["c1"] == 3
187+
assert result_dict["r2"]["c2"] == 4
188+
189+
190+
def test_pivot_preserves_order():
191+
table = Table(
192+
{
193+
"id": Vector(items=["a", "a", "a"], ray_type=Symbol),
194+
"key": Vector(items=["third", "first", "second"], ray_type=Symbol),
195+
"value": Vector(items=[3, 1, 2], ray_type=I64),
196+
}
197+
)
198+
199+
result = table.pivot(index="id", columns="key", values="value", aggfunc="min")
200+
201+
# third, first, second
202+
columns = [str(c) for c in result.columns()]
203+
assert "third" in columns
204+
assert "first" in columns
205+
assert "second" in columns
206+
207+
result_dict = _get_pivot_values(result, "id")
208+
assert result_dict["a"]["third"] == 3
209+
assert result_dict["a"]["first"] == 1
210+
assert result_dict["a"]["second"] == 2
211+
212+
213+
def test_pivot_followed_by_select():
214+
table = Table(
215+
{
216+
"symbol": Vector(items=["AAPL", "AAPL", "GOOG", "GOOG"], ray_type=Symbol),
217+
"metric": Vector(items=["price", "volume", "price", "volume"], ray_type=Symbol),
218+
"value": Vector(items=[150, 1000, 2800, 500], ray_type=I64),
219+
}
220+
)
221+
222+
result = (
223+
table.pivot(index="symbol", columns="metric", values="value")
224+
.select("symbol", "price")
225+
.execute()
226+
)
227+
columns = [str(c) for c in result.columns()]
228+
assert "symbol" in columns
229+
assert "price" in columns

0 commit comments

Comments
 (0)