Skip to content

Commit 69e5d79

Browse files
committed
Handle qualified AGGREGATE
1 parent b413d38 commit 69e5d79

File tree

1 file changed

+141
-13
lines changed

1 file changed

+141
-13
lines changed

yardstick-rs/src/sql/measures.rs

Lines changed: 141 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,26 +203,135 @@ pub fn has_as_measure(sql: &str) -> bool {
203203

204204
/// Check if SQL contains AGGREGATE( function
205205
pub fn has_aggregate_function(sql: &str) -> bool {
206-
let sql_upper = sql.to_uppercase();
207-
let mut search_pos = 0;
206+
let chars: Vec<char> = sql.chars().collect();
207+
let len = chars.len();
208+
let mut i = 0;
208209

210+
let is_ident_start = |c: char| c.is_alphabetic() || c == '_';
209211
let is_ident_char = |c: char| c.is_alphanumeric() || c == '_';
210212

211-
while let Some(offset) = sql_upper[search_pos..].find("AGGREGATE") {
212-
let start = search_pos + offset;
213-
if start > 0 {
214-
if let Some(prev) = sql_upper[..start].chars().last() {
215-
if is_ident_char(prev) {
216-
search_pos = start + 1;
217-
continue;
213+
let skip_whitespace = |mut idx: usize| -> usize {
214+
while idx < len && chars[idx].is_whitespace() {
215+
idx += 1;
216+
}
217+
idx
218+
};
219+
220+
let parse_identifier = |start: usize| -> (String, usize) {
221+
let mut idx = start + 1;
222+
while idx < len && is_ident_char(chars[idx]) {
223+
idx += 1;
224+
}
225+
let token: String = chars[start..idx].iter().collect();
226+
(token, idx)
227+
};
228+
229+
let parse_quoted_identifier = |start: usize| -> (String, usize) {
230+
let mut token = String::new();
231+
let mut idx = start;
232+
while idx < len {
233+
match chars[idx] {
234+
'"' => {
235+
if idx + 1 < len && chars[idx + 1] == '"' {
236+
token.push('"');
237+
idx += 2;
238+
} else {
239+
idx += 1;
240+
break;
241+
}
242+
}
243+
c => {
244+
token.push(c);
245+
idx += 1;
218246
}
219247
}
220248
}
221-
let after = &sql_upper[start + "AGGREGATE".len()..];
222-
if after.trim_start().starts_with('(') {
223-
return true;
249+
(token, idx)
250+
};
251+
252+
let parse_qualified_chain = |first: String, mut idx: usize| -> (String, usize) {
253+
let mut last = first;
254+
loop {
255+
idx = skip_whitespace(idx);
256+
if idx >= len || chars[idx] != '.' {
257+
break;
258+
}
259+
idx += 1;
260+
idx = skip_whitespace(idx);
261+
if idx >= len {
262+
break;
263+
}
264+
if chars[idx] == '"' {
265+
let (token, next) = parse_quoted_identifier(idx + 1);
266+
last = token;
267+
idx = next;
268+
} else if is_ident_start(chars[idx]) {
269+
let (token, next) = parse_identifier(idx);
270+
last = token;
271+
idx = next;
272+
} else {
273+
break;
274+
}
275+
}
276+
(last, idx)
277+
};
278+
279+
let is_aggregate_token = |token: &str| token.eq_ignore_ascii_case("AGGREGATE");
280+
281+
while i < len {
282+
match chars[i] {
283+
'\'' => {
284+
i += 1;
285+
while i < len {
286+
if chars[i] == '\'' {
287+
if i + 1 < len && chars[i + 1] == '\'' {
288+
i += 2;
289+
continue;
290+
}
291+
i += 1;
292+
break;
293+
}
294+
i += 1;
295+
}
296+
}
297+
'-' if i + 1 < len && chars[i + 1] == '-' => {
298+
i += 2;
299+
while i < len && chars[i] != '\n' {
300+
i += 1;
301+
}
302+
}
303+
'/' if i + 1 < len && chars[i + 1] == '*' => {
304+
i += 2;
305+
while i + 1 < len {
306+
if chars[i] == '*' && chars[i + 1] == '/' {
307+
i += 2;
308+
break;
309+
}
310+
i += 1;
311+
}
312+
}
313+
'"' => {
314+
let (token, next) = parse_quoted_identifier(i + 1);
315+
let (last, after_chain) = parse_qualified_chain(token, next);
316+
let after_ws = skip_whitespace(after_chain);
317+
if after_ws < len && chars[after_ws] == '(' && is_aggregate_token(&last) {
318+
return true;
319+
}
320+
i = after_chain;
321+
}
322+
c if is_ident_start(c) => {
323+
let (token, next) = parse_identifier(i);
324+
let (last, after_chain) = parse_qualified_chain(token, next);
325+
let after_ws = skip_whitespace(after_chain);
326+
if after_ws < len && chars[after_ws] == '(' && is_aggregate_token(&last) {
327+
return true;
328+
}
329+
i = after_chain;
330+
}
331+
_ => {
332+
i += 1;
333+
}
224334
}
225-
search_pos = start + 1;
226335
}
227336

228337
false
@@ -3818,7 +3927,13 @@ mod tests {
38183927
fn test_has_aggregate_function() {
38193928
assert!(has_aggregate_function("SELECT AGGREGATE(revenue) FROM foo"));
38203929
assert!(has_aggregate_function("SELECT AGGREGATE (revenue) FROM foo"));
3930+
assert!(has_aggregate_function("SELECT \"AGGREGATE\"(revenue) FROM foo"));
3931+
assert!(has_aggregate_function("SELECT schema.AGGREGATE(revenue) FROM foo"));
3932+
assert!(has_aggregate_function(
3933+
"SELECT \"schema\".\"AGGREGATE\" (revenue) FROM foo"
3934+
));
38213935
assert!(!has_aggregate_function("SELECT TOTAL_AGGREGATE(revenue) FROM foo"));
3936+
assert!(!has_aggregate_function("SELECT \"TOTAL_AGGREGATE\"(revenue) FROM foo"));
38223937
assert!(!has_aggregate_function("SELECT myaggregate(revenue) FROM foo"));
38233938
assert!(!has_aggregate_function("SELECT SUM(amount) FROM foo"));
38243939
}
@@ -3852,6 +3967,19 @@ mod tests {
38523967
);
38533968
}
38543969

3970+
#[test]
3971+
fn test_extract_dimension_columns_ignores_quoted_and_qualified_aggregate() {
3972+
let cols = extract_dimension_columns_from_select(
3973+
"SELECT region, \"AGGREGATE\"(revenue) FROM sales_v",
3974+
);
3975+
assert_eq!(cols, vec!["region".to_string()]);
3976+
3977+
let cols = extract_dimension_columns_from_select(
3978+
"SELECT region, schema.AGGREGATE(revenue) FROM sales_v",
3979+
);
3980+
assert_eq!(cols, vec!["region".to_string()]);
3981+
}
3982+
38553983
#[test]
38563984
#[serial]
38573985
fn test_process_create_view_basic() {

0 commit comments

Comments
 (0)