@@ -4,7 +4,7 @@ use base64;
44use generic_array:: typenum:: U32 ;
55use generic_array:: GenericArray ;
66use hmac:: { Hmac , Mac } ;
7- use rand:: { OsRng , Rng } ;
7+ use rand:: { self , Rng } ;
88use sha2:: { Digest , Sha256 } ;
99use std:: fmt:: Write ;
1010use std:: io;
@@ -17,6 +17,8 @@ const NONCE_LENGTH: usize = 24;
1717
1818/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
1919pub const SCRAM_SHA_256 : & ' static str = "SCRAM-SHA-256" ;
20+ /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
21+ pub const SCRAM_SHA_256_PLUS : & ' static str = "SCRAM-SHA-256-PLUS" ;
2022
2123// since postgres passwords are not required to exclude saslprep-prohibited
2224// characters or even be valid UTF8, we run saslprep if possible and otherwise
@@ -54,10 +56,61 @@ fn hi(str: &[u8], salt: &[u8], i: u32) -> GenericArray<u8, U32> {
5456 hi
5557}
5658
59+ enum ChannelBindingInner {
60+ Unrequested ,
61+ Unsupported ,
62+ TlsUnique ( Vec < u8 > ) ,
63+ TlsServerEndPoint ( Vec < u8 > ) ,
64+ }
65+
66+ /// The channel binding configuration for a SCRAM authentication exchange.
67+ pub struct ChannelBinding ( ChannelBindingInner ) ;
68+
69+ impl ChannelBinding {
70+ /// The server did not request channel binding.
71+ pub fn unrequested ( ) -> ChannelBinding {
72+ ChannelBinding ( ChannelBindingInner :: Unrequested )
73+ }
74+
75+ /// The server requested channel binding but the client is unable to provide it.
76+ pub fn unsupported ( ) -> ChannelBinding {
77+ ChannelBinding ( ChannelBindingInner :: Unsupported )
78+ }
79+
80+ /// The server requested channel binding and the client will use the `tls-unique` method.
81+ pub fn tls_unique ( finished : Vec < u8 > ) -> ChannelBinding {
82+ ChannelBinding ( ChannelBindingInner :: TlsUnique ( finished) )
83+ }
84+
85+ /// The server requested channel binding and the client will use the `tls-server-end-point`
86+ /// method.
87+ pub fn tls_server_end_point ( signature : Vec < u8 > ) -> ChannelBinding {
88+ ChannelBinding ( ChannelBindingInner :: TlsServerEndPoint ( signature) )
89+ }
90+
91+ fn gs2_header ( & self ) -> & ' static str {
92+ match self . 0 {
93+ ChannelBindingInner :: Unrequested => "y,," ,
94+ ChannelBindingInner :: Unsupported => "n,," ,
95+ ChannelBindingInner :: TlsUnique ( _) => "p=tls-unique,," ,
96+ ChannelBindingInner :: TlsServerEndPoint ( _) => "p=tls-server-end-point,," ,
97+ }
98+ }
99+
100+ fn cbind_data ( & self ) -> & [ u8 ] {
101+ match self . 0 {
102+ ChannelBindingInner :: Unrequested | ChannelBindingInner :: Unsupported => & [ ] ,
103+ ChannelBindingInner :: TlsUnique ( ref buf)
104+ | ChannelBindingInner :: TlsServerEndPoint ( ref buf) => buf,
105+ }
106+ }
107+ }
108+
57109enum State {
58110 Update {
59111 nonce : String ,
60112 password : Vec < u8 > ,
113+ channel_binding : ChannelBinding ,
61114 } ,
62115 Finish {
63116 salted_password : GenericArray < u8 , U32 > ,
@@ -66,7 +119,8 @@ enum State {
66119 Done ,
67120}
68121
69- /// A type which handles the client side of the SCRAM-SHA-256 authentication process.
122+ /// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
123+ /// process.
70124///
71125/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
72126/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
@@ -85,11 +139,11 @@ pub struct ScramSha256 {
85139 state : State ,
86140}
87141
88- #[ allow( missing_docs) ]
89142impl ScramSha256 {
90143 /// Constructs a new instance which will use the provided password for authentication.
91- pub fn new ( password : & [ u8 ] ) -> io:: Result < ScramSha256 > {
92- let mut rng = OsRng :: new ( ) ?;
144+ pub fn new ( password : & [ u8 ] , channel_binding : ChannelBinding ) -> io:: Result < ScramSha256 > {
145+ // rand 0.5's ThreadRng is cryptographically secure
146+ let mut rng = rand:: thread_rng ( ) ;
93147 let nonce = ( 0 ..NONCE_LENGTH )
94148 . map ( |_| {
95149 let mut v = rng. gen_range ( 0x21u8 , 0x7e ) ;
@@ -100,21 +154,20 @@ impl ScramSha256 {
100154 } )
101155 . collect :: < String > ( ) ;
102156
103- ScramSha256 :: new_inner ( password, nonce)
157+ ScramSha256 :: new_inner ( password, channel_binding , nonce)
104158 }
105159
106- fn new_inner ( password : & [ u8 ] , nonce : String ) -> io:: Result < ScramSha256 > {
107- // the docs say to use pg_same_as_startup_message as the username, but
108- // psql uses an empty string, so we'll go with that.
109- let message = format ! ( "n,,n=,r={}" , nonce) ;
110-
111- let password = normalize ( password) ;
112-
160+ fn new_inner (
161+ password : & [ u8 ] ,
162+ channel_binding : ChannelBinding ,
163+ nonce : String ,
164+ ) -> io:: Result < ScramSha256 > {
113165 Ok ( ScramSha256 {
114- message : message ,
166+ message : format ! ( "{}n=,r={}" , channel_binding . gs2_header ( ) , nonce ) ,
115167 state : State :: Update {
116- nonce : nonce,
117- password : password,
168+ nonce,
169+ password : normalize ( password) ,
170+ channel_binding : channel_binding,
118171 } ,
119172 } )
120173 }
@@ -131,10 +184,15 @@ impl ScramSha256 {
131184 ///
132185 /// This should be called when an `AuthenticationSASLContinue` message is received.
133186 pub fn update ( & mut self , message : & [ u8 ] ) -> io:: Result < ( ) > {
134- let ( client_nonce, password) = match mem:: replace ( & mut self . state , State :: Done ) {
135- State :: Update { nonce, password } => ( nonce, password) ,
136- _ => return Err ( io:: Error :: new ( io:: ErrorKind :: Other , "invalid SCRAM state" ) ) ,
137- } ;
187+ let ( client_nonce, password, channel_binding) =
188+ match mem:: replace ( & mut self . state , State :: Done ) {
189+ State :: Update {
190+ nonce,
191+ password,
192+ channel_binding,
193+ } => ( nonce, password, channel_binding) ,
194+ _ => return Err ( io:: Error :: new ( io:: ErrorKind :: Other , "invalid SCRAM state" ) ) ,
195+ } ;
138196
139197 let message =
140198 str:: from_utf8 ( message) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidInput , e) ) ?;
@@ -161,8 +219,13 @@ impl ScramSha256 {
161219 hash. input ( client_key. as_slice ( ) ) ;
162220 let stored_key = hash. result ( ) ;
163221
222+ let mut cbind_input = vec ! [ ] ;
223+ cbind_input. extend ( channel_binding. gs2_header ( ) . as_bytes ( ) ) ;
224+ cbind_input. extend ( channel_binding. cbind_data ( ) ) ;
225+ let cbind_input = base64:: encode ( & cbind_input) ;
226+
164227 self . message . clear ( ) ;
165- write ! ( & mut self . message, "c=biws ,r={}" , parsed. nonce) . unwrap ( ) ;
228+ write ! ( & mut self . message, "c={} ,r={}" , cbind_input , parsed. nonce) . unwrap ( ) ;
166229
167230 let auth_message = format ! ( "n=,r={},{},{}" , client_nonce, message, self . message) ;
168231
@@ -420,7 +483,11 @@ mod test {
420483 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
421484 let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=" ;
422485
423- let mut scram = ScramSha256 :: new_inner ( password. as_bytes ( ) , nonce. to_string ( ) ) . unwrap ( ) ;
486+ let mut scram = ScramSha256 :: new_inner (
487+ password. as_bytes ( ) ,
488+ ChannelBinding :: unsupported ( ) ,
489+ nonce. to_string ( ) ,
490+ ) . unwrap ( ) ;
424491 assert_eq ! ( str :: from_utf8( scram. message( ) ) . unwrap( ) , client_first) ;
425492
426493 scram. update ( server_first. as_bytes ( ) ) . unwrap ( ) ;
0 commit comments