diff --git a/counter.go b/counter.go index 25d37d8..d18fc49 100644 --- a/counter.go +++ b/counter.go @@ -15,12 +15,11 @@ package nftables import ( - "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -// CounterObj implements Obj. type CounterObj struct { Table *Table Name string // e.g. “fwded” @@ -41,6 +40,20 @@ func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { return ad.Err() } +func (c *CounterObj) data() expr.Any { + return &expr.Counter{ + Bytes: c.Bytes, + Packets: c.Packets, + } +} + +func (c *CounterObj) name() string { + return c.Name +} +func (c *CounterObj) objType() ObjType { + return ObjTypeCounter +} + func (c *CounterObj) table() *Table { return c.Table } @@ -48,22 +61,3 @@ func (c *CounterObj) table() *Table { func (c *CounterObj) family() TableFamily { return c.Table.Family } - -func (c *CounterObj) marshal(data bool) ([]byte, error) { - obj, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(c.Bytes)}, - {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(c.Packets)}, - }) - if err != nil { - return nil, err - } - attrs := []netlink.Attribute{ - {Type: unix.NFTA_OBJ_TABLE, Data: []byte(c.Table.Name + "\x00")}, - {Type: unix.NFTA_OBJ_NAME, Data: []byte(c.Name + "\x00")}, - {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(unix.NFT_OBJECT_COUNTER)}, - } - if data { - attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj}) - } - return netlink.MarshalAttributes(attrs) -} diff --git a/monitor.go b/monitor.go index 853d5fd..0400cc9 100644 --- a/monitor.go +++ b/monitor.go @@ -259,7 +259,7 @@ func (monitor *Monitor) monitor() { } monitor.eventCh <- event case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ: - obj, err := objFromMsg(msg) + obj, err := objFromMsg(msg, true) event := &MonitorEvent{ Type: MonitorEventType(msgType), Data: obj, diff --git a/nftables_test.go b/nftables_test.go index be8b83b..dca28a1 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1783,7 +1783,7 @@ func TestListChainByName(t *testing.T) { } func TestListChainByNameUsingLasting(t *testing.T) { - conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) if err != nil { t.Fatalf("nftables.New() failed: %v", err) @@ -1882,8 +1882,7 @@ func TestListTableByName(t *testing.T) { } // not specifying correct family should return err since no table in ipv4 - tr, err = conn.ListTable(table2.Name) - if err == nil { + if _, err = conn.ListTable(table2.Name); err == nil { t.Fatalf("conn.ListTable() should have failed") } @@ -2106,17 +2105,18 @@ func TestGetObjReset(t *testing.T) { } filter := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} - obj, err := c.ResetObject(&nftables.CounterObj{ + obj, err := c.ResetObject(&nftables.NamedObj{ Table: filter, Name: "fwded", + Type: nftables.ObjTypeCounter, }) if err != nil { t.Fatal(err) } - co, ok := obj.(*nftables.CounterObj) + co, ok := obj.(*nftables.NamedObj) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj) + t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj) } if got, want := co.Table.Name, filter.Name; got != want { t.Errorf("unexpected table name: got %q, want %q", got, want) @@ -2124,10 +2124,14 @@ func TestGetObjReset(t *testing.T) { if got, want := co.Table.Family, filter.Family; got != want { t.Errorf("unexpected table family: got %d, want %d", got, want) } - if got, want := co.Packets, uint64(9); got != want { + o, ok := co.Obj.(*expr.Counter) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Counter", o) + } + if got, want := o.Packets, uint64(9); got != want { t.Errorf("unexpected number of packets: got %d, want %d", got, want) } - if got, want := co.Bytes, uint64(1121); got != want { + if got, want := o.Bytes, uint64(1121); got != want { t.Errorf("unexpected number of bytes: got %d, want %d", got, want) } } @@ -2164,6 +2168,366 @@ func TestObjAPI(t *testing.T) { Priority: nftables.ChainPriorityFilter, }) + counter1 := c.AddObj(&nftables.NamedObj{ + Table: table, + Name: "fwded1", + Type: nftables.ObjTypeCounter, + Obj: &expr.Counter{ + Bytes: 1, + Packets: 1, + }, + }) + + counter2 := c.AddObj(&nftables.NamedObj{ + Table: table, + Name: "fwded2", + Type: nftables.ObjTypeCounter, + Obj: &expr.Counter{ + Bytes: 1, + Packets: 1, + }, + }) + + c.AddObj(&nftables.NamedObj{ + Table: tableOther, + Name: "fwdedOther", + Type: nftables.ObjTypeCounter, + Obj: &expr.Counter{ + Bytes: 0, + Packets: 0, + }, + }) + + c.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Objref{ + Type: 1, + Name: "fwded1", + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Fatalf(err.Error()) + } + + objs, err := c.GetObjects(table) + if err != nil { + t.Errorf("c.GetObjects(table) failed: %v failed", err) + } + + if got := len(objs); got != 2 { + t.Fatalf("unexpected number of objects: got %d, want %d", got, 2) + } + + objsOther, err := c.GetObjects(tableOther) + if err != nil { + t.Errorf("c.GetObjects(tableOther) failed: %v failed", err) + } + + if got := len(objsOther); got != 1 { + t.Fatalf("unexpected number of objects: got %d, want %d", got, 1) + } + + obj1, err := c.GetObject(counter1) + if err != nil { + t.Errorf("c.GetObject(counter1) failed: %v failed", err) + } + + rcounter1, ok := obj1.(*nftables.NamedObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj1) + } + + if rcounter1.Name != "fwded1" { + t.Fatalf("unexpected counter name: got %s, want %s", rcounter1.Name, "fwded1") + } + + obj2, err := c.GetObject(counter2) + if err != nil { + t.Errorf("c.GetObject(counter2) failed: %v failed", err) + } + + rcounter2, ok := obj2.(*nftables.NamedObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2) + } + + if rcounter2.Name != "fwded2" { + t.Fatalf("unexpected counter name: got %s, want %s", rcounter2.Name, "fwded2") + } + + _, err = c.ResetObject(counter1) + + if err != nil { + t.Errorf("c.ResetObjects(table) failed: %v failed", err) + } + + obj1, err = c.GetObject(counter1) + + if err != nil { + t.Errorf("c.GetObject(counter1) failed: %v failed", err) + } + + if counter1 := obj1.(*nftables.NamedObj).Obj.(*expr.Counter); counter1.Packets > 0 { + t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) + } + + obj2, err = c.GetObject(counter2) + + if err != nil { + t.Errorf("c.GetObject(counter2) failed: %v failed", err) + } + + if counter2 := obj2.(*nftables.NamedObj).Obj.(*expr.Counter); counter2.Packets != 1 { + t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) + } + + legacy, err := c.GetObj(counter1) + if err != nil { + t.Errorf("c.GetObj(counter1) failed: %v failed", err) + } + + if len(legacy) != 2 { + t.Errorf("unexpected number of objects: got %d, want %d", len(legacy), 2) + } + + legacyReset, err := c.GetObjReset(counter1) + if err != nil { + t.Errorf("c.GetObjReset(counter1) failed: %v failed", err) + } + + if len(legacyReset) != 2 { + t.Errorf("unexpected number of objects: got %d, want %d", len(legacyReset), 2) + } +} + +func TestDeleteLegacyQuotaObj(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "quota_demo", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + o := &nftables.QuotaObj{ + Table: tr, + Name: "q_test", + Bytes: 0x06400000, + Consumed: 0, + Over: true, + } + conn.AddObj(o) + + if err := conn.Flush(); err != nil { + t.Fatalf("conn.Flush() failed: %v", err) + } + + obj, err := conn.GetObj(&nftables.QuotaObj{ + Table: table, + Name: "q_test", + }) + if err != nil { + t.Fatalf("conn.GetObj() failed: %v", err) + } + + if got, want := len(obj), 1; got != want { + t.Fatalf("unexpected number of objects: got %d, want %d", got, want) + } + + if got, want := obj[0], o; !reflect.DeepEqual(got, want) { + t.Errorf("got = %+v, want = %+v", got, want) + } + + conn.DeleteObject(&nftables.QuotaObj{ + Table: tr, + Name: "q_test", + }) + + if err := conn.Flush(); err != nil { + t.Fatalf("conn.Flush() failed: %v", err) + } + + obj, err = conn.GetObj(&nftables.QuotaObj{ + Table: table, + Name: "q_test", + }) + if err != nil { + t.Fatalf("conn.GetObj() failed: %v", err) + } + if got, want := len(obj), 0; got != want { + t.Fatalf("unexpected object list length: got %d, want %d", got, want) + } +} + +func TestAddLegacyQuotaObj(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "quota_demo", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + o := &nftables.QuotaObj{ + Table: tr, + Name: "q_test", + Bytes: 0x06400000, + Consumed: 0, + Over: true, + } + conn.AddObj(o) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + obj, err := conn.GetObj(&nftables.QuotaObj{ + Table: table, + Name: "q_test", + }) + if err != nil { + t.Fatalf("conn.GetObj() failed: %v", err) + } + + if got, want := len(obj), 1; got != want { + t.Fatalf("unexpected object list length: got %d, want %d", got, want) + } + + o1, ok := obj[0].(*nftables.QuotaObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0]) + } + if got, want := o1.Name, o.Name; got != want { + t.Fatalf("quota name mismatch: got %s, want %s", got, want) + } + if got, want := o1.Bytes, o.Bytes; got != want { + t.Fatalf("quota bytes mismatch: got %d, want %d", got, want) + } + if got, want := o1.Consumed, o.Consumed; got != want { + t.Fatalf("quota consumed mismatch: got %d, want %d", got, want) + } + if got, want := o1.Over, o.Over; got != want { + t.Fatalf("quota over mismatch: got %v, want %v", got, want) + } +} + +func TestAddLegacyQuotaObjRef(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "quota_demo", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + o := &nftables.QuotaObj{ + Table: tr, + Name: "q_test", + Bytes: 0x06400000, + Consumed: 0, + Over: true, + } + conn.AddObj(o) + + r := &nftables.Rule{ + Table: table, + Chain: c, + Exprs: []expr.Any{ + &expr.Objref{ + Type: 2, + Name: "q_test", + }, + }, + } + conn.AddRule(r) + if err := conn.Flush(); err != nil { + t.Fatalf("failed to flush: %v", err) + } + + rules, err := conn.GetRules(table, c) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + if got, want := len(rules[0].Exprs), 1; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + objref, ok := rules[0].Exprs[0].(*expr.Objref) + if !ok { + t.Fatalf("Exprs[0] is type %T, want *expr.Objref", rules[0].Exprs[0]) + } + if want := r.Exprs[0]; !reflect.DeepEqual(objref, want) { + t.Errorf("objref expr = %+v, wanted %+v", objref, want) + } +} + +func TestObjAPICounterLegacyType(t *testing.T) { + if os.Getenv("TRAVIS") == "true" { + t.SkipNow() + } + + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + table := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + tableOther := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "foo", + }) + + chain := c.AddChain(&nftables.Chain{ + Name: "chain", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityFilter, + }) + counter1 := c.AddObj(&nftables.CounterObj{ Table: table, Name: "fwded1", @@ -2767,7 +3131,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { } func TestCappedErrMsgOnSets(t *testing.T) { - c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) if err != nil { t.Fatalf("nftables.New() failed: %v", err) @@ -6285,6 +6649,84 @@ func TestGetRulesObjref(t *testing.T) { } } +func TestAddLimitObj(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "limit_demo", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + l := &expr.Limit{ + Type: expr.LimitTypePkts, + Rate: 400, + Unit: expr.LimitTimeMinute, + Burst: 5, + Over: false, + } + o := &nftables.NamedObj{ + Table: tr, + Name: "limit_test", + Type: nftables.ObjTypeLimit, + Obj: l, + } + conn.AddObj(o) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + obj, err := conn.GetObj(&nftables.NamedObj{ + Table: table, + Name: "limit_test", + Type: nftables.ObjTypeLimit, + }) + if err != nil { + t.Fatalf("conn.GetObj() failed: %v", err) + } + + if got, want := len(obj), 1; got != want { + t.Fatalf("unexpected object list length: got %d, want %d", got, want) + } + + o1, ok := obj[0].(*nftables.NamedObj) + if !ok { + t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) + } + if got, want := o1.Name, o.Name; got != want { + t.Fatalf("limit name mismatch: got %s, want %s", got, want) + } + q, ok := o1.Obj.(*expr.Limit) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) + } + if got, want := q.Burst, l.Burst; got != want { + t.Fatalf("limit burst mismatch: got %d, want %d", got, want) + } + if got, want := q.Unit, l.Unit; got != want { + t.Fatalf("limit unit mismatch: got %d, want %d", got, want) + } + if got, want := q.Rate, l.Rate; got != want { + t.Fatalf("limit rate mismatch: got %v, want %v", got, want) + } + if got, want := q.Over, l.Over; got != want { + t.Fatalf("limit over mismatch: got %v, want %v", got, want) + } + if got, want := q.Type, l.Type; got != want { + t.Fatalf("limit type mismatch: got %v, want %v", got, want) + } +} + func TestAddQuotaObj(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) @@ -6303,22 +6745,26 @@ func TestAddQuotaObj(t *testing.T) { } conn.AddChain(c) - o := &nftables.QuotaObj{ - Table: tr, - Name: "q_test", - Bytes: 0x06400000, - Consumed: 0, - Over: true, + o := &nftables.NamedObj{ + Table: tr, + Name: "q_test", + Type: nftables.ObjTypeQuota, + Obj: &expr.Quota{ + Bytes: 0x06400000, + Consumed: 0, + Over: true, + }, } conn.AddObj(o) if err := conn.Flush(); err != nil { - t.Errorf("conn.Flush() failed: %v", err) + t.Fatalf("conn.Flush() failed: %v", err) } - obj, err := conn.GetObj(&nftables.QuotaObj{ + obj, err := conn.GetObj(&nftables.NamedObj{ Table: table, Name: "q_test", + Type: nftables.ObjTypeQuota, }) if err != nil { t.Fatalf("conn.GetObj() failed: %v", err) @@ -6328,20 +6774,25 @@ func TestAddQuotaObj(t *testing.T) { t.Fatalf("unexpected object list length: got %d, want %d", got, want) } - o1, ok := obj[0].(*nftables.QuotaObj) + o1, ok := obj[0].(*nftables.NamedObj) if !ok { - t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0]) + t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) } if got, want := o1.Name, o.Name; got != want { t.Fatalf("quota name mismatch: got %s, want %s", got, want) } - if got, want := o1.Bytes, o.Bytes; got != want { + q, ok := o1.Obj.(*expr.Quota) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) + } + o2, _ := o.Obj.(*expr.Quota) + if got, want := q.Bytes, o2.Bytes; got != want { t.Fatalf("quota bytes mismatch: got %d, want %d", got, want) } - if got, want := o1.Consumed, o.Consumed; got != want { + if got, want := q.Consumed, o2.Consumed; got != want { t.Fatalf("quota consumed mismatch: got %d, want %d", got, want) } - if got, want := o1.Over, o.Over; got != want { + if got, want := q.Over, o2.Over; got != want { t.Fatalf("quota over mismatch: got %v, want %v", got, want) } } @@ -6409,7 +6860,7 @@ func TestAddQuotaObjRef(t *testing.T) { } } -func TestDeleteQuotaObj(t *testing.T) { +func TestDeleteQuotaObjMixedTypes(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) conn.FlushRuleset() @@ -6427,12 +6878,15 @@ func TestDeleteQuotaObj(t *testing.T) { } conn.AddChain(c) - o := &nftables.QuotaObj{ - Table: tr, - Name: "q_test", - Bytes: 0x06400000, - Consumed: 0, - Over: true, + o := &nftables.NamedObj{ + Table: tr, + Name: "q_test", + Type: nftables.ObjTypeQuota, + Obj: &expr.Quota{ + Bytes: 0x06400000, + Consumed: 0, + Over: true, + }, } conn.AddObj(o) @@ -6440,9 +6894,10 @@ func TestDeleteQuotaObj(t *testing.T) { t.Fatalf("conn.Flush() failed: %v", err) } - obj, err := conn.GetObj(&nftables.QuotaObj{ - Table: table, + obj, err := conn.GetObj(&nftables.NamedObj{ + Table: tr, Name: "q_test", + Type: nftables.ObjTypeQuota, }) if err != nil { t.Fatalf("conn.GetObj() failed: %v", err) @@ -6452,7 +6907,18 @@ func TestDeleteQuotaObj(t *testing.T) { t.Fatalf("unexpected number of objects: got %d, want %d", got, want) } - if got, want := obj[0], o; !reflect.DeepEqual(got, want) { + o2, _ := o.Obj.(*expr.Quota) + want := &nftables.NamedObj{ + Table: tr, + Name: "q_test", + Type: nftables.ObjTypeQuota, + Obj: &expr.Quota{ + Bytes: o2.Bytes, + Consumed: o2.Consumed, + Over: o2.Over, + }, + } + if got, want := obj[0], want; !reflect.DeepEqual(got, want) { t.Errorf("got = %+v, want = %+v", got, want) } diff --git a/obj.go b/obj.go index c468a63..997c1c5 100644 --- a/obj.go +++ b/obj.go @@ -18,6 +18,9 @@ import ( "encoding/binary" "fmt" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/google/nftables/internal/parseexprfunc" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) @@ -27,13 +30,73 @@ var ( delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ) ) +type ObjType uint32 + +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1612 +const ( + ObjTypeCounter ObjType = unix.NFT_OBJECT_COUNTER + ObjTypeQuota ObjType = unix.NFT_OBJECT_QUOTA + ObjTypeCtHelper ObjType = unix.NFT_OBJECT_CT_HELPER + ObjTypeLimit ObjType = unix.NFT_OBJECT_LIMIT + ObjTypeConnLimit ObjType = unix.NFT_OBJECT_CONNLIMIT + ObjTypeTunnel ObjType = unix.NFT_OBJECT_TUNNEL + ObjTypeCtTimeout ObjType = unix.NFT_OBJECT_CT_TIMEOUT + ObjTypeSecMark ObjType = unix.NFT_OBJECT_SECMARK + ObjTypeCtExpect ObjType = unix.NFT_OBJECT_CT_EXPECT + ObjTypeSynProxy ObjType = unix.NFT_OBJECT_SYNPROXY +) + +var objByObjTypeMagic = map[ObjType]string{ + ObjTypeCounter: "counter", + ObjTypeQuota: "quota", + ObjTypeLimit: "limit", + ObjTypeConnLimit: "connlimit", + ObjTypeCtHelper: "cthelper", // not implemented in expr + ObjTypeTunnel: "tunnel", // not implemented in expr + ObjTypeCtTimeout: "cttimeout", // not implemented in expr + ObjTypeSecMark: "secmark", // not implemented in expr + ObjTypeCtExpect: "ctexpect", // not implemented in expr + ObjTypeSynProxy: "synproxy", // not implemented in expr +} + // Obj represents a netfilter stateful object. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects type Obj interface { table() *Table family() TableFamily - unmarshal(*netlink.AttributeDecoder) error - marshal(data bool) ([]byte, error) + data() expr.Any + name() string + objType() ObjType +} + +// NamedObj represents nftables stateful object attributes +// Corresponds to netfilter nft_object_attributes as per +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=116e95aa7b6358c917de8c69f6f173874030b46b#n1626 +type NamedObj struct { + Table *Table + Name string + Type ObjType + Obj expr.Any +} + +func (o *NamedObj) table() *Table { + return o.Table +} + +func (o *NamedObj) family() TableFamily { + return o.Table.Family +} + +func (o *NamedObj) data() expr.Any { + return o.Obj +} + +func (o *NamedObj) name() string { + return o.Name +} + +func (o *NamedObj) objType() ObjType { + return o.Type } // AddObject adds the specified Obj. Alias of AddObj. @@ -46,18 +109,27 @@ func (cc *Conn) AddObject(o Obj) Obj { func (cc *Conn) AddObj(o Obj) Obj { cc.mu.Lock() defer cc.mu.Unlock() - data, err := o.marshal(true) + data, err := expr.MarshalExprData(byte(o.family()), o.data()) if err != nil { cc.setErr(err) return nil } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + if len(data) > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) + } + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, - Data: append(extraHeader(uint8(o.family()), 0), data...), + Data: append(extraHeader(uint8(o.family()), 0), cc.marshalAttr(attrs)...), }) return o } @@ -66,12 +138,12 @@ func (cc *Conn) AddObj(o Obj) Obj { func (cc *Conn) DeleteObject(o Obj) { cc.mu.Lock() defer cc.mu.Unlock() - data, err := o.marshal(false) - if err != nil { - cc.setErr(err) - return + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, } - + data := cc.marshalAttr(attrs) data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) cc.messages = append(cc.messages, netlink.Message{ @@ -85,17 +157,26 @@ func (cc *Conn) DeleteObject(o Obj) { // GetObj is a legacy method that return all Obj that belongs // to the same table as the given one +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. func (cc *Conn) GetObj(o Obj) ([]Obj, error) { - return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ) + return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ, cc.useLegacyObjType(o)) } // GetObjReset is a legacy method that reset all Obj that belongs // the same table as the given one +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { - return cc.getObj(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET) + return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET, cc.useLegacyObjType(o)) } // GetObject gets the specified Object +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. func (cc *Conn) GetObject(o Obj) (Obj, error) { objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) @@ -107,11 +188,16 @@ func (cc *Conn) GetObject(o Obj) (Obj, error) { } // GetObjects get all the Obj that belongs to the given table +// This function will always return legacy QuotaObj/CounterObj +// types for backwards compatibility func (cc *Conn) GetObjects(t *Table) ([]Obj, error) { return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ) } // ResetObject reset the given Obj +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. func (cc *Conn) ResetObject(o Obj) (Obj, error) { objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) @@ -123,11 +209,13 @@ func (cc *Conn) ResetObject(o Obj) (Obj, error) { } // ResetObjects reset all the Obj that belongs to the given table +// This function will always return legacy QuotaObj/CounterObj +// types for backwards compatibility func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) { return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET) } -func objFromMsg(msg netlink.Message) (Obj, error) { +func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) { if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 { return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) } @@ -150,38 +238,30 @@ func objFromMsg(msg netlink.Message) (Obj, error) { case unix.NFTA_OBJ_TYPE: objectType = ad.Uint32() case unix.NFTA_OBJ_DATA: - switch objectType { - case unix.NFT_OBJECT_COUNTER: - o := CounterObj{ - Table: table, - Name: name, - } - - ad.Do(func(b []byte) error { - ad, err := netlink.NewAttributeDecoder(b) - if err != nil { - return err - } - ad.ByteOrder = binary.BigEndian - return o.unmarshal(ad) - }) - return &o, ad.Err() - case NFT_OBJECT_QUOTA: - o := QuotaObj{ - Table: table, - Name: name, - } - - ad.Do(func(b []byte) error { - ad, err := netlink.NewAttributeDecoder(b) - if err != nil { - return err - } - ad.ByteOrder = binary.BigEndian - return o.unmarshal(ad) - }) - return &o, ad.Err() + if returnLegacyType { + return objDataFromMsgLegacy(ad, table, name, objectType) + } + + o := NamedObj{ + Table: table, + Name: name, + Type: ObjType(objectType), + } + + objs, err := parseexprfunc.ParseExprBytesFromNameFunc(byte(o.family()), ad, objByObjTypeMagic[o.Type]) + if err != nil { + return nil, err + } + if len(objs) == 0 { + return nil, fmt.Errorf("objFromMsg: objs is empty for obj %v", o) } + exprs := make([]expr.Any, len(objs)) + for i := range exprs { + exprs[i] = objs[i].(expr.Any) + } + + o.Obj = exprs[0] + return &o, ad.Err() } } if err := ad.Err(); err != nil { @@ -190,7 +270,50 @@ func objFromMsg(msg netlink.Message) (Obj, error) { return nil, fmt.Errorf("malformed stateful object") } +func objDataFromMsgLegacy(ad *netlink.AttributeDecoder, table *Table, name string, objectType uint32) (Obj, error) { + switch objectType { + case unix.NFT_OBJECT_COUNTER: + o := CounterObj{ + Table: table, + Name: name, + } + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + return o.unmarshal(ad) + }) + return &o, ad.Err() + case unix.NFT_OBJECT_QUOTA: + o := QuotaObj{ + Table: table, + Name: name, + } + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + return o.unmarshal(ad) + }) + return &o, ad.Err() + } + if err := ad.Err(); err != nil { + return nil, err + } + return nil, fmt.Errorf("malformed stateful object") +} + func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { + return cc.getObjWithLegacyType(o, t, msgType, cc.useLegacyObjType(o)) +} + +func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLegacyObjType bool) ([]Obj, error) { conn, closer, err := cc.netlinkConn() if err != nil { return nil, err @@ -201,7 +324,12 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { var flags netlink.HeaderFlags if o != nil { - data, err = o.marshal(false) + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + data = cc.marshalAttr(attrs) } else { flags = netlink.Dump data, err = netlink.MarshalAttributes([]netlink.Attribute{ @@ -230,7 +358,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { } var objs []Obj for _, msg := range reply { - o, err := objFromMsg(msg) + o, err := objFromMsg(msg, returnLegacyObjType) if err != nil { return nil, err } @@ -239,3 +367,14 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { return objs, nil } + +func (cc *Conn) useLegacyObjType(o Obj) bool { + useLegacyType := true + if o != nil { + switch o.(type) { + case *NamedObj: + useLegacyType = false + } + } + return useLegacyType +} diff --git a/quota.go b/quota.go index 71cb9bb..123c9da 100644 --- a/quota.go +++ b/quota.go @@ -15,16 +15,11 @@ package nftables import ( - "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -const ( - NFTA_OBJ_USERDATA = 8 - NFT_OBJECT_QUOTA = 2 -) - type QuotaObj struct { Table *Table Name string @@ -47,30 +42,6 @@ func (q *QuotaObj) unmarshal(ad *netlink.AttributeDecoder) error { return nil } -func (q *QuotaObj) marshal(data bool) ([]byte, error) { - flags := uint32(0) - if q.Over { - flags = unix.NFT_QUOTA_F_INV - } - obj, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_QUOTA_BYTES, Data: binaryutil.BigEndian.PutUint64(q.Bytes)}, - {Type: unix.NFTA_QUOTA_CONSUMED, Data: binaryutil.BigEndian.PutUint64(q.Consumed)}, - {Type: unix.NFTA_QUOTA_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, - }) - if err != nil { - return nil, err - } - attrs := []netlink.Attribute{ - {Type: unix.NFTA_OBJ_TABLE, Data: []byte(q.Table.Name + "\x00")}, - {Type: unix.NFTA_OBJ_NAME, Data: []byte(q.Name + "\x00")}, - {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(NFT_OBJECT_QUOTA)}, - } - if data { - attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj}) - } - return netlink.MarshalAttributes(attrs) -} - func (q *QuotaObj) table() *Table { return q.Table } @@ -78,3 +49,19 @@ func (q *QuotaObj) table() *Table { func (q *QuotaObj) family() TableFamily { return q.Table.Family } + +func (q *QuotaObj) data() expr.Any { + return &expr.Quota{ + Bytes: q.Bytes, + Consumed: q.Consumed, + Over: q.Over, + } +} + +func (q *QuotaObj) name() string { + return q.Name +} + +func (q *QuotaObj) objType() ObjType { + return ObjTypeQuota +}