FastAPI + SQLAlchemy example¶
This example shows how to use Dependency Injector
with FastAPI and
SQLAlchemy.
The source code is available on the Github.
Thanks to @ShvetsovYura for providing initial example: FastAPI_DI_SqlAlchemy.
Application structure¶
Application has next structure:
./
├── webapp/
│ ├── __init__.py
│ ├── application.py
│ ├── containers.py
│ ├── database.py
│ ├── endpoints.py
│ ├── models.py
│ ├── repositories.py
│ ├── services.py
│ └── tests.py
├── config.yml
├── docker-compose.yml
├── Dockerfile
└── requirements.txt
Application factory¶
Application factory creates container, wires it with the endpoints
module, creates
FastAPI
app, and setup routes.
Application factory also creates database if it does not exist.
Listing of webapp/application.py
:
"""Application module."""
from fastapi import FastAPI
from .containers import Container
from . import endpoints
def create_app() -> FastAPI:
container = Container()
db = container.db()
db.create_database()
app = FastAPI()
app.container = container
app.include_router(endpoints.router)
return app
app = create_app()
Endpoints¶
Module endpoints
contains example endpoints. Endpoints have a dependency on user service.
User service is injected using Wiring feature. See webapp/endpoints.py
:
"""Endpoints module."""
from fastapi import APIRouter, Depends, Response, status
from dependency_injector.wiring import inject, Provide
from .containers import Container
from .services import UserService
from .repositories import NotFoundError
router = APIRouter()
@router.get("/users")
@inject
def get_list(
user_service: UserService = Depends(Provide[Container.user_service]),
):
return user_service.get_users()
@router.get("/users/{user_id}")
@inject
def get_by_id(
user_id: int,
user_service: UserService = Depends(Provide[Container.user_service]),
):
try:
return user_service.get_user_by_id(user_id)
except NotFoundError:
return Response(status_code=status.HTTP_404_NOT_FOUND)
@router.post("/users", status_code=status.HTTP_201_CREATED)
@inject
def add(
user_service: UserService = Depends(Provide[Container.user_service]),
):
return user_service.create_user()
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
def remove(
user_id: int,
user_service: UserService = Depends(Provide[Container.user_service]),
):
try:
user_service.delete_user_by_id(user_id)
except NotFoundError:
return Response(status_code=status.HTTP_404_NOT_FOUND)
else:
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get("/status")
def get_status():
return {"status": "OK"}
Container¶
Declarative container wires example user service, user repository, and utility database class.
See webapp/containers.py
:
"""Containers module."""
from dependency_injector import containers, providers
from .database import Database
from .repositories import UserRepository
from .services import UserService
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(modules=[".endpoints"])
config = providers.Configuration(yaml_files=["config.yml"])
db = providers.Singleton(Database, db_url=config.db.url)
user_repository = providers.Factory(
UserRepository,
session_factory=db.provided.session,
)
user_service = providers.Factory(
UserService,
user_repository=user_repository,
)
Services¶
Module services
contains example user service. See webapp/services.py
:
"""Services module."""
from uuid import uuid4
from typing import Iterator
from .repositories import UserRepository
from .models import User
class UserService:
def __init__(self, user_repository: UserRepository) -> None:
self._repository: UserRepository = user_repository
def get_users(self) -> Iterator[User]:
return self._repository.get_all()
def get_user_by_id(self, user_id: int) -> User:
return self._repository.get_by_id(user_id)
def create_user(self) -> User:
uid = uuid4()
return self._repository.add(email=f"{uid}@email.com", password="pwd")
def delete_user_by_id(self, user_id: int) -> None:
return self._repository.delete_by_id(user_id)
Repositories¶
Module repositories
contains example user repository. See webapp/repositories.py
:
"""Repositories module."""
from contextlib import AbstractContextManager
from typing import Callable, Iterator
from sqlalchemy.orm import Session
from .models import User
class UserRepository:
def __init__(self, session_factory: Callable[..., AbstractContextManager[Session]]) -> None:
self.session_factory = session_factory
def get_all(self) -> Iterator[User]:
with self.session_factory() as session:
return session.query(User).all()
def get_by_id(self, user_id: int) -> User:
with self.session_factory() as session:
user = session.query(User).filter(User.id == user_id).first()
if not user:
raise UserNotFoundError(user_id)
return user
def add(self, email: str, password: str, is_active: bool = True) -> User:
with self.session_factory() as session:
user = User(email=email, hashed_password=password, is_active=is_active)
session.add(user)
session.commit()
session.refresh(user)
return user
def delete_by_id(self, user_id: int) -> None:
with self.session_factory() as session:
entity: User = session.query(User).filter(User.id == user_id).first()
if not entity:
raise UserNotFoundError(user_id)
session.delete(entity)
session.commit()
class NotFoundError(Exception):
entity_name: str
def __init__(self, entity_id):
super().__init__(f"{self.entity_name} not found, id: {entity_id}")
class UserNotFoundError(NotFoundError):
entity_name: str = "User"
Models¶
Module models
contains example SQLAlchemy user model. See webapp/models.py
:
"""Models module."""
from sqlalchemy import Column, String, Boolean, Integer
from .database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
email = Column(String, unique=True)
hashed_password = Column(String)
is_active = Column(Boolean, default=True)
def __repr__(self):
return f"<User(id={self.id}, " \
f"email=\"{self.email}\", " \
f"hashed_password=\"{self.hashed_password}\", " \
f"is_active={self.is_active})>"
Database¶
Module database
defines declarative base and utility class with engine and session factory.
See webapp/database.py
:
"""Database module."""
from contextlib import contextmanager, AbstractContextManager
from typing import Callable
import logging
from sqlalchemy import create_engine, orm
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
Base = declarative_base()
class Database:
def __init__(self, db_url: str) -> None:
self._engine = create_engine(db_url, echo=True)
self._session_factory = orm.scoped_session(
orm.sessionmaker(
autocommit=False,
autoflush=False,
bind=self._engine,
),
)
def create_database(self) -> None:
Base.metadata.create_all(self._engine)
@contextmanager
def session(self) -> Callable[..., AbstractContextManager[Session]]:
session: Session = self._session_factory()
try:
yield session
except Exception:
logger.exception("Session rollback because of exception")
session.rollback()
raise
finally:
session.close()
Tests¶
Tests use Provider overriding feature to replace repository with a mock. See webapp/tests.py
:
"""Tests module."""
from unittest import mock
import pytest
from fastapi.testclient import TestClient
from .repositories import UserRepository, UserNotFoundError
from .models import User
from .application import app
@pytest.fixture
def client():
yield TestClient(app)
def test_get_list(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.get_all.return_value = [
User(id=1, email="test1@email.com", hashed_password="pwd", is_active=True),
User(id=2, email="test2@email.com", hashed_password="pwd", is_active=False),
]
with app.container.user_repository.override(repository_mock):
response = client.get("/users")
assert response.status_code == 200
data = response.json()
assert data == [
{"id": 1, "email": "test1@email.com", "hashed_password": "pwd", "is_active": True},
{"id": 2, "email": "test2@email.com", "hashed_password": "pwd", "is_active": False},
]
def test_get_by_id(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.get_by_id.return_value = User(
id=1,
email="xyz@email.com",
hashed_password="pwd",
is_active=True,
)
with app.container.user_repository.override(repository_mock):
response = client.get("/users/1")
assert response.status_code == 200
data = response.json()
assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True}
repository_mock.get_by_id.assert_called_once_with(1)
def test_get_by_id_404(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.get_by_id.side_effect = UserNotFoundError(1)
with app.container.user_repository.override(repository_mock):
response = client.get("/users/1")
assert response.status_code == 404
@mock.patch("webapp.services.uuid4", return_value="xyz")
def test_add(_, client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.add.return_value = User(
id=1,
email="xyz@email.com",
hashed_password="pwd",
is_active=True,
)
with app.container.user_repository.override(repository_mock):
response = client.post("/users")
assert response.status_code == 201
data = response.json()
assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True}
repository_mock.add.assert_called_once_with(email="xyz@email.com", password="pwd")
def test_remove(client):
repository_mock = mock.Mock(spec=UserRepository)
with app.container.user_repository.override(repository_mock):
response = client.delete("/users/1")
assert response.status_code == 204
repository_mock.delete_by_id.assert_called_once_with(1)
def test_remove_404(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.delete_by_id.side_effect = UserNotFoundError(1)
with app.container.user_repository.override(repository_mock):
response = client.delete("/users/1")
assert response.status_code == 404
def test_status(client):
response = client.get("/status")
assert response.status_code == 200
data = response.json()
assert data == {"status": "OK"}
Sources¶
The source code is available on the Github.
Sponsor the project on GitHub: |