• 基于rest_framework的ModelViewSet类编写登录视图和认证视图


    背景:看了博主一抹浅笑的rest_framework认证模板,发现登录视图函数是基于APIView类封装。
    优化:使用ModelViewSet类通过重写create方法编写登录函数。
    环境:既然接触到rest_framework的使用,相信已经搭建好相关环境了。

    1 建立模型

    编写模型类

    # models.py
    from django.db import models
    class User(models.Model):
        username = models.CharField(verbose_name='用户名称',unique=True,max_length=16)
        password = models.CharField(verbose_name='登陆密码',max_length=16)
    class Token(models.Model):
        username = models.CharField(verbose_name='用户名称',unique=True,max_length=16)
        token = models.CharField(verbose_name='验证密钥',max_length=32)
    

    生成迁移文件

    python manage.py makemigrations
    

    迁移数据模型

    python manage.py migrate
    

    2 确定需要重写的方法

    查看ModelViewSet类源码

    '''
    class ModelViewSet(mixins.CreateModelMixin,
                       mixins.RetrieveModelMixin,
                       mixins.UpdateModelMixin,
                       mixins.DestroyModelMixin,
                       mixins.ListModelMixin,
                       GenericViewSet):
        """
        A viewset that provides default `create()`, `retrieve()`, `update()`,
        `partial_update()`, `destroy()` and `list()` actions.
        """
        pass
    '''
    

    最终目的是往Token模型对应的表添加数据,所以得选择CreateModelMixin模型的源码查看。

    '''
    class CreateModelMixin:
        """
        Create a model instance.
        """
        def create(self, request, *args, **kwargs):
            serializer = self.get_serializer(data=request.data)
            serializer.is_valid(raise_exception=True)
            self.perform_create(serializer)
            headers = self.get_success_headers(serializer.data)
            return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
    
        def perform_create(self, serializer):
            serializer.save()
    
        def get_success_headers(self, data):
            try:
                return {'Location': str(data[api_settings.URL_FIELD_NAME])}
            except (TypeError, KeyError):
                return {}
    '''
    

    查看得知,CreateModelMixin类下的create方法调用了serializer类的save方法创建数据。继续查看save方法。
    通过serializers.ModelSerializer定位到serializers.py文件,搜索'def save('定位到以下内容。

    '''
        def save(self, **kwargs):
            assert hasattr(self, '_errors'), (
                'You must call `.is_valid()` before calling `.save()`.'
            )
    
            assert not self.errors, (
                'You cannot call `.save()` on a serializer with invalid data.'
            )
    
            # Guard against incorrect use of `serializer.save(commit=False)`
            assert 'commit' not in kwargs, (
                "'commit' is not a valid keyword argument to the 'save()' method. "
                "If you need to access data before committing to the database then "
                "inspect 'serializer.validated_data' instead. "
                "You can also pass additional keyword arguments to 'save()' if you "
                "need to set extra attributes on the saved model instance. "
                "For example: 'serializer.save(owner=request.user)'.'"
            )
    
            assert not hasattr(self, '_data'), (
                "You cannot call `.save()` after accessing `serializer.data`."
                "If you need to access data before committing to the database then "
                "inspect 'serializer.validated_data' instead. "
            )
    
            validated_data = {**self.validated_data, **kwargs}
    
            if self.instance is not None:
                self.instance = self.update(self.instance, validated_data)
                assert self.instance is not None, (
                    '`update()` did not return an object instance.'
                )
            else:
                self.instance = self.create(validated_data)
                assert self.instance is not None, (
                    '`create()` did not return an object instance.'
                )
    '''
    

    看最后这个if……else……语句中的self.instance = self.create(validated_data)。
    说明这里调用了create方法,返回一个模型对象。于是查看ModelSerializer类的create方法。

    '''
        def create(self, validated_data):
            """
            We have a bit of extra checking around this in order to provide
            descriptive messages when something goes wrong, but this method is
            essentially just:
    
                return ExampleModel.objects.create(**validated_data)
    
            If there are many to many fields present on the instance then they
            cannot be set until the model is instantiated, in which case the
            implementation is like so:
    
                example_relationship = validated_data.pop('example_relationship')
                instance = ExampleModel.objects.create(**validated_data)
                instance.example_relationship = example_relationship
                return instance
    
            The default implementation also does not handle nested relationships.
            If you want to support writable nested relationships you'll need
            to write an explicit `.create()` method.
            """
            raise_errors_on_nested_writes('create', self, validated_data)
    
            ModelClass = self.Meta.model
    
            # Remove many-to-many relationships from validated_data.
            # They are not valid arguments to the default `.create()` method,
            # as they require that the instance has already been saved.
            info = model_meta.get_field_info(ModelClass)
            many_to_many = {}
            for field_name, relation_info in info.relations.items():
                if relation_info.to_many and (field_name in validated_data):
                    many_to_many[field_name] = validated_data.pop(field_name)
    
            try:
                instance = ModelClass._default_manager.create(**validated_data)
            except TypeError:
                tb = traceback.format_exc()
                msg = (
                    'Got a `TypeError` when calling `%s.%s.create()`. '
                    'This may be because you have a writable field on the '
                    'serializer class that is not a valid argument to '
                    '`%s.%s.create()`. You may need to make the field '
                    'read-only, or override the %s.create() method to handle '
                    'this correctly.\nOriginal exception was:\n %s' %
                    (
                        ModelClass.__name__,
                        ModelClass._default_manager.name,
                        ModelClass.__name__,
                        ModelClass._default_manager.name,
                        self.__class__.__name__,
                        tb
                    )
                )
                raise TypeError(msg)
    
            # Save many-to-many relationships after the instance is created.
            if many_to_many:
                for field_name, value in many_to_many.items():
                    field = getattr(instance, field_name)
                    field.set(value)
    
            return instance
    '''
    

    这逻辑我是没看懂,但是通过print、type、dir函数可以确定
    接收对象validated_data是一个字典,
    返回对象instance是一个模型对象。
    于是可以把源码cv过来,简单测试是否能够通。

    import time
    import hashlib
    
    from rest_framework import status
    from rest_framework import serializers
    from rest_framework.response import Response
    from rest_framework.viewsets import ModelViewSet
    
    from myapp import models as myapp_models
    
    class TokenSerializer(serializers.ModelSerializer):
        class Meta:
            model = myapp_models.Token
            fields = '__all__'
        def create(self,validated_data):
            ######################################
            query_obj = myapp_models.Token.objects.update_or_create(
                username=validated_data['username'],
                defaults={"username":validated_data['username'],"token":validated_data['token']})[0]
            print(query_obj)
            return query_obj
            #------------------------------------#
    class LoginView(ModelViewSet):
        queryset = myapp_models.Token.objects.all()
        serializer_class = TokenSerializer
        def create(self, request, *args, **kwargs):
            serializer = self.get_serializer(data=request.data)
            serializer.is_valid(raise_exception=True)
            self.perform_create(serializer)
            headers = self.get_success_headers(serializer.data)
            return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
    

    3 重写create方法

    3.1 编写登录逻辑

    TokenSerializer
    1.获取username和password。
    2.验证username、password匹配性。
    3.匹配错误:更新或创建模型中username对应的token为空字符串,返回模型对象。
    4.匹配正确:通过md5加密生成token,更新或创建模型中username对应的token为密钥。
    ModelViewSet
    1.根据username查询token值。
    2.将username、token值设置到session会话。

    import time
    import hashlib
    
    from rest_framework import status
    from rest_framework import serializers
    from rest_framework.response import Response
    from rest_framework.viewsets import ModelViewSet
    
    from myapp import models as myapp_models
    
    class TokenSerializer(serializers.ModelSerializer):
        class Meta:
            model = myapp_models.Token
            fields = '__all__'
        def create(self,validated_data):
            ######################################
            user_obj = myapp_models.User.objects.filter(
                username=validated_data['username'],
                password=validated_data['token'])
            user_dict = validated_data
            user_dict['token'] = ''
            if not user_obj.exists():
                query_obj = myapp_models.Token.objects.update_or_create(
                    username=user_dict['username'],
                    defaults={"username":user_dict['username'],"token":user_dict['token']})[0]
                return query_obj
            validated_data['token'] = hashlib.md5(
                ''.format(time.time(),''.join(validated_data.values())).encode()).hexdigest()
            query_obj = myapp_models.Token.objects.update_or_create(
                username=validated_data['username'],
                defaults={"username":validated_data['username'],"token":validated_data['token']})[0]
            print(query_obj)
            return query_obj
            #------------------------------------#
    class LoginView(ModelViewSet):
        queryset = myapp_models.Token.objects.all()
        serializer_class = TokenSerializer
        def create(self, request, *args, **kwargs):
            serializer = self.get_serializer(data=request.data)
            serializer.is_valid(raise_exception=True)
            self.perform_create(serializer)
            headers = self.get_success_headers(serializer.data)
            ######################################
            token_obj = myapp_models.Token.objects.filter(
                username=request.POST.get('username')).first()
            if token_obj.token == '':
                request.session['username'] = token_obj.username
                request.session['token'] = token_obj.token
                return Response('检查输入的账户和密码')
            request.session['username'] = token_obj.username
            request.session['token'] = token_obj.token
            #------------------------------------#
            return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
    

    3.2 编写认证逻辑

    1.从session中获取username,token。
    2.判断username,token是否不存在、或token是否为空字符串。
    3.判断正确:抛出异常。
    4.判断错误:范围username和模型对象组成的元组。

    from rest_framework import exceptions
    from rest_framework.authentication import BaseAuthentication
    
    from myapp import models as myapp_models
    
    class Authentication(BaseAuthentication):
        def authenticate(self,request):
            ######################################
            username = request._request.session.get('username','')
            token = request._request.session.get('token','')
            token_obj = myapp_models.Token.objects.filter(
                username=username,token=token)
            if not token_obj.exists or token_obj.first().token == '':
                raise exceptions.AuthenticationFailed('认证失败')
            return (token_obj.first().username,token_obj.first())
            #------------------------------------#
    

    3.3 添加路由

    path('login/',myapp_views.LoginView.as_view({
            'post':'create'}),name='login')
    
  • 相关阅读:
    清源正本,鉴往知来,Go lang1.18入门精炼教程,由白丁入鸿儒,Golang中引用类型是否进行引用传递EP18
    ubuntu2204配置仓库为阿里源
    dolphinscheduler 2.0.6 任务之间的参数传递及Java脚本引擎
    Ubuntu中Python3找不到_sqlite3模块
    mysql创建定时器(event),定时调用存储过程(Procedure)将查询出结果集并批量插入新表
    Linux环境下Qt应用程序打包与发布
    信管知识梳理(三)软件工程相关知识
    【Rust 日报】2022-08-04 异步Rust的实践:性能、隐患、分析​
    震裕科技-300953 三季报分析(20231108)
    【产品运营】产品需求应该如何管理
  • 原文地址:https://www.cnblogs.com/mlcode/p/17969584/rest_framework