Skip to content

Commit

Permalink
Merge pull request #9 from readysettech/REA-3121
Browse files Browse the repository at this point in the history
Transactions aren't executed with "psql -c"
  • Loading branch information
tbjuhasz authored Aug 31, 2023
2 parents 13206f6 + a1eaceb commit 9732087
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::generic_result::GenericResult;
pub use crate::portal::Portal;
pub use crate::query::{RowStream, ResultStream};
pub use crate::query::{ResultStream, RowStream};
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
#[cfg(feature = "runtime")]
Expand Down
6 changes: 5 additions & 1 deletion tokio-postgres/src/simple_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ impl Stream for SimpleQueryStream {
.parse()
.unwrap_or(0);
let fields = if *this.include_fields_in_complete {
this.fields.clone()
if body.tag().expect("Failed to get tag").starts_with("SELECT") {
this.fields.clone()
} else {
None
}
} else {
// Reset bool for next grouping
*this.include_fields_in_complete = true;
Expand Down
81 changes: 81 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,88 @@ async fn custom_range() {
assert_eq!("floatrange", ty.name());
assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind());
}
/// This test check to make sure that empty responses for select queries include the header but not
/// for other query types.
#[tokio::test]
async fn simple_query_select_transaction() {
let client = connect("user=postgres").await;

let _ = client.simple_query("DROP TABLE sbtest1").await.unwrap();
let _ = client.simple_query("DROP TABLE sbtest2").await.unwrap();
let _ = client
.simple_query("CREATE TABLE sbtest1 (id INTEGER, k INTEGER);")
.await
.unwrap();
let _ = client
.simple_query("CREATE TABLE sbtest2 (id INTEGER, k INTEGER);")
.await
.unwrap();

let messages = client
.simple_query(
"INSERT INTO sbtest1 VALUES (1, 2);
INSERT INTO sbtest2 VALUES (1, 2);
SELECT * FROM sbtest1 ORDER BY id;
SELECT k FROM sbtest1 WHERE id = 999;
BEGIN;
UPDATE sbtest1 SET k=id;
UPDATE sbtest2 SET k=id;
END;",
)
.await
.unwrap();

match messages[0] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
_ => panic!("unexpected message or too many rows"),
}
match messages[1] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
_ => panic!("unexpected message or too many rows"),
}
match &messages[2] {
SimpleQueryMessage::Row(row) => {
assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id"));
assert_eq!(row.columns().get(1).map(|c| c.name()), Some("k"));
assert_eq!(row.get(0), Some("1"));
assert_eq!(row.get(1), Some("2"));
}
_ => panic!("unexpected message"),
}
match &messages[3] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty "),
}
match &messages[4] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields, .. }) => {
if let Some(field_vec) = &fields {
assert_eq!((&**field_vec).len(), 1);
assert_eq!("k", (&**field_vec)[0].name());
} else {
panic!("No data found");
}
}
_ => panic!("unexpected message"),
}
match &messages[5] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}
match &messages[6] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}
match &messages[7] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}
match &messages[8] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}

assert_eq!(messages.len(), 9);
}
#[tokio::test]
async fn simple_query() {
let client = connect("user=postgres").await;
Expand Down

0 comments on commit 9732087

Please sign in to comment.