Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iprules.sh into netcontrol #35

Draft
wants to merge 4 commits into
base: netcontrol
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions netcontrol/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from .nft import Nft
from .arp import Arp
from .vpnrules import VpnRules

logger = logging.getLogger('uvicorn.error')
# for some reason, default loggers are not working with FastAPI
Expand All @@ -19,19 +20,21 @@ async def lifespan(app: FastAPI):
The part before the yield is executed before the app starts;
The part after the yield is executed after the app stops.
"""

nft.setup_portail()

VpnRules(logger)

yield


VpnRules(logger, "del")
nft.remove_portail()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def root():
return "netcontrol is running"

@app.post("/connect_user")
def connect_user(mac: str, mark: int, name: str):
return nft.connect_user(mac, mark, name)
Expand All @@ -50,4 +53,4 @@ def get_mac(ip: str):

@app.get("/get_ip")
def get_ip(mac: str):
return arp.get_ip(mac)
return arp.get_ip(mac)
31 changes: 16 additions & 15 deletions netcontrol/nft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import HTTPException

variables = Variables()
import re

class Nft:
"""
Expand All @@ -19,7 +20,7 @@ def check_nftables(self):
data = self._execute_nft_cmd("list ruleset")
metainfo = data[0]["metainfo"]
self.logger.info(f"Found running nftables version {metainfo['version']} with {len(data)} ruleset entries.")

def _execute_nft_cmd(self, cmd: str) -> dict:
"""
Executes an nft command, handles the exception properly and returns an object
Expand All @@ -46,26 +47,26 @@ def setup_portail(self):
"""
Sets up the necessary nftables rules that block network access to unauthenticated devices, and marks packets based on the map
"""

# Set up table, set and map
self._execute_nft_cmd("add table ip insalan")
self._execute_nft_cmd("add set insalan netcontrol-auth { type ether_addr; }")
self._execute_nft_cmd("add map insalan netcontrol-mac2mark { type ether_addr : mark; }")

# Marks packets from authenticated users using the map
self._execute_nft_cmd("add chain insalan netcontrol-filter { type filter hook prerouting priority 0; }")
self._execute_nft_cmd("add rule insalan netcontrol-filter ip daddr != 172.16.1.0/24 ether saddr @netcontrol-auth meta mark set ether saddr map @netcontrol-mac2mark")

# Allow traffic to port 80 from unauthenticated devices and redirect it to the network head, to allow access to the langate webpage
self._execute_nft_cmd("add chain insalan netcontrol-nat { type nat hook prerouting priority 0; }")
self._execute_nft_cmd("add rule insalan netcontrol-nat ip daddr != 172.16.1.0/24 ether saddr != @netcontrol-auth tcp dport 80 redirect to :80")

# Block other traffic from users that are not authenticated
self._execute_nft_cmd("add chain insalan netcontrol-forward { type filter hook forward priority 0; }")
self._execute_nft_cmd(f"add rule insalan netcontrol-forward ip daddr != 172.16.1.1 ip saddr {variables.ip_range()} ether saddr != @netcontrol-auth reject")

self.logger.info("Gate nftables set up")

def remove_portail(self):
"""
Removes netcontrol-related chains, sets and maps from insalan table
Expand All @@ -75,55 +76,55 @@ def remove_portail(self):
self._execute_nft_cmd("delete chain insalan netcontrol-forward")
self._execute_nft_cmd("delete set insalan netcontrol-auth")
self._execute_nft_cmd("delete map insalan netcontrol-mac2mark")

self.logger.info("Gate nftables removed")

def set_mark(self, mac: str, mark: int):
"""
Changes mark of the given MAC address

Args:
mac (str): MAC address
mark (int): mark to set
"""

self.delete_user(mac)
self.connect_user(mac, mark, "previously_connected_device")

def connect_user(self, mac: str, mark: int, name: str):
"""
Connects given device with given mark

Args:
mac (str): MAC address
"""

mac = mac.lower()
try:
self._execute_nft_cmd(f"add element insalan netcontrol-mac2mark {{ {mac} : {str(mark)} }}")
self._execute_nft_cmd(f"add element insalan netcontrol-auth {{ {mac} }}")
except NftablesException:
self.logger.error(f"Tried to add device {mac} (name: {name}), unexpected nftables error occurred")
raise HTTPException(status_code=500, detail="Unexpected nftables error occurred")

self.logger.info(f"Device {mac} (name: {name}) connected with mark {mark}")

def delete_user(self, mac: str) -> None:
"""
Disconnects given device

Args:
mac (str): MAC address
"""

mac = mac.lower()
try:
self._execute_nft_cmd(f"delete element insalan netcontrol-mac2mark {{ {mac} }}")
self._execute_nft_cmd(f"delete element insalan netcontrol-auth {{ {mac} }}")
except NftablesException:
self.logger.error(f"Tried to delete device {mac} which was not previously connected")
raise HTTPException(status_code=404, detail="Device was not previously connected")

self.logger.info(f"Device {mac} disconnected")

class NftablesException(Exception):
Expand Down
45 changes: 45 additions & 0 deletions netcontrol/vpnrules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import re
import os
import logging

class VpnRules:
"""Class which manages the VPN rules

Args:
logger (logging.Logger): logger instance
command (str): command to execute (add or del), add by default
"""
def __init__(self, logger: logging.Logger, command: str="add") -> None:
self.logger = logger
self.command = command
if command not in ["add", "del"]:
raise ValueError("Invalid command")
self.vpns = self.get_vpns()
self.modify_rules()

@staticmethod
def get_vpns() -> list:
"""Get the list of vpns from /etc/hosts

Returns:
list: list of vpn names
"""
vpns = []
with open("/etc/hosts", "r") as f:
for line in f:
if match := re.search(r"(vpn\d+)$", line):
vpns.append(match.group())
return vpns

def modify_rules(self):
"""Add or remove the VPN rules
"""
for i, vpn in enumerate(self.vpns):
try:
status = os.waitstatus_to_exitcode(os.system(f"ip rule {self.command} fwmark {i + 100} table {vpn}"))
if status == 0:
self.logger.info(f"Successfully {self.command} rule for {vpn}")
else:
self.logger.error(f"Failed to {self.command} rule for {vpn}")
except ValueError as e:
self.logger.error(f"Failed to {self.command} rule for {vpn}: {e}")
Loading