diff --git a/conn.go b/conn.go index 25d88e0..fef9c2a 100644 --- a/conn.go +++ b/conn.go @@ -19,6 +19,7 @@ import ( "fmt" "os" "sync" + "syscall" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -266,8 +267,8 @@ func (cc *Conn) Flush() error { // Fetch the requested acknowledgement for each message we sent. for _, msg := range cc.messages { if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { - if errors.Is(err, os.ErrPermission) { - // Kernel will only send one permission error to user space. + if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { + // Kernel will only send one error to user space. return err } errs = errors.Join(errs, err) diff --git a/nftables_test.go b/nftables_test.go index b241327..606911b 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -23,6 +23,7 @@ import ( "os" "reflect" "strings" + "syscall" "testing" "time" @@ -7666,3 +7667,90 @@ func TestNftablesCompat(t *testing.T) { t.Fatalf("compat policy should conflict and err should not be err") } } + +func TestNftablesDeadlock(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + readBufSize int + writeBufSize int + sendRules int + wantRules int + wantErr error + }{ + { + name: "recv", + readBufSize: 1024, + writeBufSize: 1 * 1024 * 1024, + sendRules: 2048, + wantRules: 2048, + wantErr: syscall.ENOBUFS, + }, + { + name: "send", + readBufSize: 1 * 1024 * 1024, + writeBufSize: 1024, + sendRules: 2048, + wantRules: 0, + wantErr: syscall.EMSGSIZE, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) + conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.WithSockOptions(func(conn *netlink.Conn) error { + if err := conn.SetWriteBuffer(tt.writeBufSize); err != nil { + return err + } + if err := conn.SetReadBuffer(tt.readBufSize); err != nil { + return err + } + return nil + })) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := conn.AddTable(&nftables.Table{ + Name: "test_deadlock", + Family: nftables.TableFamilyIPv4, + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "filter", + Table: table, + }) + + for i := 0; i < tt.sendRules; i++ { + conn.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + }) + } + + flushErr := conn.Flush() + rules, err := conn.GetRules(table, chain) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + + if !errors.Is(flushErr, tt.wantErr) { + t.Errorf("conn.Flush() failed: %v", flushErr) + } + + if got, want := len(rules), tt.wantRules; got != want { + t.Fatalf("got rules %d, want rules %d", got, want) + } + }) + } +}