diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go index a9f72265..7c29be78 100644 --- a/cmd/rigtest/rigtest.go +++ b/cmd/rigtest/rigtest.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/rand" "crypto/sha256" "errors" @@ -286,64 +285,71 @@ func main() { require.NoError(t, h.Configurer.DeleteFile(h, fn)) require.False(t, h.Configurer.FileExist(h, fn)) - testFileSize := int64(1 << (10 * 2)) // 1MB fsyses := []rigfs.Fsys{h.Fsys(), h.SudoFsys()} for idx, fsys := range fsyses { - t.Run("fsys functions (%d) on %s", idx+1, h) + for _, testFileSize := range []int64{ + int64(500), // less than one block on most filesystems + int64(1 << (10 * 2)), // exactly 1MB + int64(4096), // exactly one block on most filesystems + int64(4097), // plus 1 + } { + t.Run("fsys (%d) functions for file size %d on %s", idx+1, testFileSize, h) - origin := io.LimitReader(rand.Reader, testFileSize) - shasum := sha256.New() - reader := io.TeeReader(origin, shasum) + origin := io.LimitReader(rand.Reader, testFileSize) + shasum := sha256.New() + reader := io.TeeReader(origin, shasum) - destf, err := fsys.OpenFile(fn, rigfs.ModeCreate, 0644) - require.NoError(t, err, "open file") + destf, err := fsys.OpenFile(fn, rigfs.ModeCreate, 0644) + require.NoError(t, err, "open file") - n, err := io.Copy(destf, reader) - require.NoError(t, err, "io.copy file from local to remote") - require.Equal(t, testFileSize, n, "file size not as expected after copy") + n, err := io.Copy(destf, reader) + require.NoError(t, err, "io.copy file from local to remote") + require.Equal(t, testFileSize, n, "file size not as expected after copy") - require.NoError(t, destf.Close(), "error while closing file") + require.NoError(t, destf.Close(), "error while closing file") - fstat, err := fsys.Stat(fn) - require.NoError(t, err, "stat error") - require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result") + fstat, err := fsys.Stat(fn) + require.NoError(t, err, "stat error") + require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result") - destSum, err := fsys.Sha256(fn) - require.NoError(t, err, "sha256 error") + destSum, err := fsys.Sha256(fn) + require.NoError(t, err, "sha256 error") - require.Equal(t, fmt.Sprintf("%x", shasum.Sum(nil)), destSum, "sha256 mismatch after io.copy from local to remote") + require.Equal(t, fmt.Sprintf("%x", shasum.Sum(nil)), destSum, "sha256 mismatch after io.copy from local to remote") - destf, err = fsys.OpenFile(fn, rigfs.ModeRead, 0644) - require.NoError(t, err, "open file for read") + destf, err = fsys.OpenFile(fn, rigfs.ModeRead, 0644) + require.NoError(t, err, "open file for read") - readSha := sha256.New() - n, err = io.Copy(readSha, destf) - require.NoError(t, err, "io.copy file from remote to local") + readSha := sha256.New() + n, err = io.Copy(readSha, destf) + require.NoError(t, err, "io.copy file from remote to local") - require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local") + require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local") - fstat, err = destf.Stat() - require.NoError(t, err, "stat error after read") - require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result after read") - require.True(t, bytes.Equal(readSha.Sum(nil), shasum.Sum(nil)), "sha256 mismatch after io.copy from remote to local") + fstat, err = destf.Stat() + require.NoError(t, err, "stat error after read") + require.Equal(t, testFileSize, fstat.Size(), "file size not as expected in stat result after read") + require.Equal(t, readSha.Sum(nil), shasum.Sum(nil), "sha256 mismatch after io.copy from remote to local") - _, err = destf.Seek(0, 0) - require.NoError(t, err, "seek") + _, err = destf.Seek(0, 0) + require.NoError(t, err, "seek") - readSha.Reset() + readSha.Reset() - n, err = io.Copy(readSha, destf) - require.NoError(t, err, "io.copy file from remote to local after seek") + n, err = io.Copy(readSha, destf) + require.NoError(t, err, "io.copy file from remote to local after seek") - require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local after seek") + require.Equal(t, testFileSize, n, "file size not as expected after copy from remote to local after seek") - require.True(t, bytes.Equal(readSha.Sum(nil), shasum.Sum(nil)), "sha256 mismatch after io.copy from remote to local after seek") + require.Equal(t, readSha.Sum(nil), shasum.Sum(nil), "sha256 mismatch after io.copy from remote to local after seek") - require.NoError(t, destf.Close(), "close after seek + read") - require.NoError(t, fsys.Remove(fn), "remove file") - _, err = destf.Stat() - require.ErrorIs(t, err, fs.ErrNotExist, "file still exists") + require.NoError(t, destf.Close(), "close after seek + read") + require.NoError(t, fsys.Remove(fn), "remove file") + _, err = destf.Stat() + require.ErrorIs(t, err, fs.ErrNotExist, "file still exists") + } + t.Run("fsys (%d) dir ops on %s", idx+1, h) // fsys dirops require.NoError(t, fsys.MkDirAll("tmpdir/nested", 0644), "make nested dir") diff --git a/pkg/rigfs/posixfsys.go b/pkg/rigfs/posixfsys.go index ddba60dd..37c1f78d 100644 --- a/pkg/rigfs/posixfsys.go +++ b/pkg/rigfs/posixfsys.go @@ -8,6 +8,8 @@ import ( "io" "io/fs" "os" + "path/filepath" + "strconv" "strings" "time" @@ -38,6 +40,8 @@ func NewPosixFsys(conn connection, opts ...exec.Option) *PosixFsys { return &PosixFsys{conn: conn, opts: opts} } +const defaultBlockSize = 4096 + // PosixFile implements fs.File for a remote file type PosixFile struct { fsys *PosixFsys @@ -47,6 +51,8 @@ type PosixFile struct { pos int64 size int64 mode FileMode + + blockSize int } // PosixDir implements fs.ReadDirFile for a remote directory @@ -83,6 +89,22 @@ func (f *PosixDir) ReadDir(n int) ([]fs.DirEntry, error) { return f.entries[old:f.hw], nil } +func (f *PosixFile) fsBlockSize() int { + if f.blockSize > 0 { + return f.blockSize + } + + out, err := f.fsys.conn.ExecOutput(fmt.Sprintf(`stat -c "%%s" %[1]s 2> /dev/null || stat -f "%%k" %[1]s`, shellescape.Quote(filepath.Dir(f.path))), f.fsys.opts...) + if err != nil { + // fall back to default + f.blockSize = defaultBlockSize + } else if bs, err := strconv.Atoi(strings.TrimSpace(out)); err == nil { + f.blockSize = bs + } + + return f.blockSize +} + func (f *PosixFile) isReadable() bool { return f.mode&ModeRead != 0 } @@ -91,13 +113,18 @@ func (f *PosixFile) isWritable() bool { return f.mode&ModeWrite != 0 } -const blockSize = 4096 - func (f *PosixFile) ddParams(offset int64, numBytes int) (int, int64, int) { - skip := offset / int64(blockSize) - count := (numBytes + blockSize - 1) / blockSize + bs := f.fsBlockSize() - return blockSize, skip, count + if numBytes < bs { + bs = numBytes + skip := offset / int64(bs) + return bs, skip, 1 + } + + skip := offset / int64(bs) + count := numBytes / bs + return bs, skip, count } // Stat returns a FileInfo describing the named file @@ -113,9 +140,13 @@ func (f *PosixFile) Read(p []byte) (int, error) { if !f.isReadable() { return 0, fmt.Errorf("%w: file %s is not open for reading", ErrCommandFailed, f.path) } - bs, skip, count := f.ddParams(f.pos, len(p)) errbuf := bytes.NewBuffer(nil) + + bs, skip, count := f.ddParams(f.pos, len(p)) + toRead := bs * count buf := bytes.NewBuffer(nil) + errbuf.Reset() + cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count), nil, buf, errbuf, f.fsys.opts...) if err != nil { return 0, fmt.Errorf("%w: failed to execute dd: %w (%s)", ErrCommandFailed, err, errbuf.String()) @@ -123,32 +154,39 @@ func (f *PosixFile) Read(p []byte) (int, error) { if err := cmd.Wait(); err != nil { return 0, fmt.Errorf("%w: read (dd): %w (%s)", ErrCommandFailed, err, errbuf.String()) } - f.pos += int64(buf.Len()) - if buf.Len() < len(p) { + + readBytes := copy(p, buf.Bytes()) + f.pos += int64(readBytes) + if readBytes < len(p) || readBytes < toRead { f.isEOF = true + return readBytes, io.EOF } - return copy(p, buf.Bytes()), nil + return readBytes, nil } func (f *PosixFile) Write(p []byte) (int, error) { if !f.isWritable() { return 0, fmt.Errorf("%w: file %s is not open for writing", ErrCommandFailed, f.path) } + bs, skip, count := f.ddParams(f.pos, len(p)) + toWrite := bs * count + errbuf := bytes.NewBuffer(nil) - cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d count=%d seek=%d conv=notrunc", f.path, bs, count, skip), io.NopCloser(bytes.NewReader(p)), io.Discard, errbuf, f.fsys.opts...) + limitedReader := bytes.NewReader(p[:toWrite]) + cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d count=%d seek=%d conv=notrunc", f.path, bs, count, skip), io.NopCloser(limitedReader), io.Discard, errbuf, f.fsys.opts...) if err != nil { return 0, fmt.Errorf("%w: write (dd): %w", ErrCommandFailed, err) } if err := cmd.Wait(); err != nil { return 0, fmt.Errorf("%w: write (dd): %w (%s)", ErrCommandFailed, err, errbuf.String()) } - written := len(p) - f.pos += int64(written) + f.pos += int64(toWrite) + if f.pos > f.size { f.size = f.pos } - return written, nil + return toWrite, nil } // CopyFromN copies n bytes from the remote file. The alt writer can be used for progress diff --git a/test/test.sh b/test/test.sh index ff9af77a..61c15a2a 100755 --- a/test/test.sh +++ b/test/test.sh @@ -30,6 +30,16 @@ sanity_check() { RET=$exit_code } +rig_test_key_from_path() { + color_echo "- Testing regular keypath and host functions" + make create-host + mv .ssh/identity .ssh/identity2 + set +e + ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity2 + local exit_code=$? + set -e + RET=$exit_code +} rig_test_agent_with_public_key() { color_echo "- Testing connection using agent and providing a path to public key" @@ -159,18 +169,6 @@ rig_test_ssh_config_no_strict() { RET=$exit_code } - -rig_test_key_from_path() { - color_echo "- Testing regular keypath and host functions" - make create-host - mv .ssh/identity .ssh/identity2 - set +e - ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity2 - local exit_code=$? - set -e - RET=$exit_code -} - rig_test_key_from_memory() { color_echo "- Testing connecting using a key from string" make create-host @@ -187,7 +185,7 @@ rig_test_key_from_default_location() { make create-host mv .ssh/identity .ssh/id_ecdsa set +e - HOME=$(pwd) ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root + HOME=$(pwd) ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect local exit_code=$? set -e RET=$exit_code