Django REST Framework(DRF)¶
深入掌握 DRF 核心组件:Serializer、视图、路由、认证、权限、分页、过滤、限流、测试等。以 DRF 3.15.x + Django 5.x 为基准。
目录¶
一、安装与基础配置¶
pip install djangorestframework
pip install djangorestframework-simplejwt
pip install django-filter
pip install drf-spectacular # OpenAPI 文档(推荐替代 drf-yasg)
# settings.py
INSTALLED_APPS = [
...
'rest_framework',
'rest_framework_simplejwt',
'django_filters',
'drf_spectacular',
]
REST_FRAMEWORK = {
# 认证
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework_simplejwt.authentication.JWTAuthentication',
'rest_framework.authentication.SessionAuthentication', # Browsable API 用
],
# 权限(全局默认,视图可覆盖)
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
],
# 分页
'DEFAULT_PAGINATION_CLASS': 'apps.common.pagination.StandardPagination',
'PAGE_SIZE': 20,
# 过滤
'DEFAULT_FILTER_BACKENDS': [
'django_filters.rest_framework.DjangoFilterBackend',
'rest_framework.filters.SearchFilter',
'rest_framework.filters.OrderingFilter',
],
# 限流
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle',
],
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day',
'user': '1000/day',
'login': '10/minute',
},
# 异常处理
'EXCEPTION_HANDLER': 'apps.common.exceptions.custom_exception_handler',
# 渲染器
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
# 生产环境去掉 BrowsableAPIRenderer
'rest_framework.renderers.BrowsableAPIRenderer',
],
# Schema(OpenAPI 文档)
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
}
二、Serializer 序列化器¶
2.1 Serializer 基础¶
from rest_framework import serializers
class UserSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
username = serializers.CharField(max_length=50)
email = serializers.EmailField()
password = serializers.CharField(write_only=True, min_length=6)
age = serializers.IntegerField(required=False, allow_null=True)
# 单字段验证(validate_<fieldname>)
def validate_username(self, value):
from apps.users.models import User
if User.objects.filter(username=value).exists():
raise serializers.ValidationError('用户名已存在')
return value.strip()
# 跨字段验证
def validate(self, data):
if data.get('age') and data['age'] < 0:
raise serializers.ValidationError({'age': '年龄不能为负数'})
return data
# 创建
def create(self, validated_data):
from apps.users.models import User
return User.objects.create_user(**validated_data)
# 更新
def update(self, instance, validated_data):
for key, value in validated_data.items():
setattr(instance, key, value)
instance.save()
return instance
# 使用
serializer = UserSerializer(data=request.data)
serializer.is_valid(raise_exception=True) # 验证失败直接返回 400
user = serializer.save() # 调用 create()
# 序列化对象
serializer = UserSerializer(user)
serializer.data # dict
# 序列化列表
serializer = UserSerializer(users, many=True)
serializer.data # list of dict
# 部分更新(PATCH)
serializer = UserSerializer(user, data=request.data, partial=True)
2.2 ModelSerializer¶
from rest_framework import serializers
from apps.users.models import User, Profile
class UserCreateSerializer(serializers.ModelSerializer):
confirm_password = serializers.CharField(write_only=True)
class Meta:
model = User
fields = ['id', 'username', 'email', 'password', 'confirm_password', 'age']
read_only_fields = ['id']
extra_kwargs = {
'password': {'write_only': True, 'min_length': 6},
'email': {'validators': []}, # 清除默认唯一性验证,自己处理
}
def validate(self, data):
if data['password'] != data.pop('confirm_password'):
raise serializers.ValidationError({'confirm_password': '两次密码不一致'})
return data
def create(self, validated_data):
return User.objects.create_user(**validated_data)
class UserListSerializer(serializers.ModelSerializer):
"""列表用:精简字段"""
class Meta:
model = User
fields = ['id', 'username', 'email', 'is_active', 'created_at']
class UserDetailSerializer(serializers.ModelSerializer):
"""详情用:完整字段"""
profile = ProfileSerializer(read_only=True) # 嵌套
class Meta:
model = User
fields = ['id', 'username', 'email', 'age', 'is_active',
'profile', 'date_joined', 'last_login']
read_only_fields = ['id', 'date_joined', 'last_login']
2.3 字段类型与参数¶
常用字段类型
字段 |
说明 |
|---|---|
|
字符串 |
|
整数 |
|
浮点数 |
|
精确小数 |
|
布尔 |
|
日期 |
|
日期时间 |
|
邮箱 |
|
URL |
|
UUID |
|
JSON |
|
枚举选项 |
|
列表 |
|
字典 |
|
自定义只读字段 |
|
隐藏字段(不接受输入) |
|
关联主键 |
|
关联指定字段 |
|
关联 |
字段通用参数
serializers.CharField(
required=True, # 是否必填(默认True)
allow_null=False, # 是否允许 null
allow_blank=False, # 是否允许空字符串
default=None, # 默认值
read_only=False, # 只用于序列化,不接受输入
write_only=False, # 只用于反序列化,不出现在响应
source='model_field', # 映射到模型的哪个字段
label='字段标签', # 显示名称
help_text='说明', # 帮助文本
validators=[], # 验证器列表
error_messages={}, # 自定义错误信息
)
2.4 验证器¶
from rest_framework.validators import UniqueValidator, UniqueTogetherValidator
class UserSerializer(serializers.ModelSerializer):
email = serializers.EmailField(
validators=[
UniqueValidator(
queryset=User.objects.all(),
message='邮箱已注册'
)
]
)
class Meta:
model = User
fields = ['username', 'email']
validators = [
UniqueTogetherValidator(
queryset=User.objects.all(),
fields=['username', 'email'],
message='用户名和邮箱组合已存在'
)
]
# 自定义验证器函数
def validate_phone(value):
import re
if not re.match(r'^1[3-9]\d{9}$', value):
raise serializers.ValidationError('手机号格式不正确')
return value
class ProfileSerializer(serializers.Serializer):
phone = serializers.CharField(validators=[validate_phone])
# 自定义验证器类
class PastDateValidator:
def __call__(self, value):
from django.utils import timezone
if value >= timezone.now().date():
raise serializers.ValidationError('必须是过去的日期')
class EventSerializer(serializers.Serializer):
end_date = serializers.DateField(validators=[PastDateValidator()])
2.5 嵌套序列化¶
# 只读嵌套(展示关联对象)
class OrderItemSerializer(serializers.ModelSerializer):
class Meta:
model = OrderItem
fields = ['id', 'product_name', 'quantity', 'price']
class OrderDetailSerializer(serializers.ModelSerializer):
user = UserListSerializer(read_only=True) # 一对一/多对一
items = OrderItemSerializer(many=True, read_only=True) # 一对多
class Meta:
model = Order
fields = ['id', 'order_no', 'status', 'amount', 'user', 'items', 'created_at']
# 可写嵌套(同时创建关联对象)
class OrderCreateSerializer(serializers.ModelSerializer):
items = OrderItemSerializer(many=True)
class Meta:
model = Order
fields = ['order_no', 'items']
def create(self, validated_data):
items_data = validated_data.pop('items')
order = Order.objects.create(**validated_data)
for item_data in items_data:
OrderItem.objects.create(order=order, **item_data)
return order
def update(self, instance, validated_data):
items_data = validated_data.pop('items', None)
instance = super().update(instance, validated_data)
if items_data is not None:
instance.items.all().delete() # 先删再创建(简单策略)
for item_data in items_data:
OrderItem.objects.create(order=instance, **item_data)
return instance
2.6 SerializerMethodField 与动态字段¶
class UserSerializer(serializers.ModelSerializer):
# 自定义只读字段
full_name = serializers.SerializerMethodField()
order_count = serializers.SerializerMethodField()
is_owner = serializers.SerializerMethodField()
def get_full_name(self, obj):
return f'{obj.first_name} {obj.last_name}'.strip()
def get_order_count(self, obj):
return obj.orders.count()
def get_is_owner(self, obj):
request = self.context.get('request')
if request and request.user.is_authenticated:
return obj.id == request.user.id
return False
class Meta:
model = User
fields = ['id', 'username', 'full_name', 'order_count', 'is_owner']
# 动态字段(根据请求参数决定返回哪些字段)
class DynamicFieldsSerializer(serializers.ModelSerializer):
def __init__(self, *args, **kwargs):
fields = kwargs.pop('fields', None)
exclude = kwargs.pop('exclude', None)
super().__init__(*args, **kwargs)
if fields:
allowed = set(fields)
existing = set(self.fields)
for field_name in existing - allowed:
self.fields.pop(field_name)
if exclude:
for field_name in exclude:
self.fields.pop(field_name, None)
class UserSerializer(DynamicFieldsSerializer):
class Meta:
model = User
fields = ['id', 'username', 'email', 'age', 'is_active']
# 使用
UserSerializer(user, fields=['id', 'username'])
UserSerializer(users, many=True, exclude=['age'])
2.7 序列化器上下文¶
# 视图中传递 context
class UserViewSet(viewsets.ModelViewSet):
def get_serializer_context(self):
ctx = super().get_serializer_context()
ctx['show_sensitive'] = self.request.user.is_staff
return ctx
# 序列化器中使用
class UserSerializer(serializers.ModelSerializer):
phone = serializers.SerializerMethodField()
def get_phone(self, obj):
if self.context.get('show_sensitive'):
return obj.phone
return '***' # 脱敏
三、视图¶
3.1 APIView¶
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from django.shortcuts import get_object_or_404
class UserListView(APIView):
"""手动实现每个 HTTP 方法,最灵活"""
permission_classes = [IsAuthenticated] # 覆盖全局设置
def get(self, request):
users = User.objects.filter(is_active=True)
serializer = UserListSerializer(users, many=True, context={'request': request})
return Response({'code': 0, 'data': serializer.data})
def post(self, request):
serializer = UserCreateSerializer(data=request.data, context={'request': request})
serializer.is_valid(raise_exception=True)
user = serializer.save()
return Response(
UserDetailSerializer(user).data,
status=status.HTTP_201_CREATED
)
class UserDetailView(APIView):
def get_object(self, pk):
return get_object_or_404(User, pk=pk, is_active=True)
def get(self, request, pk):
user = self.get_object(pk)
serializer = UserDetailSerializer(user, context={'request': request})
return Response({'code': 0, 'data': serializer.data})
def put(self, request, pk):
user = self.get_object(pk)
serializer = UserDetailSerializer(user, data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response({'code': 0, 'data': serializer.data})
def patch(self, request, pk):
user = self.get_object(pk)
serializer = UserDetailSerializer(user, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response({'code': 0, 'data': serializer.data})
def delete(self, request, pk):
user = self.get_object(pk)
user.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
3.2 GenericAPIView + Mixin¶
from rest_framework.generics import GenericAPIView
from rest_framework.mixins import (
ListModelMixin, CreateModelMixin,
RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin
)
class UserListView(ListModelMixin, CreateModelMixin, GenericAPIView):
queryset = User.objects.filter(is_active=True).order_by('-date_joined')
serializer_class = UserSerializer
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs) # ListModelMixin 提供
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs) # CreateModelMixin 提供
class UserDetailView(RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin, GenericAPIView):
queryset = User.objects.all()
serializer_class = UserSerializer
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
3.3 通用视图(最简洁)¶
from rest_framework import generics
# 列表 + 创建
class UserListCreateView(generics.ListCreateAPIView):
queryset = User.objects.all()
serializer_class = UserSerializer
# 详情 + 更新 + 删除
class UserRetrieveUpdateDestroyView(generics.RetrieveUpdateDestroyAPIView):
queryset = User.objects.all()
serializer_class = UserSerializer
# 其他通用视图
# generics.ListAPIView 只读列表
# generics.CreateAPIView 只创建
# generics.RetrieveAPIView 只读详情
# generics.UpdateAPIView 只更新
# generics.DestroyAPIView 只删除
# generics.RetrieveUpdateAPIView 详情 + 更新
# generics.RetrieveDestroyAPIView 详情 + 删除
# 覆盖通用视图的行为
class UserListCreateView(generics.ListCreateAPIView):
serializer_class = UserSerializer
def get_queryset(self):
"""动态 queryset"""
qs = User.objects.filter(is_active=True)
role = self.request.query_params.get('role')
if role:
qs = qs.filter(role=role)
return qs
def get_serializer_class(self):
"""不同操作用不同 Serializer"""
if self.request.method == 'POST':
return UserCreateSerializer
return UserListSerializer
def perform_create(self, serializer):
"""创建时注入额外数据"""
serializer.save(
created_by=self.request.user,
ip_address=self.request.META.get('REMOTE_ADDR')
)
def list(self, request, *args, **kwargs):
"""自定义列表响应格式"""
response = super().list(request, *args, **kwargs)
return Response({'code': 0, 'data': response.data})
3.4 ViewSet¶
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
class UserViewSet(viewsets.ViewSet):
"""手动实现,最灵活的 ViewSet"""
def list(self, request):
users = User.objects.all()
serializer = UserSerializer(users, many=True)
return Response(serializer.data)
def create(self, request):
serializer = UserCreateSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=201)
def retrieve(self, request, pk=None):
user = get_object_or_404(User, pk=pk)
serializer = UserSerializer(user)
return Response(serializer.data)
def update(self, request, pk=None):
user = get_object_or_404(User, pk=pk)
serializer = UserSerializer(user, data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
def partial_update(self, request, pk=None):
user = get_object_or_404(User, pk=pk)
serializer = UserSerializer(user, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
def destroy(self, request, pk=None):
user = get_object_or_404(User, pk=pk)
user.delete()
return Response(status=204)
3.5 ModelViewSet¶
class UserViewSet(viewsets.ModelViewSet):
"""自动实现 CRUD,最省代码"""
queryset = User.objects.all().order_by('-date_joined')
serializer_class = UserSerializer
filterset_class = UserFilter
search_fields = ['username', 'email', 'phone']
ordering_fields = ['date_joined', 'username', 'age']
ordering = ['-date_joined']
def get_queryset(self):
qs = super().get_queryset()
# 非管理员只能看到自己
if not self.request.user.is_staff:
qs = qs.filter(id=self.request.user.id)
return qs
def get_serializer_class(self):
serializer_map = {
'list': UserListSerializer,
'create': UserCreateSerializer,
'retrieve': UserDetailSerializer,
'update': UserUpdateSerializer,
'partial_update': UserUpdateSerializer,
}
return serializer_map.get(self.action, UserSerializer)
def get_permissions(self):
"""动态权限"""
if self.action in ['list', 'retrieve']:
return [IsAuthenticated()]
elif self.action == 'create':
return [AllowAny()]
return [IsAuthenticated(), IsOwnerOrAdmin()]
def perform_create(self, serializer):
serializer.save(created_by=self.request.user)
def perform_update(self, serializer):
serializer.save(updated_by=self.request.user)
def perform_destroy(self, instance):
# 软删除
instance.is_active = False
instance.save(update_fields=['is_active'])
3.6 自定义 Action¶
class UserViewSet(viewsets.ModelViewSet):
queryset = User.objects.all()
serializer_class = UserSerializer
# detail=True → /users/{id}/change-password/
@action(detail=True, methods=['post'], url_path='change-password',
permission_classes=[IsAuthenticated])
def change_password(self, request, pk=None):
user = self.get_object()
serializer = ChangePasswordSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
if not user.check_password(serializer.validated_data['old_password']):
return Response({'msg': '原密码错误'}, status=400)
user.set_password(serializer.validated_data['new_password'])
user.save()
return Response({'msg': '密码修改成功'})
# detail=False → /users/me/
@action(detail=False, methods=['get', 'patch'], url_path='me',
permission_classes=[IsAuthenticated])
def me(self, request):
if request.method == 'GET':
serializer = UserDetailSerializer(request.user, context={'request': request})
return Response(serializer.data)
serializer = UserUpdateSerializer(request.user, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
# detail=False → /users/batch-delete/
@action(detail=False, methods=['delete'], url_path='batch-delete',
permission_classes=[IsAdminUser])
def batch_delete(self, request):
ids = request.data.get('ids', [])
if not ids:
return Response({'msg': 'ids 不能为空'}, status=400)
count, _ = User.objects.filter(id__in=ids).delete()
return Response({'msg': f'已删除 {count} 个用户'})
# 自定义 URL 和 HTTP 方法
@action(detail=True, methods=['post'], url_path='avatar',
url_name='upload-avatar', parser_classes=[MultiPartParser])
def upload_avatar(self, request, pk=None):
user = self.get_object()
serializer = AvatarSerializer(user, data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response({'avatar_url': user.avatar.url})
四、路由(Router)¶
from rest_framework.routers import DefaultRouter, SimpleRouter
router = DefaultRouter() # 包含 API 根页面
# router = SimpleRouter() # 不包含 API 根页面
router.register('users', UserViewSet, basename='user')
router.register('orders', OrderViewSet, basename='order')
router.register('products', ProductViewSet, basename='product')
# urls.py
from django.urls import path, include
urlpatterns = [
path('api/', include(router.urls)),
# 也可以包含非 ViewSet 的路由
path('api/auth/login/', TokenObtainPairView.as_view(), name='login'),
path('api/auth/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
]
# DefaultRouter 自动生成的路由
# GET /api/users/ list
# POST /api/users/ create
# GET /api/users/{id}/ retrieve
# PUT /api/users/{id}/ update
# PATCH /api/users/{id}/ partial_update
# DELETE /api/users/{id}/ destroy
# POST /api/users/{id}/change-password/ 自定义 detail action
# GET /api/users/me/ 自定义 list action
# DELETE /api/users/batch-delete/ 自定义 list action
# 嵌套路由(pip install drf-nested-routers)
from rest_framework_nested import routers as nested_routers
router = nested_routers.DefaultRouter()
router.register('users', UserViewSet, basename='user')
users_router = nested_routers.NestedDefaultRouter(router, 'users', lookup='user')
users_router.register('orders', UserOrderViewSet, basename='user-order')
# GET /api/users/{user_pk}/orders/
# GET /api/users/{user_pk}/orders/{id}/
urlpatterns += router.urls + users_router.urls
五、认证(Authentication)¶
在 DRF 处理请求时,request.user 是第一次被访问时才触发赋值(懒加载)。从DEFAULT_AUTHENTICATION_CLASSES配置的认证类中尝试认证
5.1 内置认证方式¶
这种方式生成的token是永不过期的,
# TokenAuthentication 使用
INSTALLED_APPS += [..., 'rest_framework.authtoken']
# 会话认证(默认,Browsable API 使用)
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [ # DRF 支持同时配置多种认证方式,请求进来时会按顺序逐个尝试
'rest_framework.authentication.TokenAuthentication' # TokenAuthentication(DRF 内置 Token)
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication',
]
}
# 创建 token 表
python manage.py migrate
# 创建接口
from rest_framework.authtoken.models import Token
token, created = Token.objects.get_or_create(user=user)
# 请求头:Authorization: Token 9944b09199c62bcf9418ad846dd0e4bbdfc6ee4f
5.2 DRF自定义认证(了解即可)¶
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed
class APIKeyAuthentication(BaseAuthentication):
"""API Key 认证"""
def authenticate(self, request):
api_key = request.META.get('HTTP_X_API_KEY')
if not api_key:
return None # 返回 None 继续尝试其他认证方式
try:
key_obj = APIKey.objects.get(key=api_key, is_active=True)
return (key_obj.user, key_obj) # (user, auth)
except APIKey.DoesNotExist:
raise AuthenticationFailed('无效的 API Key')
def authenticate_header(self, request):
return 'X-API-Key' # WWW-Authenticate 响应头
5.3 JWT 认证(Simple JWT)¶
pip install djangorestframework-simplejwt
# settings.py
from datetime import timedelta
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework_simplejwt.authentication.JWTAuthentication',
]
}
SIMPLE_JWT = {
'ACCESS_TOKEN_LIFETIME': timedelta(hours=2),
'REFRESH_TOKEN_LIFETIME': timedelta(days=30),
'ROTATE_REFRESH_TOKENS': True, # 刷新时轮换 refresh token
'BLACKLIST_AFTER_ROTATION': True, # 旧 refresh token 加入黑名单
'ALGORITHM': 'HS256',
'SIGNING_KEY': SECRET_KEY,
'AUTH_HEADER_TYPES': ('Bearer',),
'AUTH_TOKEN_CLASSES': ('rest_framework_simplejwt.tokens.AccessToken',),
# 自定义 Payload
'TOKEN_OBTAIN_SERIALIZER': 'apps.users.serializers.CustomTokenObtainPairSerializer',
}
# urls.py
from rest_framework_simplejwt.views import (
TokenObtainPairView,
TokenRefreshView,
TokenVerifyView,
TokenBlacklistView,
)
urlpatterns = [
path('api/auth/login/', TokenObtainPairView.as_view(), name='token_obtain'),
path('api/auth/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
path('api/auth/verify/', TokenVerifyView.as_view(), name='token_verify'),
path('api/auth/logout/', TokenBlacklistView.as_view(), name='token_blacklist'),
]
# 自定义 Token Payload(添加额外信息)
# views.py
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.views import TokenObtainPairView
class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
@classmethod
def get_token(cls, user): # 控制的是生成的token
token = super().get_token(user)
# 添加自定义字段到 payload
token['username'] = user.username
token['email'] = user.email
token['role'] = user.role
return token
def validate(self, attrs): # 控制的是返回的json
data = super().validate(attrs)
# 响应中附加用户信息
data['user'] = UserListSerializer(self.user).data
return data
class CustomTokenObtainPairView(TokenObtainPairView):
serializer_class = CustomTokenObtainPairSerializer
# urls.py
urlpatterns = [
path('auth/login/', CustomTokenObtainPairView.as_view()), # 换成自定义的
]
# 结果payload 变成:
{
"token_type": "access",
"exp": 1234567890,
"user_id": 1,
"username": "john",
"email": "john@example.com",
"role": "admin"
}
# 结果返回给前端变成:
{
"access": "eyJ...",
"refresh": "eyJ...",
"user": {
"id": 1,
"username": "john",
"email": "john@example.com",
"role": "admin"
}
}
5.4 dj-rest-auth状态管理¶
pip install dj-rest-auth
# settings.py
INSTALLED_APPS = [
...
'rest_framework',
'rest_framework.authtoken',
'dj_rest_auth',
]
# urls.py
from django.urls import path, include
urlpatterns = [
path('api/auth/', include('dj_rest_auth.urls')),
]
会自动包含以下接口:
| URL | 方法 | 说明 |
| --------------------------- | ---- | -------- |
| `api/auth/login/` | POST | 登录 |
| `api/auth/logout/` | POST | 登出 |
| `api/auth/password/change/` | POST | 改密码 |
| `api/auth/password/reset/` | POST | 重置密码 |
改密码请求体:
{
"old_password": "oldPass123!",
"new_password1": "newPass456!",
"new_password2": "newPass456!"
}
5.5 自定义视图函数修改密码¶
当有额外业务逻辑时(如记录改密时间、发送通知、强制重新登录),需要自定义视图。
dj_rest_auth 默认不一定触发 AUTH_PASSWORD_VALIDATORS,建议手动调用
Django 修改密码时默认不对旧密码进行验证,除非在settings.py 中设置 OLD_PASSWORD_FIELD_ENABLED = True
# serializers.py
from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError as DjangoValidationError
from rest_framework import serializers
class PasswordChangeSerializerCus(serializers.Serializer):
old_password = serializers.CharField(write_only=True)
new_password1 = serializers.CharField(write_only=True)
new_password2 = serializers.CharField(write_only=True)
def validate_old_password(self, value):
user = self.context['request'].user
if not user.check_password(value):
raise serializers.ValidationError('原密码错误')
return value
def validate(self, data):
if data['new_password1'] != data['new_password2']:
raise serializers.ValidationError({'new_password2': '两次密码不一致'})
try:
# 手动触发 Django 的settings.py 的 AUTH_PASSWORD_VALIDATORS 中所有验证器
# 如果没有这步,相当于没有对密码强度进行校验
validate_password(data['new_password1'], user=self.context['request'].user)
except DjangoValidationError as e:
raise serializers.ValidationError({'new_password1': e.messages})
return data
def save(self):
user = self.context['request'].user
user.set_password(self.validated_data['new_password1'])
user.save()
# views.py
from django.utils import timezone
from django.utils.translation import gettext_lazy
from dj_rest_auth.views import PasswordChangeView
from rest_framework.response import Response
from .serializers import PasswordChangeSerializerCus
class PasswordChangeViewCus(PasswordChangeView):
serializer_class = PasswordChangeSerializerCus
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
# 额外业务逻辑:记录改密时间
user = request.user
user.pwd_change_date = timezone.now() # 比如记录用户更新时间
user.save()
return Response({'detail': gettext_lazy('New password has been saved.')})
# urls.py
from django.urls import path
from .views import PasswordChangeViewCus
urlpatterns = [
path('api/auth/password/change/', PasswordChangeViewCus.as_view()),
]
六、权限(Permission)¶
6.1 内置权限¶
权限类 |
说明 |
|---|---|
|
所有人可访问 |
|
需要登录 |
|
需要 |
|
登录可写,未登录只读 |
|
对应 Django 模型权限 |
|
对象级权限(需配合 guardian) |
6.2 自定义权限¶
from rest_framework.permissions import BasePermission, SAFE_METHODS
class IsOwnerOrReadOnly(BasePermission):
"""对象所有者可修改,其他人只读"""
def has_permission(self, request, view):
# 视图级:登录才能访问
return request.user and request.user.is_authenticated
def has_object_permission(self, request, view, obj):
# 对象级:GET/HEAD/OPTIONS 所有登录用户可以
if request.method in SAFE_METHODS:
return True
# 只有对象所有者才能修改
return obj.user == request.user
class IsOwnerOrAdmin(BasePermission):
"""对象所有者或管理员"""
message = '只有对象所有者或管理员可以操作'
def has_object_permission(self, request, view, obj):
return request.user.is_staff or obj.user == request.user
class IsVerifiedUser(BasePermission):
"""已验证邮箱的用户"""
message = '请先验证邮箱'
def has_permission(self, request, view):
return (
request.user and
request.user.is_authenticated and
request.user.email_verified
)
# 权限组合(AND 使用列表,OR 使用 | 运算符)
class UserViewSet(viewsets.ModelViewSet):
# AND 组合:必须同时满足
permission_classes = [IsAuthenticated, IsVerifiedUser]
# OR 组合(DRF 3.9+)
# from rest_framework.permissions import IsAdminUser, IsOwnerOrReadOnly
# permission_classes = [IsAdminUser | IsOwnerOrReadOnly]
6.3 对象级权限¶
class UserViewSet(viewsets.ModelViewSet):
permission_classes = [IsAuthenticated, IsOwnerOrAdmin]
def get_object(self):
obj = super().get_object()
# 手动触发对象级权限检查
self.check_object_permissions(self.request, obj)
return obj
七、分页(Pagination)¶
7.1 基本配置¶
# settings.py
REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination',
'PAGE_SIZE': 100
}
7.2 PageNumberPagination¶
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
class StandardPagination(PageNumberPagination):
page_size = 20 # 默认每页数量
page_size_query_param = 'size' # ?size=50
max_page_size = 100
page_query_param = 'page' # ?page=2
def get_paginated_response(self, data):
return Response({
'code': 0,
'data': {
'list': data,
'total': self.page.paginator.count,
'page': self.page.number,
'pages': self.page.paginator.num_pages,
'size': self.get_page_size(self.request),
'has_next': self.page.has_next(),
}
})
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'properties': {
'code': {'type': 'integer'},
'data': {
'type': 'object',
'properties': {
'list': schema,
'total': {'type': 'integer'},
'page': {'type': 'integer'},
'pages': {'type': 'integer'},
}
}
}
}
7.3 LimitOffsetPagination¶
from rest_framework.pagination import LimitOffsetPagination
class LargeResultPagination(LimitOffsetPagination):
default_limit = 20 # 默认条数
max_limit = 200 # 最大条数
# 请求:GET /api/users/?limit=20&offset=40
7.4 CursorPagination¶
from rest_framework.pagination import CursorPagination
class CreatedAtCursorPagination(CursorPagination):
page_size = 20
ordering = '-created_at' # 必须是唯一且有序的字段
cursor_query_param = 'cursor'
# 适合:实时Feed、无限滚动。不支持跳转到指定页,但防止数据重复
7.5 在视图中使用分页¶
class UserViewSet(viewsets.ModelViewSet):
pagination_class = StandardPagination # 视图级分页(覆盖全局)
# 某些视图禁用分页
class AllUsersView(generics.ListAPIView):
pagination_class = None
# 手动分页(在 APIView 中)
class UserListView(APIView):
def get(self, request):
users = User.objects.all()
paginator = StandardPagination()
page = paginator.paginate_queryset(users, request)
serializer = UserSerializer(page, many=True)
return paginator.get_paginated_response(serializer.data)
八、过滤、搜索、排序¶
8.1 配置¶
# settings.py
INSTALLED_APPS = [
...
'django_filters',
]
REST_FRAMEWORK = {
# 过滤
'DEFAULT_FILTER_BACKENDS': [
'django_filters.rest_framework.DjangoFilterBackend',
'rest_framework.filters.SearchFilter',
'rest_framework.filters.OrderingFilter',
'your.app.custom.filterclass'
}
8.2 使用¶
from rest_framework.filters import SearchFilter, OrderingFilter
class UserViewSet(viewsets.ModelViewSet):
# filter = 精确/结构化筛选(字段匹配)
# 请求:GET /api/users/?username=xiaoming&email=136@qq.com&search=18&ordering=-created_at,username
filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] # 每个 backend 只认识自己的参数
filterset_class = UserFilter # 上述自定的Filter
# SearchFilter:多字段模糊搜索,SearchFilter 永远只认 ?search= 这一个 URL 参数。不能同时有多个search
search_fields = [
'username', # 精确包含
'email',
'=phone', # = 精确匹配
'^username', # ^ 以...开头
'@bio', # @ 全文搜索(需数据库支持)
'profile__bio', # 跨关联字段
]
# OrderingFilter:排序
ordering_fields = ['created_at', 'username', 'age', 'last_login']
ordering = ['-created_at'] # 默认排序
8.3 自定义 Filter¶
import django_filters
from django_filters import rest_framework as filters
from django.db.models import Q
class UserFilter(filters.FilterSet):
# 精确匹配
is_active = filters.BooleanFilter()
role = filters.CharFilter()
# 模糊匹配
username = filters.CharFilter(lookup_expr='icontains')
email = filters.CharFilter(lookup_expr='icontains')
# 范围过滤
min_age = filters.NumberFilter(field_name='age', lookup_expr='gte')
max_age = filters.NumberFilter(field_name='age', lookup_expr='lte')
# 日期范围
created_after = filters.DateFilter(field_name='created_at', lookup_expr='date__gte')
created_before = filters.DateFilter(field_name='created_at', lookup_expr='date__lte')
created_range = filters.DateFromToRangeFilter(field_name='created_at')
# 多值 IN 过滤
ids = filters.BaseInFilter(field_name='id', lookup_expr='in')
# ?ids=1,2,3
# 自定义过滤方法
keyword = filters.CharFilter(method='filter_keyword')
def filter_keyword(self, queryset, name, value):
return queryset.filter(
Q(username__icontains=value) |
Q(email__icontains=value) |
Q(phone__icontains=value)
)
class Meta:
model = User
fields = ['is_active', 'role', 'username', 'email']
class UserViewSet(viewsets.ModelViewSet):
filterset_class = UserFilter
# 请求:GET /api/users/?username=alice&min_age=18&is_active=true&created_range_after=2024-01-01
九、限流(Throttling)¶
DRF 的规则是:所有限流类都会检查,任意一个不通过就返回 429
9.1 基本使用¶
# 全局配置
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': [
# 这两个:全局配置后,所有视图立即生效,视图层什么都不用写
'rest_framework.throttling.AnonRateThrottle'
'rest_framework.throttling.UserRateThrottle'
# 这个:全局配置只是"注册"了它,还必须在视图上写 throttle_scope 才生效
'rest_framework.throttling.ScopedRateThrottle'
],
'DEFAULT_THROTTLE_RATES': { # key 只是一个名字
'anon': '100/day', # 未登录用户:每天100次
'user': '1000/day', # 登录用户:每天1000次
'cus_name': '10/minute', # 作用域限流
}
}
# 视图级限流
from rest_framework.throttling import UserRateThrottle, AnonRateThrottle
from rest_framework.views import APIView
class MyView(APIView):
throttle_classes = [UserRateThrottle] # 覆盖全局配置
# 或者
throttle_scope = 'cus_name' # 用名字叫login的
9.2 原理¶
核心:计数器 + 缓存 本质就是在缓存里记录"某个用户/IP 在某个时间窗口内请求了多少次"。
# 缓存 key 示例
"throttle_user_123" # 用户ID为123的请求记录
"throttle_anon_1.2.3.4" # IP为1.2.3.4的请求记录
# 缓存 value:该用户所有请求的时间戳列表
[1709001600.0, 1709001610.0, 1709001620.0, ...]
# 每次请求的处理流程:
新请求进来
↓
从缓存取出该用户的时间戳列表
↓
清除列表中"时间窗口之外"的过期时间戳
↓
剩余数量 >= 限制次数?
↓ ↓
是 否
↓ ↓
拒绝请求 追加当前时间戳,存回缓存,放行
9.3 自定义限流¶
from rest_framework.throttling import (
AnonRateThrottle, UserRateThrottle,
ScopedRateThrottle, BaseThrottle
)
# 自定义限流(基于 IP)
class CustomIPThrottle(BaseThrottle):
def __init__(self):
self.cache = {}
self.rate = 60 # 每分钟60次
self.period = 60 # 时间窗口(秒)
def allow_request(self, request, view):
ip = self.get_ident(request)
now = time.time()
if ip not in self.cache:
self.cache[ip] = []
# 清除窗口外的请求记录
self.cache[ip] = [t for t in self.cache[ip] if now - t < self.period]
if len(self.cache[ip]) >= self.rate:
return False
self.cache[ip].append(now)
return True
def wait(self):
return self.period
# 基于 Redis 的自定义限流(生产推荐)
import redis
from rest_framework.throttling import SimpleRateThrottle
class RedisRateThrottle(SimpleRateThrottle):
"""使用 Redis 实现滑动窗口限流"""
scope = 'redis_throttle'
cache = redis.Redis(host='localhost', port=6379, db=2)
def get_cache_key(self, request, view):
ident = self.get_ident(request)
return self.cache_format % {'scope': self.scope, 'ident': ident}
十、解析器与渲染器¶
10.1 解析器¶
from rest_framework.parsers import JSONParser, MultiPartParser, FormParser, FileUploadParser
# 全局配置
REST_FRAMEWORK = {
'DEFAULT_PARSER_CLASSES': [
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.MultiPartParser',
'rest_framework.parsers.FormParser',
]
}
# 视图级配置
class FileUploadView(APIView):
parser_classes = [MultiPartParser, FormParser]
def post(self, request, format=None):
file = request.FILES['file']
# 处理上传...
return Response({'filename': file.name})
10.2 渲染器¶
from rest_framework.renderers import JSONRenderer, BrowsableAPIRenderer
# 全局
REST_FRAMEWORK = {
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
]
}
# 自定义 JSON 渲染器(统一响应格式)
class StandardJSONRenderer(JSONRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
response = renderer_context.get('response')
if response and response.status_code >= 400:
# 错误响应不包装
return super().render(data, accepted_media_type, renderer_context)
wrapped = {
'code': 0,
'msg': 'ok',
'data': data
}
return super().render(wrapped, accepted_media_type, renderer_context)
十一、异常处理¶
# apps/common/exceptions.py
from rest_framework.views import exception_handler
from rest_framework.exceptions import (
APIException, ValidationError, NotAuthenticated,
PermissionDenied, NotFound, MethodNotAllowed, Throttled
)
from rest_framework.response import Response
def custom_exception_handler(exc, context):
response = exception_handler(exc, context)
if response is None:
# DRF 未处理的异常(如 Django 500)
import logging
logger = logging.getLogger(__name__)
logger.exception(f'Unhandled exception: {exc}')
return Response({
'code': 500,
'msg': '服务器内部错误'
}, status=500)
# 统一格式化错误响应
if isinstance(exc, ValidationError):
response.data = {
'code': 422,
'msg': '参数验证失败',
'errors': format_validation_errors(exc.detail)
}
elif isinstance(exc, NotAuthenticated):
response.data = {'code': 401, 'msg': '请先登录'}
elif isinstance(exc, PermissionDenied):
response.data = {'code': 403, 'msg': str(exc.detail)}
elif isinstance(exc, NotFound):
response.data = {'code': 404, 'msg': '资源不存在'}
elif isinstance(exc, MethodNotAllowed):
response.data = {'code': 405, 'msg': '请求方法不允许'}
elif isinstance(exc, Throttled):
response.data = {'code': 429, 'msg': f'请求过于频繁,请 {exc.wait:.0f} 秒后重试'}
else:
response.data = {
'code': response.status_code,
'msg': str(exc.detail) if hasattr(exc, 'detail') else '请求失败'
}
return response
def format_validation_errors(detail):
"""将嵌套的验证错误格式化为扁平结构"""
errors = {}
if isinstance(detail, dict):
for field, messages in detail.items():
if isinstance(messages, list):
errors[field] = [str(m) for m in messages]
else:
errors[field] = format_validation_errors(messages)
elif isinstance(detail, list):
return [str(m) for m in detail]
return errors
# 自定义业务异常
class BusinessException(APIException):
status_code = 400
def __init__(self, message: str, code: int = 400):
self.status_code = code
super().__init__(detail=message)
# 使用
raise BusinessException('用户名已存在', code=409)
raise BusinessException('余额不足', code=400)
十二、Versioning 版本控制¶
# settings.py
REST_FRAMEWORK = {
'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning',
'ALLOWED_VERSIONS': ['v1', 'v2'],
'DEFAULT_VERSION': 'v1',
'VERSION_PARAM': 'version',
}
# URLPathVersioning 路由
# /api/v1/users/
# /api/v2/users/
urlpatterns = [
path('api/<str:version>/', include('apps.users.urls')),
]
# NamespaceVersioning
# /api/users/?version=v2
# AcceptHeaderVersioning(请求头)
# Accept: application/json; version=v1
# HostNameVersioning(子域名)
# v1.api.yourdomain.com
# 视图中根据版本切换逻辑
class UserViewSet(viewsets.ModelViewSet):
def get_serializer_class(self):
if self.request.version == 'v2':
return UserV2Serializer
return UserSerializer
def list(self, request, *args, **kwargs):
if request.version == 'v2':
# v2 新逻辑
pass
return super().list(request, *args, **kwargs)
十三、测试¶
13.1 接口测试¶
from django.test import TestCase
from django.contrib.auth import get_user_model
from rest_framework.test import APITestCase, APIClient
from rest_framework import status
User = get_user_model()
class AuthAPITest(APITestCase):
def setUp(self):
self.user = User.objects.create_user(
username='testuser',
email='test@example.com',
password='testpass123'
)
def get_token(self):
"""获取 JWT Token"""
response = self.client.post('/api/auth/login/', {
'email': 'test@example.com',
'password': 'testpass123'
}, format='json')
return response.data['access']
def test_login_success(self):
response = self.client.post('/api/auth/login/', {
'email': 'test@example.com',
'password': 'testpass123'
}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('access', response.data)
self.assertIn('refresh', response.data)
def test_login_wrong_password(self):
response = self.client.post('/api/auth/login/', {
'email': 'test@example.com',
'password': 'wrongpass'
}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class UserAPITest(APITestCase):
def setUp(self):
self.user = User.objects.create_user(
username='testuser', email='test@example.com', password='testpass123'
)
self.admin = User.objects.create_superuser(
username='admin', email='admin@example.com', password='adminpass123'
)
self.client = APIClient()
def authenticate(self, user=None):
"""设置认证"""
self.client.force_authenticate(user=user or self.user)
def test_list_users_unauthenticated(self):
response = self.client.get('/api/users/')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_list_users_authenticated(self):
self.authenticate()
response = self.client.get('/api/users/')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('data', response.data)
def test_create_user(self):
response = self.client.post('/api/users/', {
'username': 'newuser',
'email': 'new@example.com',
'password': 'newpass123',
'confirm_password': 'newpass123'
}, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertTrue(User.objects.filter(username='newuser').exists())
def test_create_user_duplicate_email(self):
response = self.client.post('/api/users/', {
'username': 'another',
'email': 'test@example.com', # 已存在
'password': 'pass123',
'confirm_password': 'pass123'
}, format='json')
self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY)
def test_update_user_by_owner(self):
self.authenticate()
response = self.client.patch(
f'/api/users/{self.user.id}/',
{'username': 'updated_name'},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.user.refresh_from_db()
self.assertEqual(self.user.username, 'updated_name')
def test_update_user_by_other(self):
other = User.objects.create_user(username='other', email='other@a.com', password='pass')
self.authenticate(user=other)
response = self.client.patch(
f'/api/users/{self.user.id}/',
{'username': 'hacked'},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_custom_action_me(self):
self.authenticate()
response = self.client.get('/api/users/me/')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['username'], self.user.username)
class PaginationTest(APITestCase):
def setUp(self):
self.user = User.objects.create_superuser('admin', 'admin@test.com', 'admin123')
self.client.force_authenticate(user=self.user)
# 创建25个用户
for i in range(25):
User.objects.create_user(f'user{i}', f'user{i}@test.com', 'pass')
def test_pagination(self):
response = self.client.get('/api/users/?page=1&size=10')
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.data['data']
self.assertEqual(len(data['list']), 10)
self.assertGreater(data['total'], 20)
self.assertTrue(data['has_next'])
13.2 Serializer 测试¶
class UserSerializerTest(TestCase):
def test_valid_data(self):
data = {
'username': 'alice',
'email': 'alice@example.com',
'password': 'pass1234',
'confirm_password': 'pass1234'
}
serializer = UserCreateSerializer(data=data)
self.assertTrue(serializer.is_valid())
def test_password_mismatch(self):
data = {
'username': 'alice',
'email': 'alice@example.com',
'password': 'pass1234',
'confirm_password': 'different'
}
serializer = UserCreateSerializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertIn('non_field_errors', serializer.errors)
def test_invalid_email(self):
data = {'username': 'alice', 'email': 'not-an-email', 'password': 'pass1234'}
serializer = UserCreateSerializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertIn('email', serializer.errors)
十四、最佳实践与常见模式¶
14.1 统一响应格式¶
# utils/response.py
from rest_framework.response import Response
class APIResponse(Response):
def __init__(self, data=None, code=0, msg='ok', status=None, **kwargs):
wrapped = {'code': code, 'msg': msg, 'data': data}
super().__init__(data=wrapped, status=status, **kwargs)
# 使用
return APIResponse(data=serializer.data, status=201)
return APIResponse(data=None, code=404, msg='用户不存在', status=404)
14.2 Mixin 封装通用逻辑¶
class SoftDeleteMixin:
"""软删除 Mixin"""
def perform_destroy(self, instance):
instance.is_deleted = True
instance.deleted_at = timezone.now()
instance.save(update_fields=['is_deleted', 'deleted_at'])
class AuditMixin:
"""审计 Mixin(自动记录创建/更新人)"""
def perform_create(self, serializer):
serializer.save(created_by=self.request.user)
def perform_update(self, serializer):
serializer.save(updated_by=self.request.user)
class UserViewSet(AuditMixin, SoftDeleteMixin, viewsets.ModelViewSet):
queryset = User.objects.filter(is_deleted=False)
serializer_class = UserSerializer
14.3 ViewSet 动态 Serializer / QuerySet / Permission¶
class FlexibleViewSet(viewsets.ModelViewSet):
"""集中管理不同 action 的 Serializer、QuerySet、Permission"""
serializer_action_map = {
'list': UserListSerializer,
'create': UserCreateSerializer,
'retrieve': UserDetailSerializer,
'update': UserUpdateSerializer,
'partial_update': UserUpdateSerializer,
}
queryset_action_map = {
'list': User.objects.filter(is_active=True).select_related('profile'),
'default': User.objects.all(),
}
permission_action_map = {
'create': [AllowAny],
'destroy': [IsAdminUser],
'default': [IsAuthenticated],
}
def get_serializer_class(self):
return self.serializer_action_map.get(self.action, UserSerializer)
def get_queryset(self):
return self.queryset_action_map.get(self.action, self.queryset_action_map['default'])
def get_permissions(self):
classes = self.permission_action_map.get(
self.action,
self.permission_action_map['default']
)
return [c() for c in classes]
14.4 OpenAPI 文档(drf-spectacular)¶
# pip install drf-spectacular
# settings.py
SPECTACULAR_SETTINGS = {
'TITLE': 'My API',
'DESCRIPTION': 'API 文档',
'VERSION': '1.0.0',
'SERVE_INCLUDE_SCHEMA': False,
}
# urls.py
from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView, SpectacularRedocView
urlpatterns += [
path('api/schema/', SpectacularAPIView.as_view(), name='schema'),
path('api/docs/', SpectacularSwaggerView.as_view(), name='swagger-ui'),
path('api/redoc/', SpectacularRedocView.as_view(), name='redoc'),
]
# 给视图加文档注解
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample
@extend_schema(
summary='获取用户列表',
description='支持分页、过滤、排序',
parameters=[
OpenApiParameter('page', int, description='页码'),
OpenApiParameter('search', str, description='搜索关键词'),
],
responses={200: UserListSerializer(many=True)},
tags=['用户管理']
)
class UserListView(generics.ListAPIView):
...
14.5 常见问题速查¶
# 1. 获取当前登录用户
request.user
# 2. 传递 context 给 Serializer
serializer = UserSerializer(user, context={'request': request})
# 3. 视图中手动触发对象级权限检查
obj = self.get_object() # get_object() 内部自动调用 check_object_permissions
# 4. 关闭某个视图的分页
class AllDataView(generics.ListAPIView):
pagination_class = None
# 5. 获取过滤后的 queryset(在 APIView 中)
from rest_framework.filters import SearchFilter
backend = SearchFilter()
queryset = backend.filter_queryset(request, User.objects.all(), self)
# 6. 手动验证 Serializer 并获取错误
serializer = UserSerializer(data=request.data)
if not serializer.is_valid():
return Response({'errors': serializer.errors}, status=400)
# 7. 序列化时排除某字段
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
exclude = ['password', 'is_superuser']
# 8. 返回 201 并附带 Location 头
response = Response(serializer.data, status=status.HTTP_201_CREATED)
response['Location'] = request.build_absolute_uri(f'/api/users/{instance.id}/')
return response
# 9. 在 serializer 中访问请求对象
user = self.context['request'].user
# 10. 批量创建并返回
serializer = UserSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
User.objects.bulk_create([User(**item) for item in serializer.validated_data])
return Response(serializer.data, status=201)
常用扩展汇总¶
库 |
功能 |
|---|---|
djangorestframework |
DRF 核心 |
djangorestframework-simplejwt |
JWT 认证 |
django-filter |
高级过滤 |
drf-spectacular |
OpenAPI 3.0 文档生成 |
drf-nested-routers |
嵌套路由 |
django-guardian |
对象级权限 |
djangorestframework-camel-case |
camelCase 字段命名 |
drf-extensions |
缓存、etag 等扩展 |
参考资源¶
SimpleJWT 文档:https://django-rest-framework-simplejwt.readthedocs.io/
django-filter 文档:https://django-filter.readthedocs.io/
drf-spectacular 文档:https://drf-spectacular.readthedocs.io/