Skip to content

Commit b57b1a4

Browse files
committed
Untested SSL support
cc rust-postgres#15
1 parent 4a3ee11 commit b57b1a4

4 files changed

Lines changed: 119 additions & 67 deletions

File tree

lib.rs

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ extern mod ssl = "github.com/sfackler/rust-ssl";
7373
use extra::container::Deque;
7474
use extra::ringbuf::RingBuf;
7575
use extra::url::{UserInfo, Url};
76-
use ssl::SslStream;
76+
use ssl::{SslStream, SslContext};
77+
use ssl::error::SslError;
7778
use std::cell::Cell;
7879
use std::hashmap::HashMap;
7980
use std::rt::io::{Writer, io_error, Decorator};
@@ -117,6 +118,7 @@ use self::message::{FrontendMessage,
117118
PasswordMessage,
118119
Query,
119120
StartupMessage,
121+
SslRequest,
120122
Sync,
121123
Terminate};
122124
use self::message::{RowDescriptionEntry, WriteMessage, ReadMessage};
@@ -196,7 +198,11 @@ pub enum PostgresConnectError {
196198
MissingPassword,
197199
/// The Postgres server requested an authentication method not supported
198200
/// by the driver
199-
UnsupportedAuthentication
201+
UnsupportedAuthentication,
202+
/// The Postgres server does not support SSL encryption
203+
NoSslSupport,
204+
/// There was an error initializing the SSL session
205+
SslError(SslError)
200206
}
201207

202208
/// Represents the position of an error in a query
@@ -333,7 +339,7 @@ pub struct PostgresCancelData {
333339
/// A `PostgresCancelData` object can be created via
334340
/// `PostgresConnection::cancel_data`. The object can cancel any query made on
335341
/// that connection.
336-
pub fn cancel_query(url: &str, data: PostgresCancelData)
342+
pub fn cancel_query(url: &str, ssl: SslMode, data: PostgresCancelData)
337343
-> Option<PostgresConnectError> {
338344
let Url { host, port, _ }: Url = match FromStr::from_str(url) {
339345
Some(url) => url,
@@ -344,7 +350,7 @@ pub fn cancel_query(url: &str, data: PostgresCancelData)
344350
None => DEFAULT_PORT
345351
};
346352

347-
let mut socket = match open_socket(host, port) {
353+
let mut socket = match initialize_stream(host, port, ssl) {
348354
Ok(socket) => socket,
349355
Err(err) => return Some(err)
350356
};
@@ -354,6 +360,7 @@ pub fn cancel_query(url: &str, data: PostgresCancelData)
354360
process_id: data.process_id,
355361
secret_key: data.secret_key
356362
});
363+
socket.flush();
357364

358365
None
359366
}
@@ -381,6 +388,34 @@ fn open_socket(host: &str, port: Port)
381388
Err(SocketError)
382389
}
383390

391+
fn initialize_stream(host: &str, port: Port, ssl: SslMode)
392+
-> Result<InternalStream, PostgresConnectError> {
393+
let mut socket = match open_socket(host, port) {
394+
Ok(socket) => socket,
395+
Err(err) => return Err(err)
396+
};
397+
398+
let (ssl_required, ctx) = match ssl {
399+
NoSsl => return Ok(Normal(socket)),
400+
PreferSsl(ctx) => (false, ctx),
401+
RequireSsl(ctx) => (true, ctx)
402+
};
403+
404+
socket.write_message(&SslRequest { code: message::SSL_CODE });
405+
socket.flush();
406+
407+
let resp = socket.read_u8();
408+
409+
if resp == 'N' as u8 && ssl_required {
410+
return Err(NoSslSupport);
411+
}
412+
413+
match SslStream::try_new(ctx, socket) {
414+
Ok(stream) => Ok(Ssl(stream)),
415+
Err(err) => Err(SslError(err))
416+
}
417+
}
418+
384419
enum InternalStream {
385420
Normal(TcpStream),
386421
Ssl(SslStream<TcpStream>)
@@ -422,14 +457,14 @@ struct InnerPostgresConnection {
422457
impl Drop for InnerPostgresConnection {
423458
fn drop(&mut self) {
424459
do io_error::cond.trap(|_| {}).inside {
425-
self.write_messages([&Terminate]);
460+
self.write_messages([Terminate]);
426461
}
427462
}
428463
}
429464

430465
impl InnerPostgresConnection {
431-
fn try_connect(url: &str) -> Result<InnerPostgresConnection,
432-
PostgresConnectError> {
466+
fn try_connect(url: &str, ssl: SslMode)
467+
-> Result<InnerPostgresConnection, PostgresConnectError> {
433468
let Url {
434469
host,
435470
port,
@@ -452,13 +487,13 @@ impl InnerPostgresConnection {
452487
None => DEFAULT_PORT
453488
};
454489

455-
let stream = match open_socket(host, port) {
490+
let stream = match initialize_stream(host, port, ssl) {
456491
Ok(stream) => stream,
457492
Err(err) => return Err(err)
458493
};
459494

460495
let mut conn = InnerPostgresConnection {
461-
stream: BufferedStream::new(Normal(stream)),
496+
stream: BufferedStream::new(stream),
462497
next_stmt_id: 0,
463498
notice_handler: ~DefaultNoticeHandler as ~PostgresNoticeHandler,
464499
notifications: RingBuf::new(),
@@ -475,7 +510,7 @@ impl InnerPostgresConnection {
475510
// path contains the leading /
476511
args.push((~"database", path.slice_from(1).to_owned()));
477512
}
478-
conn.write_messages([&StartupMessage {
513+
conn.write_messages([StartupMessage {
479514
version: message::PROTOCOL_VERSION,
480515
parameters: args.as_slice()
481516
}]);
@@ -501,8 +536,8 @@ impl InnerPostgresConnection {
501536
Ok(conn)
502537
}
503538

504-
fn write_messages(&mut self, messages: &[&FrontendMessage]) {
505-
for &message in messages.iter() {
539+
fn write_messages(&mut self, messages: &[FrontendMessage]) {
540+
for message in messages.iter() {
506541
self.stream.write_message(message);
507542
}
508543
self.stream.flush();
@@ -534,7 +569,7 @@ impl InnerPostgresConnection {
534569
Some(pass) => pass,
535570
None => return Some(MissingPassword)
536571
};
537-
self.write_messages([&PasswordMessage { password: pass }]);
572+
self.write_messages([PasswordMessage { password: pass }]);
538573
}
539574
AuthenticationMD5Password { salt } => {
540575
let UserInfo { user, pass } = user;
@@ -550,7 +585,7 @@ impl InnerPostgresConnection {
550585
md5.input_str(output);
551586
md5.input(salt);
552587
let output = "md5" + md5.result_str();
553-
self.write_messages([&PasswordMessage {
588+
self.write_messages([PasswordMessage {
554589
password: output.as_slice()
555590
}]);
556591
}
@@ -583,16 +618,16 @@ impl InnerPostgresConnection {
583618

584619
let types = [];
585620
self.write_messages([
586-
&Parse {
621+
Parse {
587622
name: stmt_name,
588623
query: query,
589624
param_types: types
590625
},
591-
&Describe {
626+
Describe {
592627
variant: 'S' as u8,
593628
name: stmt_name
594629
},
595-
&Sync]);
630+
Sync]);
596631

597632
match self.read_message() {
598633
ParseComplete => {}
@@ -659,9 +694,9 @@ impl PostgresConnection {
659694
/// The password may be omitted if not required. The default Postgres port
660695
/// (5432) is used if none is specified. The database name defaults to the
661696
/// username if not specified.
662-
pub fn try_connect(url: &str) -> Result<PostgresConnection,
663-
PostgresConnectError> {
664-
do InnerPostgresConnection::try_connect(url).map |conn| {
697+
pub fn try_connect(url: &str, ssl: SslMode)
698+
-> Result<PostgresConnection, PostgresConnectError> {
699+
do InnerPostgresConnection::try_connect(url, ssl).map |conn| {
665700
PostgresConnection {
666701
conn: Cell::new(conn)
667702
}
@@ -673,8 +708,8 @@ impl PostgresConnection {
673708
/// # Failure
674709
///
675710
/// Fails if there was an error connecting to the database.
676-
pub fn connect(url: &str) -> PostgresConnection {
677-
match PostgresConnection::try_connect(url) {
711+
pub fn connect(url: &str, ssl: SslMode) -> PostgresConnection {
712+
match PostgresConnection::try_connect(url, ssl) {
678713
Ok(conn) => conn,
679714
Err(err) => fail!("Failed to connect: {}", err.to_str())
680715
}
@@ -780,7 +815,7 @@ impl PostgresConnection {
780815

781816
fn quick_query(&self, query: &str) {
782817
do self.conn.with_mut_ref |conn| {
783-
conn.write_messages([&Query { query: query }]);
818+
conn.write_messages([Query { query: query }]);
784819

785820
loop {
786821
match conn.read_message() {
@@ -806,13 +841,23 @@ impl PostgresConnection {
806841
}
807842
}
808843

809-
fn write_messages(&self, messages: &[&FrontendMessage]) {
844+
fn write_messages(&self, messages: &[FrontendMessage]) {
810845
do self.conn.with_mut_ref |conn| {
811846
conn.write_messages(messages)
812847
}
813848
}
814849
}
815850

851+
/// Specifies the SSL support requested for a new connection
852+
pub enum SslMode<'a> {
853+
/// The connection will not use SSL
854+
NoSsl,
855+
/// The connection will use SSL if the backend supports it
856+
PreferSsl(&'a SslContext),
857+
/// The connection must use SSL
858+
RequireSsl(&'a SslContext)
859+
}
860+
816861
/// Represents a transaction on a database connection
817862
pub struct PostgresTransaction<'self> {
818863
priv conn: &'self PostgresConnection,
@@ -974,11 +1019,11 @@ impl<'self> Drop for NormalPostgresStatement<'self> {
9741019
fn drop(&mut self) {
9751020
do io_error::cond.trap(|_| {}).inside {
9761021
self.conn.write_messages([
977-
&Close {
1022+
Close {
9781023
variant: 'S' as u8,
9791024
name: self.name.as_slice()
9801025
},
981-
&Sync]);
1026+
Sync]);
9821027
loop {
9831028
match self.conn.read_message() {
9841029
ReadyForQuery {_} => break,
@@ -1008,18 +1053,18 @@ impl<'self> NormalPostgresStatement<'self> {
10081053
}).collect();
10091054

10101055
self.conn.write_messages([
1011-
&Bind {
1056+
Bind {
10121057
portal: portal_name,
10131058
statement: self.name.as_slice(),
10141059
formats: formats,
10151060
values: values,
10161061
result_formats: result_formats
10171062
},
1018-
&Execute {
1063+
Execute {
10191064
portal: portal_name,
10201065
max_rows: row_limit as i32
10211066
},
1022-
&Sync]);
1067+
Sync]);
10231068

10241069
match self.conn.read_message() {
10251070
BindComplete => None,
@@ -1201,11 +1246,11 @@ impl<'self> Drop for PostgresResult<'self> {
12011246
fn drop(&mut self) {
12021247
do io_error::cond.trap(|_| {}).inside {
12031248
self.stmt.conn.write_messages([
1204-
&Close {
1249+
Close {
12051250
variant: 'P' as u8,
12061251
name: self.name.as_slice()
12071252
},
1208-
&Sync]);
1253+
Sync]);
12091254
loop {
12101255
match self.stmt.conn.read_message() {
12111256
ReadyForQuery {_} => break,
@@ -1238,11 +1283,11 @@ impl<'self> PostgresResult<'self> {
12381283

12391284
fn execute(&mut self) {
12401285
self.stmt.conn.write_messages([
1241-
&Execute {
1286+
Execute {
12421287
portal: self.name,
12431288
max_rows: self.row_limit as i32
12441289
},
1245-
&Sync]);
1290+
Sync]);
12461291
self.read_rows();
12471292
}
12481293
}

message.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use super::types::Oid;
1010

1111
pub static PROTOCOL_VERSION: i32 = 0x0003_0000;
1212
pub static CANCEL_CODE: i32 = 80877102;
13+
pub static SSL_CODE: i32 = 80877103;
1314

1415
#[deriving(ToStr)]
1516
pub enum BackendMessage {
@@ -115,6 +116,9 @@ pub enum FrontendMessage<'self> {
115116
version: i32,
116117
parameters: &'self [(~str, ~str)]
117118
},
119+
SslRequest {
120+
code: i32
121+
},
118122
Sync,
119123
Terminate
120124
}
@@ -214,6 +218,7 @@ impl<W: Writer> WriteMessage for W {
214218
}
215219
buf.write_u8(0);
216220
}
221+
SslRequest { code } => buf.write_be_i32(code),
217222
Sync => {
218223
ident = Some('S');
219224
}

pool.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use super::{PostgresNotificationIterator,
1010
NormalPostgresStatement,
1111
PostgresDbError,
1212
PostgresConnectError,
13-
PostgresTransaction};
13+
PostgresTransaction,
14+
NoSsl};
1415
use super::types::ToSql;
1516

1617
struct InnerConnectionPool {
@@ -20,7 +21,7 @@ struct InnerConnectionPool {
2021

2122
impl InnerConnectionPool {
2223
fn new_connection(&mut self) -> Option<PostgresConnectError> {
23-
match PostgresConnection::try_connect(self.url) {
24+
match PostgresConnection::try_connect(self.url, NoSsl) {
2425
Ok(conn) => {
2526
self.pool.push(conn);
2627
None

0 commit comments

Comments
 (0)