@@ -73,7 +73,7 @@ use serialize::hex::ToHex;
7373use std:: cell:: { Cell , RefCell } ;
7474use std:: collections:: HashMap ;
7575use std:: from_str:: FromStr ;
76- use std:: io:: { BufferedStream , IoResult } ;
76+ use std:: io:: { BufferedStream , IoResult , MemWriter } ;
7777use std:: io:: net:: ip:: Port ;
7878use std:: mem;
7979use std:: fmt;
@@ -125,6 +125,8 @@ use message::{AuthenticationCleartextPassword,
125125use message:: { Bind ,
126126 CancelRequest ,
127127 Close ,
128+ CopyData ,
129+ CopyDone ,
128130 CopyFail ,
129131 Describe ,
130132 Execute ,
@@ -146,6 +148,7 @@ mod io;
146148pub mod pool;
147149mod message;
148150pub mod types;
151+ mod util;
149152
150153static CANARY : u32 = 0xdeadbeef ;
151154
@@ -520,8 +523,8 @@ impl InnerPostgresConnection {
520523 mem:: replace ( & mut self . notice_handler , handler)
521524 }
522525
523- fn prepare < ' a > ( & mut self , query : & str , conn : & ' a PostgresConnection )
524- -> PostgresResult < PostgresStatement < ' a > > {
526+ fn raw_prepare ( & mut self , query : & str )
527+ -> PostgresResult < ( String , Vec < PostgresType > , Vec < ResultDescription > ) > {
525528 let stmt_name = format ! ( "s{}" , self . next_stmt_id) ;
526529 self . next_stmt_id += 1 ;
527530
@@ -572,6 +575,12 @@ impl InnerPostgresConnection {
572575 try!( self . set_type_names ( param_types. iter_mut ( ) ) ) ;
573576 try!( self . set_type_names ( result_desc. iter_mut ( ) . map ( |d| & mut d. ty ) ) ) ;
574577
578+ Ok ( ( stmt_name, param_types, result_desc) )
579+ }
580+
581+ fn prepare < ' a > ( & mut self , query : & str , conn : & ' a PostgresConnection )
582+ -> PostgresResult < PostgresStatement < ' a > > {
583+ let ( stmt_name, param_types, result_desc) = try!( self . raw_prepare ( query) ) ;
575584 Ok ( PostgresStatement {
576585 conn : conn,
577586 name : stmt_name,
@@ -582,6 +591,54 @@ impl InnerPostgresConnection {
582591 } )
583592 }
584593
594+ fn prepare_copy_in < ' a > ( & mut self , table : & str , rows : & [ & str ] , conn : & ' a PostgresConnection )
595+ -> PostgresResult < PostgresCopyInStatement < ' a > > {
596+ let mut query = MemWriter :: new ( ) ;
597+ let _ = write ! ( query, "SELECT " ) ;
598+ let _ = util:: comma_join ( & mut query, rows. iter ( ) . map ( |& e| e) ) ;
599+ let _ = write ! ( query, " FROM {}" , table) ;
600+ let query = String :: from_utf8 ( query. unwrap ( ) ) . unwrap ( ) ;
601+ let ( stmt_name, _, result_desc) = try!( self . raw_prepare ( query. as_slice ( ) ) ) ;
602+
603+ let column_types = result_desc. iter ( ) . map ( |desc| desc. ty . clone ( ) ) . collect ( ) ;
604+ try!( self . close_statement ( stmt_name. as_slice ( ) ) ) ;
605+
606+ let mut query = MemWriter :: new ( ) ;
607+ let _ = write ! ( query, "COPY {} (" , table) ;
608+ let _ = util:: comma_join ( & mut query, rows. iter ( ) . map ( |& e| e) ) ;
609+ let _ = write ! ( query, ") FROM STDIN WITH (FORMAT binary)" ) ;
610+ let query = String :: from_utf8 ( query. unwrap ( ) ) . unwrap ( ) ;
611+ let ( stmt_name, _, _) = try!( self . raw_prepare ( query. as_slice ( ) ) ) ;
612+
613+ Ok ( PostgresCopyInStatement {
614+ conn : conn,
615+ name : stmt_name,
616+ column_types : column_types,
617+ next_portal_id : Cell :: new ( 0 ) ,
618+ finished : false ,
619+ } )
620+ }
621+
622+ fn close_statement ( & mut self , stmt_name : & str ) -> PostgresResult < ( ) > {
623+ try_pg ! ( self . write_messages( [
624+ Close {
625+ variant: b'S' ,
626+ name: stmt_name,
627+ } ,
628+ Sync ] ) ) ;
629+ loop {
630+ match try_pg ! ( self . read_message_( ) ) {
631+ ReadyForQuery { .. } => break ,
632+ ErrorResponse { fields } => {
633+ try!( self . wait_for_ready ( ) ) ;
634+ return Err ( PgDbError ( PostgresDbError :: new ( fields) ) ) ;
635+ }
636+ _ => { }
637+ }
638+ }
639+ Ok ( ( ) )
640+ }
641+
585642 fn set_type_names < ' a , I > ( & mut self , mut it : I ) -> PostgresResult < ( ) >
586643 where I : Iterator < & ' a mut PostgresType > {
587644 for ty in it {
@@ -759,6 +816,15 @@ impl PostgresConnection {
759816 conn. prepare ( query, self )
760817 }
761818
819+ pub fn prepare_copy_in < ' a > ( & ' a self , table : & str , rows : & [ & str ] )
820+ -> PostgresResult < PostgresCopyInStatement < ' a > > {
821+ let mut conn = self . conn . borrow_mut ( ) ;
822+ if conn. trans_depth != 0 {
823+ return Err ( PgWrongTransaction ) ;
824+ }
825+ conn. prepare_copy_in ( table, rows, self )
826+ }
827+
762828 /// Begins a new transaction.
763829 ///
764830 /// Returns a `PostgresTransaction` object which should be used instead of
@@ -1057,24 +1123,9 @@ impl<'conn> Drop for PostgresStatement<'conn> {
10571123
10581124impl < ' conn > PostgresStatement < ' conn > {
10591125 fn finish_inner ( & mut self ) -> PostgresResult < ( ) > {
1060- check_desync ! ( self . conn) ;
1061- try_pg ! ( self . conn. write_messages( [
1062- Close {
1063- variant: b'S' ,
1064- name: self . name. as_slice( )
1065- } ,
1066- Sync ] ) ) ;
1067- loop {
1068- match try_pg ! ( self . conn. read_message_( ) ) {
1069- ReadyForQuery { .. } => break ,
1070- ErrorResponse { fields } => {
1071- try!( self . conn . wait_for_ready ( ) ) ;
1072- return Err ( PgDbError ( PostgresDbError :: new ( fields) ) ) ;
1073- }
1074- _ => { }
1075- }
1076- }
1077- Ok ( ( ) )
1126+ let mut conn = self . conn . conn . borrow_mut ( ) ;
1127+ check_desync ! ( conn) ;
1128+ conn. close_statement ( self . name . as_slice ( ) )
10781129 }
10791130
10801131 fn inner_execute ( & self , portal_name : & str , row_limit : i32 , params : & [ & ToSql ] )
@@ -1495,3 +1546,133 @@ impl<'trans, 'stmt> Iterator<PostgresResult<PostgresRow<'stmt>>>
14951546 self . result . size_hint ( )
14961547 }
14971548}
1549+
1550+ pub struct PostgresCopyInStatement < ' a > {
1551+ conn : & ' a PostgresConnection ,
1552+ name : String ,
1553+ column_types : Vec < PostgresType > ,
1554+ next_portal_id : Cell < uint > ,
1555+ finished : bool ,
1556+ }
1557+
1558+ #[ unsafe_destructor]
1559+ impl < ' a > Drop for PostgresCopyInStatement < ' a > {
1560+ fn drop ( & mut self ) {
1561+ if !self . finished {
1562+ let _ = self . finish_inner ( ) ;
1563+ }
1564+ }
1565+ }
1566+
1567+ impl < ' a > PostgresCopyInStatement < ' a > {
1568+ fn finish_inner ( & mut self ) -> PostgresResult < ( ) > {
1569+ let mut conn = self . conn . conn . borrow_mut ( ) ;
1570+ check_desync ! ( conn) ;
1571+ conn. close_statement ( self . name . as_slice ( ) )
1572+ }
1573+
1574+ pub fn execute < ' b , I , J > ( & self , mut rows : I ) -> PostgresResult < ( ) >
1575+ where I : Iterator < J > , J : Iterator < & ' b ToSql + ' b > {
1576+ let mut conn = self . conn . conn . borrow_mut ( ) ;
1577+
1578+ try_pg ! ( conn. write_messages( [
1579+ Bind {
1580+ portal: "" ,
1581+ statement: self . name. as_slice( ) ,
1582+ formats: [ ] ,
1583+ values: [ ] ,
1584+ result_formats: [ ]
1585+ } ,
1586+ Execute {
1587+ portal: "" ,
1588+ max_rows: 0 ,
1589+ } ,
1590+ Sync ] ) ) ;
1591+
1592+ match try_pg ! ( conn. read_message_( ) ) {
1593+ BindComplete => { } ,
1594+ ErrorResponse { fields } => {
1595+ try!( conn. wait_for_ready ( ) ) ;
1596+ return Err ( PgDbError ( PostgresDbError :: new ( fields) ) ) ;
1597+ }
1598+ _ => {
1599+ conn. desynchronized = true ;
1600+ return Err ( PgBadResponse ) ;
1601+ }
1602+ }
1603+
1604+ match try_pg ! ( conn. read_message_( ) ) {
1605+ CopyInResponse { .. } => { }
1606+ _ => {
1607+ conn. desynchronized = true ;
1608+ return Err ( PgBadResponse ) ;
1609+ }
1610+ }
1611+
1612+ let mut buf = MemWriter :: new ( ) ;
1613+ let _ = buf. write ( b"PGCOPY\n \xff \r \n \x00 " ) ;
1614+ let _ = buf. write_be_i32 ( 0 ) ;
1615+ let _ = buf. write_be_i32 ( 0 ) ;
1616+
1617+ for mut row in rows {
1618+ let _ = buf. write_be_i16 ( self . column_types . len ( ) as i16 ) ;
1619+
1620+ let mut count = 0 ;
1621+ for ( i, ( val, ty) ) in row. by_ref ( ) . zip ( self . column_types . iter ( ) ) . enumerate ( ) {
1622+ match try!( val. to_sql ( ty) ) {
1623+ ( _, None ) => {
1624+ let _ = buf. write_be_i32 ( -1 ) ;
1625+ }
1626+ ( _, Some ( val) ) => {
1627+ let _ = buf. write_be_i32 ( val. len ( ) as i32 ) ;
1628+ let _ = buf. write ( val. as_slice ( ) ) ;
1629+ }
1630+ }
1631+ count = i+1 ;
1632+ }
1633+
1634+ if row. next ( ) . is_some ( ) || count != self . column_types . len ( ) {
1635+ // FIXME
1636+ fail ! ( )
1637+ }
1638+
1639+ try_pg ! ( conn. write_messages( [
1640+ CopyData {
1641+ data: buf. unwrap( ) . as_slice( )
1642+ } ] ) ) ;
1643+ buf = MemWriter :: new ( ) ;
1644+ }
1645+
1646+ let _ = buf. write_be_i16 ( -1 ) ;
1647+ try_pg ! ( conn. write_messages( [
1648+ CopyData {
1649+ data: buf. unwrap( ) . as_slice( ) ,
1650+ } ,
1651+ CopyDone ,
1652+ Sync ] ) ) ;
1653+
1654+ match try_pg ! ( conn. read_message_( ) ) {
1655+ CommandComplete { .. } => { } ,
1656+ ErrorResponse { fields } => {
1657+ try!( conn. wait_for_ready ( ) ) ;
1658+ return Err ( PgDbError ( PostgresDbError :: new ( fields) ) ) ;
1659+ }
1660+ _ => {
1661+ conn. desynchronized = true ;
1662+ return Err ( PgBadResponse ) ;
1663+ }
1664+ }
1665+
1666+ conn. wait_for_ready ( )
1667+ }
1668+
1669+ /// Consumes the statement, clearing it from the Postgres session.
1670+ ///
1671+ /// Functionally identical to the `Drop` implementation of the
1672+ /// `PostgresCopyInStatement` except that it returns any error to the
1673+ /// caller.
1674+ pub fn finish ( mut self ) -> PostgresResult < ( ) > {
1675+ self . finished = true ;
1676+ self . finish_inner ( )
1677+ }
1678+ }
0 commit comments