Files
n3wt-school/Back-End/GestionMessagerie/middleware.py
2025-05-29 15:09:22 +02:00

109 lines
4.0 KiB
Python

import jwt
import logging
from urllib.parse import parse_qs
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from channels.middleware import BaseMiddleware
from channels.db import database_sync_to_async
from Auth.models import Profile
logger = logging.getLogger(__name__)
@database_sync_to_async
def get_user(user_id):
"""Récupérer l'utilisateur de manière asynchrone"""
try:
return Profile.objects.get(id=user_id)
except Profile.DoesNotExist:
return AnonymousUser()
class JWTAuthMiddleware(BaseMiddleware):
"""Middleware pour l'authentification JWT dans les WebSockets"""
def __init__(self, inner):
super().__init__(inner)
def _check_cors_origin(self, scope):
"""Vérifier si l'origine est autorisée pour les WebSockets"""
origin = None
# Récupérer l'origine depuis les headers
for name, value in scope.get('headers', []):
if name == b'origin':
origin = value.decode('latin1')
break
if not origin:
logger.warning("Aucune origine trouvée dans les headers WebSocket")
return False
# Récupérer les origines autorisées depuis la configuration CORS
allowed_origins = getattr(settings, 'CORS_ALLOWED_ORIGINS', [])
# Si CORS_ORIGIN_ALLOW_ALL est True, autoriser toutes les origines
if getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False):
logger.info(f"Origine WebSocket autorisée (CORS_ORIGIN_ALLOW_ALL): {origin}")
return True
# Vérifier si l'origine est dans la liste des origines autorisées
if origin in allowed_origins:
logger.info(f"Origine WebSocket autorisée: {origin}")
return True
logger.warning(f"Origine WebSocket non autorisée: {origin}. Origines autorisées: {allowed_origins}")
return False
async def __call__(self, scope, receive, send):
# Vérifier les CORS pour les WebSockets
if not self._check_cors_origin(scope):
logger.error("Connexion WebSocket refusée: origine non autorisée")
# Fermer la connexion WebSocket avec un code d'erreur
await send({
'type': 'websocket.close',
'code': 1008 # Policy Violation
})
return
# Extraire le token de l'URL
query_string = parse_qs(scope['query_string'].decode())
token = query_string.get('token')
if token:
token = token[0]
try:
# Décoder le token JWT
payload = jwt.decode(
token,
settings.SIMPLE_JWT['SIGNING_KEY'],
algorithms=[settings.SIMPLE_JWT['ALGORITHM']]
)
# Vérifier que c'est un token d'accès
if payload.get('type') != 'access':
logger.warning(f"Token type invalide: {payload.get('type')}")
scope['user'] = AnonymousUser()
else:
# Récupérer l'utilisateur
user_id = payload.get('user_id')
user = await get_user(user_id)
scope['user'] = user
logger.info(f"Utilisateur authentifié via JWT: {user.email if hasattr(user, 'email') else 'Unknown'}")
except jwt.ExpiredSignatureError:
logger.warning("Token JWT expiré")
scope['user'] = AnonymousUser()
except jwt.InvalidTokenError as e:
logger.warning(f"Token JWT invalide: {str(e)}")
scope['user'] = AnonymousUser()
except Exception as e:
logger.error(f"Erreur lors de l'authentification JWT: {str(e)}")
scope['user'] = AnonymousUser()
else:
scope['user'] = AnonymousUser()
return await super().__call__(scope, receive, send)
def JWTAuthMiddlewareStack(inner):
"""Stack middleware pour l'authentification JWT"""
return JWTAuthMiddleware(inner)