Skip to content

Commit

Permalink
Improve file read/write efficiency / robustness and add test cases
Browse files Browse the repository at this point in the history
Signed-off-by: Kimmo Lehto <[email protected]>
  • Loading branch information
kke committed Oct 5, 2023
1 parent 2b813ca commit 2839efd
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 65 deletions.
84 changes: 45 additions & 39 deletions cmd/rigtest/rigtest.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"crypto/rand"
"crypto/sha256"
"errors"
Expand Down Expand Up @@ -286,64 +285,71 @@ func main() {
require.NoError(t, h.Configurer.DeleteFile(h, fn))
require.False(t, h.Configurer.FileExist(h, fn))

testFileSize := int64(1 << (10 * 2)) // 1MB
fsyses := []rigfs.Fsys{h.Fsys(), h.SudoFsys()}

for idx, fsys := range fsyses {
t.Run("fsys functions (%d) on %s", idx+1, h)
for _, testFileSize := range []int64{
int64(500), // less than one block on most filesystems
int64(1 << (10 * 2)), // exactly 1MB
int64(4096), // exactly one block on most filesystems
int64(4097), // plus 1
} {
t.Run("fsys (%d) functions for file size %d on %s", idx+1, testFileSize, h)

origin := io.LimitReader(rand.Reader, testFileSize)
shasum := sha256.New()
reader := io.TeeReader(origin, shasum)
origin := io.LimitReader(rand.Reader, testFileSize)
shasum := sha256.New()
reader := io.TeeReader(origin, shasum)

destf, err := fsys.OpenFile(fn, rigfs.ModeCreate, 0644)
require.NoError(t, err, "open file")
destf, err := fsys.OpenFile(fn, rigfs.ModeCreate, 0644)
require.NoError(t, err, "open file")

n, err := io.Copy(destf, reader)
require.NoError(t, err, "io.copy file from local to remote")
require.Equal(t, testFileSize, n, "file size not as expected after copy")
n, err := io.Copy(destf, reader)
require.NoError(t, err, "io.copy file from local to remote")
require.Equal(t, testFileSize, n, "file size not as expected after copy")

require.NoError(t, destf.Close(), "error while closing file")
require.NoError(t, destf.Close(), "error while closing file")

fstat, err := fsys.Stat(fn)
require.NoError(t, err, "stat error")
require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result")
fstat, err := fsys.Stat(fn)
require.NoError(t, err, "stat error")
require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result")

destSum, err := fsys.Sha256(fn)
require.NoError(t, err, "sha256 error")
destSum, err := fsys.Sha256(fn)
require.NoError(t, err, "sha256 error")

require.Equal(t, fmt.Sprintf("%x", shasum.Sum(nil)), destSum, "sha256 mismatch after io.copy from local to remote")
require.Equal(t, fmt.Sprintf("%x", shasum.Sum(nil)), destSum, "sha256 mismatch after io.copy from local to remote")

destf, err = fsys.OpenFile(fn, rigfs.ModeRead, 0644)
require.NoError(t, err, "open file for read")
destf, err = fsys.OpenFile(fn, rigfs.ModeRead, 0644)
require.NoError(t, err, "open file for read")

readSha := sha256.New()
n, err = io.Copy(readSha, destf)
require.NoError(t, err, "io.copy file from remote to local")
readSha := sha256.New()
n, err = io.Copy(readSha, destf)
require.NoError(t, err, "io.copy file from remote to local")

require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local")
require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local")

fstat, err = destf.Stat()
require.NoError(t, err, "stat error after read")
require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result after read")
require.True(t, bytes.Equal(readSha.Sum(nil), shasum.Sum(nil)), "sha256 mismatch after io.copy from remote to local")
fstat, err = destf.Stat()
require.NoError(t, err, "stat error after read")
require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result after read")
require.Equal(t, readSha.Sum(nil), shasum.Sum(nil), "sha256 mismatch after io.copy from remote to local")

_, err = destf.Seek(0, 0)
require.NoError(t, err, "seek")
_, err = destf.Seek(0, 0)
require.NoError(t, err, "seek")

readSha.Reset()
readSha.Reset()

n, err = io.Copy(readSha, destf)
require.NoError(t, err, "io.copy file from remote to local after seek")
n, err = io.Copy(readSha, destf)
require.NoError(t, err, "io.copy file from remote to local after seek")

require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local after seek")
require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local after seek")

require.True(t, bytes.Equal(readSha.Sum(nil), shasum.Sum(nil)), "sha256 mismatch after io.copy from remote to local after seek")
require.Equal(t, readSha.Sum(nil), shasum.Sum(nil), "sha256 mismatch after io.copy from remote to local after seek")

require.NoError(t, destf.Close(), "close after seek + read")
require.NoError(t, fsys.Remove(fn), "remove file")
_, err = destf.Stat()
require.ErrorIs(t, err, fs.ErrNotExist, "file still exists")
require.NoError(t, destf.Close(), "close after seek + read")
require.NoError(t, fsys.Remove(fn), "remove file")
_, err = destf.Stat()
require.ErrorIs(t, err, fs.ErrNotExist, "file still exists")
}
t.Run("fsys (%d) dir ops on %s", idx+1, h)

// fsys dirops
require.NoError(t, fsys.MkDirAll("tmpdir/nested", 0644), "make nested dir")
Expand Down
64 changes: 51 additions & 13 deletions pkg/rigfs/posixfsys.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"io"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -38,6 +40,8 @@ func NewPosixFsys(conn connection, opts ...exec.Option) *PosixFsys {
return &PosixFsys{conn: conn, opts: opts}
}

const defaultBlockSize = 4096

// PosixFile implements fs.File for a remote file
type PosixFile struct {
fsys *PosixFsys
Expand All @@ -47,6 +51,8 @@ type PosixFile struct {
pos int64
size int64
mode FileMode

blockSize int
}

// PosixDir implements fs.ReadDirFile for a remote directory
Expand Down Expand Up @@ -83,6 +89,22 @@ func (f *PosixDir) ReadDir(n int) ([]fs.DirEntry, error) {
return f.entries[old:f.hw], nil
}

func (f *PosixFile) fsBlockSize() int {
if f.blockSize > 0 {
return f.blockSize
}

out, err := f.fsys.conn.ExecOutput(fmt.Sprintf(`stat -c "%%s" %[1]s 2> /dev/null || stat -f "%%k" %[1]s`, shellescape.Quote(filepath.Dir(f.path))), f.fsys.opts...)
if err != nil {
// fall back to default
f.blockSize = defaultBlockSize
} else if bs, err := strconv.Atoi(strings.TrimSpace(out)); err == nil {
f.blockSize = bs
}

return f.blockSize
}

func (f *PosixFile) isReadable() bool {
return f.mode&ModeRead != 0
}
Expand All @@ -91,13 +113,18 @@ func (f *PosixFile) isWritable() bool {
return f.mode&ModeWrite != 0
}

const blockSize = 4096

func (f *PosixFile) ddParams(offset int64, numBytes int) (int, int64, int) {
skip := offset / int64(blockSize)
count := (numBytes + blockSize - 1) / blockSize
bs := f.fsBlockSize()

return blockSize, skip, count
if numBytes < bs {
bs = numBytes
skip := offset / int64(bs)
return bs, skip, 1
}

skip := offset / int64(bs)
count := numBytes / bs
return bs, skip, count
}

// Stat returns a FileInfo describing the named file
Expand All @@ -113,42 +140,53 @@ func (f *PosixFile) Read(p []byte) (int, error) {
if !f.isReadable() {
return 0, fmt.Errorf("%w: file %s is not open for reading", ErrCommandFailed, f.path)
}
bs, skip, count := f.ddParams(f.pos, len(p))
errbuf := bytes.NewBuffer(nil)

bs, skip, count := f.ddParams(f.pos, len(p))
toRead := bs * count
buf := bytes.NewBuffer(nil)
errbuf.Reset()

cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count), nil, buf, errbuf, f.fsys.opts...)
if err != nil {
return 0, fmt.Errorf("%w: failed to execute dd: %w (%s)", ErrCommandFailed, err, errbuf.String())
}
if err := cmd.Wait(); err != nil {
return 0, fmt.Errorf("%w: read (dd): %w (%s)", ErrCommandFailed, err, errbuf.String())
}
f.pos += int64(buf.Len())
if buf.Len() < len(p) {

readBytes := copy(p, buf.Bytes())
f.pos += int64(readBytes)
if readBytes < len(p) || readBytes < toRead {
f.isEOF = true
return readBytes, io.EOF
}
return copy(p, buf.Bytes()), nil
return readBytes, nil
}

func (f *PosixFile) Write(p []byte) (int, error) {
if !f.isWritable() {
return 0, fmt.Errorf("%w: file %s is not open for writing", ErrCommandFailed, f.path)
}

bs, skip, count := f.ddParams(f.pos, len(p))
toWrite := bs * count

errbuf := bytes.NewBuffer(nil)
cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d count=%d seek=%d conv=notrunc", f.path, bs, count, skip), io.NopCloser(bytes.NewReader(p)), io.Discard, errbuf, f.fsys.opts...)
limitedReader := bytes.NewReader(p[:toWrite])
cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d count=%d seek=%d conv=notrunc", f.path, bs, count, skip), io.NopCloser(limitedReader), io.Discard, errbuf, f.fsys.opts...)
if err != nil {
return 0, fmt.Errorf("%w: write (dd): %w", ErrCommandFailed, err)
}
if err := cmd.Wait(); err != nil {
return 0, fmt.Errorf("%w: write (dd): %w (%s)", ErrCommandFailed, err, errbuf.String())
}
written := len(p)
f.pos += int64(written)
f.pos += int64(toWrite)

if f.pos > f.size {
f.size = f.pos
}
return written, nil
return toWrite, nil
}

// CopyFromN copies n bytes from the remote file. The alt writer can be used for progress
Expand Down
24 changes: 11 additions & 13 deletions test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ sanity_check() {
RET=$exit_code
}

rig_test_key_from_path() {
color_echo "- Testing regular keypath and host functions"
make create-host
mv .ssh/identity .ssh/identity2
set +e
./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity2
local exit_code=$?
set -e
RET=$exit_code
}

rig_test_agent_with_public_key() {
color_echo "- Testing connection using agent and providing a path to public key"
Expand Down Expand Up @@ -159,18 +169,6 @@ rig_test_ssh_config_no_strict() {
RET=$exit_code
}


rig_test_key_from_path() {
color_echo "- Testing regular keypath and host functions"
make create-host
mv .ssh/identity .ssh/identity2
set +e
./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity2
local exit_code=$?
set -e
RET=$exit_code
}

rig_test_key_from_memory() {
color_echo "- Testing connecting using a key from string"
make create-host
Expand All @@ -187,7 +185,7 @@ rig_test_key_from_default_location() {
make create-host
mv .ssh/identity .ssh/id_ecdsa
set +e
HOME=$(pwd) ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root
HOME=$(pwd) ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect
local exit_code=$?
set -e
RET=$exit_code
Expand Down

0 comments on commit 2839efd

Please sign in to comment.