From 8d95238dccecc3474a0889ecfa316898d9bb389b Mon Sep 17 00:00:00 2001 From: Yorhel Date: Wed, 20 Dec 2023 19:04:06 +0100 Subject: [PATCH 1/2] Add support for COPY out data transfers --- spec/pg/connection_spec.cr | 38 ++++++++++++++++++++++++++++ src/pg/connection.cr | 12 +++++++++ src/pg/copy_out.cr | 52 ++++++++++++++++++++++++++++++++++++++ src/pq/connection.cr | 22 ++++++++++++++++ src/pq/frame.cr | 8 ++++++ 5 files changed, 132 insertions(+) create mode 100644 src/pg/copy_out.cr diff --git a/spec/pg/connection_spec.cr b/spec/pg/connection_spec.cr index cf089d05..d0e27769 100644 --- a/spec/pg/connection_spec.cr +++ b/spec/pg/connection_spec.cr @@ -141,3 +141,41 @@ describe PG, "#clear_time_zone_cache" do end end end + +describe PG, "COPY out" do + it "supports COPY TO STDOUT data transfer" do + with_connection do |db| + io = db.copy_out "COPY (SELECT 'text', NULL, 1) TO STDOUT" + io.gets_to_end.should eq "text\t\\N\t1\n" + io.close + db.scalar("select 1").should eq(1) + end + end + + it "propely consumes data on early close" do + with_connection do |db| + io = db.copy_out "COPY (SELECT * FROM generate_series(1, 100) x) TO STDOUT" + io.gets.should eq "1" + io.gets.should eq "2" + io.gets.should eq "3" + io.close + db.scalar("select 1").should eq(1) + end + end + + it "properly handles partial reads" do + with_connection do |db| + io = db.copy_out "COPY (VALUES (1), (333)) TO STDOUT" + io.read_char.should eq '1' + io.read_char.should eq '\n' + io.read_char.should eq '3' + io.read_char.should eq '3' + io.read_char.should eq '3' + io.read_char.should eq '\n' + io.read_char.should eq nil + io.read_char.should eq nil + io.close + db.scalar("select 1").should eq(1) + end + end +end diff --git a/src/pg/connection.cr b/src/pg/connection.cr index e6e15015..476633ba 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -37,6 +37,18 @@ module PG nil end + # Execute a "COPY .. TO STDOUT" query and return an IO object to read from. + # The IO *must* be closed before using the connection again. + # + # ``` + # io = conn.copy_out "COPY table TO STDOUT" + # data = io.gets_to_end + # io.close + # ``` + def copy_out(query : String) : CopyOut + CopyOut.new connection, query + end + # Set the callback block for notices and errors. def on_notice(&on_notice_proc : PQ::Notice ->) @connection.notice_handler = on_notice_proc diff --git a/src/pg/copy_out.cr b/src/pg/copy_out.cr new file mode 100644 index 00000000..21b36f19 --- /dev/null +++ b/src/pg/copy_out.cr @@ -0,0 +1,52 @@ +class PG::CopyOut < IO + getter? closed : Bool + + def initialize(@connection : PQ::Connection, query : String) + @connection.send_query_message query + @connection.expect_frame PQ::Frame::CopyOutResponse + + @frame_size = 0 # Remaining bytes in the current frame + @end = false + @closed = false + end + + def read(slice : Bytes) : Int32 + check_open + + return 0 if slice.empty? + return 0 if @end + + if @frame_size == 0 + if @connection.read_next_copy_start + @frame_size = @connection.read_i32 - 4 + else + @end = true + return 0 + end + end + + max_bytes = slice.size > @frame_size ? @frame_size : slice.size + bytes = @connection.read_direct(slice[0..max_bytes - 1]) + @frame_size -= bytes + bytes + end + + def write(slice : Bytes) : NoReturn + raise "Can't write to PG::CopyOut" + end + + def close : Nil + return if @closed + @closed = true + + unless @end + while @connection.read_next_copy_start + size = @connection.read_i32 - 4 + @connection.skip_bytes size + end + end + + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end +end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index 47a45b56..0e666772 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -136,6 +136,10 @@ module PQ data end + def read_direct(slice) + soc.read(slice) + end + def skip_bytes(count) soc.skip(count) end @@ -442,6 +446,24 @@ module PQ expect_frame Frame::CommandComplete, type end + def read_next_copy_start + type = soc.read_char + + while type == 'N' + # NoticeResponse + frame = read_one_frame('N') + handle_async_frames(frame) + type = soc.read_char + end + + if type == 'd' + true + else + expect_frame Frame::CopyDone, type + false + end + end + def expect_frame(frame_class, type = nil) f = type ? read(type) : read raise "Expected #{frame_class} but got #{f}" unless frame_class === f diff --git a/src/pq/frame.cr b/src/pq/frame.cr index 52e0cc52..7df9173c 100644 --- a/src/pq/frame.cr +++ b/src/pq/frame.cr @@ -21,6 +21,8 @@ module PQ when 'S' then ParameterStatus when 'K' then BackendKeyData when 'R' then Authentication + when 'c' then CopyDone + when 'H' then CopyOutResponse else nil end k ? k.new(bytes) : Unknown.new(type, bytes) @@ -245,5 +247,11 @@ module PQ struct EmptyQueryResponse < Frame end + + struct CopyDone < Frame + end + + struct CopyOutResponse < Frame + end end end From ed26146a226f31acf77c1949dea6d3e7c6b5674e Mon Sep 17 00:00:00 2001 From: Yorhel Date: Thu, 21 Dec 2023 07:32:07 +0100 Subject: [PATCH 2/2] Also add support for COPY in + fix close after partial read + remaining_row_size --- spec/pg/connection_spec.cr | 44 ++++++++----------- src/pg/connection.cr | 18 +++++--- src/pg/copy_out.cr | 52 ---------------------- src/pg/copy_result.cr | 90 ++++++++++++++++++++++++++++++++++++++ src/pq/connection.cr | 12 +++++ src/pq/frame.cr | 4 ++ 6 files changed, 135 insertions(+), 85 deletions(-) delete mode 100644 src/pg/copy_out.cr create mode 100644 src/pg/copy_result.cr diff --git a/spec/pg/connection_spec.cr b/spec/pg/connection_spec.cr index d0e27769..7bb45052 100644 --- a/spec/pg/connection_spec.cr +++ b/spec/pg/connection_spec.cr @@ -142,40 +142,32 @@ describe PG, "#clear_time_zone_cache" do end end -describe PG, "COPY out" do - it "supports COPY TO STDOUT data transfer" do +describe PG, "COPY" do + it "properly handles partial reads and consumes data on early close" do with_connection do |db| - io = db.copy_out "COPY (SELECT 'text', NULL, 1) TO STDOUT" - io.gets_to_end.should eq "text\t\\N\t1\n" + io = db.exec_copy "COPY (VALUES (1), (333)) TO STDOUT" + io.read_char.should eq '1' + io.read_char.should eq '\n' + io.read_char.should eq '3' + io.read_char.should eq '3' io.close db.scalar("select 1").should eq(1) end end - it "propely consumes data on early close" do + if "survives a COPY FROM STDIN and COPY TO STDOUT round-trip" with_connection do |db| - io = db.copy_out "COPY (SELECT * FROM generate_series(1, 100) x) TO STDOUT" - io.gets.should eq "1" - io.gets.should eq "2" - io.gets.should eq "3" - io.close - db.scalar("select 1").should eq(1) - end - end + data = "123\tdata\n\\N\t\\N\n" + db.exec("CREATE TEMPORARY TABLE IF NOT EXISTS copy_test (a int, b text)") - it "properly handles partial reads" do - with_connection do |db| - io = db.copy_out "COPY (VALUES (1), (333)) TO STDOUT" - io.read_char.should eq '1' - io.read_char.should eq '\n' - io.read_char.should eq '3' - io.read_char.should eq '3' - io.read_char.should eq '3' - io.read_char.should eq '\n' - io.read_char.should eq nil - io.read_char.should eq nil - io.close - db.scalar("select 1").should eq(1) + wr = db.exec_copy "COPY copy_test FROM STDIN" + wr << data + wr.close + + rd = db.exec_copy "COPY copy_test TO STDOUT" + rd.gets_to_end.should eq data + + db.exec("DROP TABLE copy_test") end end end diff --git a/src/pg/connection.cr b/src/pg/connection.cr index 476633ba..1f22202a 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -37,16 +37,20 @@ module PG nil end - # Execute a "COPY .. TO STDOUT" query and return an IO object to read from. - # The IO *must* be closed before using the connection again. + # Execute a "COPY" query and return an IO object to read from or write to, + # depending on the query. # # ``` - # io = conn.copy_out "COPY table TO STDOUT" - # data = io.gets_to_end - # io.close + # data = conn.exec_copy("COPY table TO STDOUT").gets_to_end # ``` - def copy_out(query : String) : CopyOut - CopyOut.new connection, query + # + # ``` + # writer = conn.exec_copy "COPY table FROM STDIN") + # writer << data + # writer.close + # ``` + def exec_copy(query : String) : CopyResult + CopyResult.new connection, query end # Set the callback block for notices and errors. diff --git a/src/pg/copy_out.cr b/src/pg/copy_out.cr deleted file mode 100644 index 21b36f19..00000000 --- a/src/pg/copy_out.cr +++ /dev/null @@ -1,52 +0,0 @@ -class PG::CopyOut < IO - getter? closed : Bool - - def initialize(@connection : PQ::Connection, query : String) - @connection.send_query_message query - @connection.expect_frame PQ::Frame::CopyOutResponse - - @frame_size = 0 # Remaining bytes in the current frame - @end = false - @closed = false - end - - def read(slice : Bytes) : Int32 - check_open - - return 0 if slice.empty? - return 0 if @end - - if @frame_size == 0 - if @connection.read_next_copy_start - @frame_size = @connection.read_i32 - 4 - else - @end = true - return 0 - end - end - - max_bytes = slice.size > @frame_size ? @frame_size : slice.size - bytes = @connection.read_direct(slice[0..max_bytes - 1]) - @frame_size -= bytes - bytes - end - - def write(slice : Bytes) : NoReturn - raise "Can't write to PG::CopyOut" - end - - def close : Nil - return if @closed - @closed = true - - unless @end - while @connection.read_next_copy_start - size = @connection.read_i32 - 4 - @connection.skip_bytes size - end - end - - @connection.expect_frame PQ::Frame::CommandComplete - @connection.expect_frame PQ::Frame::ReadyForQuery - end -end diff --git a/src/pg/copy_result.cr b/src/pg/copy_result.cr new file mode 100644 index 00000000..eb6fe153 --- /dev/null +++ b/src/pg/copy_result.cr @@ -0,0 +1,90 @@ +# IO object obtained through PG::Connection.exec_copy. +class PG::CopyResult < IO + getter? closed : Bool + + def initialize(@connection : PQ::Connection, query : String) + @connection.send_query_message query + response = @connection.expect_frame PQ::Frame::CopyOutResponse | PQ::Frame::CopyInResponse + + @reading = response.is_a? PQ::Frame::CopyOutResponse + @frame_size = 0 + @end = false + @closed = false + end + + private def read_final(done) + return if @end + @end = true + + unless done + @connection.skip_bytes @frame_size if @frame_size > 0 + + while @connection.read_next_copy_start + size = @connection.read_i32 - 4 + @connection.skip_bytes size + end + end + + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end + + # Returns the number of remaining bytes in the current row. + # Returns 0 the are no more rows to be read. + # This can be used to allocate the precise amount of memory to read a complete row. + # + # ``` + # size = io.remaining_row_size + # if size != 0 + # row = Bytes.new(size) + # io.read(row) + # # Process the row. + # end + # ``` + def remaining_row_size : Int32 + raise "Can't read from a write-only PG::CopyResult" unless @reading + check_open + + return 0 if @end + + if @frame_size == 0 + if @connection.read_next_copy_start + @frame_size = @connection.read_i32 - 4 + else + read_final true + return 0 + end + end + + @frame_size + end + + def read(slice : Bytes) : Int32 + return 0 if slice.empty? + + remaining = remaining_row_size + return 0 if remaining == 0 + + max_bytes = slice.size > remaining ? remaining : slice.size + bytes = @connection.read_direct(slice[0..max_bytes - 1]) + @frame_size -= bytes + bytes + end + + def write(slice : Bytes) : Nil + raise "Can't write to a read-only PG::CopyResult" if @reading + @connection.send_copy_data_message slice + end + + def close : Nil + return if @closed + if @reading + read_final false + else + @connection.send_copy_done_message + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end + @closed = true + end +end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index 0e666772..8d6ad888 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -550,5 +550,17 @@ module PQ write_chr 'X' write_i32 4 end + + def send_copy_data_message(slice) + write_chr 'd' + write_i32 4 + slice.size + soc.write slice + end + + def send_copy_done_message + write_chr 'c' + write_i32 4 + soc.flush + end end end diff --git a/src/pq/frame.cr b/src/pq/frame.cr index 7df9173c..4b02b899 100644 --- a/src/pq/frame.cr +++ b/src/pq/frame.cr @@ -22,6 +22,7 @@ module PQ when 'K' then BackendKeyData when 'R' then Authentication when 'c' then CopyDone + when 'G' then CopyInResponse when 'H' then CopyOutResponse else nil end @@ -251,6 +252,9 @@ module PQ struct CopyDone < Frame end + struct CopyInResponse < Frame + end + struct CopyOutResponse < Frame end end