Django REST Framework(DRF)

深入掌握 DRF 核心组件:Serializer、视图、路由、认证、权限、分页、过滤、限流、测试等。以 DRF 3.15.x + Django 5.x 为基准。


目录


一、安装与基础配置

pip install djangorestframework
pip install djangorestframework-simplejwt
pip install django-filter
pip install drf-spectacular          # OpenAPI 文档(推荐替代 drf-yasg)
# settings.py
INSTALLED_APPS = [
    ...
    'rest_framework',
    'rest_framework_simplejwt',
    'django_filters',
    'drf_spectacular',
]

REST_FRAMEWORK = {
    # 认证
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework_simplejwt.authentication.JWTAuthentication',
        'rest_framework.authentication.SessionAuthentication',   # Browsable API 用
    ],
    # 权限(全局默认,视图可覆盖)
    'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.IsAuthenticated',
    ],
    # 分页
    'DEFAULT_PAGINATION_CLASS': 'apps.common.pagination.StandardPagination',
    'PAGE_SIZE': 20,
    # 过滤
    'DEFAULT_FILTER_BACKENDS': [
        'django_filters.rest_framework.DjangoFilterBackend',
        'rest_framework.filters.SearchFilter',
        'rest_framework.filters.OrderingFilter',
    ],
    # 限流
    'DEFAULT_THROTTLE_CLASSES': [
        'rest_framework.throttling.AnonRateThrottle',
        'rest_framework.throttling.UserRateThrottle',
    ],
    'DEFAULT_THROTTLE_RATES': {
        'anon': '100/day',
        'user': '1000/day',
        'login': '10/minute',
    },
    # 异常处理
    'EXCEPTION_HANDLER': 'apps.common.exceptions.custom_exception_handler',
    # 渲染器
    'DEFAULT_RENDERER_CLASSES': [
        'rest_framework.renderers.JSONRenderer',
        # 生产环境去掉 BrowsableAPIRenderer
        'rest_framework.renderers.BrowsableAPIRenderer',
    ],
    # Schema(OpenAPI 文档)
    'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
}

二、Serializer 序列化器

2.1 Serializer 基础

from rest_framework import serializers

class UserSerializer(serializers.Serializer):
    id       = serializers.IntegerField(read_only=True)
    username = serializers.CharField(max_length=50)
    email    = serializers.EmailField()
    password = serializers.CharField(write_only=True, min_length=6)
    age      = serializers.IntegerField(required=False, allow_null=True)

    # 单字段验证(validate_<fieldname>)
    def validate_username(self, value):
        from apps.users.models import User
        if User.objects.filter(username=value).exists():
            raise serializers.ValidationError('用户名已存在')
        return value.strip()

    # 跨字段验证
    def validate(self, data):
        if data.get('age') and data['age'] < 0:
            raise serializers.ValidationError({'age': '年龄不能为负数'})
        return data

    # 创建
    def create(self, validated_data):
        from apps.users.models import User
        return User.objects.create_user(**validated_data)

    # 更新
    def update(self, instance, validated_data):
        for key, value in validated_data.items():
            setattr(instance, key, value)
        instance.save()
        return instance
# 使用
serializer = UserSerializer(data=request.data)
serializer.is_valid(raise_exception=True)   # 验证失败直接返回 400
user = serializer.save()                    # 调用 create()

# 序列化对象
serializer = UserSerializer(user)
serializer.data                             # dict

# 序列化列表
serializer = UserSerializer(users, many=True)
serializer.data                             # list of dict

# 部分更新(PATCH)
serializer = UserSerializer(user, data=request.data, partial=True)

2.2 ModelSerializer

from rest_framework import serializers
from apps.users.models import User, Profile

class UserCreateSerializer(serializers.ModelSerializer):
    confirm_password = serializers.CharField(write_only=True)

    class Meta:
        model  = User
        fields = ['id', 'username', 'email', 'password', 'confirm_password', 'age']
        read_only_fields = ['id']
        extra_kwargs = {
            'password': {'write_only': True, 'min_length': 6},
            'email':    {'validators': []},    # 清除默认唯一性验证,自己处理
        }

    def validate(self, data):
        if data['password'] != data.pop('confirm_password'):
            raise serializers.ValidationError({'confirm_password': '两次密码不一致'})
        return data

    def create(self, validated_data):
        return User.objects.create_user(**validated_data)


class UserListSerializer(serializers.ModelSerializer):
    """列表用:精简字段"""
    class Meta:
        model  = User
        fields = ['id', 'username', 'email', 'is_active', 'created_at']


class UserDetailSerializer(serializers.ModelSerializer):
    """详情用:完整字段"""
    profile = ProfileSerializer(read_only=True)   # 嵌套

    class Meta:
        model  = User
        fields = ['id', 'username', 'email', 'age', 'is_active',
                  'profile', 'date_joined', 'last_login']
        read_only_fields = ['id', 'date_joined', 'last_login']

2.3 字段类型与参数

常用字段类型

字段

说明

CharField

字符串

IntegerField

整数

FloatField

浮点数

DecimalField(max_digits, decimal_places)

精确小数

BooleanField

布尔

DateField

日期

DateTimeField

日期时间

EmailField

邮箱

URLField

URL

UUIDField

UUID

JSONField

JSON

ChoiceField(choices)

枚举选项

ListField(child=...)

列表

DictField(child=...)

字典

SerializerMethodField

自定义只读字段

HiddenField(default=...)

隐藏字段(不接受输入)

PrimaryKeyRelatedField

关联主键

SlugRelatedField

关联指定字段

StringRelatedField

关联 __str__

字段通用参数

serializers.CharField(
    required=True,          # 是否必填(默认True)
    allow_null=False,       # 是否允许 null
    allow_blank=False,      # 是否允许空字符串
    default=None,           # 默认值
    read_only=False,        # 只用于序列化,不接受输入
    write_only=False,       # 只用于反序列化,不出现在响应
    source='model_field',   # 映射到模型的哪个字段
    label='字段标签',        # 显示名称
    help_text='说明',        # 帮助文本
    validators=[],          # 验证器列表
    error_messages={},      # 自定义错误信息
)

2.4 验证器

from rest_framework.validators import UniqueValidator, UniqueTogetherValidator

class UserSerializer(serializers.ModelSerializer):
    email = serializers.EmailField(
        validators=[
            UniqueValidator(
                queryset=User.objects.all(),
                message='邮箱已注册'
            )
        ]
    )

    class Meta:
        model  = User
        fields = ['username', 'email']
        validators = [
            UniqueTogetherValidator(
                queryset=User.objects.all(),
                fields=['username', 'email'],
                message='用户名和邮箱组合已存在'
            )
        ]


# 自定义验证器函数
def validate_phone(value):
    import re
    if not re.match(r'^1[3-9]\d{9}$', value):
        raise serializers.ValidationError('手机号格式不正确')
    return value

class ProfileSerializer(serializers.Serializer):
    phone = serializers.CharField(validators=[validate_phone])


# 自定义验证器类
class PastDateValidator:
    def __call__(self, value):
        from django.utils import timezone
        if value >= timezone.now().date():
            raise serializers.ValidationError('必须是过去的日期')

class EventSerializer(serializers.Serializer):
    end_date = serializers.DateField(validators=[PastDateValidator()])

2.5 嵌套序列化

# 只读嵌套(展示关联对象)
class OrderItemSerializer(serializers.ModelSerializer):
    class Meta:
        model  = OrderItem
        fields = ['id', 'product_name', 'quantity', 'price']

class OrderDetailSerializer(serializers.ModelSerializer):
    user  = UserListSerializer(read_only=True)           # 一对一/多对一
    items = OrderItemSerializer(many=True, read_only=True)  # 一对多

    class Meta:
        model  = Order
        fields = ['id', 'order_no', 'status', 'amount', 'user', 'items', 'created_at']


# 可写嵌套(同时创建关联对象)
class OrderCreateSerializer(serializers.ModelSerializer):
    items = OrderItemSerializer(many=True)

    class Meta:
        model  = Order
        fields = ['order_no', 'items']

    def create(self, validated_data):
        items_data = validated_data.pop('items')
        order = Order.objects.create(**validated_data)
        for item_data in items_data:
            OrderItem.objects.create(order=order, **item_data)
        return order

    def update(self, instance, validated_data):
        items_data = validated_data.pop('items', None)
        instance = super().update(instance, validated_data)
        if items_data is not None:
            instance.items.all().delete()              # 先删再创建(简单策略)
            for item_data in items_data:
                OrderItem.objects.create(order=instance, **item_data)
        return instance

2.6 SerializerMethodField 与动态字段

class UserSerializer(serializers.ModelSerializer):
    # 自定义只读字段
    full_name   = serializers.SerializerMethodField()
    order_count = serializers.SerializerMethodField()
    is_owner    = serializers.SerializerMethodField()

    def get_full_name(self, obj):
        return f'{obj.first_name} {obj.last_name}'.strip()

    def get_order_count(self, obj):
        return obj.orders.count()

    def get_is_owner(self, obj):
        request = self.context.get('request')
        if request and request.user.is_authenticated:
            return obj.id == request.user.id
        return False

    class Meta:
        model  = User
        fields = ['id', 'username', 'full_name', 'order_count', 'is_owner']


# 动态字段(根据请求参数决定返回哪些字段)
class DynamicFieldsSerializer(serializers.ModelSerializer):
    def __init__(self, *args, **kwargs):
        fields = kwargs.pop('fields', None)
        exclude = kwargs.pop('exclude', None)
        super().__init__(*args, **kwargs)

        if fields:
            allowed = set(fields)
            existing = set(self.fields)
            for field_name in existing - allowed:
                self.fields.pop(field_name)

        if exclude:
            for field_name in exclude:
                self.fields.pop(field_name, None)

class UserSerializer(DynamicFieldsSerializer):
    class Meta:
        model  = User
        fields = ['id', 'username', 'email', 'age', 'is_active']

# 使用
UserSerializer(user, fields=['id', 'username'])
UserSerializer(users, many=True, exclude=['age'])

2.7 序列化器上下文

# 视图中传递 context
class UserViewSet(viewsets.ModelViewSet):
    def get_serializer_context(self):
        ctx = super().get_serializer_context()
        ctx['show_sensitive'] = self.request.user.is_staff
        return ctx

# 序列化器中使用
class UserSerializer(serializers.ModelSerializer):
    phone = serializers.SerializerMethodField()

    def get_phone(self, obj):
        if self.context.get('show_sensitive'):
            return obj.phone
        return '***'   # 脱敏

三、视图

3.1 APIView

from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from django.shortcuts import get_object_or_404

class UserListView(APIView):
    """手动实现每个 HTTP 方法,最灵活"""

    permission_classes = [IsAuthenticated]   # 覆盖全局设置

    def get(self, request):
        users = User.objects.filter(is_active=True)
        serializer = UserListSerializer(users, many=True, context={'request': request})
        return Response({'code': 0, 'data': serializer.data})

    def post(self, request):
        serializer = UserCreateSerializer(data=request.data, context={'request': request})
        serializer.is_valid(raise_exception=True)
        user = serializer.save()
        return Response(
            UserDetailSerializer(user).data,
            status=status.HTTP_201_CREATED
        )


class UserDetailView(APIView):
    def get_object(self, pk):
        return get_object_or_404(User, pk=pk, is_active=True)

    def get(self, request, pk):
        user = self.get_object(pk)
        serializer = UserDetailSerializer(user, context={'request': request})
        return Response({'code': 0, 'data': serializer.data})

    def put(self, request, pk):
        user = self.get_object(pk)
        serializer = UserDetailSerializer(user, data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response({'code': 0, 'data': serializer.data})

    def patch(self, request, pk):
        user = self.get_object(pk)
        serializer = UserDetailSerializer(user, data=request.data, partial=True)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response({'code': 0, 'data': serializer.data})

    def delete(self, request, pk):
        user = self.get_object(pk)
        user.delete()
        return Response(status=status.HTTP_204_NO_CONTENT)

3.2 GenericAPIView + Mixin

from rest_framework.generics import GenericAPIView
from rest_framework.mixins import (
    ListModelMixin, CreateModelMixin,
    RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin
)

class UserListView(ListModelMixin, CreateModelMixin, GenericAPIView):
    queryset = User.objects.filter(is_active=True).order_by('-date_joined')
    serializer_class = UserSerializer

    def get(self, request, *args, **kwargs):
        return self.list(request, *args, **kwargs)      # ListModelMixin 提供

    def post(self, request, *args, **kwargs):
        return self.create(request, *args, **kwargs)    # CreateModelMixin 提供


class UserDetailView(RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin, GenericAPIView):
    queryset = User.objects.all()
    serializer_class = UserSerializer

    def get(self, request, *args, **kwargs):
        return self.retrieve(request, *args, **kwargs)

    def put(self, request, *args, **kwargs):
        return self.update(request, *args, **kwargs)

    def patch(self, request, *args, **kwargs):
        return self.partial_update(request, *args, **kwargs)

    def delete(self, request, *args, **kwargs):
        return self.destroy(request, *args, **kwargs)

3.3 通用视图(最简洁)

from rest_framework import generics

# 列表 + 创建
class UserListCreateView(generics.ListCreateAPIView):
    queryset = User.objects.all()
    serializer_class = UserSerializer

# 详情 + 更新 + 删除
class UserRetrieveUpdateDestroyView(generics.RetrieveUpdateDestroyAPIView):
    queryset = User.objects.all()
    serializer_class = UserSerializer

# 其他通用视图
# generics.ListAPIView              只读列表
# generics.CreateAPIView            只创建
# generics.RetrieveAPIView          只读详情
# generics.UpdateAPIView            只更新
# generics.DestroyAPIView           只删除
# generics.RetrieveUpdateAPIView    详情 + 更新
# generics.RetrieveDestroyAPIView   详情 + 删除


# 覆盖通用视图的行为
class UserListCreateView(generics.ListCreateAPIView):
    serializer_class = UserSerializer

    def get_queryset(self):
        """动态 queryset"""
        qs = User.objects.filter(is_active=True)
        role = self.request.query_params.get('role')
        if role:
            qs = qs.filter(role=role)
        return qs

    def get_serializer_class(self):
        """不同操作用不同 Serializer"""
        if self.request.method == 'POST':
            return UserCreateSerializer
        return UserListSerializer

    def perform_create(self, serializer):
        """创建时注入额外数据"""
        serializer.save(
            created_by=self.request.user,
            ip_address=self.request.META.get('REMOTE_ADDR')
        )

    def list(self, request, *args, **kwargs):
        """自定义列表响应格式"""
        response = super().list(request, *args, **kwargs)
        return Response({'code': 0, 'data': response.data})

3.4 ViewSet

from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.response import Response

class UserViewSet(viewsets.ViewSet):
    """手动实现,最灵活的 ViewSet"""

    def list(self, request):
        users = User.objects.all()
        serializer = UserSerializer(users, many=True)
        return Response(serializer.data)

    def create(self, request):
        serializer = UserCreateSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data, status=201)

    def retrieve(self, request, pk=None):
        user = get_object_or_404(User, pk=pk)
        serializer = UserSerializer(user)
        return Response(serializer.data)

    def update(self, request, pk=None):
        user = get_object_or_404(User, pk=pk)
        serializer = UserSerializer(user, data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data)

    def partial_update(self, request, pk=None):
        user = get_object_or_404(User, pk=pk)
        serializer = UserSerializer(user, data=request.data, partial=True)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data)

    def destroy(self, request, pk=None):
        user = get_object_or_404(User, pk=pk)
        user.delete()
        return Response(status=204)

3.5 ModelViewSet

class UserViewSet(viewsets.ModelViewSet):
    """自动实现 CRUD,最省代码"""
    queryset = User.objects.all().order_by('-date_joined')
    serializer_class = UserSerializer
    filterset_class = UserFilter
    search_fields   = ['username', 'email', 'phone']
    ordering_fields = ['date_joined', 'username', 'age']
    ordering        = ['-date_joined']

    def get_queryset(self):
        qs = super().get_queryset()
        # 非管理员只能看到自己
        if not self.request.user.is_staff:
            qs = qs.filter(id=self.request.user.id)
        return qs

    def get_serializer_class(self):
        serializer_map = {
            'list':    UserListSerializer,
            'create':  UserCreateSerializer,
            'retrieve': UserDetailSerializer,
            'update':   UserUpdateSerializer,
            'partial_update': UserUpdateSerializer,
        }
        return serializer_map.get(self.action, UserSerializer)

    def get_permissions(self):
        """动态权限"""
        if self.action in ['list', 'retrieve']:
            return [IsAuthenticated()]
        elif self.action == 'create':
            return [AllowAny()]
        return [IsAuthenticated(), IsOwnerOrAdmin()]

    def perform_create(self, serializer):
        serializer.save(created_by=self.request.user)

    def perform_update(self, serializer):
        serializer.save(updated_by=self.request.user)

    def perform_destroy(self, instance):
        # 软删除
        instance.is_active = False
        instance.save(update_fields=['is_active'])

3.6 自定义 Action

class UserViewSet(viewsets.ModelViewSet):
    queryset = User.objects.all()
    serializer_class = UserSerializer

    # detail=True  → /users/{id}/change-password/
    @action(detail=True, methods=['post'], url_path='change-password',
            permission_classes=[IsAuthenticated])
    def change_password(self, request, pk=None):
        user = self.get_object()
        serializer = ChangePasswordSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        if not user.check_password(serializer.validated_data['old_password']):
            return Response({'msg': '原密码错误'}, status=400)
        user.set_password(serializer.validated_data['new_password'])
        user.save()
        return Response({'msg': '密码修改成功'})

    # detail=False → /users/me/
    @action(detail=False, methods=['get', 'patch'], url_path='me',
            permission_classes=[IsAuthenticated])
    def me(self, request):
        if request.method == 'GET':
            serializer = UserDetailSerializer(request.user, context={'request': request})
            return Response(serializer.data)
        serializer = UserUpdateSerializer(request.user, data=request.data, partial=True)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data)

    # detail=False → /users/batch-delete/
    @action(detail=False, methods=['delete'], url_path='batch-delete',
            permission_classes=[IsAdminUser])
    def batch_delete(self, request):
        ids = request.data.get('ids', [])
        if not ids:
            return Response({'msg': 'ids 不能为空'}, status=400)
        count, _ = User.objects.filter(id__in=ids).delete()
        return Response({'msg': f'已删除 {count} 个用户'})

    # 自定义 URL 和 HTTP 方法
    @action(detail=True, methods=['post'], url_path='avatar',
            url_name='upload-avatar', parser_classes=[MultiPartParser])
    def upload_avatar(self, request, pk=None):
        user = self.get_object()
        serializer = AvatarSerializer(user, data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response({'avatar_url': user.avatar.url})

四、路由(Router)

from rest_framework.routers import DefaultRouter, SimpleRouter

router = DefaultRouter()    # 包含 API 根页面
# router = SimpleRouter()   # 不包含 API 根页面

router.register('users',   UserViewSet,   basename='user')
router.register('orders',  OrderViewSet,  basename='order')
router.register('products', ProductViewSet, basename='product')

# urls.py
from django.urls import path, include

urlpatterns = [
    path('api/', include(router.urls)),
    # 也可以包含非 ViewSet 的路由
    path('api/auth/login/',   TokenObtainPairView.as_view(),  name='login'),
    path('api/auth/refresh/', TokenRefreshView.as_view(),     name='token_refresh'),
]

# DefaultRouter 自动生成的路由
# GET    /api/users/                  list
# POST   /api/users/                  create
# GET    /api/users/{id}/             retrieve
# PUT    /api/users/{id}/             update
# PATCH  /api/users/{id}/             partial_update
# DELETE /api/users/{id}/             destroy
# POST   /api/users/{id}/change-password/    自定义 detail action
# GET    /api/users/me/                       自定义 list action
# DELETE /api/users/batch-delete/             自定义 list action


# 嵌套路由(pip install drf-nested-routers)
from rest_framework_nested import routers as nested_routers

router = nested_routers.DefaultRouter()
router.register('users', UserViewSet, basename='user')

users_router = nested_routers.NestedDefaultRouter(router, 'users', lookup='user')
users_router.register('orders', UserOrderViewSet, basename='user-order')
# GET /api/users/{user_pk}/orders/
# GET /api/users/{user_pk}/orders/{id}/

urlpatterns += router.urls + users_router.urls

五、认证(Authentication)

在 DRF 处理请求时,request.user 是第一次被访问时才触发赋值(懒加载)。从DEFAULT_AUTHENTICATION_CLASSES配置的认证类中尝试认证

5.1 内置认证方式

这种方式生成的token是永不过期的,

# TokenAuthentication 使用
INSTALLED_APPS += [..., 'rest_framework.authtoken']

# 会话认证(默认,Browsable API 使用)
REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [ # DRF 支持同时配置多种认证方式,请求进来时会按顺序逐个尝试
        'rest_framework.authentication.TokenAuthentication' # TokenAuthentication(DRF 内置 Token)
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.BasicAuthentication',
    ]
}
# 创建 token 表
python manage.py migrate
# 创建接口
from rest_framework.authtoken.models import Token
token, created = Token.objects.get_or_create(user=user)
# 请求头:Authorization: Token 9944b09199c62bcf9418ad846dd0e4bbdfc6ee4f

5.2 DRF自定义认证(了解即可)

from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed

class APIKeyAuthentication(BaseAuthentication):
    """API Key 认证"""

    def authenticate(self, request):
        api_key = request.META.get('HTTP_X_API_KEY')
        if not api_key:
            return None    # 返回 None 继续尝试其他认证方式

        try:
            key_obj = APIKey.objects.get(key=api_key, is_active=True)
            return (key_obj.user, key_obj)   # (user, auth)
        except APIKey.DoesNotExist:
            raise AuthenticationFailed('无效的 API Key')

    def authenticate_header(self, request):
        return 'X-API-Key'   # WWW-Authenticate 响应头

5.3 JWT 认证(Simple JWT)

pip install djangorestframework-simplejwt
# settings.py
from datetime import timedelta

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework_simplejwt.authentication.JWTAuthentication',
    ]
}

SIMPLE_JWT = {
    'ACCESS_TOKEN_LIFETIME':   timedelta(hours=2),
    'REFRESH_TOKEN_LIFETIME':  timedelta(days=30),
    'ROTATE_REFRESH_TOKENS':   True,        # 刷新时轮换 refresh token
    'BLACKLIST_AFTER_ROTATION': True,       # 旧 refresh token 加入黑名单
    'ALGORITHM': 'HS256',
    'SIGNING_KEY': SECRET_KEY,
    'AUTH_HEADER_TYPES': ('Bearer',),
    'AUTH_TOKEN_CLASSES': ('rest_framework_simplejwt.tokens.AccessToken',),
    # 自定义 Payload
    'TOKEN_OBTAIN_SERIALIZER': 'apps.users.serializers.CustomTokenObtainPairSerializer',
}
# urls.py
from rest_framework_simplejwt.views import (
    TokenObtainPairView,
    TokenRefreshView,
    TokenVerifyView,
    TokenBlacklistView,
)

urlpatterns = [
    path('api/auth/login/',     TokenObtainPairView.as_view(),    name='token_obtain'),
    path('api/auth/refresh/',   TokenRefreshView.as_view(),       name='token_refresh'),
    path('api/auth/verify/',    TokenVerifyView.as_view(),        name='token_verify'),
    path('api/auth/logout/',    TokenBlacklistView.as_view(),     name='token_blacklist'),
]
# 自定义 Token Payload(添加额外信息)

# views.py
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.views import TokenObtainPairView

class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
    @classmethod
    def get_token(cls, user): # 控制的是生成的token
        token = super().get_token(user)
        # 添加自定义字段到 payload
        token['username'] = user.username
        token['email']    = user.email
        token['role']     = user.role
        return token

    def validate(self, attrs): # 控制的是返回的json
        data = super().validate(attrs)
        # 响应中附加用户信息
        data['user'] = UserListSerializer(self.user).data
        return data

class CustomTokenObtainPairView(TokenObtainPairView):
    serializer_class = CustomTokenObtainPairSerializer

# urls.py
urlpatterns = [
    path('auth/login/', CustomTokenObtainPairView.as_view()),  # 换成自定义的
]

# 结果payload 变成:
{
    "token_type": "access",
    "exp": 1234567890,
    "user_id": 1,
    "username": "john",
    "email": "john@example.com",
    "role": "admin"
}
# 结果返回给前端变成:
{
    "access": "eyJ...",
    "refresh": "eyJ...",
    "user": {
        "id": 1,
        "username": "john",
        "email": "john@example.com",
        "role": "admin"
    }
}

5.4 dj-rest-auth状态管理

pip install dj-rest-auth
# settings.py
INSTALLED_APPS = [
    ...
    'rest_framework',
    'rest_framework.authtoken',
    'dj_rest_auth',
]
# urls.py
from django.urls import path, include

urlpatterns = [
    path('api/auth/', include('dj_rest_auth.urls')),
]

会自动包含以下接口
| URL                         | 方法 | 说明     |
| --------------------------- | ---- | -------- |
| `api/auth/login/`           | POST | 登录     |
| `api/auth/logout/`          | POST | 登出     |
| `api/auth/password/change/` | POST | 改密码   |
| `api/auth/password/reset/`  | POST | 重置密码 |

改密码请求体:

{
  "old_password": "oldPass123!",
  "new_password1": "newPass456!",
  "new_password2": "newPass456!"
}

5.5 自定义视图函数修改密码

当有额外业务逻辑时(如记录改密时间、发送通知、强制重新登录),需要自定义视图。

dj_rest_auth 默认不一定触发 AUTH_PASSWORD_VALIDATORS,建议手动调用 Django 修改密码时默认不对旧密码进行验证,除非在settings.py 中设置 OLD_PASSWORD_FIELD_ENABLED = True

# serializers.py
from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError as DjangoValidationError
from rest_framework import serializers


class PasswordChangeSerializerCus(serializers.Serializer):
    old_password = serializers.CharField(write_only=True)
    new_password1 = serializers.CharField(write_only=True)
    new_password2 = serializers.CharField(write_only=True)

    def validate_old_password(self, value):
        user = self.context['request'].user
        if not user.check_password(value):
            raise serializers.ValidationError('原密码错误')
        return value

    def validate(self, data):
        if data['new_password1'] != data['new_password2']:
            raise serializers.ValidationError({'new_password2': '两次密码不一致'})
        try:
            # 手动触发 Django 的settings.py 的 AUTH_PASSWORD_VALIDATORS 中所有验证器
            # 如果没有这步,相当于没有对密码强度进行校验
            validate_password(data['new_password1'], user=self.context['request'].user)
        except DjangoValidationError as e:
            raise serializers.ValidationError({'new_password1': e.messages})
        return data

    def save(self):
        user = self.context['request'].user
        user.set_password(self.validated_data['new_password1'])
        user.save()
# views.py
from django.utils import timezone
from django.utils.translation import gettext_lazy
from dj_rest_auth.views import PasswordChangeView
from rest_framework.response import Response
from .serializers import PasswordChangeSerializerCus


class PasswordChangeViewCus(PasswordChangeView):
    serializer_class = PasswordChangeSerializerCus

    def post(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()

        # 额外业务逻辑:记录改密时间
        user = request.user
        user.pwd_change_date = timezone.now() # 比如记录用户更新时间
        user.save()

        return Response({'detail': gettext_lazy('New password has been saved.')})
# urls.py
from django.urls import path
from .views import PasswordChangeViewCus

urlpatterns = [
    path('api/auth/password/change/', PasswordChangeViewCus.as_view()),
]

六、权限(Permission)

6.1 内置权限

权限类

说明

AllowAny

所有人可访问

IsAuthenticated

需要登录

IsAdminUser

需要 is_staff=True

IsAuthenticatedOrReadOnly

登录可写,未登录只读

DjangoModelPermissions

对应 Django 模型权限

DjangoObjectPermissions

对象级权限(需配合 guardian)


6.2 自定义权限

from rest_framework.permissions import BasePermission, SAFE_METHODS

class IsOwnerOrReadOnly(BasePermission):
    """对象所有者可修改,其他人只读"""

    def has_permission(self, request, view):
        # 视图级:登录才能访问
        return request.user and request.user.is_authenticated

    def has_object_permission(self, request, view, obj):
        # 对象级:GET/HEAD/OPTIONS 所有登录用户可以
        if request.method in SAFE_METHODS:
            return True
        # 只有对象所有者才能修改
        return obj.user == request.user


class IsOwnerOrAdmin(BasePermission):
    """对象所有者或管理员"""
    message = '只有对象所有者或管理员可以操作'

    def has_object_permission(self, request, view, obj):
        return request.user.is_staff or obj.user == request.user


class IsVerifiedUser(BasePermission):
    """已验证邮箱的用户"""
    message = '请先验证邮箱'

    def has_permission(self, request, view):
        return (
            request.user and
            request.user.is_authenticated and
            request.user.email_verified
        )


# 权限组合(AND 使用列表,OR 使用 | 运算符)
class UserViewSet(viewsets.ModelViewSet):
    # AND 组合:必须同时满足
    permission_classes = [IsAuthenticated, IsVerifiedUser]

    # OR 组合(DRF 3.9+)
    # from rest_framework.permissions import IsAdminUser, IsOwnerOrReadOnly
    # permission_classes = [IsAdminUser | IsOwnerOrReadOnly]

6.3 对象级权限

class UserViewSet(viewsets.ModelViewSet):
    permission_classes = [IsAuthenticated, IsOwnerOrAdmin]

    def get_object(self):
        obj = super().get_object()
        # 手动触发对象级权限检查
        self.check_object_permissions(self.request, obj)
        return obj

七、分页(Pagination)

7.1 基本配置

# settings.py

REST_FRAMEWORK = {
    'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination',
    'PAGE_SIZE': 100
}

7.2 PageNumberPagination

from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response

class StandardPagination(PageNumberPagination):
    page_size             = 20              # 默认每页数量
    page_size_query_param = 'size'         # ?size=50
    max_page_size         = 100
    page_query_param      = 'page'         # ?page=2

    def get_paginated_response(self, data):
        return Response({
            'code': 0,
            'data': {
                'list':     data,
                'total':    self.page.paginator.count,
                'page':     self.page.number,
                'pages':    self.page.paginator.num_pages,
                'size':     self.get_page_size(self.request),
                'has_next': self.page.has_next(),
            }
        })

    def get_paginated_response_schema(self, schema):
        return {
            'type': 'object',
            'properties': {
                'code':     {'type': 'integer'},
                'data': {
                    'type': 'object',
                    'properties': {
                        'list':  schema,
                        'total': {'type': 'integer'},
                        'page':  {'type': 'integer'},
                        'pages': {'type': 'integer'},
                    }
                }
            }
        }

7.3 LimitOffsetPagination

from rest_framework.pagination import LimitOffsetPagination

class LargeResultPagination(LimitOffsetPagination):
    default_limit = 20    # 默认条数
    max_limit     = 200   # 最大条数

# 请求:GET /api/users/?limit=20&offset=40

7.4 CursorPagination

from rest_framework.pagination import CursorPagination

class CreatedAtCursorPagination(CursorPagination):
    page_size     = 20
    ordering      = '-created_at'    # 必须是唯一且有序的字段
    cursor_query_param = 'cursor'

# 适合:实时Feed、无限滚动。不支持跳转到指定页,但防止数据重复

7.5 在视图中使用分页

class UserViewSet(viewsets.ModelViewSet):
    pagination_class = StandardPagination   # 视图级分页(覆盖全局)

# 某些视图禁用分页
class AllUsersView(generics.ListAPIView):
    pagination_class = None

# 手动分页(在 APIView 中)
class UserListView(APIView):
    def get(self, request):
        users = User.objects.all()
        paginator = StandardPagination()
        page = paginator.paginate_queryset(users, request)
        serializer = UserSerializer(page, many=True)
        return paginator.get_paginated_response(serializer.data)

八、过滤、搜索、排序

8.1 配置

# settings.py
INSTALLED_APPS = [
    ...
    'django_filters',
]

REST_FRAMEWORK = {
    # 过滤
    'DEFAULT_FILTER_BACKENDS': [
        'django_filters.rest_framework.DjangoFilterBackend',
        'rest_framework.filters.SearchFilter',
        'rest_framework.filters.OrderingFilter',
        'your.app.custom.filterclass'
}

8.2 使用

from rest_framework.filters import SearchFilter, OrderingFilter

class UserViewSet(viewsets.ModelViewSet):
    # filter = 精确/结构化筛选(字段匹配)
    # 请求:GET /api/users/?username=xiaoming&email=136@qq.com&search=18&ordering=-created_at,username

    filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] # 每个 backend 只认识自己的参数
    filterset_class = UserFilter # 上述自定的Filter

    # SearchFilter:多字段模糊搜索,SearchFilter 永远只认 ?search= 这一个 URL 参数。不能同时有多个search
    search_fields = [
        'username',             # 精确包含
        'email',
        '=phone',               # = 精确匹配
        '^username',            # ^ 以...开头
        '@bio',                 # @ 全文搜索(需数据库支持)
        'profile__bio',         # 跨关联字段
    ]

    # OrderingFilter:排序
    ordering_fields = ['created_at', 'username', 'age', 'last_login']
    ordering        = ['-created_at']    # 默认排序

8.3 自定义 Filter

import django_filters
from django_filters import rest_framework as filters
from django.db.models import Q

class UserFilter(filters.FilterSet):
    # 精确匹配
    is_active  = filters.BooleanFilter()
    role       = filters.CharFilter()

    # 模糊匹配
    username   = filters.CharFilter(lookup_expr='icontains')
    email      = filters.CharFilter(lookup_expr='icontains')

    # 范围过滤
    min_age    = filters.NumberFilter(field_name='age', lookup_expr='gte')
    max_age    = filters.NumberFilter(field_name='age', lookup_expr='lte')

    # 日期范围
    created_after  = filters.DateFilter(field_name='created_at', lookup_expr='date__gte')
    created_before = filters.DateFilter(field_name='created_at', lookup_expr='date__lte')
    created_range  = filters.DateFromToRangeFilter(field_name='created_at')

    # 多值 IN 过滤
    ids = filters.BaseInFilter(field_name='id', lookup_expr='in')
    # ?ids=1,2,3

    # 自定义过滤方法
    keyword = filters.CharFilter(method='filter_keyword')

    def filter_keyword(self, queryset, name, value):
        return queryset.filter(
            Q(username__icontains=value) |
            Q(email__icontains=value) |
            Q(phone__icontains=value)
        )

    class Meta:
        model  = User
        fields = ['is_active', 'role', 'username', 'email']


class UserViewSet(viewsets.ModelViewSet):
    filterset_class = UserFilter
    # 请求:GET /api/users/?username=alice&min_age=18&is_active=true&created_range_after=2024-01-01

九、限流(Throttling)

DRF 的规则是:所有限流类都会检查,任意一个不通过就返回 429

9.1 基本使用

# 全局配置
REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES': [
        # 这两个:全局配置后,所有视图立即生效,视图层什么都不用写
        'rest_framework.throttling.AnonRateThrottle'
        'rest_framework.throttling.UserRateThrottle'

        # 这个:全局配置只是"注册"了它,还必须在视图上写 throttle_scope 才生效
        'rest_framework.throttling.ScopedRateThrottle'
    ],
    'DEFAULT_THROTTLE_RATES': { # key 只是一个名字
        'anon': '100/day',          # 未登录用户:每天100次
        'user': '1000/day',         # 登录用户:每天1000次
        'cus_name': '10/minute',       # 作用域限流
    }
}

# 视图级限流
from rest_framework.throttling import UserRateThrottle, AnonRateThrottle
from rest_framework.views import APIView

class MyView(APIView):
    throttle_classes = [UserRateThrottle]   # 覆盖全局配置
    # 或者
    throttle_scope = 'cus_name'                # 用名字叫login的

9.2 原理

核心:计数器 + 缓存 本质就是在缓存里记录"某个用户/IP 在某个时间窗口内请求了多少次"。

# 缓存 key 示例
"throttle_user_123"      # 用户ID为123的请求记录
"throttle_anon_1.2.3.4"  # IP为1.2.3.4的请求记录

# 缓存 value:该用户所有请求的时间戳列表
[1709001600.0, 1709001610.0, 1709001620.0, ...]


# 每次请求的处理流程:

新请求进来
    ↓
从缓存取出该用户的时间戳列表
    ↓
清除列表中"时间窗口之外"的过期时间戳
    ↓
剩余数量 >= 限制次数?
    ↓              ↓
   是              否
   ↓               ↓
拒绝请求        追加当前时间戳,存回缓存,放行

9.3 自定义限流

from rest_framework.throttling import (
    AnonRateThrottle, UserRateThrottle,
    ScopedRateThrottle, BaseThrottle
)

# 自定义限流(基于 IP)
class CustomIPThrottle(BaseThrottle):
    def __init__(self):
        self.cache  = {}
        self.rate   = 60    # 每分钟60次
        self.period = 60    # 时间窗口(秒)

    def allow_request(self, request, view):
        ip = self.get_ident(request)
        now = time.time()
        if ip not in self.cache:
            self.cache[ip] = []

        # 清除窗口外的请求记录
        self.cache[ip] = [t for t in self.cache[ip] if now - t < self.period]

        if len(self.cache[ip]) >= self.rate:
            return False

        self.cache[ip].append(now)
        return True

    def wait(self):
        return self.period


# 基于 Redis 的自定义限流(生产推荐)
import redis
from rest_framework.throttling import SimpleRateThrottle

class RedisRateThrottle(SimpleRateThrottle):
    """使用 Redis 实现滑动窗口限流"""
    scope = 'redis_throttle'
    cache = redis.Redis(host='localhost', port=6379, db=2)

    def get_cache_key(self, request, view):
        ident = self.get_ident(request)
        return self.cache_format % {'scope': self.scope, 'ident': ident}

十、解析器与渲染器

10.1 解析器

from rest_framework.parsers import JSONParser, MultiPartParser, FormParser, FileUploadParser

# 全局配置
REST_FRAMEWORK = {
    'DEFAULT_PARSER_CLASSES': [
        'rest_framework.parsers.JSONParser',
        'rest_framework.parsers.MultiPartParser',
        'rest_framework.parsers.FormParser',
    ]
}

# 视图级配置
class FileUploadView(APIView):
    parser_classes = [MultiPartParser, FormParser]

    def post(self, request, format=None):
        file = request.FILES['file']
        # 处理上传...
        return Response({'filename': file.name})

10.2 渲染器

from rest_framework.renderers import JSONRenderer, BrowsableAPIRenderer

# 全局
REST_FRAMEWORK = {
    'DEFAULT_RENDERER_CLASSES': [
        'rest_framework.renderers.JSONRenderer',
    ]
}

# 自定义 JSON 渲染器(统一响应格式)
class StandardJSONRenderer(JSONRenderer):
    def render(self, data, accepted_media_type=None, renderer_context=None):
        response = renderer_context.get('response')
        if response and response.status_code >= 400:
            # 错误响应不包装
            return super().render(data, accepted_media_type, renderer_context)
        wrapped = {
            'code': 0,
            'msg': 'ok',
            'data': data
        }
        return super().render(wrapped, accepted_media_type, renderer_context)

十一、异常处理

# apps/common/exceptions.py
from rest_framework.views import exception_handler
from rest_framework.exceptions import (
    APIException, ValidationError, NotAuthenticated,
    PermissionDenied, NotFound, MethodNotAllowed, Throttled
)
from rest_framework.response import Response

def custom_exception_handler(exc, context):
    response = exception_handler(exc, context)

    if response is None:
        # DRF 未处理的异常(如 Django 500)
        import logging
        logger = logging.getLogger(__name__)
        logger.exception(f'Unhandled exception: {exc}')
        return Response({
            'code': 500,
            'msg': '服务器内部错误'
        }, status=500)

    # 统一格式化错误响应
    if isinstance(exc, ValidationError):
        response.data = {
            'code': 422,
            'msg': '参数验证失败',
            'errors': format_validation_errors(exc.detail)
        }
    elif isinstance(exc, NotAuthenticated):
        response.data = {'code': 401, 'msg': '请先登录'}
    elif isinstance(exc, PermissionDenied):
        response.data = {'code': 403, 'msg': str(exc.detail)}
    elif isinstance(exc, NotFound):
        response.data = {'code': 404, 'msg': '资源不存在'}
    elif isinstance(exc, MethodNotAllowed):
        response.data = {'code': 405, 'msg': '请求方法不允许'}
    elif isinstance(exc, Throttled):
        response.data = {'code': 429, 'msg': f'请求过于频繁,请 {exc.wait:.0f} 秒后重试'}
    else:
        response.data = {
            'code': response.status_code,
            'msg': str(exc.detail) if hasattr(exc, 'detail') else '请求失败'
        }

    return response


def format_validation_errors(detail):
    """将嵌套的验证错误格式化为扁平结构"""
    errors = {}
    if isinstance(detail, dict):
        for field, messages in detail.items():
            if isinstance(messages, list):
                errors[field] = [str(m) for m in messages]
            else:
                errors[field] = format_validation_errors(messages)
    elif isinstance(detail, list):
        return [str(m) for m in detail]
    return errors


# 自定义业务异常
class BusinessException(APIException):
    status_code = 400

    def __init__(self, message: str, code: int = 400):
        self.status_code = code
        super().__init__(detail=message)

# 使用
raise BusinessException('用户名已存在', code=409)
raise BusinessException('余额不足', code=400)

十二、Versioning 版本控制

# settings.py
REST_FRAMEWORK = {
    'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning',
    'ALLOWED_VERSIONS': ['v1', 'v2'],
    'DEFAULT_VERSION': 'v1',
    'VERSION_PARAM': 'version',
}

# URLPathVersioning 路由
# /api/v1/users/
# /api/v2/users/
urlpatterns = [
    path('api/<str:version>/', include('apps.users.urls')),
]

# NamespaceVersioning
# /api/users/?version=v2

# AcceptHeaderVersioning(请求头)
# Accept: application/json; version=v1

# HostNameVersioning(子域名)
# v1.api.yourdomain.com

# 视图中根据版本切换逻辑
class UserViewSet(viewsets.ModelViewSet):
    def get_serializer_class(self):
        if self.request.version == 'v2':
            return UserV2Serializer
        return UserSerializer

    def list(self, request, *args, **kwargs):
        if request.version == 'v2':
            # v2 新逻辑
            pass
        return super().list(request, *args, **kwargs)

十三、测试

13.1 接口测试

from django.test import TestCase
from django.contrib.auth import get_user_model
from rest_framework.test import APITestCase, APIClient
from rest_framework import status

User = get_user_model()

class AuthAPITest(APITestCase):
    def setUp(self):
        self.user = User.objects.create_user(
            username='testuser',
            email='test@example.com',
            password='testpass123'
        )

    def get_token(self):
        """获取 JWT Token"""
        response = self.client.post('/api/auth/login/', {
            'email': 'test@example.com',
            'password': 'testpass123'
        }, format='json')
        return response.data['access']

    def test_login_success(self):
        response = self.client.post('/api/auth/login/', {
            'email': 'test@example.com',
            'password': 'testpass123'
        }, format='json')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertIn('access', response.data)
        self.assertIn('refresh', response.data)

    def test_login_wrong_password(self):
        response = self.client.post('/api/auth/login/', {
            'email': 'test@example.com',
            'password': 'wrongpass'
        }, format='json')
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)


class UserAPITest(APITestCase):
    def setUp(self):
        self.user = User.objects.create_user(
            username='testuser', email='test@example.com', password='testpass123'
        )
        self.admin = User.objects.create_superuser(
            username='admin', email='admin@example.com', password='adminpass123'
        )
        self.client = APIClient()

    def authenticate(self, user=None):
        """设置认证"""
        self.client.force_authenticate(user=user or self.user)

    def test_list_users_unauthenticated(self):
        response = self.client.get('/api/users/')
        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

    def test_list_users_authenticated(self):
        self.authenticate()
        response = self.client.get('/api/users/')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertIn('data', response.data)

    def test_create_user(self):
        response = self.client.post('/api/users/', {
            'username': 'newuser',
            'email': 'new@example.com',
            'password': 'newpass123',
            'confirm_password': 'newpass123'
        }, format='json')
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        self.assertTrue(User.objects.filter(username='newuser').exists())

    def test_create_user_duplicate_email(self):
        response = self.client.post('/api/users/', {
            'username': 'another',
            'email': 'test@example.com',     # 已存在
            'password': 'pass123',
            'confirm_password': 'pass123'
        }, format='json')
        self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY)

    def test_update_user_by_owner(self):
        self.authenticate()
        response = self.client.patch(
            f'/api/users/{self.user.id}/',
            {'username': 'updated_name'},
            format='json'
        )
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.user.refresh_from_db()
        self.assertEqual(self.user.username, 'updated_name')

    def test_update_user_by_other(self):
        other = User.objects.create_user(username='other', email='other@a.com', password='pass')
        self.authenticate(user=other)
        response = self.client.patch(
            f'/api/users/{self.user.id}/',
            {'username': 'hacked'},
            format='json'
        )
        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

    def test_custom_action_me(self):
        self.authenticate()
        response = self.client.get('/api/users/me/')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(response.data['username'], self.user.username)


class PaginationTest(APITestCase):
    def setUp(self):
        self.user = User.objects.create_superuser('admin', 'admin@test.com', 'admin123')
        self.client.force_authenticate(user=self.user)
        # 创建25个用户
        for i in range(25):
            User.objects.create_user(f'user{i}', f'user{i}@test.com', 'pass')

    def test_pagination(self):
        response = self.client.get('/api/users/?page=1&size=10')
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        data = response.data['data']
        self.assertEqual(len(data['list']), 10)
        self.assertGreater(data['total'], 20)
        self.assertTrue(data['has_next'])

13.2 Serializer 测试

class UserSerializerTest(TestCase):
    def test_valid_data(self):
        data = {
            'username': 'alice',
            'email': 'alice@example.com',
            'password': 'pass1234',
            'confirm_password': 'pass1234'
        }
        serializer = UserCreateSerializer(data=data)
        self.assertTrue(serializer.is_valid())

    def test_password_mismatch(self):
        data = {
            'username': 'alice',
            'email': 'alice@example.com',
            'password': 'pass1234',
            'confirm_password': 'different'
        }
        serializer = UserCreateSerializer(data=data)
        self.assertFalse(serializer.is_valid())
        self.assertIn('non_field_errors', serializer.errors)

    def test_invalid_email(self):
        data = {'username': 'alice', 'email': 'not-an-email', 'password': 'pass1234'}
        serializer = UserCreateSerializer(data=data)
        self.assertFalse(serializer.is_valid())
        self.assertIn('email', serializer.errors)

十四、最佳实践与常见模式

14.1 统一响应格式

# utils/response.py
from rest_framework.response import Response

class APIResponse(Response):
    def __init__(self, data=None, code=0, msg='ok', status=None, **kwargs):
        wrapped = {'code': code, 'msg': msg, 'data': data}
        super().__init__(data=wrapped, status=status, **kwargs)

# 使用
return APIResponse(data=serializer.data, status=201)
return APIResponse(data=None, code=404, msg='用户不存在', status=404)

14.2 Mixin 封装通用逻辑

class SoftDeleteMixin:
    """软删除 Mixin"""
    def perform_destroy(self, instance):
        instance.is_deleted = True
        instance.deleted_at = timezone.now()
        instance.save(update_fields=['is_deleted', 'deleted_at'])

class AuditMixin:
    """审计 Mixin(自动记录创建/更新人)"""
    def perform_create(self, serializer):
        serializer.save(created_by=self.request.user)

    def perform_update(self, serializer):
        serializer.save(updated_by=self.request.user)

class UserViewSet(AuditMixin, SoftDeleteMixin, viewsets.ModelViewSet):
    queryset = User.objects.filter(is_deleted=False)
    serializer_class = UserSerializer

14.3 ViewSet 动态 Serializer / QuerySet / Permission

class FlexibleViewSet(viewsets.ModelViewSet):
    """集中管理不同 action 的 Serializer、QuerySet、Permission"""

    serializer_action_map = {
        'list':           UserListSerializer,
        'create':         UserCreateSerializer,
        'retrieve':       UserDetailSerializer,
        'update':         UserUpdateSerializer,
        'partial_update': UserUpdateSerializer,
    }

    queryset_action_map = {
        'list': User.objects.filter(is_active=True).select_related('profile'),
        'default': User.objects.all(),
    }

    permission_action_map = {
        'create':  [AllowAny],
        'destroy': [IsAdminUser],
        'default': [IsAuthenticated],
    }

    def get_serializer_class(self):
        return self.serializer_action_map.get(self.action, UserSerializer)

    def get_queryset(self):
        return self.queryset_action_map.get(self.action, self.queryset_action_map['default'])

    def get_permissions(self):
        classes = self.permission_action_map.get(
            self.action,
            self.permission_action_map['default']
        )
        return [c() for c in classes]

14.4 OpenAPI 文档(drf-spectacular)

# pip install drf-spectacular

# settings.py
SPECTACULAR_SETTINGS = {
    'TITLE': 'My API',
    'DESCRIPTION': 'API 文档',
    'VERSION': '1.0.0',
    'SERVE_INCLUDE_SCHEMA': False,
}

# urls.py
from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView, SpectacularRedocView

urlpatterns += [
    path('api/schema/',    SpectacularAPIView.as_view(),         name='schema'),
    path('api/docs/',      SpectacularSwaggerView.as_view(),     name='swagger-ui'),
    path('api/redoc/',     SpectacularRedocView.as_view(),       name='redoc'),
]

# 给视图加文档注解
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample

@extend_schema(
    summary='获取用户列表',
    description='支持分页、过滤、排序',
    parameters=[
        OpenApiParameter('page', int, description='页码'),
        OpenApiParameter('search', str, description='搜索关键词'),
    ],
    responses={200: UserListSerializer(many=True)},
    tags=['用户管理']
)
class UserListView(generics.ListAPIView):
    ...

14.5 常见问题速查

# 1. 获取当前登录用户
request.user

# 2. 传递 context 给 Serializer
serializer = UserSerializer(user, context={'request': request})

# 3. 视图中手动触发对象级权限检查
obj = self.get_object()   # get_object() 内部自动调用 check_object_permissions

# 4. 关闭某个视图的分页
class AllDataView(generics.ListAPIView):
    pagination_class = None

# 5. 获取过滤后的 queryset(在 APIView 中)
from rest_framework.filters import SearchFilter
backend = SearchFilter()
queryset = backend.filter_queryset(request, User.objects.all(), self)

# 6. 手动验证 Serializer 并获取错误
serializer = UserSerializer(data=request.data)
if not serializer.is_valid():
    return Response({'errors': serializer.errors}, status=400)

# 7. 序列化时排除某字段
class UserSerializer(serializers.ModelSerializer):
    class Meta:
        model = User
        exclude = ['password', 'is_superuser']

# 8. 返回 201 并附带 Location 头
response = Response(serializer.data, status=status.HTTP_201_CREATED)
response['Location'] = request.build_absolute_uri(f'/api/users/{instance.id}/')
return response

# 9. 在 serializer 中访问请求对象
user = self.context['request'].user

# 10. 批量创建并返回
serializer = UserSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
User.objects.bulk_create([User(**item) for item in serializer.validated_data])
return Response(serializer.data, status=201)

常用扩展汇总

功能

djangorestframework

DRF 核心

djangorestframework-simplejwt

JWT 认证

django-filter

高级过滤

drf-spectacular

OpenAPI 3.0 文档生成

drf-nested-routers

嵌套路由

django-guardian

对象级权限

djangorestframework-camel-case

camelCase 字段命名

drf-extensions

缓存、etag 等扩展


参考资源