forked from rust-postgres/rust-postgres
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconnect.rs
More file actions
93 lines (81 loc) · 2.78 KB
/
Copy pathconnect.rs
File metadata and controls
93 lines (81 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use crate::config::{Host, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
use futures::{Stream, FutureExt};
use futures::future;
use pin_utils::pin_mut;
use std::io;
use std::future::Future;
use std::task::Poll;
use std::pin::Pin;
pub async fn connect<T>(
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
if config.host.is_empty() {
return Err(Error::config("host missing".into()));
}
if config.port.len() > 1 && config.port.len() != config.host.len() {
return Err(Error::config("invalid number of ports".into()));
}
let mut error = None;
for (i, host) in config.host.iter().enumerate() {
let hostname = match host {
Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
match connect_once(i, tls, config).await {
Ok((client, connection)) => return Ok((client, connection)),
Err(e) => error = Some(e),
}
}
return Err(error.unwrap());
}
async fn connect_once<T>(
idx: usize,
tls: T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: TlsConnect<Socket>,
{
let socket = connect_socket(idx, config).await?;
let (mut client, mut connection) = connect_raw(socket, tls, config, Some(idx)).await?;
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
let rows = client.simple_query("SHOW transaction_read_only");
pin_mut!(rows);
loop {
let next = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Some(Err(Error::closed())));
}
rows.as_mut().poll_next(cx)
});
match next.await.transpose()? {
Some(SimpleQueryMessage::Row(row)) => {
if row.try_get(0)? == Some("on") {
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)));
} else {
break;
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
}
}
}
Ok((client, connection))
}