1- use futures:: { Future , Stream } ;
1+ use futures:: sync:: mpsc;
2+ use futures:: { try_ready, Async , AsyncSink , Future , Poll , Sink , Stream } ;
3+ use std:: io:: { self , Read } ;
24use tokio_postgres:: types:: { ToSql , Type } ;
35use tokio_postgres:: { Error , Row } ;
46#[ cfg( feature = "runtime" ) ]
@@ -52,6 +54,29 @@ impl Client {
5254 self . 0 . query ( & statement. 0 , params) . collect ( ) . wait ( )
5355 }
5456
57+ pub fn copy_in < T , R > (
58+ & mut self ,
59+ query : & T ,
60+ params : & [ & dyn ToSql ] ,
61+ reader : R ,
62+ ) -> Result < u64 , Error >
63+ where
64+ T : ?Sized + Query ,
65+ R : Read ,
66+ {
67+ let statement = query. __statement ( self ) ?;
68+ let ( sender, receiver) = mpsc:: channel ( 1 ) ;
69+ let future = self . 0 . copy_in ( & statement. 0 , params, CopyInStream ( receiver) ) ;
70+
71+ CopyInFuture {
72+ future,
73+ sender,
74+ reader,
75+ pending : None ,
76+ }
77+ . wait ( )
78+ }
79+
5580 pub fn batch_execute ( & mut self , query : & str ) -> Result < ( ) , Error > {
5681 self . 0 . batch_execute ( query) . wait ( )
5782 }
@@ -71,3 +96,80 @@ impl From<tokio_postgres::Client> for Client {
7196 Client ( c)
7297 }
7398}
99+
100+ enum CopyData {
101+ Data ( Vec < u8 > ) ,
102+ Error ( io:: Error ) ,
103+ Done ,
104+ }
105+
106+ struct CopyInStream ( mpsc:: Receiver < CopyData > ) ;
107+
108+ impl Stream for CopyInStream {
109+ type Item = Vec < u8 > ;
110+ type Error = io:: Error ;
111+
112+ fn poll ( & mut self ) -> Poll < Option < Vec < u8 > > , io:: Error > {
113+ match self . 0 . poll ( ) . expect ( "mpsc::Receiver can't error" ) {
114+ Async :: Ready ( Some ( CopyData :: Data ( buf) ) ) => Ok ( Async :: Ready ( Some ( buf) ) ) ,
115+ Async :: Ready ( Some ( CopyData :: Error ( e) ) ) => Err ( e) ,
116+ Async :: Ready ( Some ( CopyData :: Done ) ) => Ok ( Async :: Ready ( None ) ) ,
117+ Async :: Ready ( None ) => Err ( io:: Error :: new ( io:: ErrorKind :: Other , "writer disconnected" ) ) ,
118+ Async :: NotReady => Ok ( Async :: NotReady ) ,
119+ }
120+ }
121+ }
122+
123+ struct CopyInFuture < R > {
124+ future : tokio_postgres:: CopyIn < CopyInStream > ,
125+ sender : mpsc:: Sender < CopyData > ,
126+ reader : R ,
127+ pending : Option < CopyData > ,
128+ }
129+
130+ impl < R > CopyInFuture < R > {
131+ fn poll_send_data ( & mut self , data : CopyData ) -> Poll < ( ) , Error > {
132+ match self . sender . start_send ( data) {
133+ Ok ( AsyncSink :: Ready ) => Ok ( Async :: Ready ( ( ) ) ) ,
134+ Ok ( AsyncSink :: NotReady ( pending) ) => {
135+ self . pending = Some ( pending) ;
136+ return Ok ( Async :: NotReady ) ;
137+ }
138+ // the future's hung up on its end of the channel, so we'll wait for it to report an error
139+ Err ( _) => {
140+ self . pending = Some ( CopyData :: Done ) ;
141+ return Ok ( Async :: NotReady ) ;
142+ }
143+ }
144+ }
145+ }
146+
147+ impl < R > Future for CopyInFuture < R >
148+ where
149+ R : Read ,
150+ {
151+ type Item = u64 ;
152+ type Error = Error ;
153+
154+ fn poll ( & mut self ) -> Poll < u64 , Error > {
155+ if let Async :: Ready ( n) = self . future . poll ( ) ? {
156+ return Ok ( Async :: Ready ( n) ) ;
157+ }
158+
159+ loop {
160+ let data = match self . pending . take ( ) {
161+ Some ( pending) => pending,
162+ None => {
163+ let mut buf = vec ! [ ] ;
164+ match self . reader . by_ref ( ) . take ( 4096 ) . read_to_end ( & mut buf) {
165+ Ok ( 0 ) => CopyData :: Done ,
166+ Ok ( _) => CopyData :: Data ( buf) ,
167+ Err ( e) => CopyData :: Error ( e) ,
168+ }
169+ }
170+ } ;
171+
172+ try_ready ! ( self . poll_send_data( data) ) ;
173+ }
174+ }
175+ }
0 commit comments