diff --git a/sqlparser_bench/benches/sqlparser_bench.rs b/sqlparser_bench/benches/sqlparser_bench.rs index 8654a313f..13274f4c0 100644 --- a/sqlparser_bench/benches/sqlparser_bench.rs +++ b/sqlparser_bench/benches/sqlparser_bench.rs @@ -273,6 +273,30 @@ fn parse_table_factor_paren_chain(c: &mut Criterion) { group.finish(); } +/// Benchmark parsing pathological `CAST(CASE (CAST(CASE (...` chains that +/// previously caused 2^N work in `parse_function_args` on dialects with +/// expression-named function arguments (the argument expression was parsed +/// once to detect the named form, then re-parsed on the unnamed path). +fn parse_function_arg_call_chain(c: &mut Criterion) { + let mut group = c.benchmark_group("parse_function_arg_call_chain"); + let dialect = PostgreSqlDialect {}; + + for &n in &[10usize, 20, 30] { + let sql = String::from("SELECT ") + &"CAST(CASE (".repeat(n) + &")".repeat(n); + + group.bench_function(format!("chain_{n}"), |b| { + b.iter(|| { + let _ = Parser::new(&dialect) + .with_recursion_limit(256) + .try_with_sql(std::hint::black_box(&sql)) + .and_then(|mut p| p.parse_statements()); + }); + }); + } + + group.finish(); +} + criterion_group!( benches, basic_queries, @@ -282,6 +306,7 @@ criterion_group!( parse_compound_keyword_chain, parse_prefix_keyword_call_chain, parse_prefix_case_chain, - parse_table_factor_paren_chain + parse_table_factor_paren_chain, + parse_function_arg_call_chain ); criterion_main!(benches); diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1e0c5fbbe..ac1b5e376 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -18676,37 +18676,59 @@ impl<'a> Parser<'a> { /// Parse a single function argument, handling named and unnamed variants. pub fn parse_function_args(&mut self) -> Result { - let arg = if self.dialect.supports_named_fn_args_with_expr_name() { - self.maybe_parse(|p| { - let name = p.parse_expr()?; - let operator = p.parse_function_named_arg_operator()?; - let arg = p.parse_wildcard_expr()?.into(); - Ok(FunctionArg::ExprNamed { - name, - arg, - operator, - }) - })? - } else { - self.maybe_parse(|p| { - let name = p.parse_identifier()?; - let operator = p.parse_function_named_arg_operator()?; - let arg = p.parse_wildcard_expr()?.into(); - Ok(FunctionArg::Named { - name, - arg, - operator, - }) - })? - }; + // Parse the argument expression once, then check for a named-arg + // operator. Parsing it speculatively and re-parsing on the unnamed + // path is O(2^depth) on nested calls like `CAST(CASE (CAST(CASE (…`. + if self.dialect.supports_named_fn_args_with_expr_name() { + let expr = self.parse_wildcard_expr()?; + // A wildcard is never a named-arg name; only the unnamed form applies. + if !matches!(expr, Expr::Wildcard(_) | Expr::QualifiedWildcard(..)) { + if let Some(operator) = + self.maybe_parse(|p| p.parse_function_named_arg_operator())? + { + let arg = self.parse_wildcard_expr()?.into(); + return Ok(FunctionArg::ExprNamed { + name: expr, + arg, + operator, + }); + } + } + let arg_expr = self.function_arg_expr_from_wildcard(expr)?; + return Ok(FunctionArg::Unnamed( + self.maybe_parse_aliased_function_arg(arg_expr)?, + )); + } + + let arg = self.maybe_parse(|p| { + let name = p.parse_identifier()?; + let operator = p.parse_function_named_arg_operator()?; + let arg = p.parse_wildcard_expr()?.into(); + Ok(FunctionArg::Named { + name, + arg, + operator, + }) + })?; if let Some(arg) = arg { return Ok(arg); } let wildcard_expr = self.parse_wildcard_expr()?; - let arg_expr: FunctionArgExpr = match wildcard_expr { + let arg_expr = self.function_arg_expr_from_wildcard(wildcard_expr)?; + Ok(FunctionArg::Unnamed( + self.maybe_parse_aliased_function_arg(arg_expr)?, + )) + } + + /// Wrap an already-parsed expression as a function argument, parsing any + /// trailing wildcard options (e.g. Snowflake's `HASH(* EXCLUDE(col))`). + fn function_arg_expr_from_wildcard( + &mut self, + wildcard_expr: Expr, + ) -> Result { + Ok(match wildcard_expr { Expr::Wildcard(ref token) if self.dialect.supports_select_wildcard_exclude() => { - // Support `* EXCLUDE(col1, col2, ...)` inside function calls (e.g. Snowflake's - // `HASH(* EXCLUDE(col))`). Parse the options the same way SELECT items do. + // Parse the options the same way SELECT items do. let opts = self.parse_wildcard_additional_options(token.0.clone())?; if opts.opt_exclude.is_some() || opts.opt_except.is_some() @@ -18720,9 +18742,16 @@ impl<'a> Parser<'a> { } } other => other.into(), - }; - // Aliased argument, e.g. `XMLFOREST(a AS x)` in PostgreSQL - let arg_expr = match arg_expr { + }) + } + + /// Parse an optional `AS ` on an unnamed function argument + /// (e.g. `XMLFOREST(a AS x)` in PostgreSQL). + fn maybe_parse_aliased_function_arg( + &mut self, + arg_expr: FunctionArgExpr, + ) -> Result { + Ok(match arg_expr { FunctionArgExpr::Expr(expr) if self.dialect.supports_aliased_function_args() && self.parse_keyword(Keyword::AS) => @@ -18733,8 +18762,7 @@ impl<'a> Parser<'a> { }) } other => other, - }; - Ok(FunctionArg::Unnamed(arg_expr)) + }) } fn parse_function_named_arg_operator(&mut self) -> Result { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index cfc04e9ba..634b9aea2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -15609,11 +15609,10 @@ fn parse_create_table_select() { #[test] fn test_reserved_keywords_for_identifiers() { - let dialects = all_dialects_where(|d| { - d.is_reserved_for_identifier(Keyword::INTERVAL) - && !d.supports_named_fn_args_with_expr_name() - }); - // Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier + // Dialects that reserve INTERVAL will not allow it as an unquoted + // identifier, and report the failure consistently at the token that fails + // to start an expression (`)`), independent of named-argument support. + let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL)); let sql = "SELECT MAX(interval) FROM tbl"; assert_eq!( dialects.parse_sql_statements(sql), @@ -15622,19 +15621,6 @@ fn test_reserved_keywords_for_identifiers() { )) ); - // Dialects with expression-named function arguments parse the argument - // expression twice, so the second attempt reports the memoized failure - // at the start of the expression - let dialects = all_dialects_where(|d| { - d.is_reserved_for_identifier(Keyword::INTERVAL) && d.supports_named_fn_args_with_expr_name() - }); - assert_eq!( - dialects.parse_sql_statements(sql), - Err(ParserError::ParserError( - "Expected: an expression, found: interval".to_string() - )) - ); - // Dialects that do not reserve the word INTERVAL will allow it let dialects = all_dialects_where(|d| !d.is_reserved_for_identifier(Keyword::INTERVAL)); let sql = "SELECT MAX(interval) FROM tbl"; @@ -19572,3 +19558,29 @@ fn parse_unlogged_table_logging_controls_in_all_dialects() { _ => unreachable!("Expected ALTER TABLE"), } } + +/// Regression test for the 2^N parse-time blowup in `parse_function_args` on +/// inputs like `CAST(CASE (CAST(CASE (...`. On dialects with expression-named +/// function arguments (e.g. PostgreSQL), the named-arg arm parsed the whole +/// argument expression and then re-parsed it on the unnamed path, doubling +/// work per level. Post-fix the leading expression is parsed exactly once. +#[test] +fn parse_function_arg_call_chain_no_exponential_blowup() { + use std::sync::mpsc; + use std::thread; + use std::time::Duration; + + let sql = String::from("SELECT ") + &"CAST(CASE (".repeat(30) + &")".repeat(30); + + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let _ = Parser::new(&PostgreSqlDialect {}) + .with_recursion_limit(256) + .try_with_sql(&sql) + .and_then(|mut p| p.parse_statements()); + let _ = tx.send(()); + }); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("parser should reject this quickly, not loop exponentially"); +}