• peft模型微调--Prompt Tuning


    模型微调(Model Fine-Tuning)是指在预训练模型的基础上,针对特定任务进行进一步的训练以优化模型性能的过程。预训练模型通常是在大规模数据集上通过无监督或自监督学习方法预先训练好的,具有捕捉语言或数据特征的强大能力。

    PEFT(Parameter-Efficient Fine-Tuning)是一种针对大模型微调的技术,其核心思想是在保持大部分预训练模型参数不变的基础上,仅对一小部分额外参数进行微调,以实现高效的资源利用和性能优化。这种方法对于那些计算资源有限、但又需要针对特定任务调整大型语言模型(如LLM:Large Language Models)的行为时特别有用。

    在应用PEFT技术进行模型微调时,通常采用以下策略之一或组合:

    Adapter Layers: 在模型的各个层中插入适配器模块,这些适配器模块通常具有较低的维度,并且仅对这部分新增的参数进行微调,而不改变原模型主体的参数。

    Prefix Tuning / Prompt Tuning: 通过在输入序列前添加可学习的“提示”向量(即prefix或prompt),来影响模型的输出结果,从而达到微调的目的,而无需更改模型原有权重。

    LoRA (Low-Rank Adaptation): 使用低秩矩阵更新原始模型权重,这样可以大大减少要训练的参数数量,同时保持模型的表达能力。

    P-Tuning V1/V2: 清华大学提出的一种方法,它通过学习一个连续的prompt嵌入向量来指导模型生成特定任务相关的输出。

    冻结(Freezing)大部分模型参数: 只对模型的部分层或头部(如分类器层)进行微调,其余部分则保持预训练时的状态不变。

    下面简单介绍一个通过peft使用Prompt Tuning对模型进行微调训练的简单流程。

    # 基于peft使用prompt tuning对生成式对话模型进行微调 
    from datasets import Dataset
    from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
    
    • 1
    • 2
    • 3
    # 数据加载
    ds = Dataset.load_from_disk("/alpaca_data_zh")
    print(ds[:3])
    
    • 1
    • 2
    • 3
    # 数据处理
    tokenizer = AutoTokenizer.from_pretrained("../models/bloom-1b4-zh")
    # 数据处理函数
    def process_func(example):
        MAX_LENGTH = 256
        input_ids, attention_mask, labels = [], [], []
        instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
        response = tokenizer(example["output"] + tokenizer.eos_token)
        input_ids = instruction["input_ids"] + response["input_ids"]
        attention_mask = instruction["attention_mask"] + response["attention_mask"]
        labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
        if len(input_ids) > MAX_LENGTH:
            input_ids = input_ids[:MAX_LENGTH]
            attention_mask = attention_mask[:MAX_LENGTH]
            labels = labels[:MAX_LENGTH]
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }
    
    # 数据处理
    tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
    print(tokenized_ds)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    # 模型创建
    model = AutoModelForCausalLM.from_pretrained("../models/bloom-1b4-zh", low_cpu_mem_usage=True)
    
    • 1
    • 2
    # 套用peft对模型进行参数微调
    from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit
    
    # 1、配置文件参数
    config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,
                                prompt_tuning_init=PromptTuningInit.TEXT,
                                prompt_tuning_init_text="下面是一段人与机器人的对话。",
                                num_virtual_tokens=len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),
                                tokenizer_name_or_path="../models/bloom-1b4-zh")
    
    # 2、创建模型
    model = get_peft_model(model, config)
    # 查看模型的训练参数
    model.print_trainable_parameters()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    # 配置训练参数
    args = TrainingArguments(
        output_dir="./peft_model",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        logging_steps=10,
        num_train_epochs=1
    )
    
    # 创建训练器
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized_ds,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    # 模型训练
    trainer.train()
    
    • 1
    • 2
    # 模型推理
    peft_model = model.cuda()
    ipt = tokenizer("Human: {}\n{}".format("周末去重庆怎么玩?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
    print(tokenizer.decode(peft_model.generate(**ipt, max_length=256, do_sample=True)[0], skip_special_tokens=True))
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    1044 火星数字 (测试点2.4说明)
    Webmin--一个用于Linux基于Web的系统管理工具
    入门力扣自学笔记131 C++ (题目编号655)
    react实现一个搜索部门(input + tree)
    1060 Are They Equal
    python基础(三)
    架构师常用设计模型
    SpringBoot对Spring MVC都做了哪些事?
    求w=1+2的1次方+....+2的10次方
    Leetcode刷题解析——904. 水果成篮
  • 原文地址:https://blog.csdn.net/LLMUZI123456789/article/details/136652290