Skip to content

Commit a4bdcb1

Browse files
committed
Overhaul error type
1 parent a05adff commit a4bdcb1

21 files changed

Lines changed: 396 additions & 378 deletions

File tree

postgres-protocol/src/authentication/sasl.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub struct ScramSha256 {
141141

142142
impl ScramSha256 {
143143
/// Constructs a new instance which will use the provided password for authentication.
144-
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> io::Result<ScramSha256> {
144+
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
145145
// rand 0.5's ThreadRng is cryptographically secure
146146
let mut rng = rand::thread_rng();
147147
let nonce = (0..NONCE_LENGTH)
@@ -151,25 +151,20 @@ impl ScramSha256 {
151151
v = 0x7e
152152
}
153153
v as char
154-
})
155-
.collect::<String>();
154+
}).collect::<String>();
156155

157156
ScramSha256::new_inner(password, channel_binding, nonce)
158157
}
159158

160-
fn new_inner(
161-
password: &[u8],
162-
channel_binding: ChannelBinding,
163-
nonce: String,
164-
) -> io::Result<ScramSha256> {
165-
Ok(ScramSha256 {
159+
fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
160+
ScramSha256 {
166161
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
167162
state: State::Update {
168163
nonce,
169164
password: normalize(password),
170165
channel_binding: channel_binding,
171166
},
172-
})
167+
}
173168
}
174169

175170
/// Returns the message which should be sent to the backend in an `SASLResponse` message.
@@ -487,7 +482,7 @@ mod test {
487482
password.as_bytes(),
488483
ChannelBinding::unsupported(),
489484
nonce.to_string(),
490-
).unwrap();
485+
);
491486
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
492487

493488
scram.update(server_first.as_bytes()).unwrap();

postgres/src/lib.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ impl InnerConnection {
464464
error::connect("a password was requested but not provided".into())
465465
})?;
466466

467-
let mut scram = ScramSha256::new(pass.as_bytes(), channel_binding)?;
467+
let mut scram = ScramSha256::new(pass.as_bytes(), channel_binding);
468468

469469
self.stream.write_message(|buf| {
470470
frontend::sasl_initial_response(mechanism, scram.message(), buf)
@@ -763,8 +763,7 @@ impl InnerConnection {
763763
field.name().to_owned(),
764764
self.get_type(field.type_oid())?,
765765
))
766-
})
767-
.collect()
766+
}).collect()
768767
.map_err(From::from),
769768
None => Ok(vec![]),
770769
}
@@ -820,7 +819,8 @@ impl InnerConnection {
820819
let (name, type_, elem_oid, rngsubtype, basetype, schema, relid) = {
821820
let name =
822821
String::from_sql_nullable(&Type::NAME, get_raw(0)).map_err(error::conversion)?;
823-
let type_ = i8::from_sql_nullable(&Type::CHAR, get_raw(1)).map_err(error::conversion)?;
822+
let type_ =
823+
i8::from_sql_nullable(&Type::CHAR, get_raw(1)).map_err(error::conversion)?;
824824
let elem_oid =
825825
Oid::from_sql_nullable(&Type::OID, get_raw(2)).map_err(error::conversion)?;
826826
let rngsubtype = Option::<Oid>::from_sql_nullable(&Type::OID, get_raw(3))
@@ -829,7 +829,8 @@ impl InnerConnection {
829829
Oid::from_sql_nullable(&Type::OID, get_raw(4)).map_err(error::conversion)?;
830830
let schema =
831831
String::from_sql_nullable(&Type::NAME, get_raw(5)).map_err(error::conversion)?;
832-
let relid = Oid::from_sql_nullable(&Type::OID, get_raw(6)).map_err(error::conversion)?;
832+
let relid =
833+
Oid::from_sql_nullable(&Type::OID, get_raw(6)).map_err(error::conversion)?;
833834
(name, type_, elem_oid, rngsubtype, basetype, schema, relid)
834835
};
835836

@@ -894,7 +895,7 @@ impl InnerConnection {
894895
let mut variants = vec![];
895896
for row in rows {
896897
variants.push(
897-
String::from_sql_nullable(&Type::NAME, row.get(0)).map_err(error::conversion)?
898+
String::from_sql_nullable(&Type::NAME, row.get(0)).map_err(error::conversion)?,
898899
);
899900
}
900901

@@ -930,8 +931,8 @@ impl InnerConnection {
930931
let mut fields = vec![];
931932
for row in rows {
932933
let (name, type_) = {
933-
let name =
934-
String::from_sql_nullable(&Type::NAME, row.get(0)).map_err(error::conversion)?;
934+
let name = String::from_sql_nullable(&Type::NAME, row.get(0))
935+
.map_err(error::conversion)?;
935936
let type_ =
936937
Oid::from_sql_nullable(&Type::OID, row.get(1)).map_err(error::conversion)?;
937938
(name, type_)

tokio-postgres/src/error/mod.rs

Lines changed: 132 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
33
use fallible_iterator::FallibleIterator;
44
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
5-
use std::convert::From;
65
use std::error;
76
use std::fmt;
87
use std::io;
8+
use tokio_timer;
99

1010
pub use self::sqlstate::*;
1111

@@ -333,147 +333,175 @@ pub enum ErrorPosition {
333333
},
334334
}
335335

336-
#[doc(hidden)]
337-
pub fn connect(e: Box<error::Error + Sync + Send>) -> Error {
338-
Error(Box::new(ErrorKind::ConnectParams(e)))
336+
#[derive(Debug, PartialEq)]
337+
enum Kind {
338+
Io,
339+
UnexpectedMessage,
340+
Tls,
341+
ToSql,
342+
FromSql,
343+
CopyInStream,
344+
Closed,
345+
Db,
346+
Parse,
347+
Encode,
348+
MissingUser,
349+
MissingPassword,
350+
UnsupportedAuthentication,
351+
Connect,
352+
Timer,
353+
Authentication,
339354
}
340355

341-
#[doc(hidden)]
342-
pub fn tls(e: Box<error::Error + Sync + Send>) -> Error {
343-
Error(Box::new(ErrorKind::Tls(e)))
356+
struct ErrorInner {
357+
kind: Kind,
358+
cause: Option<Box<error::Error + Sync + Send>>,
344359
}
345360

346-
#[doc(hidden)]
347-
pub fn db(e: DbError) -> Error {
348-
Error(Box::new(ErrorKind::Db(e)))
349-
}
361+
/// An error communicating with the Postgres server.
362+
pub struct Error(Box<ErrorInner>);
350363

351-
#[doc(hidden)]
352-
pub fn __db(e: ErrorResponseBody) -> Error {
353-
match DbError::new(&mut e.fields()) {
354-
Ok(e) => Error(Box::new(ErrorKind::Db(e))),
355-
Err(e) => Error(Box::new(ErrorKind::Io(e))),
364+
impl fmt::Debug for Error {
365+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
366+
fmt.debug_struct("Error")
367+
.field("kind", &self.0.kind)
368+
.field("cause", &self.0.cause)
369+
.finish()
356370
}
357371
}
358372

359-
#[doc(hidden)]
360-
pub fn __user<T>(e: T) -> Error
361-
where
362-
T: Into<Box<error::Error + Sync + Send>>,
363-
{
364-
Error(Box::new(ErrorKind::Conversion(e.into())))
365-
}
366-
367-
#[doc(hidden)]
368-
pub fn io(e: io::Error) -> Error {
369-
Error(Box::new(ErrorKind::Io(e)))
370-
}
371-
372-
#[doc(hidden)]
373-
pub fn conversion(e: Box<error::Error + Sync + Send>) -> Error {
374-
Error(Box::new(ErrorKind::Conversion(e)))
375-
}
376-
377-
#[derive(Debug)]
378-
enum ErrorKind {
379-
ConnectParams(Box<error::Error + Sync + Send>),
380-
Tls(Box<error::Error + Sync + Send>),
381-
Db(DbError),
382-
Io(io::Error),
383-
Conversion(Box<error::Error + Sync + Send>),
384-
}
385-
386-
/// An error communicating with the Postgres server.
387-
#[derive(Debug)]
388-
pub struct Error(Box<ErrorKind>);
389-
390373
impl fmt::Display for Error {
391374
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
392375
fmt.write_str(error::Error::description(self))?;
393-
match *self.0 {
394-
ErrorKind::ConnectParams(ref err) => write!(fmt, ": {}", err),
395-
ErrorKind::Tls(ref err) => write!(fmt, ": {}", err),
396-
ErrorKind::Db(ref err) => write!(fmt, ": {}", err),
397-
ErrorKind::Io(ref err) => write!(fmt, ": {}", err),
398-
ErrorKind::Conversion(ref err) => write!(fmt, ": {}", err),
376+
if let Some(ref cause) = self.0.cause {
377+
write!(fmt, ": {}", cause)?;
399378
}
379+
Ok(())
400380
}
401381
}
402382

403383
impl error::Error for Error {
404384
fn description(&self) -> &str {
405-
match *self.0 {
406-
ErrorKind::ConnectParams(_) => "invalid connection parameters",
407-
ErrorKind::Tls(_) => "TLS handshake error",
408-
ErrorKind::Db(_) => "database error",
409-
ErrorKind::Io(_) => "IO error",
410-
ErrorKind::Conversion(_) => "type conversion error",
385+
match self.0.kind {
386+
Kind::Io => "error communicating with the server",
387+
Kind::UnexpectedMessage => "unexpected message from server",
388+
Kind::Tls => "error performing TLS handshake",
389+
Kind::ToSql => "error serializing a value",
390+
Kind::FromSql => "error deserializing a value",
391+
Kind::CopyInStream => "error from a copy_in stream",
392+
Kind::Closed => "connection closed",
393+
Kind::Db => "db error",
394+
Kind::Parse => "error parsing response from server",
395+
Kind::Encode => "error encoding message to server",
396+
Kind::MissingUser => "username not provided",
397+
Kind::MissingPassword => "password not provided",
398+
Kind::UnsupportedAuthentication => "unsupported authentication method requested",
399+
Kind::Connect => "error connecting to server",
400+
Kind::Timer => "timer error",
401+
Kind::Authentication => "authentication error",
411402
}
412403
}
413404

414405
fn cause(&self) -> Option<&error::Error> {
415-
match *self.0 {
416-
ErrorKind::ConnectParams(ref err) => Some(&**err),
417-
ErrorKind::Tls(ref err) => Some(&**err),
418-
ErrorKind::Db(ref err) => Some(err),
419-
ErrorKind::Io(ref err) => Some(err),
420-
ErrorKind::Conversion(ref err) => Some(&**err),
421-
}
406+
self.0.cause.as_ref().map(|e| &**e as &error::Error)
422407
}
423408
}
424409

425410
impl Error {
426-
/// Returns the SQLSTATE error code associated with this error if it is a DB
427-
/// error.
411+
/// Returns the error's cause.
412+
///
413+
/// This is the same as `Error::cause` except that it provides extra bounds
414+
/// required to be able to downcast the error.
415+
pub fn cause2(&self) -> Option<&(error::Error + 'static + Sync + Send)> {
416+
self.0.cause.as_ref().map(|e| &**e)
417+
}
418+
419+
/// Consumes the error, returning its cause.
420+
pub fn into_cause(self) -> Option<Box<error::Error + Sync + Send>> {
421+
self.0.cause
422+
}
423+
424+
/// Returns the SQLSTATE error code associated with the error.
425+
///
426+
/// This is a convenience method that downcasts the cause to a `DbError`
427+
/// and returns its code.
428428
pub fn code(&self) -> Option<&SqlState> {
429-
self.as_db().map(|e| &e.code)
429+
self.cause2()
430+
.and_then(|e| e.downcast_ref::<DbError>())
431+
.map(|e| e.code())
430432
}
431433

432-
/// Returns the inner error if this is a connection parameter error.
433-
pub fn as_connection(&self) -> Option<&(error::Error + 'static + Sync + Send)> {
434-
match *self.0 {
435-
ErrorKind::ConnectParams(ref err) => Some(&**err),
436-
_ => None,
437-
}
434+
fn new(kind: Kind, cause: Option<Box<error::Error + Sync + Send>>) -> Error {
435+
Error(Box::new(ErrorInner { kind, cause }))
438436
}
439437

440-
/// Returns the `DbError` associated with this error if it is a DB error.
441-
pub fn as_db(&self) -> Option<&DbError> {
442-
match *self.0 {
443-
ErrorKind::Db(ref err) => Some(err),
444-
_ => None,
445-
}
438+
pub(crate) fn closed() -> Error {
439+
Error::new(Kind::Closed, None)
446440
}
447441

448-
/// Returns the inner error if this is a conversion error.
449-
pub fn as_conversion(&self) -> Option<&(error::Error + 'static + Sync + Send)> {
450-
match *self.0 {
451-
ErrorKind::Conversion(ref err) => Some(&**err),
452-
_ => None,
453-
}
442+
pub(crate) fn unexpected_message() -> Error {
443+
Error::new(Kind::UnexpectedMessage, None)
454444
}
455445

456-
/// Returns the inner `io::Error` associated with this error if it is an IO
457-
/// error.
458-
pub fn as_io(&self) -> Option<&io::Error> {
459-
match *self.0 {
460-
ErrorKind::Io(ref err) => Some(err),
461-
_ => None,
446+
pub(crate) fn db(error: ErrorResponseBody) -> Error {
447+
match DbError::new(&mut error.fields()) {
448+
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
449+
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
462450
}
463451
}
464-
}
465452

466-
impl From<io::Error> for Error {
467-
fn from(err: io::Error) -> Error {
468-
Error(Box::new(ErrorKind::Io(err)))
453+
pub(crate) fn parse(e: io::Error) -> Error {
454+
Error::new(Kind::Parse, Some(Box::new(e)))
469455
}
470-
}
471456

472-
impl From<Error> for io::Error {
473-
fn from(err: Error) -> io::Error {
474-
match *err.0 {
475-
ErrorKind::Io(e) => e,
476-
_ => io::Error::new(io::ErrorKind::Other, err),
477-
}
457+
pub(crate) fn encode(e: io::Error) -> Error {
458+
Error::new(Kind::Encode, Some(Box::new(e)))
459+
}
460+
461+
pub(crate) fn to_sql(e: Box<error::Error + Sync + Send>) -> Error {
462+
Error::new(Kind::ToSql, Some(e))
463+
}
464+
465+
pub(crate) fn from_sql(e: Box<error::Error + Sync + Send>) -> Error {
466+
Error::new(Kind::FromSql, Some(e))
467+
}
468+
469+
pub(crate) fn copy_in_stream<E>(e: E) -> Error
470+
where
471+
E: Into<Box<error::Error + Sync + Send>>,
472+
{
473+
Error::new(Kind::CopyInStream, Some(e.into()))
474+
}
475+
476+
pub(crate) fn missing_user() -> Error {
477+
Error::new(Kind::MissingUser, None)
478+
}
479+
480+
pub(crate) fn missing_password() -> Error {
481+
Error::new(Kind::MissingPassword, None)
482+
}
483+
484+
pub(crate) fn unsupported_authentication() -> Error {
485+
Error::new(Kind::UnsupportedAuthentication, None)
486+
}
487+
488+
pub(crate) fn tls(e: Box<error::Error + Sync + Send>) -> Error {
489+
Error::new(Kind::Tls, Some(e))
490+
}
491+
492+
pub(crate) fn connect(e: io::Error) -> Error {
493+
Error::new(Kind::Connect, Some(Box::new(e)))
494+
}
495+
496+
pub(crate) fn timer(e: tokio_timer::Error) -> Error {
497+
Error::new(Kind::Timer, Some(Box::new(e)))
498+
}
499+
500+
pub(crate) fn io(e: io::Error) -> Error {
501+
Error::new(Kind::Io, Some(Box::new(e)))
502+
}
503+
504+
pub(crate) fn authentication(e: io::Error) -> Error {
505+
Error::new(Kind::Authentication, Some(Box::new(e)))
478506
}
479507
}

0 commit comments

Comments
 (0)