Skip to content

Commit b566bbd

Browse files
committed
Merge pull request diesel-rs#252 from weiznich/fix_211
Fix diesel-rs#211: Allow other primary keys than id for infer_schema
2 parents 1af8d56 + 7377bb8 commit b566bbd

4 files changed

Lines changed: 100 additions & 16 deletions

File tree

diesel_codegen/src/schema_inference/data_structures.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ impl<ST> Queryable<ST, Pg> for ColumnInformation where
3535
#[cfg(feature = "sqlite")]
3636
impl<ST> Queryable<ST, Sqlite> for ColumnInformation where
3737
Sqlite: HasSqlType<ST>,
38-
(i32, String, String, bool, Option<String>, i32): FromSqlRow<ST, Sqlite>,
38+
(i32, String, String, bool, Option<String>, bool): FromSqlRow<ST, Sqlite>,
3939
{
40-
type Row = (i32, String, String, bool, Option<String>, i32);
40+
type Row = (i32, String, String, bool, Option<String>, bool);
4141

4242
fn build(row: Self::Row) -> Self {
4343
ColumnInformation {

diesel_codegen/src/schema_inference/mod.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,30 @@ fn table_macro_call(
9292
Err(DummyResult::any(sp))
9393
}
9494
Ok(data) => {
95-
let tokens = data.iter().map(|a| column_def_tokens(cx, a, &connection))
96-
.collect::<Vec<_>>();
97-
let table_name = str_to_ident(table_name);
98-
let item = quote_item!(cx, table! {
99-
$table_name {
100-
$tokens
95+
let primary_keys = match get_primary_keys(connection, table_name) {
96+
Ok(keys) => keys,
97+
Err(_) =>{
98+
cx.span_err(sp, "error loading schema");
99+
return Err(DummyResult::any(sp));
101100
}
102-
}).unwrap();
103-
Ok(item)
101+
};
102+
if primary_keys.len() != 1 {
103+
cx.span_err(sp,
104+
&format!("table {} has {} primary keys, only one is currently supported",
105+
table_name, primary_keys.len()));
106+
Err(DummyResult::any(sp))
107+
} else {
108+
let tokens = data.iter().map(|a| column_def_tokens(cx, a, &connection))
109+
.collect::<Vec<_>>();
110+
let table_name = str_to_ident(table_name);
111+
let primary_key = str_to_ident(&primary_keys[0]);
112+
let item = quote_item!(cx, table! {
113+
$table_name ($primary_key) {
114+
$tokens
115+
}
116+
}).unwrap();
117+
Ok(item)
118+
}
104119
}
105120
}
106121
}
@@ -201,3 +216,12 @@ fn determine_column_type(cx: &mut ExtCtxt, attr: &ColumnInformation, conn: &Infe
201216
InferConnection::Pg(_) => pg::determine_column_type(cx, attr),
202217
}
203218
}
219+
220+
fn get_primary_keys(conn: &InferConnection, table_name: &str) -> QueryResult<Vec<String>> {
221+
match *conn {
222+
#[cfg(feature = "sqlite")]
223+
InferConnection::Sqlite(ref c) => sqlite::get_primary_keys(c, table_name),
224+
#[cfg(feature = "postgres")]
225+
InferConnection::Pg(ref c) => pg::get_primary_keys(c, table_name),
226+
}
227+
}

diesel_codegen/src/schema_inference/pg.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ joinable!(pg_attribute -> pg_type (atttypid));
3030
select_column_workaround!(pg_attribute -> pg_type (attrelid, attname, atttypid, attnotnull, attnum, attisdropped));
3131
select_column_workaround!(pg_type -> pg_attribute (oid, typname));
3232

33+
table! {
34+
pg_index (indrelid) {
35+
indrelid -> Oid,
36+
indexrelid -> Oid,
37+
indkey -> Array<SmallInt>,
38+
indisprimary -> Bool,
39+
}
40+
}
41+
3342
table! {
3443
pg_class (oid) {
3544
oid -> Oid,
@@ -74,17 +83,32 @@ pub fn load_table_names(
7483
pub fn get_table_data(conn: &PgConnection, table_name: &str) -> QueryResult<Vec<ColumnInformation>> {
7584
use self::pg_attribute::dsl::*;
7685
use self::pg_type::dsl::{pg_type, typname};
77-
let t_oid = try!(table_oid(conn, table_name));
86+
use self::pg_class::dsl::*;
87+
88+
let table_oid = pg_class.select(oid).filter(relname.eq(table_name)).limit(1);
7889

7990
pg_attribute.inner_join(pg_type)
8091
.select((attname, typname, attnotnull))
81-
.filter(attrelid.eq(t_oid))
92+
.filter(attrelid.eq_any(table_oid))
8293
.filter(attnum.gt(0).and(attisdropped.ne(true)))
8394
.order(attnum)
8495
.load(conn)
8596
}
8697

87-
fn table_oid(conn: &PgConnection, table_name: &str) -> QueryResult<u32> {
98+
99+
pub fn get_primary_keys(conn: &PgConnection, table_name: &str) -> QueryResult<Vec<String>> {
100+
use self::pg_attribute::dsl::*;
101+
use self::pg_index::dsl::{pg_index, indisprimary, indexrelid, indrelid};
88102
use self::pg_class::dsl::*;
89-
pg_class.select(oid).filter(relname.eq(table_name)).first(conn)
103+
104+
let table_oid = pg_class.select(oid).filter(relname.eq(table_name)).limit(1);
105+
106+
let pk_query = pg_index.select(indexrelid)
107+
.filter(indrelid.eq_any(table_oid))
108+
.filter(indisprimary.eq(true));
109+
110+
pg_attribute.select(attname)
111+
.filter(attrelid.eq_any(pk_query))
112+
.order(attnum)
113+
.load(conn)
90114
}

diesel_codegen/src/schema_inference/sqlite.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use diesel::*;
2-
use diesel::sqlite::SqliteConnection;
2+
use diesel::sqlite::{SqliteConnection, Sqlite};
3+
use diesel::types::{HasSqlType, FromSqlRow};
34
use syntax::ast;
45
use syntax::codemap::Span;
56
use syntax::ext::base::*;
@@ -14,7 +15,7 @@ table!{
1415
type_name -> VarChar,
1516
notnull -> Bool,
1617
dflt_value -> Nullable<VarChar>,
17-
pk -> Integer,
18+
pk -> Bool,
1819
}
1920
}
2021

@@ -62,3 +63,38 @@ pub fn load_table_names(
6263
.filter(sql::<types::Bool>("type='table' AND name NOT LIKE '\\_\\_%'"));
6364
query.load(connection)
6465
}
66+
67+
struct FullTableInfo {
68+
_cid: i32,
69+
name: String,
70+
_type_name: String,
71+
_not_null: bool,
72+
_dflt_value: Option<String>,
73+
primary_key: bool,
74+
}
75+
76+
impl<ST> Queryable<ST, Sqlite> for FullTableInfo where
77+
Sqlite: HasSqlType<ST>,
78+
(i32, String, String, bool, Option<String>, bool): FromSqlRow<ST, Sqlite>,
79+
{
80+
type Row = (i32, String, String, bool, Option<String>, bool);
81+
82+
fn build(row: Self::Row) -> Self {
83+
FullTableInfo {
84+
_cid: row.0,
85+
name: row.1,
86+
_type_name: row.2,
87+
_not_null: row.3,
88+
_dflt_value: row.4,
89+
primary_key: row.5,
90+
}
91+
}
92+
}
93+
94+
pub fn get_primary_keys(conn: &SqliteConnection, table_name: &str) -> QueryResult<Vec<String>> {
95+
conn.execute_pragma::<pragma_table_info::SqlType, FullTableInfo>(
96+
&format!("PRAGMA TABLE_INFO('{}')", table_name))
97+
.map( |i| i.iter()
98+
.filter_map(|i| if i.primary_key { Some(i.name.clone()) } else { None })
99+
.collect())
100+
}

0 commit comments

Comments
 (0)