Skip to content

Commit

Permalink
Add CIDR trie
Browse files Browse the repository at this point in the history
  • Loading branch information
hmgle committed Dec 6, 2023
1 parent 191dac2 commit 5cf97c1
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ TARGET = graftcp $(GRAFTCP_LOCAL_BIN)
all:: $(TARGET)


graftcp: main.o graftcp.o util.o string-set.o conf.o
graftcp: main.o graftcp.o util.o string-set.o cidr-trie.o conf.o
$(CC) $^ -o $@

libgraftcp.a: graftcp.o util.o string-set.o conf.o
libgraftcp.a: graftcp.o util.o string-set.o cidr-trie.o conf.o
$(AR) rcs $@ $^

%.o: %.c
Expand Down
223 changes: 223 additions & 0 deletions cidr-trie.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/*
* graftcp
* Copyright (C) 2023 Hmgle <[email protected]>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*/

/* Inspired by nginx: ngx_radix_tree.c */
#include <stdlib.h>
#include <string.h>
#include <arpa/inet.h>

#include "cidr-trie.h"

static trie_node_t *node_callc()
{
trie_node_t *node = calloc(1, sizeof(*node));
node->left = NULL;
node->right = NULL;
node->value = TRIE_NO_VALUE;
return node;
}

trie_t *trie_new()
{
trie_t *trie = calloc(1, sizeof(*trie));
trie->root = node_callc();
return trie;
}

void trie32_insert(trie_t *trie, struct cidr_s *cidr, int value)
{
trie_node_t *node, *next;
uint32_t bit = 0x80000000;

next = trie->root;
for (node = trie->root; bit & cidr->mask; bit >>= 1) {
next = cidr->addr & bit ? node->right : node->left;
if (next == NULL)
break;
node = next;
}
if (next) {
node->value = value;
return;
}
for (; bit & cidr->mask; bit >>= 1) {
next = node_callc();
if (cidr->addr & bit)
node->right = next;
else
node->left = next;
node = next;
}
node->value = value;
}

#define IPV4_MAX_TEXT_LENGTH 15
#define IPV6_MAX_TEXT_LENGTH 45

static int parse_cidr(const char *line, struct cidr_s *cidr)
{
char *p;
int shift;
char ipbuf[IPV4_MAX_TEXT_LENGTH + 1];

p = strchr(line, '/');
if (p) {
if (p - line > IPV4_MAX_TEXT_LENGTH)
return -1;
strncpy(ipbuf, line, p - line);
ipbuf[p - line] = '\0';
cidr->addr = ntohl(inet_addr(ipbuf));

shift = strtol(++p, NULL, 0);
if (shift < 0 || shift > 32)
return -1;
cidr->mask = shift ? (uint32_t)(0xffffffff << (32 - shift)) : 0;
} else {
cidr->mask = 0xffffffff;
cidr->addr = ntohl(inet_addr(line));
}
return 0;
}

int trie32_insert_str(trie_t *trie, const char *ipstr, int value)
{
struct cidr_s cidr;

if (parse_cidr(ipstr, &cidr))
return -1;
trie32_insert(trie, &cidr, value);
return 0;
}

int trie32_lookup(trie_t *trie, uint32_t ip)
{
uint32_t bit = 0x80000000;
trie_node_t *node;

for (node = trie->root; node;) {
if (node->value != TRIE_NO_VALUE)
return node->value;
node = ip & bit ? node->right : node->left;
bit >>= 1;
}
return TRIE_NO_VALUE;
}

void trie128_insert(trie_t *trie, struct cidr6_s *cidr6, int value)
{
trie_node_t *node, *next;
uint8_t bit = 0x80;
uint i = 0;

next = trie->root;
for (node = trie->root; bit & cidr6->mask.s6_addr[i];) {
next = bit & cidr6->addr.s6_addr[i] ? node->right : node->left;
if (next == NULL)
break;
bit >>= 1;
node = next;
if (bit == 0) {
if (++i == 16)
break;
bit = 0x80;
}
}
if (next) {
node->value = value;
return;
}
for (; bit & cidr6->mask.s6_addr[i];) {
next = node_callc();
if (bit & cidr6->addr.s6_addr[i])
node->right = next;
else
node->left = next;
bit >>= 1;
node = next;
if (bit == 0) {
if (++i == 16)
break;
bit = 0x80;
}
}
node->value = value;
}

static int parse_cidr6(const char *line, struct cidr6_s *cidr6)
{
char *p;
int shift;
char ip6buf[IPV6_MAX_TEXT_LENGTH + 1];
uint8_t *mask;
uint i, s;

p = strchr(line, '/');
if (p) {
if (p - line > IPV6_MAX_TEXT_LENGTH)
return -1;
strncpy(ip6buf, line, p - line);
ip6buf[p - line] = '\0';
if (inet_pton(AF_INET6, ip6buf, &cidr6->addr) != 1)
return -1;

shift = strtol(++p, NULL, 0);
if (shift < 0 || shift > 128)
return -1;
if (shift) {
mask = cidr6->mask.s6_addr;
for (i = 0; i < 16; i++) {
s = (shift > 8) ? 8 : shift;
shift -= s;
mask[i] = (u_char) (0xffu << (8 - s));
}
} else {
memset(cidr6->mask.s6_addr, 0, 16);
}
} else {
if (inet_pton(AF_INET6, line, &cidr6->addr) != 1)
return -1;
memset(cidr6->mask.s6_addr, 0xff, 16);
}
return 0;
}

int trie128_insert_str(trie_t *trie, const char *ipstr, int value)
{
struct cidr6_s cidr6;

if (parse_cidr6(ipstr, &cidr6))
return -1;
trie128_insert(trie, &cidr6, value);
return 0;
}

int trie128_lookup(trie_t *trie, uint8_t *ip)
{
trie_node_t *node;
uint8_t bit = 0x80;
uint i = 0;

for (node = trie->root; node;) {
if (node->value != TRIE_NO_VALUE)
return node->value;
node = bit & ip[i] ? node->right : node->left;
bit >>= 1;
if (bit == 0) {
i++;
bit = 0x80;
}
}
return TRIE_NO_VALUE;
}
55 changes: 55 additions & 0 deletions cidr-trie.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* graftcp
* Copyright (C) 2023 Hmgle <[email protected]>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*/

/* Inspired by nginx: ngx_radix_tree.h */
#ifndef CIDR_TRIE_H
#define CIDR_TRIE_H

#include <netinet/in.h>
#include <stdint.h>

#define TRIE_NO_VALUE 0

typedef struct trie_node_s trie_node_t;

struct trie_node_s {
trie_node_t *left;
trie_node_t *right;
int value;
};

typedef struct {
trie_node_t *root;
} trie_t;

struct cidr_s {
uint32_t addr;
uint32_t mask;
};

struct cidr6_s {
struct in6_addr addr;
struct in6_addr mask;
};

trie_t *trie_new();
void trie32_insert(trie_t *trie, struct cidr_s *cidr, int value);
int trie32_insert_str(trie_t *trie, const char *ipstr, int value);
int trie32_lookup(trie_t *trie, uint32_t ip);
void trie128_insert(trie_t *trie, struct cidr6_s *cidr6, int value);
int trie128_insert_str(trie_t *trie, const char *ipstr, int value);
int trie128_lookup(trie_t *trie, uint8_t *ip);

#endif
31 changes: 16 additions & 15 deletions graftcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

#include "graftcp.h"
#include "conf.h"
#include "string-set.h"
#include "cidr-trie.h"

#ifndef VERSION
#define VERSION "v0.6"
Expand All @@ -43,12 +43,12 @@ char *DEFAULT_LOCAL_PIPE_PAHT = "/tmp/graftcplocal.fifo";
bool DEFAULT_IGNORE_LOCAL = true;
int LOCAL_PIPE_FD;

struct str_set *BLACKLIST_IP = NULL;
struct str_set *WHITELACKLIST_IP = NULL;
trie_t *BLACKLIST_IP = NULL;
trie_t *WHITELACKLIST_IP = NULL;

static int exit_code = 0;

static void load_ip_file(char *path, struct str_set **set)
static void load_ip_file(char *path, trie_t **trie)
{
FILE *f;
char *line = NULL;
Expand All @@ -60,14 +60,14 @@ static void load_ip_file(char *path, struct str_set **set)
perror("fopen");
exit(1);
}
if (*set == NULL)
*set = str_set_new();
if (*trie == NULL)
*trie = trie_new();
while ((read = getline(&line, &len, f)) != -1) {
/* 7 is the shortest ip: (x.x.x.x) */
if (read < 7)
continue;
line[read - 1] = '\0';
str_set_put(*set, line);
trie32_insert_str(*trie, line, 1);
line = NULL;
}
fclose(f);
Expand All @@ -83,14 +83,14 @@ static void load_whiteip_file(char *path)
load_ip_file(path, &WHITELACKLIST_IP);
}

static bool is_ignore(const char *ip)
static bool is_ignore(uint32_t ip)
{
if (BLACKLIST_IP) {
if (is_str_set_member(BLACKLIST_IP, ip))
if (trie32_lookup(BLACKLIST_IP, ip))
return true;
}
if (WHITELACKLIST_IP) {
if (!is_str_set_member(WHITELACKLIST_IP, ip))
if (!trie32_lookup(WHITELACKLIST_IP, ip))
return true;
}
return false;
Expand Down Expand Up @@ -197,16 +197,17 @@ void connect_pre_handle(struct proc_info *pinfp)
dest_ip_port = SOCKPORT(dest_sa);
dest_ip_addr.s_addr = SOCKADDR(dest_sa);
dest_ip_addr_str = inet_ntoa(dest_ip_addr);
if (is_ignore(dest_ip_addr.s_addr))
return;
} else if (dest_sa.sin_family == AF_INET6) { /* IPv6 */
getdata(pinfp->pid, addr, (char *)&dest_sa6, sizeof(dest_sa6));
dest_ip_port = SOCKPORT6(dest_sa6);
inet_ntop(AF_INET6, &dest_sa6.sin6_addr, dest_str, INET6_ADDRSTRLEN);
dest_ip_addr_str = dest_str;
// TODO: is_ignore128()
} else {
return;
}
if (is_ignore(dest_ip_addr_str))
return;

if (dest_sa.sin_family == AF_INET) /* IPv4 */
putdata(pinfp->pid, addr, (char *)&PROXY_SA, sizeof(PROXY_SA));
Expand Down Expand Up @@ -563,9 +564,9 @@ int client_main(int argc, char **argv)
load_whiteip_file(conf.whiteip_file_path);
if (*conf.ignore_local) {
if (BLACKLIST_IP == NULL)
BLACKLIST_IP = str_set_new();
str_set_put(BLACKLIST_IP, conf.local_addr);
str_set_put(BLACKLIST_IP, LOCAL_DEFAULT_ADDR);
BLACKLIST_IP = trie_new();
trie32_insert_str(BLACKLIST_IP, conf.local_addr, 1);
trie32_insert_str(BLACKLIST_IP, LOCAL_DEFAULT_ADDR, 1);
}
PROXY_SA.sin_family = AF_INET;
PROXY_SA.sin_port = htons(*conf.local_port);
Expand Down

0 comments on commit 5cf97c1

Please sign in to comment.