11use crate :: codec:: { BackendMessage , BackendMessages , FrontendMessage , PostgresCodec } ;
2- use crate :: config:: Config ;
2+ use crate :: config:: { self , Config } ;
33use crate :: connect_tls:: connect_tls;
44use crate :: maybe_tls_stream:: MaybeTlsStream ;
55use crate :: tls:: { ChannelBinding , TlsConnect } ;
@@ -141,8 +141,13 @@ where
141141 T : AsyncRead + AsyncWrite + Unpin ,
142142{
143143 match stream. try_next ( ) . await . map_err ( Error :: io) ? {
144- Some ( Message :: AuthenticationOk ) => return Ok ( ( ) ) ,
144+ Some ( Message :: AuthenticationOk ) => {
145+ can_skip_channel_binding ( config) ?;
146+ return Ok ( ( ) ) ;
147+ }
145148 Some ( Message :: AuthenticationCleartextPassword ) => {
149+ can_skip_channel_binding ( config) ?;
150+
146151 let pass = config
147152 . password
148153 . as_ref ( )
@@ -151,6 +156,8 @@ where
151156 authenticate_password ( stream, pass) . await ?;
152157 }
153158 Some ( Message :: AuthenticationMd5Password ( body) ) => {
159+ can_skip_channel_binding ( config) ?;
160+
154161 let user = config
155162 . user
156163 . as_ref ( )
@@ -164,12 +171,7 @@ where
164171 authenticate_password ( stream, output. as_bytes ( ) ) . await ?;
165172 }
166173 Some ( Message :: AuthenticationSasl ( body) ) => {
167- let pass = config
168- . password
169- . as_ref ( )
170- . ok_or_else ( || Error :: config ( "password missing" . into ( ) ) ) ?;
171-
172- authenticate_sasl ( stream, body, channel_binding, pass) . await ?;
174+ authenticate_sasl ( stream, body, channel_binding, config) . await ?;
173175 }
174176 Some ( Message :: AuthenticationKerberosV5 )
175177 | Some ( Message :: AuthenticationScmCredential )
@@ -192,6 +194,16 @@ where
192194 }
193195}
194196
197+ fn can_skip_channel_binding ( config : & Config ) -> Result < ( ) , Error > {
198+ match config. channel_binding {
199+ config:: ChannelBinding :: Disable | config:: ChannelBinding :: Prefer => Ok ( ( ) ) ,
200+ config:: ChannelBinding :: Require => Err ( Error :: authentication (
201+ "server did not use channel binding" . into ( ) ,
202+ ) ) ,
203+ config:: ChannelBinding :: __NonExhaustive => unreachable ! ( ) ,
204+ }
205+ }
206+
195207async fn authenticate_password < S , T > (
196208 stream : & mut StartupStream < S , T > ,
197209 password : & [ u8 ] ,
@@ -213,12 +225,17 @@ async fn authenticate_sasl<S, T>(
213225 stream : & mut StartupStream < S , T > ,
214226 body : AuthenticationSaslBody ,
215227 channel_binding : ChannelBinding ,
216- password : & [ u8 ] ,
228+ config : & Config ,
217229) -> Result < ( ) , Error >
218230where
219231 S : AsyncRead + AsyncWrite + Unpin ,
220232 T : AsyncRead + AsyncWrite + Unpin ,
221233{
234+ let password = config
235+ . password
236+ . as_ref ( )
237+ . ok_or_else ( || Error :: config ( "password missing" . into ( ) ) ) ?;
238+
222239 let mut has_scram = false ;
223240 let mut has_scram_plus = false ;
224241 let mut mechanisms = body. mechanisms ( ) ;
@@ -232,12 +249,15 @@ where
232249
233250 let channel_binding = channel_binding
234251 . tls_server_end_point
252+ . filter ( |_| config. channel_binding != config:: ChannelBinding :: Disable )
235253 . map ( sasl:: ChannelBinding :: tls_server_end_point) ;
236254
237255 let ( channel_binding, mechanism) = if has_scram_plus {
238256 match channel_binding {
239257 Some ( channel_binding) => ( channel_binding, sasl:: SCRAM_SHA_256_PLUS ) ,
240- None => ( sasl:: ChannelBinding :: unsupported ( ) , sasl:: SCRAM_SHA_256 ) ,
258+ None => {
259+ ( sasl:: ChannelBinding :: unsupported ( ) , sasl:: SCRAM_SHA_256 )
260+ } ,
241261 }
242262 } else if has_scram {
243263 match channel_binding {
@@ -248,6 +268,10 @@ where
248268 return Err ( Error :: authentication ( "unsupported SASL mechanism" . into ( ) ) ) ;
249269 } ;
250270
271+ if mechanism != sasl:: SCRAM_SHA_256_PLUS {
272+ can_skip_channel_binding ( config) ?;
273+ }
274+
251275 let mut scram = ScramSha256 :: new ( password, channel_binding) ;
252276
253277 let mut buf = vec ! [ ] ;
0 commit comments