import jwt import logging from urllib.parse import parse_qs from django.conf import settings from django.contrib.auth.models import AnonymousUser from channels.middleware import BaseMiddleware from channels.db import database_sync_to_async from Auth.models import Profile logger = logging.getLogger(__name__) @database_sync_to_async def get_user(user_id): """Récupérer l'utilisateur de manière asynchrone""" try: return Profile.objects.get(id=user_id) except Profile.DoesNotExist: return AnonymousUser() class JWTAuthMiddleware(BaseMiddleware): """Middleware pour l'authentification JWT dans les WebSockets""" def __init__(self, inner): super().__init__(inner) def _check_cors_origin(self, scope): """Vérifier si l'origine est autorisée pour les WebSockets""" origin = None # Récupérer l'origine depuis les headers for name, value in scope.get('headers', []): if name == b'origin': origin = value.decode('latin1') break if not origin: logger.warning("Aucune origine trouvée dans les headers WebSocket") return False # Récupérer les origines autorisées depuis la configuration CORS allowed_origins = getattr(settings, 'CORS_ALLOWED_ORIGINS', []) # Si CORS_ORIGIN_ALLOW_ALL est True, autoriser toutes les origines if getattr(settings, 'CORS_ORIGIN_ALLOW_ALL', False): logger.info(f"Origine WebSocket autorisée (CORS_ORIGIN_ALLOW_ALL): {origin}") return True # Vérifier si l'origine est dans la liste des origines autorisées if origin in allowed_origins: logger.info(f"Origine WebSocket autorisée: {origin}") return True logger.warning(f"Origine WebSocket non autorisée: {origin}. Origines autorisées: {allowed_origins}") return False async def __call__(self, scope, receive, send): # Vérifier les CORS pour les WebSockets if not self._check_cors_origin(scope): logger.error("Connexion WebSocket refusée: origine non autorisée") # Fermer la connexion WebSocket avec un code d'erreur await send({ 'type': 'websocket.close', 'code': 1008 # Policy Violation }) return # Extraire le token de l'URL query_string = parse_qs(scope['query_string'].decode()) token = query_string.get('token') if token: token = token[0] try: # Décoder le token JWT payload = jwt.decode( token, settings.SIMPLE_JWT['SIGNING_KEY'], algorithms=[settings.SIMPLE_JWT['ALGORITHM']] ) # Vérifier que c'est un token d'accès if payload.get('type') != 'access': logger.warning(f"Token type invalide: {payload.get('type')}") scope['user'] = AnonymousUser() else: # Récupérer l'utilisateur user_id = payload.get('user_id') user = await get_user(user_id) scope['user'] = user logger.info(f"Utilisateur authentifié via JWT: {user.email if hasattr(user, 'email') else 'Unknown'}") except jwt.ExpiredSignatureError: logger.warning("Token JWT expiré") scope['user'] = AnonymousUser() except jwt.InvalidTokenError as e: logger.warning(f"Token JWT invalide: {str(e)}") scope['user'] = AnonymousUser() except Exception as e: logger.error(f"Erreur lors de l'authentification JWT: {str(e)}") scope['user'] = AnonymousUser() else: scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) def JWTAuthMiddlewareStack(inner): """Stack middleware pour l'authentification JWT""" return JWTAuthMiddleware(inner)