Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for COPY data transfers #279

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions spec/pg/connection_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,33 @@ describe PG, "#clear_time_zone_cache" do
end
end
end

describe PG, "COPY" do
it "properly handles partial reads and consumes data on early close" do
with_connection do |db|
io = db.exec_copy "COPY (VALUES (1), (333)) TO STDOUT"
io.read_char.should eq '1'
io.read_char.should eq '\n'
io.read_char.should eq '3'
io.read_char.should eq '3'
io.close
db.scalar("select 1").should eq(1)
end
end

if "survives a COPY FROM STDIN and COPY TO STDOUT round-trip"
with_connection do |db|
data = "123\tdata\n\\N\t\\N\n"
db.exec("CREATE TEMPORARY TABLE IF NOT EXISTS copy_test (a int, b text)")

wr = db.exec_copy "COPY copy_test FROM STDIN"
wr << data
wr.close

rd = db.exec_copy "COPY copy_test TO STDOUT"
rd.gets_to_end.should eq data

db.exec("DROP TABLE copy_test")
end
end
end
16 changes: 16 additions & 0 deletions src/pg/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ module PG
nil
end

# Execute a "COPY" query and return an IO object to read from or write to,
# depending on the query.
#
# ```
# data = conn.exec_copy("COPY table TO STDOUT").gets_to_end
# ```
#
# ```
# writer = conn.exec_copy "COPY table FROM STDIN")
# writer << data
# writer.close
# ```
def exec_copy(query : String) : CopyResult
CopyResult.new connection, query
end

# Set the callback block for notices and errors.
def on_notice(&on_notice_proc : PQ::Notice ->)
@connection.notice_handler = on_notice_proc
Expand Down
90 changes: 90 additions & 0 deletions src/pg/copy_result.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# IO object obtained through PG::Connection.exec_copy.
class PG::CopyResult < IO
getter? closed : Bool

def initialize(@connection : PQ::Connection, query : String)
@connection.send_query_message query
response = @connection.expect_frame PQ::Frame::CopyOutResponse | PQ::Frame::CopyInResponse

@reading = response.is_a? PQ::Frame::CopyOutResponse
@frame_size = 0
@end = false
@closed = false
end

private def read_final(done)
return if @end
@end = true

unless done
@connection.skip_bytes @frame_size if @frame_size > 0

while @connection.read_next_copy_start
size = @connection.read_i32 - 4
@connection.skip_bytes size
end
end

@connection.expect_frame PQ::Frame::CommandComplete
@connection.expect_frame PQ::Frame::ReadyForQuery
end

# Returns the number of remaining bytes in the current row.
# Returns 0 the are no more rows to be read.
# This can be used to allocate the precise amount of memory to read a complete row.
#
# ```
# size = io.remaining_row_size
# if size != 0
# row = Bytes.new(size)
# io.read(row)
# # Process the row.
# end
# ```
def remaining_row_size : Int32
raise "Can't read from a write-only PG::CopyResult" unless @reading
check_open

return 0 if @end

if @frame_size == 0
if @connection.read_next_copy_start
@frame_size = @connection.read_i32 - 4
else
read_final true
return 0
end
end

@frame_size
end

def read(slice : Bytes) : Int32
return 0 if slice.empty?

remaining = remaining_row_size
return 0 if remaining == 0

max_bytes = slice.size > remaining ? remaining : slice.size
bytes = @connection.read_direct(slice[0..max_bytes - 1])
@frame_size -= bytes
bytes
end

def write(slice : Bytes) : Nil
raise "Can't write to a read-only PG::CopyResult" if @reading
@connection.send_copy_data_message slice
end

def close : Nil
return if @closed
if @reading
read_final false
else
@connection.send_copy_done_message
@connection.expect_frame PQ::Frame::CommandComplete
@connection.expect_frame PQ::Frame::ReadyForQuery
end
@closed = true
end
end
34 changes: 34 additions & 0 deletions src/pq/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ module PQ
data
end

def read_direct(slice)
soc.read(slice)
end

def skip_bytes(count)
soc.skip(count)
end
Expand Down Expand Up @@ -442,6 +446,24 @@ module PQ
expect_frame Frame::CommandComplete, type
end

def read_next_copy_start
type = soc.read_char

while type == 'N'
# NoticeResponse
frame = read_one_frame('N')
handle_async_frames(frame)
type = soc.read_char
end

if type == 'd'
true
else
expect_frame Frame::CopyDone, type
false
end
end

def expect_frame(frame_class, type = nil)
f = type ? read(type) : read
raise "Expected #{frame_class} but got #{f}" unless frame_class === f
Expand Down Expand Up @@ -528,5 +550,17 @@ module PQ
write_chr 'X'
write_i32 4
end

def send_copy_data_message(slice)
write_chr 'd'
write_i32 4 + slice.size
soc.write slice
end

def send_copy_done_message
write_chr 'c'
write_i32 4
soc.flush
end
end
end
12 changes: 12 additions & 0 deletions src/pq/frame.cr
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ module PQ
when 'S' then ParameterStatus
when 'K' then BackendKeyData
when 'R' then Authentication
when 'c' then CopyDone
when 'G' then CopyInResponse
when 'H' then CopyOutResponse
else nil
end
k ? k.new(bytes) : Unknown.new(type, bytes)
Expand Down Expand Up @@ -245,5 +248,14 @@ module PQ

struct EmptyQueryResponse < Frame
end

struct CopyDone < Frame
end

struct CopyInResponse < Frame
end

struct CopyOutResponse < Frame
end
end
end