Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion sqlparser_bench/benches/sqlparser_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,30 @@ fn parse_table_factor_paren_chain(c: &mut Criterion) {
group.finish();
}

/// Benchmark parsing pathological `CAST(CASE (CAST(CASE (...` chains that
/// previously caused 2^N work in `parse_function_args` on dialects with
/// expression-named function arguments (the argument expression was parsed
/// once to detect the named form, then re-parsed on the unnamed path).
fn parse_function_arg_call_chain(c: &mut Criterion) {
let mut group = c.benchmark_group("parse_function_arg_call_chain");
let dialect = PostgreSqlDialect {};

for &n in &[10usize, 20, 30] {
let sql = String::from("SELECT ") + &"CAST(CASE (".repeat(n) + &")".repeat(n);

group.bench_function(format!("chain_{n}"), |b| {
b.iter(|| {
let _ = Parser::new(&dialect)
.with_recursion_limit(256)
.try_with_sql(std::hint::black_box(&sql))
.and_then(|mut p| p.parse_statements());
});
});
}

group.finish();
}

criterion_group!(
benches,
basic_queries,
Expand All @@ -282,6 +306,7 @@ criterion_group!(
parse_compound_keyword_chain,
parse_prefix_keyword_call_chain,
parse_prefix_case_chain,
parse_table_factor_paren_chain
parse_table_factor_paren_chain,
parse_function_arg_call_chain
);
criterion_main!(benches);
90 changes: 59 additions & 31 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18676,37 +18676,59 @@ impl<'a> Parser<'a> {

/// Parse a single function argument, handling named and unnamed variants.
pub fn parse_function_args(&mut self) -> Result<FunctionArg, ParserError> {
let arg = if self.dialect.supports_named_fn_args_with_expr_name() {
self.maybe_parse(|p| {
let name = p.parse_expr()?;
let operator = p.parse_function_named_arg_operator()?;
let arg = p.parse_wildcard_expr()?.into();
Ok(FunctionArg::ExprNamed {
name,
arg,
operator,
})
})?
} else {
self.maybe_parse(|p| {
let name = p.parse_identifier()?;
let operator = p.parse_function_named_arg_operator()?;
let arg = p.parse_wildcard_expr()?.into();
Ok(FunctionArg::Named {
name,
arg,
operator,
})
})?
};
// Parse the argument expression once, then check for a named-arg
// operator. Parsing it speculatively and re-parsing on the unnamed
// path is O(2^depth) on nested calls like `CAST(CASE (CAST(CASE (…`.
if self.dialect.supports_named_fn_args_with_expr_name() {
let expr = self.parse_wildcard_expr()?;
// A wildcard is never a named-arg name; only the unnamed form applies.
if !matches!(expr, Expr::Wildcard(_) | Expr::QualifiedWildcard(..)) {
if let Some(operator) =
self.maybe_parse(|p| p.parse_function_named_arg_operator())?
{
let arg = self.parse_wildcard_expr()?.into();
return Ok(FunctionArg::ExprNamed {
name: expr,
arg,
operator,
});
}
}
let arg_expr = self.function_arg_expr_from_wildcard(expr)?;
return Ok(FunctionArg::Unnamed(
self.maybe_parse_aliased_function_arg(arg_expr)?,
));
}

let arg = self.maybe_parse(|p| {
let name = p.parse_identifier()?;
let operator = p.parse_function_named_arg_operator()?;
let arg = p.parse_wildcard_expr()?.into();
Ok(FunctionArg::Named {
name,
arg,
operator,
})
})?;
if let Some(arg) = arg {
return Ok(arg);
}
let wildcard_expr = self.parse_wildcard_expr()?;
let arg_expr: FunctionArgExpr = match wildcard_expr {
let arg_expr = self.function_arg_expr_from_wildcard(wildcard_expr)?;
Ok(FunctionArg::Unnamed(
self.maybe_parse_aliased_function_arg(arg_expr)?,
))
}

/// Wrap an already-parsed expression as a function argument, parsing any
/// trailing wildcard options (e.g. Snowflake's `HASH(* EXCLUDE(col))`).
fn function_arg_expr_from_wildcard(
&mut self,
wildcard_expr: Expr,
) -> Result<FunctionArgExpr, ParserError> {
Ok(match wildcard_expr {
Expr::Wildcard(ref token) if self.dialect.supports_select_wildcard_exclude() => {
// Support `* EXCLUDE(col1, col2, ...)` inside function calls (e.g. Snowflake's
// `HASH(* EXCLUDE(col))`). Parse the options the same way SELECT items do.
// Parse the options the same way SELECT items do.
let opts = self.parse_wildcard_additional_options(token.0.clone())?;
if opts.opt_exclude.is_some()
|| opts.opt_except.is_some()
Expand All @@ -18720,9 +18742,16 @@ impl<'a> Parser<'a> {
}
}
other => other.into(),
};
// Aliased argument, e.g. `XMLFOREST(a AS x)` in PostgreSQL
let arg_expr = match arg_expr {
})
}

/// Parse an optional `AS <alias>` on an unnamed function argument
/// (e.g. `XMLFOREST(a AS x)` in PostgreSQL).
fn maybe_parse_aliased_function_arg(
&mut self,
arg_expr: FunctionArgExpr,
) -> Result<FunctionArgExpr, ParserError> {
Ok(match arg_expr {
FunctionArgExpr::Expr(expr)
if self.dialect.supports_aliased_function_args()
&& self.parse_keyword(Keyword::AS) =>
Expand All @@ -18733,8 +18762,7 @@ impl<'a> Parser<'a> {
})
}
other => other,
};
Ok(FunctionArg::Unnamed(arg_expr))
})
}

fn parse_function_named_arg_operator(&mut self) -> Result<FunctionArgOperator, ParserError> {
Expand Down
48 changes: 30 additions & 18 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15609,11 +15609,10 @@ fn parse_create_table_select() {

#[test]
fn test_reserved_keywords_for_identifiers() {
let dialects = all_dialects_where(|d| {
d.is_reserved_for_identifier(Keyword::INTERVAL)
&& !d.supports_named_fn_args_with_expr_name()
});
// Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier
// Dialects that reserve INTERVAL will not allow it as an unquoted
// identifier, and report the failure consistently at the token that fails
// to start an expression (`)`), independent of named-argument support.
let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL));
let sql = "SELECT MAX(interval) FROM tbl";
assert_eq!(
dialects.parse_sql_statements(sql),
Expand All @@ -15622,19 +15621,6 @@ fn test_reserved_keywords_for_identifiers() {
))
);

// Dialects with expression-named function arguments parse the argument
// expression twice, so the second attempt reports the memoized failure
// at the start of the expression
let dialects = all_dialects_where(|d| {
d.is_reserved_for_identifier(Keyword::INTERVAL) && d.supports_named_fn_args_with_expr_name()
});
assert_eq!(
dialects.parse_sql_statements(sql),
Err(ParserError::ParserError(
"Expected: an expression, found: interval".to_string()
))
);

// Dialects that do not reserve the word INTERVAL will allow it
let dialects = all_dialects_where(|d| !d.is_reserved_for_identifier(Keyword::INTERVAL));
let sql = "SELECT MAX(interval) FROM tbl";
Expand Down Expand Up @@ -19572,3 +19558,29 @@ fn parse_unlogged_table_logging_controls_in_all_dialects() {
_ => unreachable!("Expected ALTER TABLE"),
}
}

/// Regression test for the 2^N parse-time blowup in `parse_function_args` on
/// inputs like `CAST(CASE (CAST(CASE (...`. On dialects with expression-named
/// function arguments (e.g. PostgreSQL), the named-arg arm parsed the whole
/// argument expression and then re-parsed it on the unnamed path, doubling
/// work per level. Post-fix the leading expression is parsed exactly once.
#[test]
fn parse_function_arg_call_chain_no_exponential_blowup() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

let sql = String::from("SELECT ") + &"CAST(CASE (".repeat(30) + &")".repeat(30);

let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let _ = Parser::new(&PostgreSqlDialect {})
.with_recursion_limit(256)
.try_with_sql(&sql)
.and_then(|mut p| p.parse_statements());
let _ = tx.send(());
});

rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}
Loading