diff --git a/Makefile b/Makefile index 9ae4ae2..775b2ed 100644 --- a/Makefile +++ b/Makefile @@ -58,10 +58,10 @@ TARGET = graftcp $(GRAFTCP_LOCAL_BIN) all:: $(TARGET) -graftcp: main.o graftcp.o util.o string-set.o conf.o +graftcp: main.o graftcp.o util.o string-set.o cidr-trie.o conf.o $(CC) $^ -o $@ -libgraftcp.a: graftcp.o util.o string-set.o conf.o +libgraftcp.a: graftcp.o util.o string-set.o cidr-trie.o conf.o $(AR) rcs $@ $^ %.o: %.c diff --git a/cidr-trie.c b/cidr-trie.c new file mode 100644 index 0000000..bb3d2bd --- /dev/null +++ b/cidr-trie.c @@ -0,0 +1,223 @@ +/* + * graftcp + * Copyright (C) 2023 Hmgle + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +/* Inspired by nginx: ngx_radix_tree.c */ +#include +#include +#include + +#include "cidr-trie.h" + +static trie_node_t *node_callc() +{ + trie_node_t *node = calloc(1, sizeof(*node)); + node->left = NULL; + node->right = NULL; + node->value = TRIE_NO_VALUE; + return node; +} + +trie_t *trie_new() +{ + trie_t *trie = calloc(1, sizeof(*trie)); + trie->root = node_callc(); + return trie; +} + +void trie32_insert(trie_t *trie, struct cidr_s *cidr, int value) +{ + trie_node_t *node, *next; + uint32_t bit = 0x80000000; + + next = trie->root; + for (node = trie->root; bit & cidr->mask; bit >>= 1) { + next = cidr->addr & bit ? node->right : node->left; + if (next == NULL) + break; + node = next; + } + if (next) { + node->value = value; + return; + } + for (; bit & cidr->mask; bit >>= 1) { + next = node_callc(); + if (cidr->addr & bit) + node->right = next; + else + node->left = next; + node = next; + } + node->value = value; +} + +#define IPV4_MAX_TEXT_LENGTH 15 +#define IPV6_MAX_TEXT_LENGTH 45 + +static int parse_cidr(const char *line, struct cidr_s *cidr) +{ + char *p; + int shift; + char ipbuf[IPV4_MAX_TEXT_LENGTH + 1]; + + p = strchr(line, '/'); + if (p) { + if (p - line > IPV4_MAX_TEXT_LENGTH) + return -1; + strncpy(ipbuf, line, p - line); + ipbuf[p - line] = '\0'; + cidr->addr = ntohl(inet_addr(ipbuf)); + + shift = strtol(++p, NULL, 0); + if (shift < 0 || shift > 32) + return -1; + cidr->mask = shift ? (uint32_t)(0xffffffff << (32 - shift)) : 0; + } else { + cidr->mask = 0xffffffff; + cidr->addr = ntohl(inet_addr(line)); + } + return 0; +} + +int trie32_insert_str(trie_t *trie, const char *ipstr, int value) +{ + struct cidr_s cidr; + + if (parse_cidr(ipstr, &cidr)) + return -1; + trie32_insert(trie, &cidr, value); + return 0; +} + +int trie32_lookup(trie_t *trie, uint32_t ip) +{ + uint32_t bit = 0x80000000; + trie_node_t *node; + + for (node = trie->root; node;) { + if (node->value != TRIE_NO_VALUE) + return node->value; + node = ip & bit ? node->right : node->left; + bit >>= 1; + } + return TRIE_NO_VALUE; +} + +void trie128_insert(trie_t *trie, struct cidr6_s *cidr6, int value) +{ + trie_node_t *node, *next; + uint8_t bit = 0x80; + uint i = 0; + + next = trie->root; + for (node = trie->root; bit & cidr6->mask.s6_addr[i];) { + next = bit & cidr6->addr.s6_addr[i] ? node->right : node->left; + if (next == NULL) + break; + bit >>= 1; + node = next; + if (bit == 0) { + if (++i == 16) + break; + bit = 0x80; + } + } + if (next) { + node->value = value; + return; + } + for (; bit & cidr6->mask.s6_addr[i];) { + next = node_callc(); + if (bit & cidr6->addr.s6_addr[i]) + node->right = next; + else + node->left = next; + bit >>= 1; + node = next; + if (bit == 0) { + if (++i == 16) + break; + bit = 0x80; + } + } + node->value = value; +} + +static int parse_cidr6(const char *line, struct cidr6_s *cidr6) +{ + char *p; + int shift; + char ip6buf[IPV6_MAX_TEXT_LENGTH + 1]; + uint8_t *mask; + uint i, s; + + p = strchr(line, '/'); + if (p) { + if (p - line > IPV6_MAX_TEXT_LENGTH) + return -1; + strncpy(ip6buf, line, p - line); + ip6buf[p - line] = '\0'; + if (inet_pton(AF_INET6, ip6buf, &cidr6->addr) != 1) + return -1; + + shift = strtol(++p, NULL, 0); + if (shift < 0 || shift > 128) + return -1; + if (shift) { + mask = cidr6->mask.s6_addr; + for (i = 0; i < 16; i++) { + s = (shift > 8) ? 8 : shift; + shift -= s; + mask[i] = (u_char) (0xffu << (8 - s)); + } + } else { + memset(cidr6->mask.s6_addr, 0, 16); + } + } else { + if (inet_pton(AF_INET6, line, &cidr6->addr) != 1) + return -1; + memset(cidr6->mask.s6_addr, 0xff, 16); + } + return 0; +} + +int trie128_insert_str(trie_t *trie, const char *ipstr, int value) +{ + struct cidr6_s cidr6; + + if (parse_cidr6(ipstr, &cidr6)) + return -1; + trie128_insert(trie, &cidr6, value); + return 0; +} + +int trie128_lookup(trie_t *trie, uint8_t *ip) +{ + trie_node_t *node; + uint8_t bit = 0x80; + uint i = 0; + + for (node = trie->root; node;) { + if (node->value != TRIE_NO_VALUE) + return node->value; + node = bit & ip[i] ? node->right : node->left; + bit >>= 1; + if (bit == 0) { + i++; + bit = 0x80; + } + } + return TRIE_NO_VALUE; +} diff --git a/cidr-trie.h b/cidr-trie.h new file mode 100644 index 0000000..87ae6ec --- /dev/null +++ b/cidr-trie.h @@ -0,0 +1,55 @@ +/* + * graftcp + * Copyright (C) 2023 Hmgle + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +/* Inspired by nginx: ngx_radix_tree.h */ +#ifndef CIDR_TRIE_H +#define CIDR_TRIE_H + +#include +#include + +#define TRIE_NO_VALUE 0 + +typedef struct trie_node_s trie_node_t; + +struct trie_node_s { + trie_node_t *left; + trie_node_t *right; + int value; +}; + +typedef struct { + trie_node_t *root; +} trie_t; + +struct cidr_s { + uint32_t addr; + uint32_t mask; +}; + +struct cidr6_s { + struct in6_addr addr; + struct in6_addr mask; +}; + +trie_t *trie_new(); +void trie32_insert(trie_t *trie, struct cidr_s *cidr, int value); +int trie32_insert_str(trie_t *trie, const char *ipstr, int value); +int trie32_lookup(trie_t *trie, uint32_t ip); +void trie128_insert(trie_t *trie, struct cidr6_s *cidr6, int value); +int trie128_insert_str(trie_t *trie, const char *ipstr, int value); +int trie128_lookup(trie_t *trie, uint8_t *ip); + +#endif diff --git a/graftcp.c b/graftcp.c index 0bdb793..e86b046 100644 --- a/graftcp.c +++ b/graftcp.c @@ -27,7 +27,7 @@ #include "graftcp.h" #include "conf.h" -#include "string-set.h" +#include "cidr-trie.h" #ifndef VERSION #define VERSION "v0.6" @@ -43,12 +43,12 @@ char *DEFAULT_LOCAL_PIPE_PAHT = "/tmp/graftcplocal.fifo"; bool DEFAULT_IGNORE_LOCAL = true; int LOCAL_PIPE_FD; -struct str_set *BLACKLIST_IP = NULL; -struct str_set *WHITELACKLIST_IP = NULL; +trie_t *BLACKLIST_IP = NULL; +trie_t *WHITELACKLIST_IP = NULL; static int exit_code = 0; -static void load_ip_file(char *path, struct str_set **set) +static void load_ip_file(char *path, trie_t **trie) { FILE *f; char *line = NULL; @@ -60,14 +60,14 @@ static void load_ip_file(char *path, struct str_set **set) perror("fopen"); exit(1); } - if (*set == NULL) - *set = str_set_new(); + if (*trie == NULL) + *trie = trie_new(); while ((read = getline(&line, &len, f)) != -1) { /* 7 is the shortest ip: (x.x.x.x) */ if (read < 7) continue; line[read - 1] = '\0'; - str_set_put(*set, line); + trie32_insert_str(*trie, line, 1); line = NULL; } fclose(f); @@ -83,14 +83,14 @@ static void load_whiteip_file(char *path) load_ip_file(path, &WHITELACKLIST_IP); } -static bool is_ignore(const char *ip) +static bool is_ignore(uint32_t ip) { if (BLACKLIST_IP) { - if (is_str_set_member(BLACKLIST_IP, ip)) + if (trie32_lookup(BLACKLIST_IP, ip)) return true; } if (WHITELACKLIST_IP) { - if (!is_str_set_member(WHITELACKLIST_IP, ip)) + if (!trie32_lookup(WHITELACKLIST_IP, ip)) return true; } return false; @@ -197,16 +197,17 @@ void connect_pre_handle(struct proc_info *pinfp) dest_ip_port = SOCKPORT(dest_sa); dest_ip_addr.s_addr = SOCKADDR(dest_sa); dest_ip_addr_str = inet_ntoa(dest_ip_addr); + if (is_ignore(dest_ip_addr.s_addr)) + return; } else if (dest_sa.sin_family == AF_INET6) { /* IPv6 */ getdata(pinfp->pid, addr, (char *)&dest_sa6, sizeof(dest_sa6)); dest_ip_port = SOCKPORT6(dest_sa6); inet_ntop(AF_INET6, &dest_sa6.sin6_addr, dest_str, INET6_ADDRSTRLEN); dest_ip_addr_str = dest_str; + // TODO: is_ignore128() } else { return; } - if (is_ignore(dest_ip_addr_str)) - return; if (dest_sa.sin_family == AF_INET) /* IPv4 */ putdata(pinfp->pid, addr, (char *)&PROXY_SA, sizeof(PROXY_SA)); @@ -563,9 +564,9 @@ int client_main(int argc, char **argv) load_whiteip_file(conf.whiteip_file_path); if (*conf.ignore_local) { if (BLACKLIST_IP == NULL) - BLACKLIST_IP = str_set_new(); - str_set_put(BLACKLIST_IP, conf.local_addr); - str_set_put(BLACKLIST_IP, LOCAL_DEFAULT_ADDR); + BLACKLIST_IP = trie_new(); + trie32_insert_str(BLACKLIST_IP, conf.local_addr, 1); + trie32_insert_str(BLACKLIST_IP, LOCAL_DEFAULT_ADDR, 1); } PROXY_SA.sin_family = AF_INET; PROXY_SA.sin_port = htons(*conf.local_port);