1- use openssl:: ssl:: { SslStream , MaybeSslStream } ;
1+ use openssl:: ssl:: { SslStream , SslContext } ;
2+ use std:: error:: Error ;
23use std:: io;
34use std:: io:: prelude:: * ;
45use std:: net:: TcpStream ;
56#[ cfg( feature = "unix_socket" ) ]
67use unix_socket:: UnixStream ;
78use byteorder:: ReadBytesExt ;
89
9- use { ConnectParams , SslMode , ConnectTarget , ConnectError } ;
10+ use { ConnectParams , ConnectTarget , ConnectError } ;
1011use message;
1112use message:: WriteMessage ;
1213use message:: FrontendMessage :: SslRequest ;
1314
1415const DEFAULT_PORT : u16 = 5432 ;
1516
17+ pub trait StreamWrapper < S : Read +Write > : Read +Write +Send {
18+ fn get_ref ( & self ) -> & S ;
19+ fn get_mut ( & mut self ) -> & mut S ;
20+ }
21+
22+ impl < S : Read +Write +Send > StreamWrapper < S > for SslStream < S > {
23+ fn get_ref ( & self ) -> & S {
24+ self . get_ref ( )
25+ }
26+
27+ fn get_mut ( & mut self ) -> & mut S {
28+ self . get_mut ( )
29+ }
30+ }
31+
32+ pub trait NegotiateSsl {
33+ fn negotiate_ssl < S > ( & mut self , stream : S ) -> Result < Box < StreamWrapper < S > > , Box < Error > >
34+ where S : Read +Write +Send +' static ;
35+ }
36+
37+ impl NegotiateSsl for SslContext {
38+ fn negotiate_ssl < S > ( & mut self , stream : S ) -> Result < Box < StreamWrapper < S > > , Box < Error > >
39+ where S : Read +Write +Send +' static {
40+ let stream = try!( SslStream :: new ( self , stream) ) ;
41+ Ok ( Box :: new ( stream) )
42+ }
43+ }
44+
45+ /// Specifies the SSL support requested for a new connection.
46+ pub enum SslMode < N = NoSsl > {
47+ /// The connection will not use SSL.
48+ None ,
49+ /// The connection will use SSL if the backend supports it.
50+ Prefer ( N ) ,
51+ /// The connection must use SSL.
52+ Require ( N ) ,
53+ }
54+
55+ pub enum NoSsl { }
56+
57+ impl NegotiateSsl for NoSsl {
58+ fn negotiate_ssl < S : Read +Write > ( & mut self , stream : S )
59+ -> Result < Box < StreamWrapper < S > > , Box < Error > > {
60+ match * self { }
61+ }
62+ }
63+
1664pub enum InternalStream {
1765 Tcp ( TcpStream ) ,
1866 #[ cfg( feature = "unix_socket" ) ]
1967 Unix ( UnixStream ) ,
2068}
2169
70+ impl StreamWrapper < InternalStream > for InternalStream {
71+ fn get_ref ( & self ) -> & InternalStream {
72+ self
73+ }
74+
75+ fn get_mut ( & mut self ) -> & mut InternalStream {
76+ self
77+ }
78+ }
79+
2280impl Read for InternalStream {
2381 fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
2482 match * self {
@@ -62,14 +120,15 @@ fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
62120 }
63121}
64122
65- pub fn initialize_stream ( params : & ConnectParams , ssl : & SslMode )
66- -> Result < MaybeSslStream < InternalStream > , ConnectError > {
123+ pub fn initialize_stream < N > ( params : & ConnectParams , ssl : & mut SslMode < N > )
124+ -> Result < Box < StreamWrapper < InternalStream > > , ConnectError >
125+ where N : NegotiateSsl {
67126 let mut socket = try!( open_socket ( params) ) ;
68127
69- let ( ssl_required, ctx ) = match * ssl {
70- SslMode :: None => return Ok ( MaybeSslStream :: Normal ( socket) ) ,
71- SslMode :: Prefer ( ref ctx ) => ( false , ctx ) ,
72- SslMode :: Require ( ref ctx ) => ( true , ctx )
128+ let ( ssl_required, negotiator ) = match * ssl {
129+ SslMode :: None => return Ok ( Box :: new ( socket) ) ,
130+ SslMode :: Prefer ( ref mut negotiator ) => ( false , negotiator ) ,
131+ SslMode :: Require ( ref mut negotiator ) => ( true , negotiator ) ,
73132 } ;
74133
75134 try!( socket. write_message ( & SslRequest { code : message:: SSL_CODE } ) ) ;
@@ -79,12 +138,12 @@ pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode)
79138 if ssl_required {
80139 return Err ( ConnectError :: NoSslSupport ) ;
81140 } else {
82- return Ok ( MaybeSslStream :: Normal ( socket) ) ;
141+ return Ok ( Box :: new ( socket) ) ;
83142 }
84143 }
85144
86- match SslStream :: new ( ctx , socket) {
87- Ok ( stream) => Ok ( MaybeSslStream :: Ssl ( stream) ) ,
145+ match negotiator . negotiate_ssl ( socket) {
146+ Ok ( stream) => Ok ( stream) ,
88147 Err ( err) => Err ( ConnectError :: SslError ( err) )
89148 }
90149}
0 commit comments