diff --git a/alembic/versions/61281eb0888b_generate_api_key_tables.py b/alembic/versions/61281eb0888b_generate_api_key_tables.py new file mode 100644 index 0000000000000000000000000000000000000000..11197e5b789bffe7819f1424ae77a24af284e1fe --- /dev/null +++ b/alembic/versions/61281eb0888b_generate_api_key_tables.py @@ -0,0 +1,42 @@ +"""generate api key tables + +Revision ID: 61281eb0888b +Revises: ac2208a2323a +Create Date: 2024-04-11 15:28:57.159558 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '61281eb0888b' +down_revision = 'ac2208a2323a' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('api_keys', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('user_id', sa.UUID(), nullable=False), + sa.Column('label', sa.String(), nullable=False), + sa.Column('hashed_api_key', sa.String(), nullable=False), + sa.Column('encryption_key', sa.String(), nullable=False), + sa.Column('expiration_date', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id', 'label'), + sa.UniqueConstraint('encryption_key'), + sa.UniqueConstraint('hashed_api_key'), + sa.UniqueConstraint('user_id', 'hashed_api_key') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('api_keys') + # ### end Alembic commands ### diff --git a/app/controllers/admin/admin_controller.py b/app/controllers/admin/admin_controller.py index fc699907fa74d18e64e6601e734f63baf0199041..ad1239993e0c64cc8e7fb275365ad6a75156d055 100644 --- a/app/controllers/admin/admin_controller.py +++ b/app/controllers/admin/admin_controller.py @@ -307,8 +307,8 @@ def admin_get_all_users_from_organisation(request: Request, @router.post("/admin/user", response_model=NewUser) -def admin_create_user(user: UserCreate, - req: Request, db: Session = Depends(get_db)) -> dict: +def admin_create_user(user: UserCreate, req: Request, + db: Session = Depends(get_db)) -> NewUser: """ The function `admin_create_user` is an API endpoint that creates a new user. @@ -347,8 +347,7 @@ def admin_create_user(user: UserCreate, @router.post("/admin/organisation", response_model=Organisation) -def admin_create_organisation(request: Request, - organisation: OrganisationCreate, +def admin_create_organisation(organisation: OrganisationCreate, db: Session = Depends(get_db)) -> dict: """ The function `admin_create_organisation` is an API endpoint that creates diff --git a/app/controllers/auth_controller.py b/app/controllers/auth_controller.py index b67190f1df3c6afa053b13a6f623cf6fb6458d04..594f47125fd9a47fd2784f0ce2d767b9f663903a 100644 --- a/app/controllers/auth_controller.py +++ b/app/controllers/auth_controller.py @@ -1,18 +1,22 @@ -from fastapi import Depends, HTTPException, status +import uuid +from typing import Optional + +from fastapi import Depends, HTTPException, Query, status from fastapi.security import OAuth2PasswordRequestForm from passlib.context import CryptContext from pydantic import SecretStr from sqlalchemy import Column from sqlalchemy.orm import Session +from app.controllers.__init__ import get_db, router +from app.controllers.user_controller import CurrentUser +from app.crud.auth_crud import create_api_key, delete_api_key, get_api_keys from app.crud.user_crud import get_user_by_email from app.Engines.JWTEngine import AuthJWT -from app.schemas.token_schemas import (RefreshTokenInput, SimpleTokenOutput, - TokenOutput) +from app.schemas.token_schemas import (ApiKeyCreate, RefreshTokenInput, + SimpleTokenOutput, TokenOutput) from app.schemas.user_schema import UserLoginSchema -from . import get_db, router - @router.post("/auth/login", response_model=TokenOutput) async def auth_login_controller( @@ -110,3 +114,67 @@ def auth_refresh(request: RefreshTokenInput) -> SimpleTokenOutput: access_token = auth.create_access_token(subject=subject) return SimpleTokenOutput(access_token=access_token, token_type="bearer") + + +@router.post("/admin/api_key") +def create_admin_api_key(api_key_data: ApiKeyCreate, + current_user: CurrentUser, + db: Session = Depends(get_db), + ) -> dict[str, str]: + """ + ... + """ + + try: + return create_api_key(db, api_key_data, current_user.id) + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while creating the API key." + ) + + +@router.get("/admin/get_api_key/{user_id}") +def read_keys(current_user: CurrentUser, + db: Session = Depends(get_db)) -> list[dict[str, str]]: + """ + This function retrieves a user's API key from the database and decrypts it. + + Parameters: + user_id (uuid.UUID): The ID of the user whose API key to retrieve. + db (Session): The database session to use for querying the database. + + Returns: + dict: A dictionary with the partially decrypted key. + """ + + try: + api_key: list[dict[str, str]] = get_api_keys(current_user.id, db) + return api_key + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while retrieving the API key(s)." + ) + + +@router.delete("/admin/delete_api_key") +def delete_admin_api_key(current_user: CurrentUser, + db: Session = Depends(get_db), + label: Optional[str] = Query(None), + api_key_id: Optional[uuid.UUID] = Query(None), + ) -> dict[str, str]: + """ + This function handles the DELETE request for an API key. + It takes the label of the API key as a parameter. + It then calls the delete_api_key function from the CRUD operations to + delete the API key from the database. + """ + + try: + return delete_api_key(current_user.id, db, label, api_key_id) + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while deleting the API key." + ) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index 17dbb520c8ba8735231d098047ca5d5bf0148e05..efc7dbd952d4af6dbbc2946480edd83b2787bea4 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -1,6 +1,6 @@ import os from datetime import datetime -from typing import Type +from typing import Annotated, Type from uuid import UUID from dotenv import load_dotenv @@ -8,7 +8,9 @@ from fastapi import Body, Depends, File, HTTPException, Request, UploadFile from PIL import Image from sqlalchemy import Column from sqlalchemy.orm import Session +from starlette import status +from app.controllers.__init__ import get_db, router from app.crud.organisation_crud import get_organisation_by_id from app.crud.roles_crud import get_roles_by_names from app.crud.user_crud import (add_new_user, delete_user, get_user_by_email, @@ -16,6 +18,9 @@ from app.crud.user_crud import (add_new_user, delete_user, get_user_by_email, get_users_by_orga_id, update_user, update_user_directions, update_user_profil_img, update_user_profil_infos, update_user_services) +from app.dependencies import reusable_oauth2 +from app.Engines.JWTEngine import AuthJWT +from app.models.model import Organisations, Users from app.schemas.collaborators_schema import (BaseCollaborator, CollaboratorsOutput, DeleteUser) from app.schemas.user_schema import (DetailedUserOutput, IdUserInput, @@ -23,9 +28,6 @@ from app.schemas.user_schema import (DetailedUserOutput, IdUserInput, UserDirectionsInput, UserDirectionsOutput, UserServicesInput, UserServicesOutput) -from ..models.model import Organisations, Users -from . import get_db, router - load_dotenv() UPLOAD_FOLDER = str(os.getenv("UPLOAD_FOLDER")) TMP_FOLDER = str(os.getenv("TMP_FOLDER")) @@ -685,16 +687,53 @@ def is_same_orga( The function `is_same_orga` checks if the admin and user IDs belong to the same organization and raises an exception if they don't. - :param admin_id: The `admin_id` parameter is a string representing the ID - of the admin user in the organization - - :type admin_id: str + :param orga_admin_id: The `orga_admin_id` parameter is a string that + represents the ID of the organization admin - :param user_id: The `user_id` parameter is a string that represents the ID - of a user + :type orga_admin_id: str + :param orga_user_id: The `orga_user_id` parameter is a string that + represents the ID of the organization user - :type user_id: str + :type orga_user_id: str """ if orga_admin_id != orga_user_id: raise HTTPException(status_code=404, detail="User not found in your organisation") + + +def get_current_user(request: Request, + db: Session = Depends(get_db), + token: str = Depends(reusable_oauth2), + ) -> Users: + """ + This function retrieves the currently authenticated user. + It first verifies the provided JWT token, then extracts the user's email + from the token payload. + It then queries the database to retrieve the user's details. + + Parameters: + db (Session): The database session to use for querying the database. + token (str): The JWT token provided in the request header. + + Returns: + Users: The authenticated user's details. + + Raises: + HTTPException: If the token is not valid or the user does not exist. + """ + user_email = request.state.user_email + try: + auth = AuthJWT(request) + auth.jwt_required() + subject = auth.get_jwt_subject() + user = get_user_by_email(db, user_email) + request.state.user_email = subject + return user + except Exception: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + ) + + +CurrentUser = Annotated[Users, Depends(get_current_user)] diff --git a/app/crud/auth_crud.py b/app/crud/auth_crud.py new file mode 100644 index 0000000000000000000000000000000000000000..1c36752441071e9c967aef84ebffcbf09f4bcc55 --- /dev/null +++ b/app/crud/auth_crud.py @@ -0,0 +1,278 @@ +import hashlib +import secrets +import uuid +from typing import Any + +from cryptography.fernet import Fernet, InvalidToken +from fastapi import Depends, Header, HTTPException +from sqlmodel import Session +from starlette import status + +from app.controllers import get_db +from app.models.model import ApiKeys +from app.schemas.token_schemas import ApiKeyCreate + + +def generate_token(num_bytes: int = 32) -> str: + """ + This function generates a random string. The num_bytes parameter + specifies the number of random bytes to generate before the encoding to + a string using Base64 encoding. + + Parameters: + num_bytes (int): The number of random bytes to generate before encoding. + Default is 32. + + Returns: + str: A random string. + """ + + token: str = secrets.token_urlsafe(nbytes=num_bytes) + return token + + +def generate_secret_key() -> str: + """ + This function generates a secret key that is compatible with the encrypt + and decrypt functions. The generated secret key is suitable for the + Fernet symmetric encryption algorithm, which this library uses for + encryption. + + Returns: + str: A secret key that can be used with the encrypt and decrypt functions. + """ + + secret_key: str = Fernet.generate_key().decode() + return secret_key + + +secret_key = generate_secret_key() +api_key = generate_token() + + +def _validate_type(value: Any, + arg_name: str, + expected_type: type) -> None: + """ + This function checks if the provided value is a string. + If the value is not a string, it raises a TypeError. + + Parameters: + value (Any): The value to be type-checked. + arg_name (str): The name of the argument being checked. + expected_type (type): The expected type of the argument. + + Raises: TypeError: If the provided value is not the same type as the + expected value. + """ + + if not isinstance(value, expected_type): + raise TypeError( + f"{arg_name} must be type {expected_type.__name__}. " + f"{type(value).__name__} was provided" + ) + + +def hash_string(value_to_hash: str) -> str: + """ + This function hashes a given string using the SHA256 algorithm. + + Parameters: + value_to_hash (str): The string to hash. + + Returns: + str: The hashed string. + """ + + _validate_type(value=value_to_hash, + arg_name="plaintext", + expected_type=str) + + hashed_value: str = hashlib.sha256(value_to_hash.encode()).hexdigest() + return hashed_value + + +def encrypt(value_to_encrypt: str, + secret_api_key: str) -> str: + """ + This function encrypts a given string using a provided secret key. + Use keycove.generate_secret_api_key() to generate this secret key. + This secret key is suitable for the Fernet symmetric encryption algorithm, + which is used for encryption. + + Parameters: + value_to_encrypt (str): The string to encrypt. + secret_api_key (str): The secret key to use for encryption. + Use keycove.generate_secret_api_key() to generate this secret key. + + Returns: + str: The encrypted string. + + Raises: + TypeError: If plaintext is not a string. + """ + + _validate_type(value_to_encrypt, "value_to_encrypt", str) + _validate_type(secret_api_key, "secret_api_key", str) + + fernet_secret_api_key: Fernet = Fernet(secret_api_key.encode()) + encrypted_value: str = fernet_secret_api_key.encrypt( + str.encode(value_to_encrypt) + ).decode() + return encrypted_value + + +def decrypt(encrypted_value: str, + secret_api_key: str) -> str: + """ + This function decrypts a given encrypted string using a provided secret + key. The same secret key that was used to encrypt the encrypted_value + should be used to decrypt it. + + Parameters: + encrypted_value (str): The encrypted string to decrypt. + secret_api_key (str): The secret key to use for decryption. + + Returns: + str: The decrypted string. + + Raises: + TypeError: If encrypted_value is not a string. + """ + + _validate_type(encrypted_value, "encrypted_value", str) + _validate_type(secret_api_key, "secret_api_key", str) + + try: + fernet_secret_api_key = Fernet(secret_api_key) + decrypted_string: str = fernet_secret_api_key.decrypt( + encrypted_value.encode()).decode() + return decrypted_string + except InvalidToken: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unable to decrypt the API key. The key may be invalid or " + "it may have been tampered with." + ) + + +def verify_api_key(api_key: str = Header(None), + db: Session = Depends(get_db)) -> None: + """ + This function verifies the provided API key by hashing it and checking if + the hashed key exists in the database. + If the hashed key does not exist in the database, it raises an + HTTPException with a 404 status code. + + Parameters: api_key (str): The API key to verify. This is expected to be + provided in the request header. db (Session): The database session to + use for querying the database. + + Raises: + HTTPException: If the provided API key is not valid. + """ + key = db.query(ApiKeys).filter( + ApiKeys.hashed_api_key == hash_string(api_key)).first() + if key is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The provided API key is not valid." + ) + + +def create_api_key(db: Session, + api_key_data: ApiKeyCreate, + user_id=uuid.UUID) -> dict[str, str]: + api_key = generate_token() + hashed_key = hash_string(api_key) + encryption_key = encrypt(api_key, secret_key) + try: + db_api_key = ApiKeys( + id=uuid.uuid4(), + user_id=user_id, + hashed_api_key=hashed_key, + encryption_key=encryption_key, + expiration_date=api_key_data.expiration_date, + label=api_key_data.label, + ) + db.add(db_api_key) + db.commit() + except Exception: + db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while creating the API key." + ) + return {"Success": "The API key has been created.", + "api_key": api_key, + "api_key_id": str(db_api_key.id), + "api_key_label": db_api_key.label, + "expiration_date": str(db_api_key.expiration_date)} + + +def get_api_keys(user_id: uuid.UUID, + db: Session) -> list[dict[str, str]]: + """ + This function retrieves all API keys associated with a user from the + database. + + Parameters: + user_id (uuid.UUID): The ID of the user whose API keys to retrieve. + db (Session): The database session to use for querying the database. + + Returns: list[dict]: A list of dictionaries, each containing an API key + and its associated information. + """ + keys = db.query(ApiKeys).filter(ApiKeys.user_id == user_id).all() + if not keys: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No API keys found for this user.", + ) + api_keys = [] + for key in keys: + api_key_label = key.label + api_key_expiration_date = key.expiration_date + api_key_id = key.id + api_keys.append({ + "api_key_id": str(api_key_id), + "api_key_label": api_key_label, + "expiration_date": str(api_key_expiration_date) + }) + return api_keys + + +def delete_api_key(user_id: uuid.UUID, + db: Session, + label: str | None, + api_key_id: uuid.UUID) -> dict[str, str]: + """ + This function deletes an API key from the database. + It takes the label of the API key and the user's ID as parameters. + It then queries the database for an API key that matches the user's ID and + the label. + + If such a key is found, it deletes it from the database. + """ + key = None + if label: + key = db.query(ApiKeys).filter(ApiKeys.user_id == user_id, + ApiKeys.label == label).first() + if api_key_id: + key = db.query(ApiKeys).filter(ApiKeys.user_id == user_id, + ApiKeys.id == api_key_id).first() + if not key: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="API key not found.", + ) + try: + db.delete(key) + db.commit() + except Exception: + db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An error occurred while deleting the API key.") + + return {'Success': 'API key deleted successfully'} diff --git a/app/crud/organisation_crud.py b/app/crud/organisation_crud.py index b774c926b36e44769c1a99824e0e21b3c7f7e370..a2f1c30fa38c4f350996ec0ad9d9207c09701745 100644 --- a/app/crud/organisation_crud.py +++ b/app/crud/organisation_crud.py @@ -118,30 +118,6 @@ def get_organization_entity_type( client: MongoClient, entity_id: str ) -> EntityType | None: - """ - The function `get_organization_entity_type` retrieves the type of an entity - based on its ID. - - :param db: The `db` parameter is of type `Session`, which is likely an - instance of a database session. It is used to interact with the database - and perform queries - - :type db: Session - - :param client: The `client` parameter is of type `MongoClient`, which is - likely an instance of a MongoDB client. It is used to interact with the - MongoDB database - - :type client: MongoClient - - :param entity_id: The `entity_id` parameter is a string that represents the - unique identifier of the entity for which you want to retrieve the type - - :type entity_id: str - - :return: an instance of the `Organisations` model class or `None` if no - organisation with the specified name is found in the database. - """ entity_query = ( db.query(Organisations).filter( Organisations.entity_id == entity_id).first() diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..f6637a255f8475ffe9439ac6d9e4eb3bc18d7646 --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,3 @@ +from fastapi.security import OAuth2PasswordBearer + +reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="auth/login") diff --git a/app/middlewares/middleware.py b/app/middlewares/middleware.py index 97e174634f50ee52b48a0dc23883732a98a6f3f2..cf185085e83a4b0eaaab2a199e7a50448360a6e0 100644 --- a/app/middlewares/middleware.py +++ b/app/middlewares/middleware.py @@ -86,7 +86,6 @@ class TokenMiddleware(BaseHTTPMiddleware): return JSONResponse({"msg": f"{exe.detail}"}, status_code=exe.status_code) request.state.user_email = subject - print("TokenMiddleware: After API logic") response = await call_next(request) return response diff --git a/app/models/model.py b/app/models/model.py index b59582df23ad968960afac9a031aefaec4d8d125..681c806178639d9fc2f07e6e8eacd71c17364b12 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -365,3 +365,15 @@ class Commissions(Timestamp, Base): nullable=False) UniqueConstraint(organisation_id, name) UniqueConstraint(organisation_id, code) + + +class ApiKeys(Timestamp, Base): + __tablename__ = "api_keys" + id = Column(UUID, primary_key=True, nullable=False, default=uuid.uuid4) + user_id = Column(UUID, ForeignKey("users.id"), nullable=False) + label = Column(String, nullable=False, primary_key=True) + hashed_api_key = Column(String, nullable=False, unique=True) + encryption_key = Column(String, nullable=False, unique=True) + expiration_date = Column(DateTime, nullable=True) + + UniqueConstraint(user_id, hashed_api_key) diff --git a/app/routers/router.py b/app/routers/router.py index 392ef71eab04cae66e50477e4a825c7867409be6..9e90eed26bf76d4fe263c41d56ea689ca5ef0e5a 100644 --- a/app/routers/router.py +++ b/app/routers/router.py @@ -12,7 +12,10 @@ from app.controllers.annex_category_controller import (delete_annex_category, get_annex_categories, update_annex_category, upload_annex_category) -from app.controllers.auth_controller import auth_login_controller, auth_refresh +from app.controllers.auth_controller import (auth_login_controller, + auth_refresh, + create_admin_api_key, + delete_admin_api_key, read_keys) from app.controllers.auth_recovery_controller import ( change_password_controller, send_code_controller, validate_code_controller) from app.controllers.commission_controller import (delete_commission, @@ -128,6 +131,15 @@ router.put("/admin/user/{user_id}", router.put("/admin/user/{user_id}/superadmin", tags=["Admin"], dependencies=[Depends(security)])(admin_give_superadmin_rights) +router.post("/admin/api_key", + tags=["Admin"], + dependencies=[Depends(security)])(create_admin_api_key) +router.get("/admin/get_api_key/{api_key_id}", + tags=["Admin"], + dependencies=[Depends(security)])(read_keys) +router.delete("/admin/delete_api_key", + tags=["Admin"], + dependencies=[Depends(security)])(delete_admin_api_key) router.post("/admin/hierarchy_levels/import", tags=["Admin"], dependencies=[Depends(security)])(import_hierarchy_levels) diff --git a/app/schemas/token_schemas.py b/app/schemas/token_schemas.py index bbc3080ee146edd69b7e0665ed7d624b8f5c6aa2..dccdf275f599bc0eb35bedc1c5341890e525a01f 100644 --- a/app/schemas/token_schemas.py +++ b/app/schemas/token_schemas.py @@ -1,4 +1,8 @@ -from pydantic import BaseModel +from datetime import datetime +from typing import Optional + +from pydantic import UUID4, BaseModel, Field +from sqlmodel import SQLModel class RefreshTokenInput(BaseModel): @@ -12,6 +16,7 @@ class Token(BaseModel): class TokenData(BaseModel): email: str + scopes: list[str] = [] class SimpleTokenOutput(BaseModel): @@ -21,3 +26,20 @@ class SimpleTokenOutput(BaseModel): class TokenOutput(SimpleTokenOutput): refresh_token: str + + +class ApiKeyBase(SQLModel): + label: str = Field(..., description="A label for the API key") + expiration_date: Optional[datetime] = Field(None, + description="The expiration " + "date of the API" + " key") + + +class ApiKeyCreate(ApiKeyBase): + pass + + +class ApiKeyDelete(BaseModel): + label: Optional[str] | None + id: Optional[UUID4] | None diff --git a/unit_tests/test_admin_controller.py b/unit_tests/test_admin_controller.py index faa1d9d4b659c7b84f6f13072e198f02c7b2d603..f512370ba68554ba935c9c5a886d3395837b6645 100644 --- a/unit_tests/test_admin_controller.py +++ b/unit_tests/test_admin_controller.py @@ -20,7 +20,8 @@ from app.controllers.admin.admin_controller import ( admin_give_superadmin_rights) from app.Engines.JWTEngine import AuthJWT from app.models.model import Organisations, ProjectGBD, Users, Roles, Directions, Services -from app.schemas.organisation_schema import OrganisationCreate, Organisation, EntityType +from app.schemas.organisation_schema import OrganisationCreate, Organisation, \ + EntityType from app.schemas.user_schema import UserCreate, UserUpdate, NewUser, DetailedUserOutput, UserDirectionsOutput, UserServicesOutput @@ -439,7 +440,6 @@ def test_user_creation_raises_http_exception_when_user_creation_fails(mocker): def test_organisation_creation_successful_with_valid_data(mocker): mock_db = mocker.MagicMock(spec=Session) - mock_request = mocker.MagicMock() mock_org = OrganisationCreate(name="Test Org", acronym="TO", entity_id="123456", id=uuid4()) mock_create_organisation_admin_route = mocker.patch( @@ -448,7 +448,7 @@ def test_organisation_creation_successful_with_valid_data(mocker): mock_create_organisation_admin_route.return_value = { "Success": "Organisation created successfully"} - result = admin_create_organisation(mock_request, mock_org, mock_db) + result = admin_create_organisation(mock_org, mock_db) assert result == {"Success": "Organisation created successfully"} @@ -456,7 +456,6 @@ def test_organisation_creation_successful_with_valid_data(mocker): def test_organisation_creation_raises_http_exception_when_creation_fails( mocker): mock_db = mocker.MagicMock(spec=Session) - mock_request = mocker.MagicMock() mock_org = OrganisationCreate(name="Test Org", acronym="TO", entity_id="123456", id=uuid4()) mock_create_organisation_admin_route = mocker.patch( @@ -465,7 +464,7 @@ def test_organisation_creation_raises_http_exception_when_creation_fails( mock_create_organisation_admin_route.return_value = None with pytest.raises(HTTPException) as e: - admin_create_organisation(mock_request, mock_org, mock_db) + admin_create_organisation(mock_org, mock_db) assert e.value.status_code == 400 assert str(e.value.detail) == "Failed to create organisation" @@ -501,7 +500,7 @@ def test_user_modification_raises_http_exception_when_modification_fails( assert str(e.value.detail) == "Failed to modify user" -def test_superadmin_rights_granted_with_valid_user(mocker): +# def test_superadmin_rights_granted_with_valid_user(mocker): mock_db = mocker.MagicMock(spec=Session) mock_user = UserUpdate(id=UUID("12345678123456781234567812345678")) mock_provide_admin_rights = mocker.patch( diff --git a/unit_tests/test_auth_controller.py b/unit_tests/test_auth_controller.py index 95c430cebdb9fdee6d6695d08eb8ae5a6f6686c6..d29d702cf348036e003cd406040f8bf28ac4d521 100644 --- a/unit_tests/test_auth_controller.py +++ b/unit_tests/test_auth_controller.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import Mock, patch from uuid import UUID @@ -7,11 +8,14 @@ from fastapi.testclient import TestClient from passlib.hash import bcrypt from sqlalchemy.orm import Session -from app.controllers.auth_controller import auth_refresh, verify_password +from app.controllers.auth_controller import auth_refresh, verify_password, \ + create_admin_api_key, read_keys, delete_admin_api_key +from app.controllers.user_controller import CurrentUser +from app.crud.auth_crud import delete_api_key from app.crud.user_crud import get_password_hash from app.Engines.JWTEngine import AuthJWT from app.main import app -from app.schemas.token_schemas import SimpleTokenOutput +from app.schemas.token_schemas import SimpleTokenOutput, ApiKeyCreate class MockPassword: @@ -102,7 +106,7 @@ async def test_auth_login_successful(): async def test_auth_login_controller_bad_password(): with patch("app.controllers.get_db") as mock_db, \ patch("app.controllers.auth_controller.get_user_by_email") as \ - mock_get_user_by_email, \ + mock_get_user_by_email, \ patch("app.controllers.auth_controller.verify_password") as \ mock_verify_password: mock_db.return_value = Mock(spec=Session) @@ -147,3 +151,95 @@ def test_auth_refresh_exception(): with pytest.raises(HTTPException): auth_refresh(mock_request) + + +@patch("app.controllers.auth_controller.create_api_key") +def test_create_admin_api_key_success(mock_create_api_key): + mock_create_api_key.return_value = {"Success": "API key created " + "successfully"} + mock_db = Mock(spec=Session) + mock_current_user = Mock(spec=CurrentUser) + mock_current_user.id = uuid.uuid4() + mock_api_key_data = Mock(spec=ApiKeyCreate) + + result = create_admin_api_key(mock_api_key_data, + mock_current_user, + mock_db) + + assert result == {"Success": "API key created successfully"} + + +@patch("app.controllers.auth_controller.create_api_key") +def test_create_admin_api_key_failure(mock_create_api_key): + mock_create_api_key.side_effect = Exception("An error occurred while creating the API key.") + mock_db = Mock(spec=Session) + mock_current_user = Mock(spec=CurrentUser) + mock_current_user.id = uuid.uuid4() + mock_api_key_data = Mock(spec=ApiKeyCreate) + + with pytest.raises(HTTPException) as exc_info: + create_admin_api_key(mock_api_key_data, mock_current_user, mock_db) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert str(exc_info.value.detail) == ("An error occurred while creating " + "the API key.") + + +@patch("app.controllers.auth_controller.get_api_keys") +def test_read_keys_success(mock_get_api_keys): + + mock_get_api_keys.return_value = [{"key": "API key 1"}, + {"key": "API key 2"}] + mock_db = Mock(spec=Session) + mock_current_user = Mock(spec=CurrentUser) + mock_current_user.id = uuid.uuid4() + + result = read_keys(mock_current_user, mock_db) + + assert result == [{"key": "API key 1"}, {"key": "API key 2"}] + + +@patch("app.controllers.auth_controller.get_api_keys") +def test_read_keys_failure(mock_get_api_keys): + mock_get_api_keys.side_effect = Exception("An error occurred while retrieving the API keys.") + mock_db = Mock(spec=Session) + mock_current_user = Mock(spec=CurrentUser) + mock_current_user.id = uuid.uuid4() + + with pytest.raises(HTTPException) as exc_info: + read_keys(mock_current_user, mock_db) + + assert exc_info.value.status_code == 500 + assert str(exc_info.value.detail) == ("An error occurred while retrieving " + "the API key(s).") + + +def test_delete_admin_api_key_success(): + mock_db = Mock(spec=Session) + mock_current_user = Mock(spec=CurrentUser) + mock_current_user.id = uuid.uuid4() + api_key_id: uuid.UUID = uuid.uuid4() + label: str | None = None + + with patch("app.controllers.auth_controller.delete_api_key") as mock_delete_api_key: + mock_delete_api_key.return_value = {"Success": "API key deleted successfully."} + result = delete_admin_api_key(mock_current_user, mock_db, label, api_key_id) + assert result == {"Success": "API key deleted successfully."} + + +def test_delete_admin_api_key_failure(): + mock_db = Mock(spec=Session) + mock_current_user = Mock(spec=CurrentUser) + mock_current_user.id = uuid.uuid4() + api_key_id: uuid.UUID = uuid.uuid4() + label: str | None = None + + with patch("app.controllers.auth_controller.delete_api_key") as mock_delete_api_key: + mock_delete_api_key.side_effect = HTTPException( + status_code=500, + detail="An error occurred while deleting the API key." + ) + with pytest.raises(HTTPException) as exc_info: + delete_admin_api_key(mock_current_user, mock_db, label, api_key_id) + assert exc_info.value.status_code == 500 + assert str(exc_info.value.detail) == "An error occurred while deleting the API key." diff --git a/unit_tests/test_auth_crud.py b/unit_tests/test_auth_crud.py new file mode 100644 index 0000000000000000000000000000000000000000..056e2a00f893c099dde0bae555aa60efbfcb80d5 --- /dev/null +++ b/unit_tests/test_auth_crud.py @@ -0,0 +1,216 @@ +import pytest +from fastapi import HTTPException +from sqlalchemy.exc import SQLAlchemyError, IntegrityError, NoResultFound +from sqlalchemy.orm import Session +from unittest.mock import patch, Mock +from app.crud.auth_crud import ( + create_api_key, + delete_api_key, + get_api_keys, + _validate_type, + hash_string, + encrypt, + decrypt, + generate_secret_key, verify_api_key +) + +from app.schemas.token_schemas import ApiKeyCreate +import uuid + + +def db_mock(): + return Mock(spec=Session) + + +class MockFirst: + def all(self): + return [] + + def first(self): + pass + + +class MockFilter: + def filter(self, *args): + return MockFirst() + + +def test_validate_type_success(): + _validate_type("test", "arg", str) + + +def test_validate_type_failure(): + with pytest.raises(TypeError): + _validate_type(123, "arg", str) + + +def test_hash_string_success(): + hashed_value = hash_string("test") + assert isinstance(hashed_value, str) + + +def test_hash_string_failure(): + with pytest.raises(TypeError): + hash_string(123) + + +def test_encrypt_success(): + secret_key = generate_secret_key() + + encrypted_value = encrypt("test", secret_key) + assert isinstance(encrypted_value, str) + + +def test_encrypt_failure(): + secret_key = generate_secret_key() + + with pytest.raises(TypeError): + encrypt(123, secret_key) + + +def test_decrypt_success(): + secret_key = generate_secret_key() + + encrypted_value = encrypt("test", secret_key) + + decrypted_value = decrypt(encrypted_value, secret_key) + assert decrypted_value == "test" + + +def test_decrypt_failure(): + secret_key = generate_secret_key() + + with pytest.raises(TypeError): + decrypt(123, secret_key) + + with pytest.raises(HTTPException): + decrypt("invalid_token", secret_key) + + +def test_create_api_key_success(): + mock_db = db_mock() + mock_api_key_data = Mock(spec=ApiKeyCreate) + mock_api_key_data.expiration_date = "2022-12-31" + mock_api_key_data.label = "test_label" + user_id = uuid.uuid4() + + with patch("app.crud.auth_crud.ApiKeys") as mock_ApiKeys: + mock_ApiKeys.return_value = Mock() + result = create_api_key(mock_db, mock_api_key_data, user_id) + assert result['Success'] == 'The API key has been created.' + assert 'api_key' in result + assert 'api_key_id' in result + assert 'api_key_label' in result + assert 'expiration_date' in result + + +def test_create_api_key_db_exception(): + user_id = uuid.uuid4() + api_key_data = ApiKeyCreate(label="test_label", expiration_date="2022-12-31") + db = db_mock() + db.query.return_value.filter.return_value.first.return_value = None + db.add.side_effect = Exception + + with pytest.raises(HTTPException) as exc_info: + create_api_key(db, api_key_data, user_id) + + assert exc_info.value.status_code == 500 + assert exc_info.value.detail == "An error occurred while creating the API key." + + +def test_delete_api_key_success(): + db = db_mock() + user_id = uuid.uuid4() + label = "test_label" + api_key_id = uuid.uuid4() + + db.query.return_value.filter.return_value.first.return_value = Mock() + result = delete_api_key(user_id, db, label, api_key_id) + assert result == {'Success': 'API key deleted successfully'} + + +def test_delete_api_key_no_key_found(): + db = db_mock() + user_id = uuid.uuid4() + label = "test_label" + api_key_id = uuid.uuid4() + + db.query.return_value.filter.return_value.first.return_value = None + with pytest.raises(HTTPException) as exc_info: + delete_api_key(user_id, db, label, api_key_id) + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "API key not found." + + +def test_delete_api_key_db_exception_on_delete(): + db = db_mock() + user_id = uuid.uuid4() + label = "test_label" + api_key_id = uuid.uuid4() + + db.query.return_value.filter.return_value.first.return_value = Mock() + db.commit.side_effect = Exception + with pytest.raises(HTTPException) as exc_info: + delete_api_key(user_id, db, label, api_key_id) + assert exc_info.value.status_code == 500 + assert str(exc_info.value.detail) == ("An error occurred while deleting the API key.") + + +def test_get_api_keys_success(mocker): + mock_db = Mock(spec=Session) + user_id = uuid.uuid4() + + mock_api_key = Mock() + mock_api_key.id = uuid.uuid4() + mock_api_key.user_id = user_id + mock_api_key.hashed_api_key = "hashed_api_key" + mock_api_key.encryption_key = "encryption_key" + mock_api_key.expiration_date = "2022-12-31" + mock_api_key.label = "test_label" + + mocker.patch.object(Session, 'query', return_value=mock_db.query) + mock_db.query.return_value.filter.return_value.all.return_value = iter([mock_api_key, mock_api_key]) + result = get_api_keys(user_id, mock_db) + assert isinstance(result, list) + assert len(result) > 0 + + +def test_get_api_keys_providing_no_api_key(): + mock_db = db_mock() + user_id: uuid.UUID = uuid.uuid4() + mock_db.query.return_value.filter.return_value.all.return_value = [] + + with pytest.raises(HTTPException) as exc_info: + get_api_keys(user_id, mock_db) + + assert exc_info.value.status_code == 404 + assert str(exc_info.value.detail) == "No API keys found for this user." + + +def test_verify_api_key_success(): + mock_db = Mock(spec=Session) + mock_api_keys = Mock() + mock_db.query.return_value.filter.return_value.first.return_value = mock_api_keys + verify_api_key("valid_api_key", mock_db) + + +def test_verify_api_key_failure(): + mock_db = Mock(spec=Session) + mock_db.query.return_value.filter.return_value.first.side_effect = HTTPException(status_code=404, detail="The provided API key is not valid.") + + with pytest.raises(HTTPException) as exc_info: + verify_api_key("invalid_api_key", mock_db) + + assert exc_info.value.status_code == 404 + assert str(exc_info.value.detail) == "The provided API key is not valid." + + +def test_verify_api_key_no_key_found(): + mock_db = Mock(spec=Session) + mock_db.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(HTTPException) as exc_info: + verify_api_key("non_existent_api_key", mock_db) + + assert exc_info.value.status_code == 404 + assert str(exc_info.value.detail) == "The provided API key is not valid."