diff --git a/internal/sysfs/poll_windows.go b/internal/sysfs/poll_windows.go index af7cdd42946..872177a092b 100644 --- a/internal/sysfs/poll_windows.go +++ b/internal/sysfs/poll_windows.go @@ -45,6 +45,10 @@ const ( ) func poll(fds []PollFd, timeout int) (int, sys.Errno) { + if fds == nil { + return -1, sys.ENOSYS + } + regular, pipes, sockets, errno := ftypes(fds) nregular := len(regular) if errno != 0 { diff --git a/internal/sysfs/poll_windows_test.go b/internal/sysfs/poll_windows_test.go new file mode 100644 index 00000000000..b2c5c00188b --- /dev/null +++ b/internal/sysfs/poll_windows_test.go @@ -0,0 +1,294 @@ +package sysfs + +import ( + "context" + "net" + "os" + "syscall" + "testing" + "time" + + "github.com/tetratelabs/wazero/experimental/sys" + "github.com/tetratelabs/wazero/internal/testing/require" +) + +func TestPoll_Windows(t *testing.T) { + type result struct { + n int + err sys.Errno + } + + testCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + _ = testCtx + + pollToChannel := func(readHandle syscall.Handle, duration *time.Duration, ch chan result) { + r := result{} + fds := []PollFd{{fd: int32(readHandle), events: _POLLIN}} + r.n, r.err = Poll(fds, duration) + ch <- r + close(ch) + } + + t.Run("Poll returns sys.ENOSYS when n == 0 and duration is nil", func(t *testing.T) { + n, errno := Poll(nil, nil) + require.Equal(t, -1, n) + require.EqualErrno(t, sys.ENOSYS, errno) + }) + + t.Run("peekNamedPipe should report the correct state of incoming data in the pipe", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + // Ensure the pipe has no data. + n, err := peekNamedPipe(rh) + require.Zero(t, err) + require.Zero(t, n) + + // Write to the channel. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Ensure the pipe has data. + n, err = peekNamedPipe(rh) + require.Zero(t, err) + require.Equal(t, 6, int(n)) + }) + + t.Run("peekPipes should return an error on invalid handle", func(t *testing.T) { + fds := []PollFd{{fd: -1}} + _, err := peekPipes(fds) + require.EqualErrno(t, sys.EBADF, err) + }) + + t.Run("peekAll should return an error on invalid handle", func(t *testing.T) { + fds := []PollFd{{fd: -1}} + npipes, nsockets, err := peekAll(fds, nil) + require.EqualErrno(t, sys.EBADF, err) + require.Equal(t, 0, npipes) + require.Equal(t, 0, nsockets) + }) + + t.Run("poll should return successfully with a regular file", func(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "test") + require.NoError(t, err) + defer f.Close() + + fds := []PollFd{{fd: int32(f.Fd())}} + + n, errno := poll(fds, 0) + require.Zero(t, errno) + require.Equal(t, 1, n) + }) + + t.Run("peekAll should return successfully with a pipe", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + defer r.Close() + defer w.Close() + + fds := []PollFd{{fd: int32(r.Fd())}} + + npipes, nsockets, errno := peekAll(fds, nil) + require.Zero(t, errno) + require.Equal(t, 0, npipes) + require.Equal(t, 0, nsockets) + + w.Write([]byte("wazero")) + npipes, nsockets, errno = peekAll(fds, nil) + require.Zero(t, errno) + require.Equal(t, 1, npipes) + require.Equal(t, 0, nsockets) + }) + + t.Run("peekAll should return successfully with a socket", func(t *testing.T) { + listen, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listen.Close() + + conn, err := listen.(*net.TCPListener).SyscallConn() + require.NoError(t, err) + + fds := []PollFd{} + conn.Control(func(fd uintptr) { + fds = append(fds, PollFd{fd: int32(fd)}) + }) + + npipes, nsockets, errno := peekAll(nil, fds) + require.Zero(t, errno) + require.Equal(t, 0, npipes) + require.Equal(t, 1, nsockets) + + tcpAddr, err := net.ResolveTCPAddr("tcp", listen.Addr().String()) + require.NoError(t, err) + tcp, err := net.DialTCP("tcp", nil, tcpAddr) + require.NoError(t, err) + tcp.Write([]byte("wazero")) + + conn.Control(func(fd uintptr) { + fds[0].fd = int32(fd) + }) + npipes, nsockets, errno = peekAll(nil, fds) + require.Zero(t, errno) + require.Equal(t, 0, npipes) + require.Equal(t, 1, nsockets) + }) + + t.Run("Poll should return immediately when duration is zero (no data)", func(t *testing.T) { + r, _, err := os.Pipe() + require.NoError(t, err) + d := time.Duration(0) + fds := []PollFd{{fd: int32(r.Fd()), events: _POLLIN}} + n, err := Poll(fds, &d) + require.Zero(t, err) + require.Zero(t, n) + }) + + t.Run("Poll should return immediately when duration is zero (data)", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + fds := []PollFd{{fd: int32(r.Fd()), events: _POLLIN}} + wh := syscall.Handle(w.Fd()) + + // Write to the channel immediately. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Verify that the write is reported. + d := time.Duration(0) + n, err := Poll(fds, &d) + require.Zero(t, err) + require.Equal(t, 1, n) + }) + + t.Run("Poll should wait forever when duration is nil (no writes)", func(t *testing.T) { + r, _, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + + ch := make(chan result, 1) + go pollToChannel(rh, nil, ch) + + // Wait a little, then ensure no writes occurred. + <-time.After(500 * time.Millisecond) + require.Equal(t, 0, len(ch)) + }) + + t.Run("Poll should wait forever when duration is nil", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + ch := make(chan result, 1) + go pollToChannel(rh, nil, ch) + + // Wait a little, then ensure no writes occurred. + <-time.After(100 * time.Millisecond) + require.Equal(t, 0, len(ch)) + + // Write a message to the pipe. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Ensure that the write occurs (panic after an arbitrary timeout). + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("unreachable!") + case r := <-ch: + require.Zero(t, r.err) + require.NotEqual(t, 0, r.n) + } + }) + + t.Run("Poll should wait for the given duration", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + d := 500 * time.Millisecond + ch := make(chan result, 1) + go pollToChannel(rh, &d, ch) + + // Wait a little, then ensure no writes occurred. + <-time.After(100 * time.Millisecond) + require.Equal(t, 0, len(ch)) + + // Write a message to the pipe. + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + // Ensure that the write occurs before the timer expires. + select { + case <-time.After(500 * time.Millisecond): + panic("no data!") + case r := <-ch: + require.Zero(t, r.err) + require.Equal(t, 1, r.n) + } + }) + + t.Run("Poll should timeout after the given duration", func(t *testing.T) { + r, _, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + + d := 200 * time.Millisecond + ch := make(chan result, 1) + go pollToChannel(rh, &d, ch) + + // Wait a little, then ensure a message has been written to the channel. + <-time.After(300 * time.Millisecond) + require.Equal(t, 1, len(ch)) + + // Ensure that the timer has expired. + res := <-ch + require.Zero(t, res.err) + require.Zero(t, res.n) + }) + + t.Run("Poll should return when a write occurs before the given duration", func(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + rh := syscall.Handle(r.Fd()) + wh := syscall.Handle(w.Fd()) + + d := 600 * time.Millisecond + ch := make(chan result, 1) + go pollToChannel(rh, &d, ch) + + <-time.After(300 * time.Millisecond) + require.Equal(t, 0, len(ch)) + + msg, err := syscall.ByteSliceFromString("test\n") + require.NoError(t, err) + _, err = syscall.Write(wh, msg) + require.NoError(t, err) + + res := <-ch + require.Zero(t, res.err) + require.Equal(t, 1, res.n) + }) + + t.Run("Poll should return when a regular file is given", func(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "ex") + defer f.Close() + require.NoError(t, err) + fds := []PollFd{{fd: int32(f.Fd()), events: _POLLIN}} + d := time.Duration(0) + n, errno := Poll(fds, &d) + require.Zero(t, errno) + require.Equal(t, 1, n) + }) +}