Skip to content

Commit 6ae93a0

Browse files
committed
Add a convenience connect free function
1 parent af41875 commit 6ae93a0

5 files changed

Lines changed: 28 additions & 19 deletions

File tree

tokio-postgres-openssl/src/test.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ fn runtime() {
7777
builder.set_ca_file("../test/server.crt").unwrap();
7878
let connector = MakeTlsConnector::new(builder.build());
7979

80-
let connect = "host=localhost port=5433 user=postgres"
81-
.parse::<tokio_postgres::Builder>()
82-
.unwrap()
83-
.connect(RequireTls(connector));
80+
let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", RequireTls(connector));
8481
let (mut client, connection) = runtime.block_on(connect).unwrap();
8582
let connection = connection.map_err(|e| panic!("{}", e));
8683
runtime.spawn(connection);

tokio-postgres/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ impl Builder {
128128
where
129129
T: MakeTlsMode<Socket>,
130130
{
131-
Connect(ConnectFuture::new(make_tls_mode, self.clone()))
131+
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
132132
}
133133
}
134134

tokio-postgres/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ fn next_portal() -> String {
3636
format!("p{}", ID.fetch_add(1, Ordering::SeqCst))
3737
}
3838

39+
#[cfg(feature = "runtime")]
40+
pub fn connect<T>(config: &str, tls_mode: T) -> Connect<T>
41+
where
42+
T: MakeTlsMode<Socket>,
43+
{
44+
Connect(proto::ConnectFuture::new(tls_mode, config.parse()))
45+
}
46+
3947
pub fn cancel_query<S, T>(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQuery<S, T>
4048
where
4149
S: AsyncRead + AsyncWrite,

tokio-postgres/src/proto/connect.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ where
1010
T: MakeTlsMode<Socket>,
1111
{
1212
#[state_machine_future(start, transitions(MakingTlsMode))]
13-
Start { make_tls_mode: T, config: Builder },
13+
Start {
14+
make_tls_mode: T,
15+
config: Result<Builder, Error>,
16+
},
1417
#[state_machine_future(transitions(Connecting))]
1518
MakingTlsMode {
1619
future: T::Future,
@@ -38,15 +41,17 @@ where
3841
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
3942
let mut state = state.take();
4043

41-
if state.config.0.host.is_empty() {
44+
let config = state.config?;
45+
46+
if config.0.host.is_empty() {
4247
return Err(Error::missing_host());
4348
}
4449

45-
if state.config.0.port.len() > 1 && state.config.0.port.len() != state.config.0.host.len() {
50+
if config.0.port.len() > 1 && config.0.port.len() != config.0.host.len() {
4651
return Err(Error::invalid_port_count());
4752
}
4853

49-
let hostname = match &state.config.0.host[0] {
54+
let hostname = match &config.0.host[0] {
5055
Host::Tcp(host) => &**host,
5156
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
5257
#[cfg(unix)]
@@ -58,7 +63,7 @@ where
5863
future,
5964
idx: 0,
6065
make_tls_mode: state.make_tls_mode,
61-
config: state.config,
66+
config,
6267
})
6368
}
6469

@@ -113,7 +118,7 @@ impl<T> ConnectFuture<T>
113118
where
114119
T: MakeTlsMode<Socket>,
115120
{
116-
pub fn new(make_tls_mode: T, config: Builder) -> ConnectFuture<T> {
121+
pub fn new(make_tls_mode: T, config: Result<Builder, Error>) -> ConnectFuture<T> {
117122
Connect::start(make_tls_mode, config)
118123
}
119124
}

tokio-postgres/tests/test/runtime.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use futures::Future;
22
use tokio::runtime::current_thread::Runtime;
3-
use tokio_postgres::{Client, Connection, Error, NoTls, Socket};
4-
5-
fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error = Error> {
6-
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
7-
}
3+
use tokio_postgres::NoTls;
84

95
fn smoke_test(s: &str) {
106
let mut runtime = Runtime::new().unwrap();
11-
let connect = connect(s);
7+
let connect = tokio_postgres::connect(s, NoTls);
128
let (mut client, connection) = runtime.block_on(connect).unwrap();
139
let connection = connection.map_err(|e| panic!("{}", e));
1410
runtime.spawn(connection);
@@ -41,9 +37,12 @@ fn multiple_hosts_multiple_ports() {
4137
#[test]
4238
fn wrong_port_count() {
4339
let mut runtime = Runtime::new().unwrap();
44-
let f = connect("host=localhost port=5433,5433 user=postgres");
40+
let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls);
4541
runtime.block_on(f).err().unwrap();
4642

47-
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
43+
let f = tokio_postgres::connect(
44+
"host=localhost,localhost,localhost port=5433,5433 user=postgres",
45+
NoTls,
46+
);
4847
runtime.block_on(f).err().unwrap();
4948
}

0 commit comments

Comments
 (0)