33use bytes:: { Bytes , BytesMut } ;
44use futures:: channel:: mpsc;
55use futures:: {
6- future, join, pin_mut, stream, try_join, FutureExt , SinkExt , StreamExt , TryStreamExt ,
6+ future, join, pin_mut, stream, try_join, Future , FutureExt , SinkExt , StreamExt , TryStreamExt ,
77} ;
8+ use pin_project_lite:: pin_project;
89use std:: fmt:: Write ;
10+ use std:: pin:: Pin ;
11+ use std:: task:: { Context , Poll } ;
912use std:: time:: Duration ;
1013use tokio:: net:: TcpStream ;
1114use tokio:: time;
@@ -22,6 +25,35 @@ mod parse;
2225mod runtime;
2326mod types;
2427
28+ pin_project ! {
29+ /// Polls `F` at most `polls_left` times returning `Some(F::Output)` if
30+ /// [`Future`] returned [`Poll::Ready`] or [`None`] otherwise.
31+ struct Cancellable <F > {
32+ #[ pin]
33+ fut: F ,
34+ polls_left: usize ,
35+ }
36+ }
37+
38+ impl < F : Future > Future for Cancellable < F > {
39+ type Output = Option < F :: Output > ;
40+
41+ fn poll ( self : Pin < & mut Self > , ctx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
42+ let this = self . project ( ) ;
43+ match this. fut . poll ( ctx) {
44+ Poll :: Ready ( r) => Poll :: Ready ( Some ( r) ) ,
45+ Poll :: Pending => {
46+ * this. polls_left = this. polls_left . saturating_sub ( 1 ) ;
47+ if * this. polls_left == 0 {
48+ Poll :: Ready ( None )
49+ } else {
50+ Poll :: Pending
51+ }
52+ }
53+ }
54+ }
55+ }
56+
2557async fn connect_raw ( s : & str ) -> Result < ( Client , Connection < TcpStream , NoTlsStream > ) , Error > {
2658 let socket = TcpStream :: connect ( "127.0.0.1:5433" ) . await . unwrap ( ) ;
2759 let config = s. parse :: < Config > ( ) . unwrap ( ) ;
@@ -35,6 +67,20 @@ async fn connect(s: &str) -> Client {
3567 client
3668}
3769
70+ async fn current_transaction_id ( client : & Client ) -> i64 {
71+ client
72+ . query ( "SELECT txid_current()" , & [ ] )
73+ . await
74+ . unwrap ( )
75+ . pop ( )
76+ . unwrap ( )
77+ . get :: < _ , i64 > ( "txid_current" )
78+ }
79+
80+ async fn in_transaction ( client : & Client ) -> bool {
81+ current_transaction_id ( client) . await == current_transaction_id ( client) . await
82+ }
83+
3884#[ tokio:: test]
3985async fn plain_password_missing ( ) {
4086 connect_raw ( "user=pass_user dbname=postgres" )
@@ -377,6 +423,80 @@ async fn transaction_rollback() {
377423 assert_eq ! ( rows. len( ) , 0 ) ;
378424}
379425
426+ #[ tokio:: test]
427+ async fn transaction_future_cancellation ( ) {
428+ let mut client = connect ( "user=postgres" ) . await ;
429+
430+ for i in 0 .. {
431+ let done = {
432+ let txn = client. transaction ( ) ;
433+ let fut = Cancellable {
434+ fut : txn,
435+ polls_left : i,
436+ } ;
437+ fut. await
438+ . map ( |res| res. expect ( "transaction failed" ) )
439+ . is_some ( )
440+ } ;
441+
442+ assert ! ( !in_transaction( & client) . await ) ;
443+
444+ if done {
445+ break ;
446+ }
447+ }
448+ }
449+
450+ #[ tokio:: test]
451+ async fn transaction_commit_future_cancellation ( ) {
452+ let mut client = connect ( "user=postgres" ) . await ;
453+
454+ for i in 0 .. {
455+ let done = {
456+ let txn = client. transaction ( ) . await . unwrap ( ) ;
457+ let commit = txn. commit ( ) ;
458+ let fut = Cancellable {
459+ fut : commit,
460+ polls_left : i,
461+ } ;
462+ fut. await
463+ . map ( |res| res. expect ( "transaction failed" ) )
464+ . is_some ( )
465+ } ;
466+
467+ assert ! ( !in_transaction( & client) . await ) ;
468+
469+ if done {
470+ break ;
471+ }
472+ }
473+ }
474+
475+ #[ tokio:: test]
476+ async fn transaction_rollback_future_cancellation ( ) {
477+ let mut client = connect ( "user=postgres" ) . await ;
478+
479+ for i in 0 .. {
480+ let done = {
481+ let txn = client. transaction ( ) . await . unwrap ( ) ;
482+ let rollback = txn. rollback ( ) ;
483+ let fut = Cancellable {
484+ fut : rollback,
485+ polls_left : i,
486+ } ;
487+ fut. await
488+ . map ( |res| res. expect ( "transaction failed" ) )
489+ . is_some ( )
490+ } ;
491+
492+ assert ! ( !in_transaction( & client) . await ) ;
493+
494+ if done {
495+ break ;
496+ }
497+ }
498+ }
499+
380500#[ tokio:: test]
381501async fn transaction_rollback_drop ( ) {
382502 let mut client = connect ( "user=postgres" ) . await ;
0 commit comments