Skip to content

Commit 707b87a

Browse files
committed
Fix parameter parsing and add test
Our behavior matches libpq's - in particular it allows any escape sequence and trailing \'s...
1 parent 7297661 commit 707b87a

3 files changed

Lines changed: 109 additions & 39 deletions

File tree

tokio-postgres/src/builder.rs

Lines changed: 74 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
use std::borrow::Cow;
2-
use std::collections::HashMap;
1+
use std::collections::hash_map::{self, HashMap};
32
use std::iter;
43
use std::str::{self, FromStr};
54
use tokio_io::{AsyncRead, AsyncWrite};
@@ -44,6 +43,11 @@ impl Builder {
4443
self
4544
}
4645

46+
/// FIXME do we want this?
47+
pub fn iter(&self) -> Iter<'_> {
48+
Iter(self.params.iter())
49+
}
50+
4751
pub fn connect<S, T>(&self, stream: S, tls_mode: T) -> Connect<S, T>
4852
where
4953
S: AsyncRead + AsyncWrite,
@@ -61,13 +65,30 @@ impl FromStr for Builder {
6165
let mut builder = Builder::new();
6266

6367
while let Some((key, value)) = parser.parameter()? {
64-
builder.param(key, &value);
68+
builder.params.insert(key.to_string(), value);
6569
}
6670

6771
Ok(builder)
6872
}
6973
}
7074

75+
#[derive(Debug, Clone)]
76+
pub struct Iter<'a>(hash_map::Iter<'a, String, String>);
77+
78+
impl<'a> Iterator for Iter<'a> {
79+
type Item = (&'a str, &'a str);
80+
81+
fn next(&mut self) -> Option<(&'a str, &'a str)> {
82+
self.0.next().map(|(k, v)| (&**k, &**v))
83+
}
84+
}
85+
86+
impl<'a> ExactSizeIterator for Iter<'a> {
87+
fn len(&self) -> usize {
88+
self.0.len()
89+
}
90+
}
91+
7192
struct Parser<'a> {
7293
s: &'a str,
7394
it: iter::Peekable<str::CharIndices<'a>>,
@@ -82,9 +103,7 @@ impl<'a> Parser<'a> {
82103
}
83104

84105
fn skip_ws(&mut self) {
85-
while let Some(&(_, ' ')) = self.it.peek() {
86-
self.it.next();
87-
}
106+
self.take_while(|c| c.is_whitespace());
88107
}
89108

90109
fn take_while<F>(&mut self, f: F) -> &'a str
@@ -133,7 +152,8 @@ impl<'a> Parser<'a> {
133152

134153
fn keyword(&mut self) -> Option<&'a str> {
135154
let s = self.take_while(|c| match c {
136-
' ' | '=' => false,
155+
c if c.is_whitespace() => false,
156+
'=' => false,
137157
_ => true,
138158
});
139159

@@ -144,52 +164,67 @@ impl<'a> Parser<'a> {
144164
}
145165
}
146166

147-
fn value(&mut self) -> Result<Cow<'a, str>, Error> {
148-
let raw = if self.eat_if('\'') {
149-
let s = self.take_while(|c| c != '\'');
167+
fn value(&mut self) -> Result<String, Error> {
168+
let value = if self.eat_if('\'') {
169+
let value = self.quoted_value()?;
150170
self.eat('\'')?;
151-
s
171+
value
152172
} else {
153-
let s = self.take_while(|c| c != ' ');
154-
if s.is_empty() {
155-
return Err(Error::connection_syntax("unexpected EOF".into()));
156-
}
157-
s
173+
self.simple_value()?
158174
};
159175

160-
self.unescape_value(raw)
176+
Ok(value)
161177
}
162178

163-
fn unescape_value(&mut self, raw: &'a str) -> Result<Cow<'a, str>, Error> {
164-
if !raw.contains('\\') {
165-
return Ok(Cow::Borrowed(raw));
179+
fn simple_value(&mut self) -> Result<String, Error> {
180+
let mut value = String::new();
181+
182+
while let Some(&(_, c)) = self.it.peek() {
183+
if c.is_whitespace() {
184+
break;
185+
}
186+
187+
self.it.next();
188+
if c == '\\' {
189+
if let Some((_, c2)) = self.it.next() {
190+
value.push(c2);
191+
}
192+
} else {
193+
value.push(c);
194+
}
195+
}
196+
197+
if value.is_empty() {
198+
return Err(Error::connection_syntax("unexpected EOF".into()));
166199
}
167200

168-
let mut s = String::with_capacity(raw.len());
169-
170-
let mut it = raw.chars();
171-
while let Some(c) = it.next() {
172-
let to_push = if c == '\\' {
173-
match it.next() {
174-
Some('\'') => '\'',
175-
Some('\\') => '\\',
176-
Some(c) => {
177-
return Err(Error::connection_syntax(
178-
format!("invalid escape `\\{}`", c).into(),
179-
));
180-
}
181-
None => return Err(Error::connection_syntax("unexpected EOF".into())),
201+
Ok(value)
202+
}
203+
204+
fn quoted_value(&mut self) -> Result<String, Error> {
205+
let mut value = String::new();
206+
207+
while let Some(&(_, c)) = self.it.peek() {
208+
if c == '\'' {
209+
return Ok(value);
210+
}
211+
212+
self.it.next();
213+
if c == '\\' {
214+
if let Some((_, c2)) = self.it.next() {
215+
value.push(c2);
182216
}
183217
} else {
184-
c
185-
};
186-
s.push(to_push);
218+
value.push(c);
219+
}
187220
}
188221

189-
Ok(Cow::Owned(s))
222+
Err(Error::connection_syntax(
223+
"unterminated quoted connection parameter value".into(),
224+
))
190225
}
191226

192-
fn parameter(&mut self) -> Result<Option<(&'a str, Cow<'a, str>)>, Error> {
227+
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
193228
self.skip_ws();
194229
let keyword = match self.keyword() {
195230
Some(keyword) => keyword,

tokio-postgres/tests/test/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use tokio_postgres::error::SqlState;
1313
use tokio_postgres::types::{Kind, Type};
1414
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls};
1515

16+
mod parse;
1617
mod types;
1718

1819
fn connect(

tokio-postgres/tests/test/parse.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
use std::collections::HashMap;
2+
3+
#[test]
4+
fn pairs_ok() {
5+
let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''"
6+
.parse::<tokio_postgres::Builder>()
7+
.unwrap();
8+
let params = params.iter().collect::<HashMap<_, _>>();
9+
10+
let mut expected = HashMap::new();
11+
expected.insert("user", "foo");
12+
expected.insert("password", r" fizz 'buzz\ ");
13+
expected.insert("thing", "");
14+
expected.insert("client_encoding", "UTF8");
15+
expected.insert("timezone", "GMT");
16+
17+
assert_eq!(params, expected);
18+
}
19+
20+
#[test]
21+
fn pairs_ws() {
22+
let params = " user\t=\r\n\x0bfoo \t password = hunter2 "
23+
.parse::<tokio_postgres::Builder>()
24+
.unwrap();;
25+
let params = params.iter().collect::<HashMap<_, _>>();
26+
27+
let mut expected = HashMap::new();
28+
expected.insert("user", "foo");
29+
expected.insert("password", r"hunter2");
30+
expected.insert("client_encoding", "UTF8");
31+
expected.insert("timezone", "GMT");
32+
33+
assert_eq!(params, expected);
34+
}

0 commit comments

Comments
 (0)