• NLP - 使用 transformers 翻译


    from transformers import AutoTokenizer
    
    #加载编码器
    tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro',
                                              use_fast=True)
    
    print(tokenizer)
    
    #编码试算
    tokenizer.batch_encode_plus(
        [['Hello, this one sentence!', 'This is another sentence.']])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    PreTrainedTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '', 'unk_token': '', 'pad_token': ''})
    {'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
    
    • 1
    • 2

    from datasets import load_dataset, load_from_disk
    
    #加载数据
    dataset = load_dataset(path='wmt16', name='ro-en')
    # dataset = load_from_disk('datas/wmt16/ro-en')
    
    #采样,数据量太大了跑不动
    dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
    dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
    dataset['test'] = dataset['test'].shuffle(1).select(range(200))
    
    
    #数据预处理
    def preprocess_function(data):
        #取出数据中的en和ro
        en = [ex['en'] for ex in data['translation']]
        ro = [ex['ro'] for ex in data['translation']]
    
        #源语言直接编码就行了
        data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)
    
        #目标语言在特殊模块中编码
        with tokenizer.as_target_tokenizer():
            data['labels'] = tokenizer.batch_encode_plus(
                ro, max_length=128, truncation=True)['input_ids']
    
        return data
    
    
    dataset = dataset.map(function=preprocess_function,
                          batched=True,
                          batch_size=1000,
                          num_proc=4,
                          remove_columns=['translation'])
    
    print(dataset['train'][0])
    
    dataset
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    {'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
    DatasetDict({
        train: Dataset({
            features: ['input_ids', 'attention_mask', 'labels'],
            num_rows: 20000
        })
        validation: Dataset({
            features: ['input_ids', 'attention_mask', 'labels'],
            num_rows: 200
        })
        test: Dataset({
            features: ['input_ids', 'attention_mask', 'labels'],
            num_rows: 200
        })
    })
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    #这个函数和下面这个工具类等价,但我也是模仿实现的,不确定有没有出入
    #from transformers import DataCollatorForSeq2Seq
    #DataCollatorForSeq2Seq(tokenizer, model=model)
    
    import torch
    
    
    #数据批处理函数
    def collate_fn(data):
        #求最长的label
        max_length = max([len(i['labels']) for i in data])
    
        #把所有的label都补pad到最长
        for i in data:
            pads = [-100] * (max_length - len(i['labels']))
            i['labels'] = i['labels'] + pads
    
        #把多个数据整合成一个tensor
        data = tokenizer.pad(
            encoded_inputs=data,
            padding=True,
            max_length=None,
            pad_to_multiple_of=None,
            return_tensors='pt',
        )
    
        #定义decoder_input_ids
        data['decoder_input_ids'] = torch.full_like(data['labels'],
                                                    tokenizer.get_vocab()[''],
                                                    dtype=torch.long)
        data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
        data['decoder_input_ids'][data['decoder_input_ids'] ==
                                  -100] = tokenizer.get_vocab()['']
    
        return data
    
    
    data = [{
        'input_ids': [21603, 10, 37, 3719, 13],
        'attention_mask': [1, 1, 1, 1, 1],
        'labels': [10455, 120, 80]
    }, {
        'input_ids': [21603, 10, 7086, 8408, 563],
        'attention_mask': [1, 1, 1, 1, 1],
        'labels': [301, 53, 4074, 1669]
    }]
    
    collate_fn(data)['decoder_input_ids']
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    tensor([[59542, 10455,   120,    80],
            [59542,   301,    53,  4074]])
    
    • 1
    • 2

    import torch
    
    #数据加载器
    loader = torch.utils.data.DataLoader(
        dataset=dataset['train'],
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
    )
    
    for i, data in enumerate(loader):
        break
    
    for k, v in data.items():
        print(k, v.shape, v[:2])
    
    len(loader)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    from transformers import AutoModelForSeq2SeqLM, MarianModel
    
    #加载模型
    #model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')
    
    
    #定义下游任务模型
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.pretrained = MarianModel.from_pretrained(
                'Helsinki-NLP/opus-mt-en-ro')
    
            self.register_buffer('final_logits_bias',
                                 torch.zeros(1, tokenizer.vocab_size))
    
            self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)
    
            #加载预训练模型的参数
            parameters = AutoModelForSeq2SeqLM.from_pretrained(
                'Helsinki-NLP/opus-mt-en-ro')
            self.fc.load_state_dict(parameters.lm_head.state_dict())
    
            self.criterion = torch.nn.CrossEntropyLoss()
    
        def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
            logits = self.pretrained(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     decoder_input_ids=decoder_input_ids)
            logits = logits.last_hidden_state
    
            logits = self.fc(logits) + self.final_logits_bias
    
            loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())
    
            return {'loss': loss, 'logits': logits}
    
    
    model = Model()
    
    #统计参数量
    print(sum(i.numel() for i in model.parameters()) / 10000)
    
    #out = model(**data)
    #out['loss'], out['logits'].shape
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    from datasets import load_metric
    
    #加载评价函数
    metric = load_metric(path='sacrebleu')
    
    #试算
    metric.compute(predictions=['hello there', 'general kenobi'],
                   references=[['hello there'], ['general kenobi']])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    
    
    • 1

    测试

    #测试
    def test():
        model.eval()
    
        #数据加载器
        loader_test = torch.utils.data.DataLoader(
            dataset=dataset['test'],
            batch_size=8,
            collate_fn=collate_fn,
            shuffle=True,
            drop_last=True,
        )
    
        predictions = []
        references = []
        for i, data in enumerate(loader_test):
            #计算
            with torch.no_grad():
                out = model(**data)
    
            pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
            label = tokenizer.batch_decode(data['decoder_input_ids'])
            predictions.extend(pred)
            references.extend(label)
    
            if i % 2 == 0:
                print(i)
                input_ids = tokenizer.decode(data['input_ids'][0])
    
                print('input_ids=', input_ids)
                print('pred=', pred[0])
                print('label=', label[0])
    
            if i == 10:
                break
    
        references = [[j] for j in references]
        metric_out = metric.compute(predictions=predictions, references=references)
        print(metric_out)
    
    
    test()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    
    
    • 1

    from transformers import AdamW
    from transformers.optimization import get_scheduler
    
    
    #训练
    def train():
        optimizer = AdamW(model.parameters(), lr=2e-5)
        scheduler = get_scheduler(name='linear',
                                  num_warmup_steps=0,
                                  num_training_steps=len(loader),
                                  optimizer=optimizer)
    
        model.train()
        for i, data in enumerate(loader):
            out = model(**data)
            loss = out['loss']
    
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
            optimizer.step()
            scheduler.step()
    
            optimizer.zero_grad()
            model.zero_grad()
    
            if i % 50 == 0:
                out = out['logits'].argmax(dim=2)
                correct = (data['decoder_input_ids'] == out).sum().item()
                total = data['decoder_input_ids'].shape[1] * 8
                accuracy = correct / total
                del correct
                del total
    
                predictions = []
                references = []
                for j in range(8):
                    pred = tokenizer.decode(out[j])
                    label = tokenizer.decode(data['decoder_input_ids'][j])
                    predictions.append(pred)
                    references.append([label])
    
                metric_out = metric.compute(predictions=predictions,
                                            references=references)
    
                lr = optimizer.state_dict()['param_groups'][0]['lr']
    
                print(i, loss.item(), accuracy, metric_out, lr)
    
        torch.save(model, 'models/7.翻译.model')
    
    
    train()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53

    
    
    • 1

    model = torch.load('models/7.翻译.model')
    test()
    
    • 1
    • 2

    
    
    • 1

    
    
    • 1

    
    
    • 1

    
    
    • 1

    
    
    • 1

  • 相关阅读:
    2011年下半年 系统架构设计师 下午试卷 II
    Servlet小项目 | 基于纯Servlet手写一个单表的CRUD操作
    前后端分离毕设项目之springboot同城上门喂遛宠物系统(内含文档+源码+教程)
    尝试用Unity还原蔚蓝(Celeste)—— 真·操控、移动、手感篇
    MySQL报错:Row size too large (> 8126)
    Docker
    Java List 过滤重复数据
    【重学前端】004-JavaScript:我们真的需要模拟类吗
    吉时利2600A系列/2611A数字源表
    【Python零基础入门篇 · 10】:集合的相关操作
  • 原文地址:https://blog.csdn.net/lovechris00/article/details/128230696