Skip to content

Commit

Permalink
Query parser: encode ?? as ?
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Sep 24, 2024
1 parent 75ce343 commit e3d1692
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Query {
/// during query execution (`execute()`, `fetch()` etc).
///
/// WARNING: This means that the query must not have any extra `?`, even if
/// they are in a string literal!
/// they are in a string literal! Use `??` to have plain `?` in query.
///
/// [`Serialize`]: serde::Serialize
/// [`Identifier`]: crate::sql::Identifier
Expand Down
45 changes: 30 additions & 15 deletions src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub(crate) enum SqlBuilder {
pub(crate) enum Part {
Arg,
Fields,
Str(&'static str),
Text(String),
}

Expand All @@ -45,20 +46,28 @@ impl fmt::Display for SqlBuilder {

impl SqlBuilder {
pub(crate) fn new(template: &str) -> Self {
let mut iter = template.split('?');
let prefix = String::from(iter.next().unwrap());
let mut parts = vec![Part::Text(prefix)];
let mut parts = Vec::new();
let mut rest = template;
while let Some(idx) = rest.find('?') {
if rest[idx + 1..].starts_with('?') {
parts.push(Part::Text(rest[..idx + 1].to_string()));
rest = &rest[idx + 2..];
continue;
} else if idx != 0 {
parts.push(Part::Text(rest[..idx].to_string()));
}

for s in iter {
let text = if let Some(text) = s.strip_prefix("fields") {
rest = &rest[idx + 1..];
if let Some(restfields) = rest.strip_prefix("fields") {
parts.push(Part::Fields);
text
rest = restfields;
} else {
parts.push(Part::Arg);
s
};
}
}

parts.push(Part::Text(text.into()));
if !rest.is_empty() {
parts.push(Part::Text(rest.to_string()));
}

SqlBuilder::InProgress(parts)
Expand Down Expand Up @@ -96,16 +105,12 @@ impl SqlBuilder {
}
}

pub(crate) fn append(&mut self, suffix: &str) {
pub(crate) fn append(&mut self, suffix: &'static str) {
let Self::InProgress(parts) = self else {
return;
};

if let Some(Part::Text(text)) = parts.last_mut() {
text.push_str(suffix);
} else {
// Do nothing, it will fail in `finish()`.
}
parts.push(Part::Str(suffix));
}

pub(crate) fn finish(mut self) -> Result<String> {
Expand All @@ -114,6 +119,7 @@ impl SqlBuilder {
if let Self::InProgress(parts) = &self {
for part in parts {
match part {
Part::Str(text) => sql.push_str(text),
Part::Text(text) => sql.push_str(text),
Part::Arg => {
self.error("unbound query argument");
Expand Down Expand Up @@ -223,6 +229,15 @@ mod tests {
);
}

#[test]
fn question_escape() {
let sql = SqlBuilder::new("SELECT 1 FROM test WHERE a IN 'a??b'");
assert_eq!(
sql.finish().unwrap(),
r"SELECT 1 FROM test WHERE a IN 'a?b'"
);
}

#[test]
fn option_as_null() {
let mut sql = SqlBuilder::new("SELECT 1 FROM test WHERE a = ?");
Expand Down

0 comments on commit e3d1692

Please sign in to comment.