diff --git a/netcontrol/main.py b/netcontrol/main.py index 8bf7952..d732a22 100644 --- a/netcontrol/main.py +++ b/netcontrol/main.py @@ -3,6 +3,7 @@ import logging from .nft import Nft from .arp import Arp +from .vpnrules import VpnRules logger = logging.getLogger('uvicorn.error') # for some reason, default loggers are not working with FastAPI @@ -19,11 +20,13 @@ async def lifespan(app: FastAPI): The part before the yield is executed before the app starts; The part after the yield is executed after the app stops. """ - + nft.setup_portail() - + VpnRules(logger) + yield - + + VpnRules(logger, "del") nft.remove_portail() app = FastAPI(lifespan=lifespan) @@ -31,7 +34,7 @@ async def lifespan(app: FastAPI): @app.get("/") def root(): return "netcontrol is running" - + @app.post("/connect_user") def connect_user(mac: str, mark: int, name: str): return nft.connect_user(mac, mark, name) @@ -50,4 +53,4 @@ def get_mac(ip: str): @app.get("/get_ip") def get_ip(mac: str): - return arp.get_ip(mac) \ No newline at end of file + return arp.get_ip(mac) diff --git a/netcontrol/nft.py b/netcontrol/nft.py index 56e3ace..48139b6 100644 --- a/netcontrol/nft.py +++ b/netcontrol/nft.py @@ -5,6 +5,7 @@ from fastapi import HTTPException variables = Variables() +import re class Nft: """ @@ -19,7 +20,7 @@ def check_nftables(self): data = self._execute_nft_cmd("list ruleset") metainfo = data[0]["metainfo"] self.logger.info(f"Found running nftables version {metainfo['version']} with {len(data)} ruleset entries.") - + def _execute_nft_cmd(self, cmd: str) -> dict: """ Executes an nft command, handles the exception properly and returns an object @@ -46,26 +47,26 @@ def setup_portail(self): """ Sets up the necessary nftables rules that block network access to unauthenticated devices, and marks packets based on the map """ - + # Set up table, set and map self._execute_nft_cmd("add table ip insalan") self._execute_nft_cmd("add set insalan netcontrol-auth { type ether_addr; }") self._execute_nft_cmd("add map insalan netcontrol-mac2mark { type ether_addr : mark; }") - + # Marks packets from authenticated users using the map self._execute_nft_cmd("add chain insalan netcontrol-filter { type filter hook prerouting priority 0; }") self._execute_nft_cmd("add rule insalan netcontrol-filter ip daddr != 172.16.1.0/24 ether saddr @netcontrol-auth meta mark set ether saddr map @netcontrol-mac2mark") - + # Allow traffic to port 80 from unauthenticated devices and redirect it to the network head, to allow access to the langate webpage self._execute_nft_cmd("add chain insalan netcontrol-nat { type nat hook prerouting priority 0; }") self._execute_nft_cmd("add rule insalan netcontrol-nat ip daddr != 172.16.1.0/24 ether saddr != @netcontrol-auth tcp dport 80 redirect to :80") - + # Block other traffic from users that are not authenticated self._execute_nft_cmd("add chain insalan netcontrol-forward { type filter hook forward priority 0; }") self._execute_nft_cmd(f"add rule insalan netcontrol-forward ip daddr != 172.16.1.1 ip saddr {variables.ip_range()} ether saddr != @netcontrol-auth reject") self.logger.info("Gate nftables set up") - + def remove_portail(self): """ Removes netcontrol-related chains, sets and maps from insalan table @@ -75,29 +76,29 @@ def remove_portail(self): self._execute_nft_cmd("delete chain insalan netcontrol-forward") self._execute_nft_cmd("delete set insalan netcontrol-auth") self._execute_nft_cmd("delete map insalan netcontrol-mac2mark") - + self.logger.info("Gate nftables removed") def set_mark(self, mac: str, mark: int): """ Changes mark of the given MAC address - + Args: mac (str): MAC address mark (int): mark to set """ - + self.delete_user(mac) self.connect_user(mac, mark, "previously_connected_device") def connect_user(self, mac: str, mark: int, name: str): """ Connects given device with given mark - + Args: mac (str): MAC address """ - + mac = mac.lower() try: self._execute_nft_cmd(f"add element insalan netcontrol-mac2mark {{ {mac} : {str(mark)} }}") @@ -105,17 +106,17 @@ def connect_user(self, mac: str, mark: int, name: str): except NftablesException: self.logger.error(f"Tried to add device {mac} (name: {name}), unexpected nftables error occurred") raise HTTPException(status_code=500, detail="Unexpected nftables error occurred") - + self.logger.info(f"Device {mac} (name: {name}) connected with mark {mark}") def delete_user(self, mac: str) -> None: """ Disconnects given device - + Args: mac (str): MAC address """ - + mac = mac.lower() try: self._execute_nft_cmd(f"delete element insalan netcontrol-mac2mark {{ {mac} }}") @@ -123,7 +124,7 @@ def delete_user(self, mac: str) -> None: except NftablesException: self.logger.error(f"Tried to delete device {mac} which was not previously connected") raise HTTPException(status_code=404, detail="Device was not previously connected") - + self.logger.info(f"Device {mac} disconnected") class NftablesException(Exception): diff --git a/netcontrol/vpnrules.py b/netcontrol/vpnrules.py new file mode 100644 index 0000000..f3626e5 --- /dev/null +++ b/netcontrol/vpnrules.py @@ -0,0 +1,45 @@ +import re +import os +import logging + +class VpnRules: + """Class which manages the VPN rules + + Args: + logger (logging.Logger): logger instance + command (str): command to execute (add or del), add by default + """ + def __init__(self, logger: logging.Logger, command: str="add") -> None: + self.logger = logger + self.command = command + if command not in ["add", "del"]: + raise ValueError("Invalid command") + self.vpns = self.get_vpns() + self.modify_rules() + + @staticmethod + def get_vpns() -> list: + """Get the list of vpns from /etc/hosts + + Returns: + list: list of vpn names + """ + vpns = [] + with open("/etc/hosts", "r") as f: + for line in f: + if match := re.search(r"(vpn\d+)$", line): + vpns.append(match.group()) + return vpns + + def modify_rules(self): + """Add or remove the VPN rules + """ + for i, vpn in enumerate(self.vpns): + try: + status = os.waitstatus_to_exitcode(os.system(f"ip rule {self.command} fwmark {i + 100} table {vpn}")) + if status == 0: + self.logger.info(f"Successfully {self.command} rule for {vpn}") + else: + self.logger.error(f"Failed to {self.command} rule for {vpn}") + except ValueError as e: + self.logger.error(f"Failed to {self.command} rule for {vpn}: {e}")