diff --git a/spec/pg/connection_spec.cr b/spec/pg/connection_spec.cr index cf089d05..7bb45052 100644 --- a/spec/pg/connection_spec.cr +++ b/spec/pg/connection_spec.cr @@ -141,3 +141,33 @@ describe PG, "#clear_time_zone_cache" do end end end + +describe PG, "COPY" do + it "properly handles partial reads and consumes data on early close" do + with_connection do |db| + io = db.exec_copy "COPY (VALUES (1), (333)) TO STDOUT" + io.read_char.should eq '1' + io.read_char.should eq '\n' + io.read_char.should eq '3' + io.read_char.should eq '3' + io.close + db.scalar("select 1").should eq(1) + end + end + + if "survives a COPY FROM STDIN and COPY TO STDOUT round-trip" + with_connection do |db| + data = "123\tdata\n\\N\t\\N\n" + db.exec("CREATE TEMPORARY TABLE IF NOT EXISTS copy_test (a int, b text)") + + wr = db.exec_copy "COPY copy_test FROM STDIN" + wr << data + wr.close + + rd = db.exec_copy "COPY copy_test TO STDOUT" + rd.gets_to_end.should eq data + + db.exec("DROP TABLE copy_test") + end + end +end diff --git a/src/pg/connection.cr b/src/pg/connection.cr index e6e15015..1f22202a 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -37,6 +37,22 @@ module PG nil end + # Execute a "COPY" query and return an IO object to read from or write to, + # depending on the query. + # + # ``` + # data = conn.exec_copy("COPY table TO STDOUT").gets_to_end + # ``` + # + # ``` + # writer = conn.exec_copy "COPY table FROM STDIN") + # writer << data + # writer.close + # ``` + def exec_copy(query : String) : CopyResult + CopyResult.new connection, query + end + # Set the callback block for notices and errors. def on_notice(&on_notice_proc : PQ::Notice ->) @connection.notice_handler = on_notice_proc diff --git a/src/pg/copy_result.cr b/src/pg/copy_result.cr new file mode 100644 index 00000000..eb6fe153 --- /dev/null +++ b/src/pg/copy_result.cr @@ -0,0 +1,90 @@ +# IO object obtained through PG::Connection.exec_copy. +class PG::CopyResult < IO + getter? closed : Bool + + def initialize(@connection : PQ::Connection, query : String) + @connection.send_query_message query + response = @connection.expect_frame PQ::Frame::CopyOutResponse | PQ::Frame::CopyInResponse + + @reading = response.is_a? PQ::Frame::CopyOutResponse + @frame_size = 0 + @end = false + @closed = false + end + + private def read_final(done) + return if @end + @end = true + + unless done + @connection.skip_bytes @frame_size if @frame_size > 0 + + while @connection.read_next_copy_start + size = @connection.read_i32 - 4 + @connection.skip_bytes size + end + end + + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end + + # Returns the number of remaining bytes in the current row. + # Returns 0 the are no more rows to be read. + # This can be used to allocate the precise amount of memory to read a complete row. + # + # ``` + # size = io.remaining_row_size + # if size != 0 + # row = Bytes.new(size) + # io.read(row) + # # Process the row. + # end + # ``` + def remaining_row_size : Int32 + raise "Can't read from a write-only PG::CopyResult" unless @reading + check_open + + return 0 if @end + + if @frame_size == 0 + if @connection.read_next_copy_start + @frame_size = @connection.read_i32 - 4 + else + read_final true + return 0 + end + end + + @frame_size + end + + def read(slice : Bytes) : Int32 + return 0 if slice.empty? + + remaining = remaining_row_size + return 0 if remaining == 0 + + max_bytes = slice.size > remaining ? remaining : slice.size + bytes = @connection.read_direct(slice[0..max_bytes - 1]) + @frame_size -= bytes + bytes + end + + def write(slice : Bytes) : Nil + raise "Can't write to a read-only PG::CopyResult" if @reading + @connection.send_copy_data_message slice + end + + def close : Nil + return if @closed + if @reading + read_final false + else + @connection.send_copy_done_message + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end + @closed = true + end +end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index 47a45b56..8d6ad888 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -136,6 +136,10 @@ module PQ data end + def read_direct(slice) + soc.read(slice) + end + def skip_bytes(count) soc.skip(count) end @@ -442,6 +446,24 @@ module PQ expect_frame Frame::CommandComplete, type end + def read_next_copy_start + type = soc.read_char + + while type == 'N' + # NoticeResponse + frame = read_one_frame('N') + handle_async_frames(frame) + type = soc.read_char + end + + if type == 'd' + true + else + expect_frame Frame::CopyDone, type + false + end + end + def expect_frame(frame_class, type = nil) f = type ? read(type) : read raise "Expected #{frame_class} but got #{f}" unless frame_class === f @@ -528,5 +550,17 @@ module PQ write_chr 'X' write_i32 4 end + + def send_copy_data_message(slice) + write_chr 'd' + write_i32 4 + slice.size + soc.write slice + end + + def send_copy_done_message + write_chr 'c' + write_i32 4 + soc.flush + end end end diff --git a/src/pq/frame.cr b/src/pq/frame.cr index 52e0cc52..4b02b899 100644 --- a/src/pq/frame.cr +++ b/src/pq/frame.cr @@ -21,6 +21,9 @@ module PQ when 'S' then ParameterStatus when 'K' then BackendKeyData when 'R' then Authentication + when 'c' then CopyDone + when 'G' then CopyInResponse + when 'H' then CopyOutResponse else nil end k ? k.new(bytes) : Unknown.new(type, bytes) @@ -245,5 +248,14 @@ module PQ struct EmptyQueryResponse < Frame end + + struct CopyDone < Frame + end + + struct CopyInResponse < Frame + end + + struct CopyOutResponse < Frame + end end end