1- use crate :: codec:: { BackendMessages , FrontendMessage , PostgresCodec } ;
1+ use crate :: codec:: { BackendMessage , BackendMessages , FrontendMessage , PostgresCodec } ;
2+ use crate :: error:: DbError ;
23use crate :: maybe_tls_stream:: MaybeTlsStream ;
4+ use crate :: { AsyncMessage , Error , Notification } ;
5+ use fallible_iterator:: FallibleIterator ;
36use futures:: channel:: mpsc;
4- use std:: collections:: HashMap ;
7+ use futures:: { ready, Sink , Stream , StreamExt } ;
8+ use log:: trace;
9+ use postgres_protocol:: message:: backend:: Message ;
10+ use postgres_protocol:: message:: frontend;
11+ use std:: collections:: { HashMap , VecDeque } ;
12+ use std:: future:: Future ;
13+ use std:: io;
14+ use std:: pin:: Pin ;
15+ use std:: task:: { Context , Poll } ;
516use tokio:: codec:: Framed ;
17+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
618
719pub enum RequestMessages {
820 Single ( FrontendMessage ) ,
@@ -13,13 +25,40 @@ pub struct Request {
1325 pub sender : mpsc:: Sender < BackendMessages > ,
1426}
1527
28+ pub struct Response {
29+ sender : mpsc:: Sender < BackendMessages > ,
30+ }
31+
32+ #[ derive( PartialEq , Debug ) ]
33+ enum State {
34+ Active ,
35+ Terminating ,
36+ Closing ,
37+ }
38+
39+ /// A connection to a PostgreSQL database.
40+ ///
41+ /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
42+ /// server, and should generally be spawned off onto an executor to run in the background.
43+ ///
44+ /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
45+ /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
46+ #[ must_use = "futures do nothing unless polled" ]
1647pub struct Connection < S , T > {
1748 stream : Framed < MaybeTlsStream < S , T > , PostgresCodec > ,
1849 parameters : HashMap < String , String > ,
1950 receiver : mpsc:: UnboundedReceiver < Request > ,
51+ pending_request : Option < RequestMessages > ,
52+ pending_response : Option < BackendMessage > ,
53+ responses : VecDeque < Response > ,
54+ state : State ,
2055}
2156
22- impl < S , T > Connection < S , T > {
57+ impl < S , T > Connection < S , T >
58+ where
59+ S : AsyncRead + AsyncWrite + Unpin ,
60+ T : AsyncRead + AsyncWrite + Unpin ,
61+ {
2362 pub ( crate ) fn new (
2463 stream : Framed < MaybeTlsStream < S , T > , PostgresCodec > ,
2564 parameters : HashMap < String , String > ,
@@ -29,6 +68,240 @@ impl<S, T> Connection<S, T> {
2968 stream,
3069 parameters,
3170 receiver,
71+ pending_request : None ,
72+ pending_response : None ,
73+ responses : VecDeque :: new ( ) ,
74+ state : State :: Active ,
75+ }
76+ }
77+
78+ /// Returns the value of a runtime parameter for this connection.
79+ pub fn parameter ( & self , name : & str ) -> Option < & str > {
80+ self . parameters . get ( name) . map ( |s| & * * s)
81+ }
82+
83+ fn poll_response (
84+ & mut self ,
85+ cx : & mut Context < ' _ > ,
86+ ) -> Poll < Option < Result < BackendMessage , Error > > > {
87+ if let Some ( message) = self . pending_response . take ( ) {
88+ trace ! ( "retrying pending response" ) ;
89+ return Poll :: Ready ( Some ( Ok ( message) ) ) ;
90+ }
91+
92+ Pin :: new ( & mut self . stream )
93+ . poll_next ( cx)
94+ . map ( |o| o. map ( |r| r. map_err ( Error :: io) ) )
95+ }
96+
97+ fn poll_read ( & mut self , cx : & mut Context < ' _ > ) -> Result < Option < AsyncMessage > , Error > {
98+ if self . state != State :: Active {
99+ trace ! ( "poll_read: done" ) ;
100+ return Ok ( None ) ;
101+ }
102+
103+ loop {
104+ let message = match self . poll_response ( cx) ? {
105+ Poll :: Ready ( Some ( message) ) => message,
106+ Poll :: Ready ( None ) => return Err ( Error :: closed ( ) ) ,
107+ Poll :: Pending => {
108+ trace ! ( "poll_read: waiting on response" ) ;
109+ return Ok ( None ) ;
110+ }
111+ } ;
112+
113+ let ( mut messages, request_complete) = match message {
114+ BackendMessage :: Async ( Message :: NoticeResponse ( body) ) => {
115+ let error = DbError :: parse ( & mut body. fields ( ) ) . map_err ( Error :: parse) ?;
116+ return Ok ( Some ( AsyncMessage :: Notice ( error) ) ) ;
117+ }
118+ BackendMessage :: Async ( Message :: NotificationResponse ( body) ) => {
119+ let notification = Notification {
120+ process_id : body. process_id ( ) ,
121+ channel : body. channel ( ) . map_err ( Error :: parse) ?. to_string ( ) ,
122+ payload : body. message ( ) . map_err ( Error :: parse) ?. to_string ( ) ,
123+ } ;
124+ return Ok ( Some ( AsyncMessage :: Notification ( notification) ) ) ;
125+ }
126+ BackendMessage :: Async ( Message :: ParameterStatus ( body) ) => {
127+ self . parameters . insert (
128+ body. name ( ) . map_err ( Error :: parse) ?. to_string ( ) ,
129+ body. value ( ) . map_err ( Error :: parse) ?. to_string ( ) ,
130+ ) ;
131+ continue ;
132+ }
133+ BackendMessage :: Async ( _) => unreachable ! ( ) ,
134+ BackendMessage :: Normal {
135+ messages,
136+ request_complete,
137+ } => ( messages, request_complete) ,
138+ } ;
139+
140+ let mut response = match self . responses . pop_front ( ) {
141+ Some ( response) => response,
142+ None => match messages. next ( ) . map_err ( Error :: parse) ? {
143+ Some ( Message :: ErrorResponse ( error) ) => return Err ( Error :: db ( error) ) ,
144+ _ => return Err ( Error :: unexpected_message ( ) ) ,
145+ } ,
146+ } ;
147+
148+ match response. sender . poll_ready ( cx) {
149+ Poll :: Ready ( Ok ( ( ) ) ) => {
150+ let _ = response. sender . start_send ( messages) ;
151+ if !request_complete {
152+ self . responses . push_front ( response) ;
153+ }
154+ }
155+ Poll :: Ready ( Err ( _) ) => {
156+ // we need to keep paging through the rest of the messages even if the receiver's hung up
157+ if !request_complete {
158+ self . responses . push_front ( response) ;
159+ }
160+ }
161+ Poll :: Pending => {
162+ self . responses . push_front ( response) ;
163+ self . pending_response = Some ( BackendMessage :: Normal {
164+ messages,
165+ request_complete,
166+ } ) ;
167+ trace ! ( "poll_read: waiting on sender" ) ;
168+ return Ok ( None ) ;
169+ }
170+ }
171+ }
172+ }
173+
174+ fn poll_request ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Option < RequestMessages > > {
175+ if let Some ( messages) = self . pending_request . take ( ) {
176+ trace ! ( "retrying pending request" ) ;
177+ return Poll :: Ready ( Some ( messages) ) ;
178+ }
179+
180+ match self . receiver . poll_next_unpin ( cx) {
181+ Poll :: Ready ( Some ( request) ) => {
182+ trace ! ( "polled new request" ) ;
183+ self . responses . push_back ( Response {
184+ sender : request. sender ,
185+ } ) ;
186+ Poll :: Ready ( Some ( request. messages ) )
187+ }
188+ Poll :: Ready ( None ) => Poll :: Ready ( None ) ,
189+ Poll :: Pending => Poll :: Pending ,
32190 }
33191 }
192+
193+ fn poll_write ( & mut self , cx : & mut Context < ' _ > ) -> Result < bool , Error > {
194+ loop {
195+ if self . state == State :: Closing {
196+ trace ! ( "poll_write: done" ) ;
197+ return Ok ( false ) ;
198+ }
199+
200+ let request = match self . poll_request ( cx) {
201+ Poll :: Ready ( Some ( request) ) => request,
202+ Poll :: Ready ( None ) if self . responses . is_empty ( ) && self . state == State :: Active => {
203+ trace ! ( "poll_write: at eof, terminating" ) ;
204+ self . state = State :: Terminating ;
205+ let mut request = vec ! [ ] ;
206+ frontend:: terminate ( & mut request) ;
207+ RequestMessages :: Single ( FrontendMessage :: Raw ( request) )
208+ }
209+ Poll :: Ready ( None ) => {
210+ trace ! (
211+ "poll_write: at eof, pending responses {}" ,
212+ self . responses. len( )
213+ ) ;
214+ return Ok ( true ) ;
215+ }
216+ Poll :: Pending => {
217+ trace ! ( "poll_write: waiting on request" ) ;
218+ return Ok ( true ) ;
219+ }
220+ } ;
221+
222+ if let Poll :: Pending = Pin :: new ( & mut self . stream )
223+ . poll_ready ( cx)
224+ . map_err ( Error :: io) ?
225+ {
226+ trace ! ( "poll_write: waiting on socket" ) ;
227+ self . pending_request = Some ( request) ;
228+ return Ok ( false ) ;
229+ }
230+
231+ match request {
232+ RequestMessages :: Single ( request) => {
233+ Pin :: new ( & mut self . stream )
234+ . start_send ( request)
235+ . map_err ( Error :: io) ?;
236+ if self . state == State :: Terminating {
237+ trace ! ( "poll_write: sent eof, closing" ) ;
238+ self . state = State :: Closing ;
239+ }
240+ }
241+ }
242+ }
243+ }
244+
245+ fn poll_flush ( & mut self , cx : & mut Context < ' _ > ) -> Result < ( ) , Error > {
246+ match Pin :: new ( & mut self . stream )
247+ . poll_flush ( cx)
248+ . map_err ( Error :: io) ?
249+ {
250+ Poll :: Ready ( ( ) ) => trace ! ( "poll_flush: flushed" ) ,
251+ Poll :: Pending => trace ! ( "poll_flush: waiting on socket" ) ,
252+ }
253+ Ok ( ( ) )
254+ }
255+
256+ fn poll_shutdown ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
257+ if self . state != State :: Closing {
258+ return Poll :: Pending ;
259+ }
260+
261+ match Pin :: new ( & mut self . stream )
262+ . poll_close ( cx)
263+ . map_err ( Error :: io) ?
264+ {
265+ Poll :: Ready ( ( ) ) => {
266+ trace ! ( "poll_shutdown: complete" ) ;
267+ Poll :: Ready ( Ok ( ( ) ) )
268+ }
269+ Poll :: Pending => {
270+ trace ! ( "poll_shutdown: waiting on socket" ) ;
271+ Poll :: Pending
272+ }
273+ }
274+ }
275+
276+ pub fn poll_message (
277+ mut self : Pin < & mut Self > ,
278+ cx : & mut Context < ' _ > ,
279+ ) -> Poll < Option < Result < AsyncMessage , Error > > > {
280+ let message = self . poll_read ( cx) ?;
281+ let want_flush = self . poll_write ( cx) ?;
282+ if want_flush {
283+ self . poll_flush ( cx) ?;
284+ }
285+ match message {
286+ Some ( message) => Poll :: Ready ( Some ( Ok ( message) ) ) ,
287+ None => match self . poll_shutdown ( cx) {
288+ Poll :: Ready ( Ok ( ( ) ) ) => Poll :: Ready ( None ) ,
289+ Poll :: Ready ( Err ( e) ) => Poll :: Ready ( Some ( Err ( e) ) ) ,
290+ Poll :: Pending => Poll :: Pending ,
291+ } ,
292+ }
293+ }
294+ }
295+
296+ impl < S , T > Future for Connection < S , T >
297+ where
298+ S : AsyncRead + AsyncWrite + Unpin ,
299+ T : AsyncRead + AsyncWrite + Unpin ,
300+ {
301+ type Output = Result < ( ) , Error > ;
302+
303+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
304+ while let Some ( _) = ready ! ( Pin :: as_mut( & mut self ) . poll_message( cx) ?) { }
305+ Poll :: Ready ( Ok ( ( ) ) )
306+ }
34307}
0 commit comments