#[desc="A native PostgreSQL driver"]; #[license="MIT"]; extern mod extra; use extra::container::Deque; use extra::digest::Digest; use extra::ringbuf::RingBuf; use extra::md5::Md5; use extra::url::{UserInfo, Url}; use std::cell::Cell; use std::hashmap::HashMap; use std::rt::io::{Writer, io_error, Decorator}; use std::rt::io::buffered::BufferedStream; use std::rt::io::net; use std::rt::io::net::ip::{Port, SocketAddr}; use std::rt::io::net::tcp::TcpStream; use std::task; use std::util; use error::hack::PostgresSqlState; use message::{BackendMessage, AuthenticationOk, AuthenticationKerberosV5, AuthenticationCleartextPassword, AuthenticationMD5Password, AuthenticationSCMCredential, AuthenticationGSS, AuthenticationSSPI, BackendKeyData, BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData, NoticeResponse, ParameterDescription, ParameterStatus, ParseComplete, PortalSuspended, ReadyForQuery, RowDescription}; use message::{FrontendMessage, Bind, Close, Describe, Execute, Parse, PasswordMessage, Query, StartupMessage, Sync, Terminate}; use message::{RowDescriptionEntry, WriteMessage, ReadMessage}; use types::{PostgresType, ToSql, FromSql}; pub mod error; pub mod pool; mod message; pub mod types; pub trait PostgresNoticeHandler { fn handle(&mut self, notice: PostgresDbError); } pub struct DefaultNoticeHandler; impl PostgresNoticeHandler for DefaultNoticeHandler { fn handle(&mut self, notice: PostgresDbError) { info2!("{}: {}", notice.severity, notice.message); } } #[deriving(ToStr)] pub enum PostgresConnectError { InvalidUrl, MissingUser, DnsError, SocketError, DbError(PostgresDbError), MissingPassword, UnsupportedAuthentication } #[deriving(ToStr)] pub enum PostgresErrorPosition { Position(uint), InternalPosition { position: uint, query: ~str } } #[deriving(ToStr)] pub struct PostgresDbError { // This could almost be an enum, except the values can be localized :( severity: ~str, code: PostgresSqlState, message: ~str, detail: Option<~str>, hint: Option<~str>, position: Option, where: Option<~str>, file: ~str, line: uint, routine: ~str } impl PostgresDbError { fn new(fields: ~[(u8, ~str)]) -> PostgresDbError { // move_rev_iter is more efficient than move_iter let mut map: HashMap = fields.move_rev_iter().collect(); PostgresDbError { severity: map.pop(&('S' as u8)).unwrap(), code: FromStr::from_str(map.pop(&('C' as u8)).unwrap()).unwrap(), message: map.pop(&('M' as u8)).unwrap(), detail: map.pop(&('D' as u8)), hint: map.pop(&('H' as u8)), position: match map.pop(&('P' as u8)) { Some(pos) => Some(Position(FromStr::from_str(pos).unwrap())), None => match map.pop(&('p' as u8)) { Some(pos) => Some(InternalPosition { position: FromStr::from_str(pos).unwrap(), query: map.pop(&('q' as u8)).unwrap() }), None => None } }, where: map.pop(&('W' as u8)), file: map.pop(&('F' as u8)).unwrap(), line: FromStr::from_str(map.pop(&('L' as u8)).unwrap()).unwrap(), routine: map.pop(&('R' as u8)).unwrap() } } fn pretty_error(&self, query: &str) -> ~str { match self.position { Some(Position(pos)) => format!("{}: {} at position {} in\n{}", self.severity, self.message, pos, query), Some(InternalPosition { position, query: ref inner_query }) => format!("{}: {} at position {} in\n{} called from\n{}", self.severity, self.message, position, *inner_query, query), None => format!("{}: {} in\n{}", self.severity, self.message, query) } } } struct InnerPostgresConnection { stream: BufferedStream, next_stmt_id: int, notice_handler: ~PostgresNoticeHandler } impl Drop for InnerPostgresConnection { fn drop(&mut self) { do io_error::cond.trap(|_| {}).inside { self.write_messages([&Terminate]); } } } impl InnerPostgresConnection { fn try_connect(url: &str) -> Result { let Url { host, port, user, path, query: args, _ }: Url = match FromStr::from_str(url) { Some(url) => url, None => return Err(InvalidUrl) }; let user = match user { Some(user) => user, None => return Err(MissingUser) }; let mut args = args; let port = match port { Some(port) => FromStr::from_str(port).unwrap(), None => 5432 }; let stream = match InnerPostgresConnection::open_socket(host, port) { Ok(stream) => stream, Err(err) => return Err(err) }; let mut conn = InnerPostgresConnection { stream: BufferedStream::new(stream), next_stmt_id: 0, notice_handler: ~DefaultNoticeHandler as ~PostgresNoticeHandler }; args.push((~"client_encoding", ~"UTF8")); // Postgres uses the value of TimeZone as the time zone for TIMESTAMP // WITH TIME ZONE values. Timespec converts to GMT internally. args.push((~"TimeZone", ~"GMT")); // We have to clone here since we need the user again for auth args.push((~"user", user.user.clone())); if !path.is_empty() { args.push((~"database", path)); } conn.write_messages([&StartupMessage { version: message::PROTOCOL_VERSION, parameters: args.as_slice() }]); match conn.handle_auth(user) { Some(err) => return Err(err), None => () } loop { match conn.read_message() { BackendKeyData {_} => (), ReadyForQuery {_} => break, _ => fail!() } } Ok(conn) } fn open_socket(host: &str, port: Port) -> Result { let addrs = do io_error::cond.trap(|_| {}).inside { net::get_host_addresses(host) }; let addrs = match addrs { Some(addrs) => addrs, None => return Err(DnsError) }; for addr in addrs.iter() { let socket = do io_error::cond.trap(|_| {}).inside { TcpStream::connect(SocketAddr { ip: *addr, port: port }) }; match socket { Some(socket) => return Ok(socket), None => () } } Err(SocketError) } fn write_messages(&mut self, messages: &[&FrontendMessage]) { for &message in messages.iter() { self.stream.write_message(message); } self.stream.flush(); } fn read_message(&mut self) -> BackendMessage { loop { match self.stream.read_message() { NoticeResponse { fields } => self.notice_handler.handle(PostgresDbError::new(fields)), ParameterStatus { parameter, value } => debug!("Parameter %s = %s", parameter, value), msg => return msg } } } fn handle_auth(&mut self, user: UserInfo) -> Option { match self.read_message() { AuthenticationOk => return None, AuthenticationCleartextPassword => { let pass = match user.pass { Some(pass) => pass, None => return Some(MissingPassword) }; self.write_messages([&PasswordMessage { password: pass }]); } AuthenticationMD5Password { salt } => { let UserInfo { user, pass } = user; let pass = match pass { Some(pass) => pass, None => return Some(MissingPassword) }; let input = pass + user; let mut md5 = Md5::new(); md5.input_str(input); let output = md5.result_str(); md5.reset(); md5.input_str(output); md5.input(salt); let output = "md5" + md5.result_str(); self.write_messages([&PasswordMessage { password: output.as_slice() }]); } AuthenticationKerberosV5 | AuthenticationSCMCredential | AuthenticationGSS | AuthenticationSSPI => return Some(UnsupportedAuthentication), _ => fail!() } match self.read_message() { AuthenticationOk => None, ErrorResponse { fields } => Some(DbError(PostgresDbError::new(fields))), _ => fail!() } } fn set_notice_handler(&mut self, handler: ~PostgresNoticeHandler) -> ~PostgresNoticeHandler { util::replace(&mut self.notice_handler, handler) } pub fn try_prepare<'a>(&mut self, query: &str, conn: &'a PostgresConnection) -> Result, PostgresDbError> { let stmt_name = format!("statement_{}", self.next_stmt_id); self.next_stmt_id += 1; let types = []; self.write_messages([ &Parse { name: stmt_name, query: query, param_types: types }, &Describe { variant: 'S' as u8, name: stmt_name }, &Sync]); match self.read_message() { ParseComplete => (), ErrorResponse { fields } => { self.wait_for_ready(); return Err(PostgresDbError::new(fields)); } _ => fail!() } let param_types = match self.read_message() { ParameterDescription { types } => types.iter().map(|ty| { PostgresType::from_oid(*ty) }) .collect(), _ => fail!() }; let result_desc = match self.read_message() { RowDescription { descriptions } => { let mut res: ~[ResultDescription] = descriptions .move_rev_iter().map(|desc| { ResultDescription::from_row_description_entry(desc) }).collect(); res.reverse(); res }, NoData => ~[], _ => fail!() }; self.wait_for_ready(); Ok(NormalPostgresStatement { conn: conn, name: stmt_name, param_types: param_types, result_desc: result_desc, next_portal_id: Cell::new(0) }) } fn wait_for_ready(&mut self) { match self.read_message() { ReadyForQuery {_} => (), _ => fail!() } } } // FIXME should be a newtype pub struct PostgresConnection { priv conn: Cell } impl PostgresConnection { pub fn connect(url: &str) -> PostgresConnection { match PostgresConnection::try_connect(url) { Ok(conn) => conn, Err(err) => fail2!("Failed to connect: {}", err.to_str()) } } pub fn try_connect(url: &str) -> Result { do InnerPostgresConnection::try_connect(url).map_move |conn| { PostgresConnection { conn: Cell::new(conn) } } } pub fn set_notice_handler(&self, handler: ~PostgresNoticeHandler) -> ~PostgresNoticeHandler { let mut conn = self.conn.take(); let handler = conn.set_notice_handler(handler); self.conn.put_back(conn); handler } pub fn prepare<'a>(&'a self, query: &str) -> NormalPostgresStatement<'a> { match self.try_prepare(query) { Ok(stmt) => stmt, Err(err) => fail2!("Error preparing statement:\n{}", err.pretty_error(query)) } } pub fn try_prepare<'a>(&'a self, query: &str) -> Result, PostgresDbError> { do self.conn.with_mut_ref |conn| { conn.try_prepare(query, self) } } pub fn in_transaction(&self, blk: &fn(&PostgresTransaction) -> T) -> T { self.quick_query("BEGIN"); blk(&PostgresTransaction { conn: self, commit: Cell::new(true), nested: false }) } pub fn update(&self, query: &str, params: &[&ToSql]) -> uint { match self.try_update(query, params) { Ok(res) => res, Err(err) => fail2!("Error running update:\n{}", err.pretty_error(query)) } } pub fn try_update(&self, query: &str, params: &[&ToSql]) -> Result { do self.try_prepare(query).and_then |stmt| { stmt.try_update(params) } } fn quick_query(&self, query: &str) { do self.conn.with_mut_ref |conn| { conn.write_messages([&Query { query: query }]); loop { match conn.read_message() { ReadyForQuery {_} => break, ErrorResponse { fields } => fail2!("Error: {}", PostgresDbError::new(fields).to_str()), _ => () } } } } fn wait_for_ready(&self) { do self.conn.with_mut_ref |conn| { conn.wait_for_ready() } } fn read_message(&self) -> BackendMessage { do self.conn.with_mut_ref |conn| { conn.read_message() } } fn write_messages(&self, messages: &[&FrontendMessage]) { do self.conn.with_mut_ref |conn| { conn.write_messages(messages) } } } pub struct PostgresTransaction<'self> { priv conn: &'self PostgresConnection, priv commit: Cell, priv nested: bool } #[unsafe_destructor] impl<'self> Drop for PostgresTransaction<'self> { fn drop(&mut self) { do io_error::cond.trap(|_| {}).inside { if task::failing() || !self.commit.take() { if self.nested { self.conn.quick_query("ROLLBACK TO sp"); } else { self.conn.quick_query("ROLLBACK"); } } else { if self.nested { self.conn.quick_query("RELEASE sp"); } else { self.conn.quick_query("COMMIT"); } } } } } impl<'self> PostgresTransaction<'self> { pub fn prepare<'a>(&'a self, query: &str) -> TransactionalPostgresStatement<'a> { TransactionalPostgresStatement(self.conn.prepare(query)) } pub fn try_prepare<'a>(&'a self, query: &str) -> Result, PostgresDbError> { self.conn.try_prepare(query).map_move(TransactionalPostgresStatement) } pub fn update(&self, query: &str, params: &[&ToSql]) -> uint { self.conn.update(query, params) } pub fn try_update(&self, query: &str, params: &[&ToSql]) -> Result { self.conn.try_update(query, params) } pub fn in_transaction(&self, blk: &fn(&PostgresTransaction) -> T) -> T { self.conn.quick_query("SAVEPOINT sp"); blk(&PostgresTransaction { conn: self.conn, commit: Cell::new(true), nested: true }) } pub fn will_commit(&self) -> bool { let commit = self.commit.take(); self.commit.put_back(commit); commit } pub fn set_commit(&self) { self.commit.take(); self.commit.put_back(true); } pub fn set_rollback(&self) { self.commit.take(); self.commit.put_back(false); } } pub trait PostgresStatement { fn param_types<'a>(&'a self) -> &'a [PostgresType]; fn result_descriptions<'a>(&'a self) -> &'a [ResultDescription]; fn update(&self, params: &[&ToSql]) -> uint; fn try_update(&self, params: &[&ToSql]) -> Result; fn query<'a>(&'a self, params: &[&ToSql]) -> PostgresResult<'a>; fn try_query<'a>(&'a self, params: &[&ToSql]) -> Result, PostgresDbError>; fn find_col_named(&self, col: &str) -> Option; } pub struct NormalPostgresStatement<'self> { priv conn: &'self PostgresConnection, priv name: ~str, priv param_types: ~[PostgresType], priv result_desc: ~[ResultDescription], priv next_portal_id: Cell } #[unsafe_destructor] impl<'self> Drop for NormalPostgresStatement<'self> { fn drop(&mut self) { do io_error::cond.trap(|_| {}).inside { self.conn.write_messages([ &Close { variant: 'S' as u8, name: self.name.as_slice() }, &Sync]); loop { match self.conn.read_message() { ReadyForQuery {_} => break, _ => () } } } } } impl<'self> NormalPostgresStatement<'self> { fn execute(&self, portal_name: &str, row_limit: uint, params: &[&ToSql]) -> Option { let mut formats = ~[]; let mut values = ~[]; assert!(self.param_types.len() == params.len(), "Expected %u parameters but found %u", self.param_types.len(), params.len()); for (¶m, &ty) in params.iter().zip(self.param_types.iter()) { let (format, value) = param.to_sql(ty); formats.push(format as i16); values.push(value); }; let result_formats: ~[i16] = self.result_desc.iter().map(|desc| { desc.ty.result_format() as i16 }).collect(); self.conn.write_messages([ &Bind { portal: portal_name, statement: self.name.as_slice(), formats: formats, values: values, result_formats: result_formats }, &Execute { portal: portal_name, max_rows: row_limit as i32 }, &Sync]); match self.conn.read_message() { BindComplete => None, ErrorResponse { fields } => { self.conn.wait_for_ready(); Some(PostgresDbError::new(fields)) } _ => fail!() } } fn lazy_query<'a>(&'a self, row_limit: uint, params: &[&ToSql]) -> PostgresResult<'a> { match self.try_lazy_query(row_limit, params) { Ok(result) => result, Err(err) => fail2!("Error executing query:\n{}", err.to_str()) } } fn try_lazy_query<'a>(&'a self, row_limit: uint, params: &[&ToSql]) -> Result, PostgresDbError> { let id = self.next_portal_id.take(); let portal_name = format!("{}_portal_{}", self.name, id); self.next_portal_id.put_back(id + 1); match self.execute(portal_name, row_limit, params) { Some(err) => { return Err(err); } None => () } let mut result = PostgresResult { stmt: self, name: portal_name, data: RingBuf::new(), row_limit: row_limit, more_rows: true }; result.read_rows(); Ok(result) } } impl<'self> PostgresStatement for NormalPostgresStatement<'self> { fn param_types<'a>(&'a self) -> &'a [PostgresType] { self.param_types.as_slice() } fn result_descriptions<'a>(&'a self) -> &'a [ResultDescription] { self.result_desc.as_slice() } fn update(&self, params: &[&ToSql]) -> uint { match self.try_update(params) { Ok(count) => count, Err(err) => fail2!("Error running update\n{}", err.to_str()) } } fn try_update(&self, params: &[&ToSql]) -> Result { match self.execute("", 0, params) { Some(err) => { return Err(err); } None => () } let num; loop { match self.conn.read_message() { CommandComplete { tag } => { let s = tag.split_iter(' ').last().unwrap(); num = match FromStr::from_str(s) { None => 0, Some(n) => n }; break; } DataRow {_} => (), EmptyQueryResponse => { num = 0; break; } NoticeResponse {_} => (), ErrorResponse { fields } => { self.conn.wait_for_ready(); return Err(PostgresDbError::new(fields)); } _ => fail!() } } self.conn.wait_for_ready(); Ok(num) } fn query<'a>(&'a self, params: &[&ToSql]) -> PostgresResult<'a> { self.lazy_query(0, params) } fn try_query<'a>(&'a self, params: &[&ToSql]) -> Result, PostgresDbError> { self.try_lazy_query(0, params) } fn find_col_named(&self, col: &str) -> Option { do self.result_desc.iter().position |desc| { desc.name.as_slice() == col } } } #[deriving(Eq)] pub struct ResultDescription { name: ~str, ty: PostgresType } impl ResultDescription { fn from_row_description_entry(row: RowDescriptionEntry) -> ResultDescription { let RowDescriptionEntry { name, type_oid, _ } = row; ResultDescription { name: name, ty: PostgresType::from_oid(type_oid) } } } pub struct TransactionalPostgresStatement<'self>(NormalPostgresStatement<'self>); impl<'self> PostgresStatement for TransactionalPostgresStatement<'self> { fn param_types<'a>(&'a self) -> &'a [PostgresType] { (**self).param_types() } fn result_descriptions<'a>(&'a self) -> &'a [ResultDescription] { (**self).result_descriptions() } fn update(&self, params: &[&ToSql]) -> uint { (**self).update(params) } fn try_update(&self, params: &[&ToSql]) -> Result { (**self).try_update(params) } fn query<'a>(&'a self, params: &[&ToSql]) -> PostgresResult<'a> { (**self).query(params) } fn try_query<'a>(&'a self, params: &[&ToSql]) -> Result, PostgresDbError> { (**self).try_query(params) } fn find_col_named(&self, col: &str) -> Option { (**self).find_col_named(col) } } impl<'self> TransactionalPostgresStatement<'self> { pub fn lazy_query<'a>(&'a self, row_limit: uint, params: &[&ToSql]) -> PostgresResult<'a> { (**self).lazy_query(row_limit, params) } pub fn try_lazy_query<'a>(&'a self, row_limit: uint, params: &[&ToSql]) -> Result, PostgresDbError> { (**self).try_lazy_query(row_limit, params) } } pub struct PostgresResult<'self> { priv stmt: &'self NormalPostgresStatement<'self>, priv name: ~str, priv data: RingBuf<~[Option<~[u8]>]>, priv row_limit: uint, priv more_rows: bool } #[unsafe_destructor] impl<'self> Drop for PostgresResult<'self> { fn drop(&mut self) { do io_error::cond.trap(|_| {}).inside { self.stmt.conn.write_messages([ &Close { variant: 'P' as u8, name: self.name.as_slice() }, &Sync]); loop { match self.stmt.conn.read_message() { ReadyForQuery {_} => break, _ => () } } } } } impl<'self> PostgresResult<'self> { fn read_rows(&mut self) { loop { match self.stmt.conn.read_message() { EmptyQueryResponse | CommandComplete {_} => { self.more_rows = false; break; }, PortalSuspended => { self.more_rows = true; break; }, DataRow { row } => self.data.push_back(row), _ => fail!() } } self.stmt.conn.wait_for_ready(); } fn execute(&mut self) { self.stmt.conn.write_messages([ &Execute { portal: self.name, max_rows: self.row_limit as i32 }, &Sync]); self.read_rows(); } } impl<'self> Iterator> for PostgresResult<'self> { fn next(&mut self) -> Option> { if self.data.is_empty() && self.more_rows { self.execute(); } do self.data.pop_front().map_move |row| { PostgresRow { stmt: self.stmt, data: row } } } } pub struct PostgresRow<'self> { priv stmt: &'self NormalPostgresStatement<'self>, priv data: ~[Option<~[u8]>] } impl<'self> Container for PostgresRow<'self> { fn len(&self) -> uint { self.data.len() } } impl<'self, I: RowIndex, T: FromSql> Index for PostgresRow<'self> { #[inline] fn index(&self, idx: &I) -> T { let idx = idx.idx(self.stmt); FromSql::from_sql(self.stmt.result_desc[idx].ty, &self.data[idx]) } } pub trait RowIndex { fn idx(&self, stmt: &NormalPostgresStatement) -> uint; } impl RowIndex for uint { #[inline] fn idx(&self, _stmt: &NormalPostgresStatement) -> uint { *self } } // This is a convenience as the 0 in get[0] resolves to int :( impl RowIndex for int { #[inline] fn idx(&self, _stmt: &NormalPostgresStatement) -> uint { assert!(*self >= 0); *self as uint } } impl<'self> RowIndex for &'self str { #[inline] fn idx(&self, stmt: &NormalPostgresStatement) -> uint { match stmt.find_col_named(*self) { Some(idx) => idx, None => fail2!("No column with name {}", *self) } } }