Skip to content

Commit be49982

Browse files
committed
Make SSL infrastructure implementation agnostic
1 parent 627f101 commit be49982

5 files changed

Lines changed: 161 additions & 117 deletions

File tree

src/error.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
pub use ugh_privacy::DbError;
22

33
use byteorder;
4-
use openssl::ssl::error::SslError;
54
use phf;
65
use std::error;
76
use std::convert::From;
@@ -29,8 +28,8 @@ pub enum ConnectError {
2928
UnsupportedAuthentication,
3029
/// The Postgres server does not support SSL encryption.
3130
NoSslSupport,
32-
/// There was an error initializing the SSL session.
33-
SslError(SslError),
31+
/// There was an error initializing the SSL session
32+
SslError(Box<error::Error>),
3433
/// There was an error communicating with the server.
3534
IoError(io::Error),
3635
/// The server sent an unexpected response.
@@ -67,7 +66,7 @@ impl error::Error for ConnectError {
6766
fn cause(&self) -> Option<&error::Error> {
6867
match *self {
6968
ConnectError::DbError(ref err) => Some(err),
70-
ConnectError::SslError(ref err) => Some(err),
69+
ConnectError::SslError(ref err) => Some(&**err),
7170
ConnectError::IoError(ref err) => Some(err),
7271
_ => None
7372
}
@@ -86,12 +85,6 @@ impl From<DbError> for ConnectError {
8685
}
8786
}
8887

89-
impl From<SslError> for ConnectError {
90-
fn from(err: SslError) -> ConnectError {
91-
ConnectError::SslError(err)
92-
}
93-
}
94-
9588
impl From<byteorder::Error> for ConnectError {
9689
fn from(err: byteorder::Error) -> ConnectError {
9790
ConnectError::IoError(From::from(err))

src/io_util.rs

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,82 @@
1-
use openssl::ssl::{SslStream, MaybeSslStream};
1+
use openssl::ssl::{SslStream, SslContext};
2+
use std::error::Error;
23
use std::io;
34
use std::io::prelude::*;
45
use std::net::TcpStream;
56
#[cfg(feature = "unix_socket")]
67
use unix_socket::UnixStream;
78
use byteorder::ReadBytesExt;
89

9-
use {ConnectParams, SslMode, ConnectTarget, ConnectError};
10+
use {ConnectParams, ConnectTarget, ConnectError};
1011
use message;
1112
use message::WriteMessage;
1213
use message::FrontendMessage::SslRequest;
1314

1415
const DEFAULT_PORT: u16 = 5432;
1516

17+
pub trait StreamWrapper<S: Read+Write>: Read+Write+Send {
18+
fn get_ref(&self) -> &S;
19+
fn get_mut(&mut self) -> &mut S;
20+
}
21+
22+
impl<S: Read+Write+Send> StreamWrapper<S> for SslStream<S> {
23+
fn get_ref(&self) -> &S {
24+
self.get_ref()
25+
}
26+
27+
fn get_mut(&mut self) -> &mut S {
28+
self.get_mut()
29+
}
30+
}
31+
32+
pub trait NegotiateSsl {
33+
fn negotiate_ssl<S>(&mut self, stream: S) -> Result<Box<StreamWrapper<S>>, Box<Error>>
34+
where S: Read+Write+Send+'static;
35+
}
36+
37+
impl NegotiateSsl for SslContext {
38+
fn negotiate_ssl<S>(&mut self, stream: S) -> Result<Box<StreamWrapper<S>>, Box<Error>>
39+
where S: Read+Write+Send+'static {
40+
let stream = try!(SslStream::new(self, stream));
41+
Ok(Box::new(stream))
42+
}
43+
}
44+
45+
/// Specifies the SSL support requested for a new connection.
46+
pub enum SslMode<N = NoSsl> {
47+
/// The connection will not use SSL.
48+
None,
49+
/// The connection will use SSL if the backend supports it.
50+
Prefer(N),
51+
/// The connection must use SSL.
52+
Require(N),
53+
}
54+
55+
pub enum NoSsl {}
56+
57+
impl NegotiateSsl for NoSsl {
58+
fn negotiate_ssl<S: Read+Write>(&mut self, stream: S)
59+
-> Result<Box<StreamWrapper<S>>, Box<Error>> {
60+
match *self {}
61+
}
62+
}
63+
1664
pub enum InternalStream {
1765
Tcp(TcpStream),
1866
#[cfg(feature = "unix_socket")]
1967
Unix(UnixStream),
2068
}
2169

70+
impl StreamWrapper<InternalStream> for InternalStream {
71+
fn get_ref(&self) -> &InternalStream {
72+
self
73+
}
74+
75+
fn get_mut(&mut self) -> &mut InternalStream {
76+
self
77+
}
78+
}
79+
2280
impl Read for InternalStream {
2381
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
2482
match *self {
@@ -62,14 +120,15 @@ fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
62120
}
63121
}
64122

65-
pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode)
66-
-> Result<MaybeSslStream<InternalStream>, ConnectError> {
123+
pub fn initialize_stream<N>(params: &ConnectParams, ssl: &mut SslMode<N>)
124+
-> Result<Box<StreamWrapper<InternalStream>>, ConnectError>
125+
where N: NegotiateSsl {
67126
let mut socket = try!(open_socket(params));
68127

69-
let (ssl_required, ctx) = match *ssl {
70-
SslMode::None => return Ok(MaybeSslStream::Normal(socket)),
71-
SslMode::Prefer(ref ctx) => (false, ctx),
72-
SslMode::Require(ref ctx) => (true, ctx)
128+
let (ssl_required, negotiator) = match *ssl {
129+
SslMode::None => return Ok(Box::new(socket)),
130+
SslMode::Prefer(ref mut negotiator) => (false, negotiator),
131+
SslMode::Require(ref mut negotiator) => (true, negotiator),
73132
};
74133

75134
try!(socket.write_message(&SslRequest { code: message::SSL_CODE }));
@@ -79,12 +138,12 @@ pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode)
79138
if ssl_required {
80139
return Err(ConnectError::NoSslSupport);
81140
} else {
82-
return Ok(MaybeSslStream::Normal(socket));
141+
return Ok(Box::new(socket));
83142
}
84143
}
85144

86-
match SslStream::new(ctx, socket) {
87-
Ok(stream) => Ok(MaybeSslStream::Ssl(stream)),
145+
match negotiator.negotiate_ssl(socket) {
146+
Ok(stream) => Ok(stream),
88147
Err(err) => Err(ConnectError::SslError(err))
89148
}
90149
}

src/lib.rs

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ extern crate debug_builders;
5959
use bufstream::BufStream;
6060
use debug_builders::DebugStruct;
6161
use openssl::crypto::hash::{self, Hasher};
62-
use openssl::ssl::{SslContext, MaybeSslStream};
6362
use serialize::hex::ToHex;
6463
use std::ascii::AsciiExt;
6564
use std::borrow::{ToOwned, Cow};
@@ -80,6 +79,7 @@ use std::path::PathBuf;
8079
pub use error::{Error, ConnectError, SqlState, DbError, ErrorPosition};
8180
#[doc(inline)]
8281
pub use types::{Oid, Type, Kind, ToSql, FromSql};
82+
pub use io_util::{SslMode, NegotiateSsl, StreamWrapper, NoSsl};
8383
use types::IsNull;
8484
#[doc(inline)]
8585
pub use types::Slice;
@@ -387,8 +387,9 @@ pub struct CancelData {
387387
/// # let _ =
388388
/// postgres::cancel_query(url, &SslMode::None, cancel_data);
389389
/// ```
390-
pub fn cancel_query<T>(params: T, ssl: &SslMode, data: CancelData)
391-
-> result::Result<(), ConnectError> where T: IntoConnectParams {
390+
pub fn cancel_query<T, N>(params: T, ssl: &mut SslMode<N>, data: CancelData)
391+
-> result::Result<(), ConnectError>
392+
where T: IntoConnectParams, N: NegotiateSsl {
392393
let params = try!(params.into_connect_params());
393394
let mut socket = try!(io_util::initialize_stream(&params, ssl));
394395

@@ -464,7 +465,7 @@ struct CachedStatement {
464465
}
465466

466467
struct InnerConnection {
467-
stream: BufStream<MaybeSslStream<InternalStream>>,
468+
stream: BufStream<Box<StreamWrapper<InternalStream>>>,
468469
notice_handler: Box<HandleNotice>,
469470
notifications: VecDeque<Notification>,
470471
cancel_data: CancelData,
@@ -486,8 +487,9 @@ impl Drop for InnerConnection {
486487
}
487488

488489
impl InnerConnection {
489-
fn connect<T>(params: T, ssl: &SslMode) -> result::Result<InnerConnection, ConnectError>
490-
where T: IntoConnectParams {
490+
fn connect<T, N>(params: T, ssl: &mut SslMode<N>)
491+
-> result::Result<InnerConnection, ConnectError>
492+
where T: IntoConnectParams, N: NegotiateSsl {
491493
let params = try!(params.into_connect_params());
492494
let stream = try!(io_util::initialize_stream(&params, ssl));
493495

@@ -1005,8 +1007,9 @@ impl Connection {
10051007
/// let conn = try!(Connection::connect(params, &SslMode::None));
10061008
/// # Ok(()) };
10071009
/// ```
1008-
pub fn connect<T>(params: T, ssl: &SslMode) -> result::Result<Connection, ConnectError>
1009-
where T: IntoConnectParams {
1010+
pub fn connect<T, N>(params: T, ssl: &mut SslMode<N>)
1011+
-> result::Result<Connection, ConnectError>
1012+
where T: IntoConnectParams, N: NegotiateSsl {
10101013
InnerConnection::connect(params, ssl).map(|conn| {
10111014
Connection { conn: RefCell::new(conn) }
10121015
})
@@ -1244,17 +1247,6 @@ impl Connection {
12441247
}
12451248
}
12461249

1247-
/// Specifies the SSL support requested for a new connection.
1248-
#[derive(Debug)]
1249-
pub enum SslMode {
1250-
/// The connection will not use SSL.
1251-
None,
1252-
/// The connection will use SSL if the backend supports it.
1253-
Prefer(SslContext),
1254-
/// The connection must use SSL.
1255-
Require(SslContext)
1256-
}
1257-
12581250
/// Represents a transaction on a database connection.
12591251
///
12601252
/// The transaction will roll back by default.

0 commit comments

Comments
 (0)