Skip to content

Commit

Permalink
Added: signature management code (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
signebedi committed Mar 26, 2024
1 parent b58905d commit 007239b
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 70 deletions.
149 changes: 79 additions & 70 deletions libreforms_fastapi/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def write_api_call_to_transaction_log(api_key, endpoint, remote_addr=None, query
# Truncate to avoid unpredictable behavior
query_params = query_params[:max_length]

# logger.info(api_key, endpoint, remote_addr, query_params)

with SessionLocal() as session:
user = session.query(User).filter_by(api_key=api_key).first()
if user:
Expand All @@ -324,58 +326,58 @@ def write_api_call_to_transaction_log(api_key, endpoint, remote_addr=None, query
session.rollback()


if config.DEBUG:
### This is a dummy route to validate jinja2 templates
@app.get("/debug/items/{id}", response_class=HTMLResponse, include_in_schema=False)
async def read_item(request: Request, id: str):
return templates.TemplateResponse(
request=request,
name="item.html",
context={"id": id}
)

### These are dummy routes to validate the sqlalchemy-signing library in development
@app.get("/debug/create", include_in_schema=False)
async def create_key():
key = signatures.write_key(expiration=.5, scope="api_key")
return {"key": key}

@app.get("/debug/get", include_in_schema=False)
async def get_key_details(key: str):
key_details = signatures.get_key(key)
return {"key": key_details}

@app.get("/debug/verify", include_in_schema=False)
async def verify_key_details(key: str = Depends(X_API_KEY)):

try:
verify = signatures.verify_key(key, scope=[])

except RateLimitExceeded:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded"
)

except KeyDoesNotExist:
raise HTTPException(
status_code=401,
detail="Invalid API key"
)

except ScopeMismatch:
raise HTTPException(
status_code=401,
detail="Invalid API key"
)

except KeyExpired:
raise HTTPException(
status_code=401,
detail="API key expired"
)

return {"valid": verify}
# if config.DEBUG:
# ### This is a dummy route to validate jinja2 templates
# @app.get("/debug/items/{id}", response_class=HTMLResponse, include_in_schema=False)
# async def read_item(request: Request, id: str):
# return templates.TemplateResponse(
# request=request,
# name="item.html",
# context={"id": id}
# )

# ### These are dummy routes to validate the sqlalchemy-signing library in development
# @app.get("/debug/create", include_in_schema=False)
# async def create_key():
# key = signatures.write_key(expiration=.5, scope="api_key")
# return {"key": key}

# @app.get("/debug/get", include_in_schema=False)
# async def get_key_details(key: str):
# key_details = signatures.get_key(key)
# return {"key": key_details}

# @app.get("/debug/verify", include_in_schema=False)
# async def verify_key_details(key: str = Depends(X_API_KEY)):

# try:
# verify = signatures.verify_key(key, scope=[])

# except RateLimitExceeded:
# raise HTTPException(
# status_code=429,
# detail="Rate limit exceeded"
# )

# except KeyDoesNotExist:
# raise HTTPException(
# status_code=401,
# detail="Invalid API key"
# )

# except ScopeMismatch:
# raise HTTPException(
# status_code=401,
# detail="Invalid API key"
# )

# except KeyExpired:
# raise HTTPException(
# status_code=401,
# detail="API key expired"
# )

# return {"valid": verify}


##########################
Expand Down Expand Up @@ -1004,13 +1006,18 @@ async def api_form_search_all(

# Sign form
# This is a metadata-only field. It should not impact the data, just the metadata - namely, to afix
# a digital signature to the form.
# @app.patch("/api/form/sign/{form_name}/{document_id}")
# async def api_form_sign():
# a digital signature to the form. See https://github.com/signebedi/libreforms-fastapi/issues/59.
@app.patch("/api/form/sign/{form_name}/{document_id}")
async def api_form_sign():

# The underlying principle is that the user can only sign their own form. The question is what
# part of the application decides: the API, or the document database?

pass

# Approve form
# This is a metadata-only field. It should not impact the data, just the metadata - namely, to afix
# an approval - in the format of a digital signature - to the form.
# an approval - in the format of a digital signature - to the form.
# @app.patch("/api/form/approve/{form_name}/{document_id}")
# async def api_form_approve():

Expand Down Expand Up @@ -1107,16 +1114,6 @@ async def api_auth_create(
"message": f"Successfully created new user {user_request.username}"
}

# Change user password / usermod
@app.patch("/api/auth/update")
async def api_auth_update(
user_request: CreateUserRequest,
background_tasks: BackgroundTasks,
request: Request,
session: SessionLocal = Depends(get_db)
):
pass

# Get User / id
@app.get("/api/auth/get/{id}", dependencies=[Depends(api_key_auth)])
async def api_auth_get(
Expand Down Expand Up @@ -1166,6 +1163,22 @@ async def api_auth_get(

return profile_data


# Update User - change user password / usermod
# @app.patch("/api/auth/update")
# async def api_auth_update(
# user_request: CreateUserRequest,
# background_tasks: BackgroundTasks,
# request: Request,
# session: SessionLocal = Depends(get_db)
# ):
# pass

# Rotate user API key
# @app.patch("/api/auth/rotate_key")
# async def api_auth_rotate_key():


# Request Password Reset - Forgot Password
# @app.patch("/api/auth/forgot_password")
# async def api_auth_forgot_password(user_request: CreateUserRequest, session: SessionLocal = Depends(get_db)):
Expand All @@ -1174,10 +1187,6 @@ async def api_auth_get(
# @app.patch("/api/auth/forgot_password/{single_use_token}")
# async def api_auth_forgot_password_confirm(user_request: CreateUserRequest, session: SessionLocal = Depends(get_db)):

# Rotate user API key
# @app.patch("/api/auth/rotate_key")
# async def api_auth_rotate_key():

##########################
### API Routes - Validators
##########################
Expand Down
159 changes: 159 additions & 0 deletions libreforms_fastapi/utils/certificates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
# Example usage:
user_id = 'user123'
ds_manager = DigitalSignatureManager(user_id)
ds_manager.generate_rsa_key_pair()
data_to_sign = b"Important document content."
signature = ds_manager.sign_data(data_to_sign)
print("Signature:", signature)
verification_result = ds_manager.verify_signature(data_to_sign, signature)
print("Verification:", verification_result)
record_to_sign = {"data": {"text_input": "Sample text"}, "metadata": {"_signature": None}}
signed_record = sign_record(record_to_sign, ds_manager)
print("Signed Record:", signed_record)
verification_result = verify_record_signature(signed_record, ds_manager)
print("Verification Result:", verification_result)
"""
import os, json, copy
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import serialization
from cryptography.exceptions import InvalidSignature

class DigitalSignatureManager:
def __init__(self, user_id, env="development", key_storage_path=os.path.join('instance', 'keys')):
self.user_id = user_id
self.env = env
self.key_storage_path = key_storage_path
self.ensure_key_storage()

def ensure_key_storage(self):
if not os.path.exists(self.key_storage_path):
os.makedirs(self.key_storage_path)

def get_private_key_file(self):
return os.path.join(self.key_storage_path, f"{self.env}_{self.user_id}_private.key")

def get_public_key_file(self):
return os.path.join(self.key_storage_path, f"{self.env}_{self.user_id}_public.key")

def generate_rsa_key_pair(self):
"""
Generates an RSA key pair and saves them to files.
"""
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
public_key = private_key.public_key()

# Save the private key
with open(self.get_private_key_file(), "wb") as private_file:
private_file.write(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
)

# Save the public key
with open(self.get_public_key_file(), "wb") as public_file:
public_file.write(
public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
)

def sign_data(self, data):
"""
Signs data using the private key.
"""
with open(self.get_private_key_file(), "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
backend=default_backend()
)

signature = private_key.sign(
data,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return signature

def verify_signature(self, data, signature):
"""
Verifies the signature of the data using the public key.
"""
with open(self.get_public_key_file(), "rb") as key_file:
public_key = serialization.load_pem_public_key(
key_file.read(),
backend=default_backend()
)

try:
public_key.verify(
signature,
data,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return True
except Exception as e:
return False


def serialize_record_for_signing(record):
"""
Serializes the record in a consistent, deterministic manner for signing.
Excludes the '_signature' field from the serialization.
"""
record_copy = dict(copy.deepcopy(record)) # Make a copy to avoid modifying the original
# The logic of selecting only the data field is that, while metadata is subject (and really
# expected) to change, eg. through the form approval process, we expect the data to remain
# the same.
select_data_fields = record_copy['data']
print(select_data_fields)
return json.dumps(select_data_fields, sort_keys=True)

def sign_record(record, ds_manager):
"""
Generates a signature for the given record and inserts it into the '_signature' field.
"""
serialized = serialize_record_for_signing(record)
signature = ds_manager.sign_data(serialized.encode())
record['metadata']['_signature'] = signature.hex() # Store the signature as a hex string
return record

def verify_record_signature(record, ds_manager):
"""
Verifies the signature of the given record.
Returns True if the signature is valid, False otherwise.
"""
if '_signature' not in record['metadata'] or record['metadata']['_signature'] is None:
return False # No signature to verify

record_copy = copy.deepcopy(record)
signature_bytes = bytes.fromhex(record['metadata']['_signature'])
serialized = serialize_record_for_signing(record_copy)

try:
return ds_manager.verify_signature(serialized.encode(), signature_bytes)
except InvalidSignature:
return False

1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
cryptography<42.0.0
email_validator<3.0.0
fastapi<1.0.0
fuzzywuzzy<1.0.0
Expand Down
1 change: 1 addition & 0 deletions requirements/latest.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
cryptography
email_validator
fastapi
fuzzywuzzy
Expand Down

0 comments on commit 007239b

Please sign in to comment.