A single global chat channel is fun, but real-world chat applications typically feature multiple “rooms” or “channels” where users can have separate conversations. This chapter will modify our ConnectionManager and WebSocket endpoint to support room-based messaging.

Purpose of this Chapter

By the end of this chapter, you will:

  • Modify the ConnectionManager to manage connections per room.
  • Update the WebSocket endpoint to allow clients to specify a chat room.
  • Implement broadcasting messages only to users within the same room.
  • Create an endpoint to list available rooms.

Concepts Explained: Room-Based Messaging

Instead of a flat list of all active WebSocket connections, we’ll use a dictionary where keys are room names (e.g., “general”, “python-dev”, “random”) and values are lists of WebSocket objects for users in that specific room.

This structure allows us to send messages (broadcast) only to the connections associated with a particular room, significantly improving relevance and reducing unnecessary message traffic for clients not interested in a specific conversation.

Step-by-Step Tasks

1. Update app/connections.py for Room Management

Let’s refactor our ConnectionManager to handle multiple rooms.

# app/connections.py (updated)

from typing import Dict, List
from fastapi import WebSocket

class ConnectionManager:
    # active_connections: Dict[str, List[WebSocket]] where key is room_name
    def __init__(self):
        self.active_connections: Dict[str, List[WebSocket]] = {}

    async def connect(self, websocket: WebSocket, room_name: str):
        await websocket.accept()
        if room_name not in self.active_connections:
            self.active_connections[room_name] = []
        self.active_connections[room_name].append(websocket)
        print(f"Client connected to room '{room_name}'. Total connections in room: {len(self.active_connections[room_name])}")


    def disconnect(self, websocket: WebSocket, room_name: str):
        if room_name in self.active_connections:
            try:
                self.active_connections[room_name].remove(websocket)
                if not self.active_connections[room_name]: # Remove room if empty
                    del self.active_connections[room_name]
                print(f"Client disconnected from room '{room_name}'. Remaining in room: {len(self.active_connections.get(room_name, []))}")
            except ValueError:
                print(f"WebSocket not found in room '{room_name}' during disconnect.")


    async def send_personal_message(self, message: str, websocket: WebSocket):
        await websocket.send_text(message)

    async def broadcast(self, message: str, room_name: str):
        if room_name in self.active_connections:
            # Create a copy to avoid "list changed size during iteration" errors
            # if a client disconnects while iterating
            connections_in_room = list(self.active_connections[room_name])
            for connection in connections_in_room:
                try:
                    await connection.send_text(message)
                except RuntimeError:
                    # Connection might have closed during iteration, remove it
                    print(f"Removing dead connection from room '{room_name}'.")
                    self.active_connections[room_name].remove(connection)
                    if not self.active_connections[room_name]:
                        del self.active_connections[room_name]

Code Explanation (app/connections.py):

  • self.active_connections: Dict[str, List[WebSocket]]: Now stores a dictionary mapping room_name strings to lists of WebSocket objects.
  • connect(websocket, room_name): Takes an additional room_name argument. If the room doesn’t exist, it’s created. The WebSocket is then added to that room’s list.
  • disconnect(websocket, room_name): Removes the WebSocket from the specified room. If the room becomes empty, it’s removed from the active_connections dictionary.
  • broadcast(message, room_name): Only iterates through connections in the specified room_name to send the message. Includes improved error handling for dead connections during broadcast.

2. Update app/main.py for Room-based Chat

Now, let’s modify our app/main.py to use the room functionality. We’ll add a room_name path parameter to our WebSocket endpoint.

# app/main.py (further updated)

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from typing import Dict, List
from sqlalchemy.orm import Session
from sqlalchemy import select # New import for better queries

from .auth import Hasher, create_access_token, get_current_user_db, ACCESS_TOKEN_EXPIRE_MINUTES, Token, oauth2_scheme
from .connections import ConnectionManager
from .database import get_db
from . import models, schemas
from datetime import timedelta

from jose import jwt, JWTError # For get_current_user_db_websocket


app = FastAPI()

manager = ConnectionManager()

# Temporary "available rooms" list for demonstration
# In a real app, rooms could be dynamic and stored in the DB
AVAILABLE_ROOMS = ["general", "python", "frontend", "random"]

# Dependency to authenticate user for WebSocket connection
async def get_current_active_user_ws(websocket: WebSocket, token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials for WebSocket",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        # We need to extract the token from the WebSocket request.
        # This example assumes the token is passed as a query parameter or
        # can be intercepted from standard HTTP upgrade headers if custom middleware is used.
        # For simplicity, we'll assume a token is passed in the URL for now as ?token=XYZ
        # but the standard way is to use subprotocol or custom headers.
        # Let's adjust this dependency for a realistic token check during WS upgrade.
        # The `token: str = Depends(oauth2_scheme)` is for HTTP endpoints.
        # For WS, we often pass it as a query param or manually extract from headers.

        # For this example, let's assume the client passes the token in a query parameter named 'token'
        # e.g., ws://localhost:8000/ws/general?token=YOUR_JWT
        query_params = websocket.query_params
        token_from_ws = query_params.get("token")

        if not token_from_ws:
            raise credentials_exception

        payload = jwt.decode(token_from_ws, SECRET_KEY, algorithms=[ALGORITHM]) # SECRET_KEY imported from auth
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
        user = db.query(models.User).filter(models.User.username == username).first()
        if user is None:
            raise credentials_exception
        return user
    except JWTError:
        raise credentials_exception
    except HTTPException: # Re-raise HTTPExceptions from get_db or other deps
        raise
    except Exception as e:
        print(f"WebSocket authentication error: {e}")
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal authentication error")


@app.post("/register", response_model=schemas.UserResponse, status_code=status.HTTP_201_CREATED)
async def register_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
    db_user = db.query(models.User).filter(models.User.username == user.username).first()
    if db_user:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered")

    hashed_password = Hasher.get_password_hash(user.password)
    db_user = models.User(username=user.username, hashed_password=hashed_password)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    user = db.query(models.User).filter(models.User.username == form_data.username).first()
    if not user or not Hasher.verify_password(form_data.password, user.hashed_password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

@app.get("/users/me", response_model=schemas.UserResponse)
async def read_users_me(current_user: models.User = Depends(get_current_user_db)):
    return current_user

@app.get("/messages/{room_name}", response_model=List[schemas.MessageResponse])
async def get_chat_history_for_room(
    room_name: str,
    db: Session = Depends(get_db),
    current_user: models.User = Depends(get_current_user_db),
    skip: int = 0,
    limit: int = 100
):
    # For now, messages aren't room-specific in DB. We'll simulate fetching all
    # and filter on the client side if needed, or update the Message model later.
    # In a proper implementation, the Message model would have a 'room_name' field.
    messages = db.query(models.Message).order_by(models.Message.timestamp.desc()).offset(skip).limit(limit).all()
    # TODO: Filter messages by room_name once Message model supports it
    return messages


# New endpoint to get available chat rooms
@app.get("/rooms")
async def get_rooms():
    return {"rooms": AVAILABLE_ROOMS}

# Updated WebSocket endpoint for room-based chat
@app.websocket("/ws/{room_name}")
async def websocket_room_endpoint(
    websocket: WebSocket,
    room_name: str,
    db: Session = Depends(get_db),
    current_user: models.User = Depends(get_current_active_user_ws) # Authenticate user
):
    # Ensure room_name is valid/allowed
    if room_name not in AVAILABLE_ROOMS:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Invalid chat room")
        return

    # User is authenticated via current_user dependency
    await manager.connect(websocket, room_name)
    welcome_message = f"User '{current_user.username}' joined room '{room_name}'."
    await manager.broadcast(welcome_message, room_name)

    try:
        while True:
            data = await websocket.receive_text()
            full_message = f"[{room_name}] {current_user.username}: {data}"

            # Save message to database
            # Now with the actual authenticated user's ID
            new_message = models.Message(content=data, owner_id=current_user.id, timestamp=datetime.utcnow())
            db.add(new_message)
            db.commit()
            db.refresh(new_message)

            await manager.broadcast(full_message, room_name)
    except WebSocketDisconnect:
        manager.disconnect(websocket, room_name)
        await manager.broadcast(f"User '{current_user.username}' left room '{room_name}'.", room_name)
    except HTTPException as e: # Catch authentication exceptions from get_current_active_user_ws
        print(f"WebSocket authentication failed for user: {e.detail}")
        # The WebSocket connection is closed by FastAPI automatically if a dependency raises HTTPException.
        # No need for explicit close here, but good to log.
    except Exception as e:
        print(f"Unexpected error in WebSocket for user '{current_user.username}' in room '{room_name}': {e}")
        # Attempt to gracefully disconnect on unexpected errors
        manager.disconnect(websocket, room_name)
        await manager.broadcast(f"User '{current_user.username}' left room '{room_name}' due to an error.", room_name)

Note on WebSocket Authentication (get_current_active_user_ws): In get_current_active_user_ws, we now explicitly look for a JWT in the WebSocket’s query parameters (e.g., ws://localhost:8000/ws/general?token=YOUR_JWT). This is a common way to pass tokens to WebSockets when standard HTTP Authorization headers are not directly accessible during the WebSocket handshake from client-side JavaScript. This ensures that only authenticated users can establish a WebSocket connection.

To make chat history truly room-based, let’s add a room_name column to our Message model.

# app/models.py (updated to include room_name in Message)

from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from datetime import datetime

from .database import Base

class User(Base):
    __tablename__ = "users"

    id = Column(Integer, primary_key=True, index=True)
    username = Column(String, unique=True, index=True)
    hashed_password = Column(String)

    messages = relationship("Message", back_populates="owner")

class Message(Base):
    __tablename__ = "messages"

    id = Column(Integer, primary_key=True, index=True)
    content = Column(String, index=True)
    timestamp = Column(DateTime, default=datetime.utcnow)
    owner_id = Column(Integer, ForeignKey("users.id"))
    room_name = Column(String, index=True, default="general") # NEW: Add room_name

    owner = relationship("User", back_populates="messages")

After updating app/models.py, you need to apply this schema change to your database. For SQLite, this typically means:

  1. Deleting your existing chat.db file (for development purposes, not production!).
  2. Rerunning the database creation command:
    pipenv shell
    python -c "from app.database import Base, engine; from app import models; Base.metadata.create_all(bind=engine)"
    

This will recreate the chat.db file with the new room_name column.

Also, update the get_chat_history_for_room endpoint in app/main.py to filter by room:

# app/main.py (updated get_chat_history_for_room)

# ... (imports and other code) ...

@app.get("/messages/{room_name}", response_model=List[schemas.MessageResponse])
async def get_chat_history_for_room(
    room_name: str,
    db: Session = Depends(get_db),
    current_user: models.User = Depends(get_current_user_db),
    skip: int = 0,
    limit: int = 100
):
    if room_name not in AVAILABLE_ROOMS:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Room not found")

    messages = db.query(models.Message)\
                 .filter(models.Message.room_name == room_name)\
                 .order_by(models.Message.timestamp.desc())\
                 .offset(skip)\
                 .limit(limit)\
                 .all()
    return messages

# ... (rest of the code) ...

And in the websocket_room_endpoint, when saving a message:

            # Save message to database
            new_message = models.Message(
                content=data,
                owner_id=current_user.id,
                timestamp=datetime.utcnow(),
                room_name=room_name # Now save the room_name
            )

4. Update client.html for Room Selection and Authentication

The client now needs to specify a room and pass an authentication token.

<!-- client.html (further updated) -->
<!DOCTYPE html>
<html>
<head>
    <title>FastAPI WebSocket Chat Client</title>
</head>
<body>
    <h1>WebSocket Chat Test</h1>

    <div>
        <label for="usernameInput">Username:</label>
        <input type="text" id="usernameInput" value="chatuser" placeholder="Enter username">
        <label for="passwordInput">Password:</label>
        <input type="password" id="passwordInput" value="chatpassword" placeholder="Enter password">
        <button onclick="login()">Login</button>
        <button onclick="register()">Register</button>
    </div>
    <p>Auth Token: <span id="authToken">No Token</span></p>

    <hr>

    <div>
        <label for="roomSelect">Join Room:</label>
        <select id="roomSelect" onchange="updateWebSocket()">
            <option value="general">general</option>
            <option value="python">python</option>
            <option value="frontend">frontend</option>
            <option value="random">random</option>
        </select>
        <button onclick="updateWebSocket()">Connect/Reconnect</button>
    </div>
    <input type="text" id="messageInput" placeholder="Type a message">
    <button onclick="sendMessage()">Send</button>
    <button onclick="fetchHistory()">Fetch History</button>
    <div id="messages"></div>

    <script>
        let ws;
        let authToken = '';
        const messagesDiv = document.getElementById("messages");
        const messageInput = document.getElementById("messageInput");
        const roomSelect = document.getElementById("roomSelect");
        const usernameInput = document.getElementById("usernameInput");
        const passwordInput = document.getElementById("passwordInput");
        const authTokenSpan = document.getElementById("authToken");

        async function login() {
            const username = usernameInput.value;
            const password = passwordInput.value;
            if (!username || !password) {
                alert("Please enter both username and password.");
                return;
            }

            const formData = new URLSearchParams();
            formData.append('username', username);
            formData.append('password', password);

            try {
                const response = await fetch('http://localhost:8000/token', {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/x-www-form-urlencoded'
                    },
                    body: formData
                });
                if (!response.ok) {
                    throw new Error(`HTTP error! status: ${response.status}`);
                }
                const data = await response.json();
                authToken = data.access_token;
                authTokenSpan.textContent = authToken.substring(0, 30) + '...'; // Show truncated token
                messagesDiv.innerHTML += `<p>Logged in as ${username}. Token obtained.</p>`;
                updateWebSocket(); // Connect to WS after login
            } catch (error) {
                console.error("Login failed:", error);
                messagesDiv.innerHTML += `<p style='color:red;'>Login failed: ${error.message}</p>`;
            }
        }

        async function register() {
            const username = usernameInput.value;
            const password = passwordInput.value;
            if (!username || !password) {
                alert("Please enter both username and password.");
                return;
            }

            try {
                const response = await fetch('http://localhost:8000/register', {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json'
                    },
                    body: JSON.stringify({ username, password })
                });
                if (!response.ok) {
                    const errorData = await response.json();
                    throw new Error(`HTTP error! status: ${response.status} - ${errorData.detail}`);
                }
                const data = await response.json();
                messagesDiv.innerHTML += `<p>User ${data.username} registered successfully!</p>`;
            } catch (error) {
                console.error("Registration failed:", error);
                messagesDiv.innerHTML += `<p style='color:red;'>Registration failed: ${error.message}</p>`;
            }
        }

        function connectWebSocket(roomName) {
            if (ws) {
                ws.close();
            }
            if (!authToken) {
                messagesDiv.innerHTML += "<p style='color:red;'>Please login first to get an authentication token.</p>";
                return;
            }
            ws = new WebSocket(`ws://localhost:8000/ws/${roomName}?token=${authToken}`);

            ws.onopen = (event) => {
                messagesDiv.innerHTML += `<p>Connected to room '${roomName}' as ${usernameInput.value}!</p>`;
                console.log("WebSocket opened:", event);
            };

            ws.onmessage = (event) => {
                messagesDiv.innerHTML += `<p>Received: ${event.data}</p>`;
                console.log("WebSocket message:", event.data);
            };

            ws.onclose = (event) => {
                messagesDiv.innerHTML += `<p>Disconnected from room '${roomName}'.</p>`;
                console.log("WebSocket closed:", event);
            };

            ws.onerror = (event) => {
                messagesDiv.innerHTML += "<p style='color:red;'>WebSocket error!</p>";
                console.error("WebSocket error:", event);
            };
        }

        function updateWebSocket() {
            const roomName = roomSelect.value;
            if (authToken) { // Only connect if we have a token
                connectWebSocket(roomName);
            } else {
                messagesDiv.innerHTML += "<p style='color:orange;'>Login to connect to chat rooms.</p>";
            }
        }

        function sendMessage() {
            const message = messageInput.value;
            if (message && ws && ws.readyState === WebSocket.OPEN) {
                ws.send(message);
                messageInput.value = ""; // Clear input field
            } else if (!ws || ws.readyState !== WebSocket.OPEN) {
                messagesDiv.innerHTML += "<p style='color:orange;'>WebSocket is not open. Please connect to a room.</p>";
            }
        }

        async function fetchHistory() {
            if (!authToken) {
                messagesDiv.innerHTML += "<p style='color:red;'>Please login first to fetch history.</p>";
                return;
            }
            const roomName = roomSelect.value;
            try {
                const response = await fetch(`http://localhost:8000/messages/${roomName}`, {
                    headers: {
                        'Authorization': `Bearer ${authToken}`
                    }
                });
                if (!response.ok) {
                    throw new Error(`HTTP error! status: ${response.status}`);
                }
                const messages = await response.json();
                messagesDiv.innerHTML += `<h2>History for '${roomName}':</h2>`;
                if (messages.length === 0) {
                    messagesDiv.innerHTML += `<p>No history found for '${roomName}'.</p>`;
                } else {
                    messages.reverse().forEach(msg => { // Display oldest first
                        messagesDiv.innerHTML += `<p>[${new Date(msg.timestamp).toLocaleTimeString()}] ${msg.owner.username}: ${msg.content}</p>`;
                    });
                }
            } catch (error) {
                console.error("Failed to fetch chat history:", error);
                messagesDiv.innerHTML += `<p style='color:red;'>Failed to fetch history: ${error.message}</p>`;
            }
        }

        // Initial setup
        // It's a good idea to populate roomSelect dynamically from an API endpoint later.
        // For now, it's hardcoded in HTML.
        // Also, fetch existing rooms dynamically.
        async function loadRooms() {
            try {
                const response = await fetch('http://localhost:8000/rooms');
                if (!response.ok) throw new Error('Could not fetch rooms');
                const data = await response.json();
                roomSelect.innerHTML = ''; // Clear existing options
                data.rooms.forEach(room => {
                    const option = document.createElement('option');
                    option.value = room;
                    option.textContent = room;
                    roomSelect.appendChild(option);
                });
            } catch (error) {
                console.error("Error loading rooms:", error);
                messagesDiv.innerHTML += `<p style='color:red;'>Error loading rooms: ${error.message}</p>`;
            }
        }

        loadRooms(); // Call this on page load to populate rooms
    </script>
</body>
</html>

5. Run and Test Room-based Chat

  1. Ensure database is recreated with the room_name column in Message table, if you made that change.
  2. Start the server:
    pipenv shell
    uvicorn app.main:app --reload
    
  3. Open client.html in multiple browser tabs.
  4. Register a few users (e.g., user1/pass1, user2/pass2) using the Register button.
  5. Log in each client in separate tabs to get their JWTs. The token will be displayed.
  6. Connect to different rooms:
    • In one tab, select “general” and click “Connect/Reconnect”.
    • In a second tab, select “general” and click “Connect/Reconnect”.
    • In a third tab, select “python” and click “Connect/Reconnect”.
  7. Send messages:
    • Send a message from a “general” tab. Both “general” tabs should receive it, but the “python” tab should not.
    • Send a message from the “python” tab. Only that tab should receive it (until another user joins “python”).
  8. Test chat history: Click “Fetch History” in each tab for their respective rooms.

Tips/Challenges/Errors

  • WebSocket token in URL: Passing tokens in URLs (query parameters) is less secure than using HTTP headers. For WebSockets, this is a common workaround but consider Sec-WebSocket-Protocol or custom headers for production.
  • Database Schema Changes: Modifying models.py after chat.db is created requires a database migration. For development, deleting chat.db and recreating tables is acceptable. For production, use Alembic for proper schema migrations.
  • Front-end complexity: The client.html is growing. For a real front-end, you’d use a framework like React, Vue, or Angular to manage state and UI updates more effectively.

Summary/Key Takeaways

You’ve successfully enhanced your real-time chat application to support multiple chat rooms, significantly improving its utility. The ConnectionManager now intelligently handles connections per room, and messages are broadcast only to relevant participants. Furthermore, you’ve started integrating WebSocket authentication for secure room access. The next step is to refine the user authentication process, ensuring smooth registration and login flows.