• 新闻主题分类任务——torchtext 库进行文本分类


    简介

    使用浅层网络构建新闻主题分类器。
    以一段新闻报道中的文本描述内容为输入, 使用模型帮助我们判断它最有可能属于哪一种类型的新闻, 这是典型的文本分类问题, 我们这里假定每种类型是互斥的, 即文本描述有且只有一种类型。

    导入相关的torch工具包

    import time
    
    import torch
    import torch.nn as nn
    from torchtext.datasets import AG_NEWS
    from torchtext.data.utils import get_tokenizer
    from torchtext.vocab import build_vocab_from_iterator
    from torch.utils.data import DataLoader
    from torch.utils.data.dataset import random_split
    from torchtext.data.functional import to_map_style_dataset
    from TextClassificationModule import TextClassificationModule
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    访问原始数据集迭代器

    torchtext 库提供了一些原始数据集迭代器,它们产生原始文本字符串。例如,AG_NEWS数据集迭代器将原始数据生成为标签和文本的元组。

    # 可用设备检测, 有GPU的话将优先使用GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 基本的英文分词器
    tokenizer = get_tokenizer('basic_english')
    # 训练数据加载器
    train_iter = AG_NEWS(split="train")
    test_iter = AG_NEWS(split="test")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    对读取到的数据进行测试,该读取的数据是从网上自动下载到缓存,其中读取到的 train_iter 和 test_iter 为训练集和测试集,且均为迭代器类型。

    print('test:')
    train_data = iter(train_iter)
    test_data = iter(test_iter)
    print(next(train_data))
    print(next(test_data))
    
    • 1
    • 2
    • 3
    • 4
    • 5

    运行结果

    test:
    (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")
    (3, "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.")
    
    • 1
    • 2
    • 3

    使用原始训练数据集构建词汇表

    其中分词生成器中的 “_” 表示一个不用的变量即类别,text 表示新闻文本,如:
    _ = 3
    text = Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.
    python 中 yield 的作用就是把一个函数变成一个 generator,带有 yield 的函数不再是一个普通函数,Python 解释器会将其视为一个 generator,调用 fab(5) 不会执行 fab 函数,而是返回一个 iterable 对象。
    示例

    def yield_test(n):  
        for i in range(n):  
            yield call(i)  
            print("i=",i)  
        #做一些其它的事情      
        print("do something.")      
        print("end.")  
    
    def call(i):  
        return i*2  
    
    #使用for循环  
    for i in yield_test(5):  
        print(i,",")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    运行结果

    0 ,  
    i= 0  
    2 ,  
    i= 1  
    4 ,  
    i= 2  
    6 ,  
    i= 3  
    8 ,  
    i= 4  
    do something.  
    end.  
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    使用原始训练数据集构建词汇表

    # 分词生成器
    def yield_tokens(data_iter):
        for _, text in data_iter:
            yield tokenizer(text)
    
    
    # 根据训练数据构建词汇表,torchtext.vocab
    vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[""])
    # 设置默认索引,当某个单词不在词汇表 vocab 时(OOV),返回该单词索引
    vocab.set_default_index(vocab[""])
    
    # 词汇表会将 token 映射到词汇表中的索引上
    print(vocab(["here", "is", "an", "example"]))
    
    # 构建数据加载器 dataloader
    # text_pipeline 将一个文本字符串转换为整数 List, List 中每项对应词汇表 vocab 中的单词的索引号
    text_pipeline = lambda x: vocab(tokenizer(x))
    
    # label_pipeline 将 label 转换为整数
    label_pipeline = lambda x: int(x) - 1
    
    # pipeline example
    print(text_pipeline("hello world! I'am happy"))
    print(label_pipeline("10"))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    运行结果

    [475, 21, 30, 5297]
    [12544, 50, 764, 282, 16, 1913, 2734]
    9
    
    • 1
    • 2
    • 3

    生成数据批处理和迭代器

    def collate_batch(batch):
        label_list, text_list, offsets = [], [], [0]
        for (_label, _text) in batch:
            label_list.append(label_pipeline(_label))
            processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
            text_list.append(processed_text)
            offsets.append(processed_text.size(0))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
        text_list = torch.cat(text_list)
        return label_list.to(device), text_list.to(device), offsets.to(device)
    
    
    # 加载数据集合,转换为张量
    dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    定义模型

    该模型由 nn.EmbeddingBag 层和用于分类目的的线性层组成。nn.EmbeddingBag 使用默认模式“mean”计算嵌入“bag”的平均值。尽管此处的文本条目具有不同的长度,但 nn.EmbeddingBag 模块在此处不需要填充,因为文本长度保存在偏移量中。
    nn.EmbeddingBag 可以提高性能和内存效率以处理一系列张量。

    import torch.nn as nn
    
    
    class TextClassificationModule(nn.Module):
        def __init__(self, vocab_size, embed_dim, num_class):
            """
                文本分类模型
                description: 类的初始化函数
                :param vocab_size: 整个语料包含的不同词汇总数
                :param embed_dim: 指定词嵌入的维度
                :param num_class: 文本分类的类别总数
            """
            super(TextClassificationModule, self).__init__()
            # 实例化embedding层, sparse=True代表每次对该层求解梯度时, 只更新部分权重
            self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
            # 实例化全连接层, 参数分别是embed_dim和num_class
            self.fc = nn.Linear(embed_dim, num_class)
            # 为各层初始化权重
            self.init_weights()
    
        def init_weights(self):
            """初始化权重函数"""
            # 指定初始权重的取值范围数
            initrange = 0.5
            # 各层的权重参数都是初始化为均匀分布
            self.embedding.weight.data.uniform_(-initrange, initrange)
            self.fc.weight.data.uniform_(-initrange, initrange)
            # 偏置初始化为0
            self.fc.bias.data.zero_()
    
        def forward(self, text, offsets):
            """
                :param text: 文本数值映射后的结果
                :return: 与类别数尺寸相同的张量, 用以判断文本类别
            """
            embedded = self.embedding(text, offsets)
            return self.fc(embedded)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    定义函数来训练模型和评估结果

    def train(dataloader):
        model.train()
        total_acc, total_count = 0, 0
        log_interval = 500
        start_time = time.time()
        for idx, (label, text, offsets) in enumerate(dataloader):
            optimizer.zero_grad()
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
            if idx % log_interval == 0 and idx > 0:
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches '
                      '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                                  total_acc / total_count))
                total_acc, total_count = 0, 0
                start_time = time.time()
    
    
    def evaluate(dataloader):
        model.eval()
        total_acc, total_count = 0, 0
    
        with torch.no_grad():
            for idx, (label, text, offsets) in enumerate(dataloader):
                predicted_label = model(text, offsets)
                loss = criterion(predicted_label, label)
                total_acc += (predicted_label.argmax(1) == label).sum().item()
                total_count += label.size(0)
        return total_acc / total_count
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    实例化并运行模型

    # 加载数据集合,转换为张量
    dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
    # 一个嵌入维度为 64 的模型。词汇大小等于词汇实例的长度。类的数量等于标签的数量,
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(vocab)
    emsize = 64
    model = TextClassificationModule(vocab_size, emsize, num_class).to(device)
    
    # 训练轮数
    EPOCHS = 10
    # 学习率
    LR = 5
    # 训练数据规模
    BATCH_SIZE = 64
    # 交叉熵损失函数
    criterion = torch.nn.CrossEntropyLoss()
    # 优化器
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    # 调整学习率机制
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
    total_accu = None
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    # 划分训练集中5%的数据最为验证集
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    
    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader)
        accu_val = evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                               time.time() - epoch_start_time,
                                               accu_val))
        print('-' * 59)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    运行结果

    | epoch   1 |   500/ 1782 batches | accuracy    0.689
    | epoch   1 |  1000/ 1782 batches | accuracy    0.856
    | epoch   1 |  1500/ 1782 batches | accuracy    0.873
    -----------------------------------------------------------
    | end of epoch   1 | time: 23.38s | valid accuracy    0.879 
    -----------------------------------------------------------
    | epoch   2 |   500/ 1782 batches | accuracy    0.896
    | epoch   2 |  1000/ 1782 batches | accuracy    0.904
    | epoch   2 |  1500/ 1782 batches | accuracy    0.900
    -----------------------------------------------------------
    | end of epoch   2 | time: 32.21s | valid accuracy    0.891 
    -----------------------------------------------------------
    | epoch   3 |   500/ 1782 batches | accuracy    0.915
    | epoch   3 |  1000/ 1782 batches | accuracy    0.916
    | epoch   3 |  1500/ 1782 batches | accuracy    0.915
    -----------------------------------------------------------
    | end of epoch   3 | time: 36.85s | valid accuracy    0.899 
    -----------------------------------------------------------
    | epoch   4 |   500/ 1782 batches | accuracy    0.925
    | epoch   4 |  1000/ 1782 batches | accuracy    0.925
    | epoch   4 |  1500/ 1782 batches | accuracy    0.922
    -----------------------------------------------------------
    | end of epoch   4 | time: 20.15s | valid accuracy    0.897 
    -----------------------------------------------------------
    | epoch   5 |   500/ 1782 batches | accuracy    0.937
    | epoch   5 |  1000/ 1782 batches | accuracy    0.938
    | epoch   5 |  1500/ 1782 batches | accuracy    0.936
    -----------------------------------------------------------
    | end of epoch   5 | time: 28.52s | valid accuracy    0.905 
    -----------------------------------------------------------
    | epoch   6 |   500/ 1782 batches | accuracy    0.939
    | epoch   6 |  1000/ 1782 batches | accuracy    0.938
    | epoch   6 |  1500/ 1782 batches | accuracy    0.941
    -----------------------------------------------------------
    | end of epoch   6 | time: 33.47s | valid accuracy    0.905 
    -----------------------------------------------------------
    | epoch   7 |   500/ 1782 batches | accuracy    0.940
    | epoch   7 |  1000/ 1782 batches | accuracy    0.941
    | epoch   7 |  1500/ 1782 batches | accuracy    0.939
    -----------------------------------------------------------
    | end of epoch   7 | time: 20.75s | valid accuracy    0.904 
    -----------------------------------------------------------
    | epoch   8 |   500/ 1782 batches | accuracy    0.941
    | epoch   8 |  1000/ 1782 batches | accuracy    0.941
    | epoch   8 |  1500/ 1782 batches | accuracy    0.940
    -----------------------------------------------------------
    | end of epoch   8 | time: 27.11s | valid accuracy    0.906 
    -----------------------------------------------------------
    | epoch   9 |   500/ 1782 batches | accuracy    0.942
    | epoch   9 |  1000/ 1782 batches | accuracy    0.942
    | epoch   9 |  1500/ 1782 batches | accuracy    0.942
    -----------------------------------------------------------
    | end of epoch   9 | time: 34.83s | valid accuracy    0.906 
    -----------------------------------------------------------
    | epoch  10 |   500/ 1782 batches | accuracy    0.942
    | epoch  10 |  1000/ 1782 batches | accuracy    0.942
    | epoch  10 |  1500/ 1782 batches | accuracy    0.940
    -----------------------------------------------------------
    | end of epoch  10 | time: 22.78s | valid accuracy    0.906 
    -----------------------------------------------------------
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60

    使用测试数据集评估模型

    print('Checking the results of test dataset.')
    accu_test = evaluate(test_dataloader)
    print('test accuracy {:8.3f}'.format(accu_test))
    
    • 1
    • 2
    • 3

    运行结果

    Checking the results of test dataset.
    test accuracy    0.906
    
    • 1
    • 2

    测试随机新闻

    # 测试随机新闻
    # 使用迄今为止最好的模型并测试高尔夫新闻。
    ag_news_label = {1: "World",
                     2: "Sports",
                     3: "Business",
                     4: "Sci/Tec"}
    
    
    def predict(text, text_pipeline):
        with torch.no_grad():
            text = torch.tensor(text_pipeline(text))
            output = model(text, torch.tensor([0]))
            return output.argmax(1).item() + 1
    
    
    ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
        enduring the season’s worst weather conditions on Sunday at The \
        Open on his way to a closing 75 at Royal Portrush, which \
        considering the wind and the rain was a respectable showing. \
        Thursday’s first round at the WGC-FedEx St. Jude Invitational \
        was another story. With temperatures in the mid-80s and hardly any \
        wind, the Spaniard was 13 strokes better in a flawless round. \
        Thanks to his best putting performance on the PGA Tour, Rahm \
        finished with an 8-under 62 for a three-stroke lead, which \
        was even more impressive considering he’d never played the \
        front nine at TPC Southwind."
    
    model = model.to("cpu")
    
    print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    运行结果

    This is a Sports news
    
    • 1

    完整代码

    import time
    
    import torch
    import torch.nn as nn
    from torchtext.datasets import AG_NEWS
    from torchtext.data.utils import get_tokenizer
    from torchtext.vocab import build_vocab_from_iterator
    from torch.utils.data import DataLoader
    from torch.utils.data.dataset import random_split
    from torchtext.data.functional import to_map_style_dataset
    
    
    # 可用设备检测, 有GPU的话将优先使用GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 基本的英文分词器
    tokenizer = get_tokenizer('basic_english')
    # 训练数据加载器
    train_iter = AG_NEWS(split="train")
    test_iter = AG_NEWS(split="test")
    
    
    # print('test:')
    # train_data = iter(train_iter)
    # test_data = iter(test_iter)
    # print(next(train_data))
    # print(next(test_data))
    
    
    # 分词生成器
    def yield_tokens(data_iter):
        for _, text in data_iter:
            yield tokenizer(text)
    
    
    # 根据训练数据构建词汇表
    vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[""])
    # 设置默认索引,当某个单词不在词汇表 vocab 时(OOV),返回该单词索引
    vocab.set_default_index(vocab[""])
    
    # 词汇表会将 token 映射到词汇表中的索引上
    # print(vocab(["here", "is", "an", "example"]))
    
    # 构建数据加载器 dataloader
    # text_pipeline 将一个文本字符串转换为整数 List, List 中每项对应词汇表 vocab 中的单词的索引号
    text_pipeline = lambda x: vocab(tokenizer(x))
    
    # label_pipeline 将 label 转换为整数
    label_pipeline = lambda x: int(x) - 1
    
    
    # pipeline example
    # print(text_pipeline("hello world! I'am happy"))
    # print(label_pipeline("10"))
    
    # 模型
    class TextClassificationModule(nn.Module):
        def __init__(self, vocab_size, embed_dim, num_class):
            """
                文本分类模型
                description: 类的初始化函数
                :param vocab_size: 整个语料包含的不同词汇总数
                :param embed_dim: 指定词嵌入的维度
                :param num_class: 文本分类的类别总数
            """
            super(TextClassificationModule, self).__init__()
            # 实例化embedding层, sparse=True代表每次对该层求解梯度时, 只更新部分权重
            self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
            # 实例化全连接层, 参数分别是embed_dim和num_class
            self.fc = nn.Linear(embed_dim, num_class)
            # 为各层初始化权重
            self.init_weights()
    
        def init_weights(self):
            """初始化权重函数"""
            # 指定初始权重的取值范围数
            initrange = 0.5
            # 各层的权重参数都是初始化为均匀分布
            self.embedding.weight.data.uniform_(-initrange, initrange)
            self.fc.weight.data.uniform_(-initrange, initrange)
            # 偏置初始化为0
            self.fc.bias.data.zero_()
    
        def forward(self, text, offsets):
            """
                :param text: 文本数值映射后的结果
                :return: 与类别数尺寸相同的张量, 用以判断文本类别
            """
            embedded = self.embedding(text, offsets)
            return self.fc(embedded)
    
    def collate_batch(batch):
        label_list, text_list, offsets = [], [], [0]
        for (_label, _text) in batch:
            label_list.append(label_pipeline(_label))
            processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
            text_list.append(processed_text)
            offsets.append(processed_text.size(0))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
        text_list = torch.cat(text_list)
        return label_list.to(device), text_list.to(device), offsets.to(device)
    
    
    def train(dataloader):
        model.train()
        total_acc, total_count = 0, 0
        log_interval = 500
        start_time = time.time()
        for idx, (label, text, offsets) in enumerate(dataloader):
            optimizer.zero_grad()
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
            if idx % log_interval == 0 and idx > 0:
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches '
                      '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                                  total_acc / total_count))
                total_acc, total_count = 0, 0
                start_time = time.time()
    
    
    def evaluate(dataloader):
        model.eval()
        total_acc, total_count = 0, 0
    
        with torch.no_grad():
            for idx, (label, text, offsets) in enumerate(dataloader):
                predicted_label = model(text, offsets)
                loss = criterion(predicted_label, label)
                total_acc += (predicted_label.argmax(1) == label).sum().item()
                total_count += label.size(0)
        return total_acc / total_count
    
    
    # 加载数据集合,转换为张量
    dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
    # 一个嵌入维度为 64 的模型。词汇大小等于词汇实例的长度。类的数量等于标签的数量,
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(vocab)
    emsize = 64
    model = TextClassificationModule(vocab_size, emsize, num_class).to(device)
    
    # 训练轮数
    EPOCHS = 10
    # 学习率
    LR = 5
    # 训练数据规模
    BATCH_SIZE = 64
    # 交叉熵损失函数
    criterion = torch.nn.CrossEntropyLoss()
    # 优化器
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    # 调整学习率机制
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
    total_accu = None
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    
    # 划分训练集中5%的数据最为验证集
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    
    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader)
        accu_val = evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                               time.time() - epoch_start_time,
                                               accu_val))
        print('-' * 59)
    
    '''使用测试数据集评估模型'''
    print('Checking the results of test dataset.')
    accu_test = evaluate(test_dataloader)
    print('test accuracy {:8.3f}'.format(accu_test))
    
    # 测试随机新闻
    # 使用迄今为止最好的模型并测试高尔夫新闻。
    ag_news_label = {1: "World",
                     2: "Sports",
                     3: "Business",
                     4: "Sci/Tec"}
    
    
    def predict(text, text_pipeline):
        with torch.no_grad():
            text = torch.tensor(text_pipeline(text))
            output = model(text, torch.tensor([0]))
            return output.argmax(1).item() + 1
    
    
    ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
        enduring the season’s worst weather conditions on Sunday at The \
        Open on his way to a closing 75 at Royal Portrush, which \
        considering the wind and the rain was a respectable showing. \
        Thursday’s first round at the WGC-FedEx St. Jude Invitational \
        was another story. With temperatures in the mid-80s and hardly any \
        wind, the Spaniard was 13 strokes better in a flawless round. \
        Thanks to his best putting performance on the PGA Tour, Rahm \
        finished with an 8-under 62 for a three-stroke lead, which \
        was even more impressive considering he’d never played the \
        front nine at TPC Southwind."
    
    model = model.to("cpu")
    
    print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222

    运行结果

    | epoch   1 |   500/ 1782 batches | accuracy    0.689
    | epoch   1 |  1000/ 1782 batches | accuracy    0.856
    | epoch   1 |  1500/ 1782 batches | accuracy    0.873
    -----------------------------------------------------------
    | end of epoch   1 | time: 23.38s | valid accuracy    0.879 
    -----------------------------------------------------------
    | epoch   2 |   500/ 1782 batches | accuracy    0.896
    | epoch   2 |  1000/ 1782 batches | accuracy    0.904
    | epoch   2 |  1500/ 1782 batches | accuracy    0.900
    -----------------------------------------------------------
    | end of epoch   2 | time: 32.21s | valid accuracy    0.891 
    -----------------------------------------------------------
    | epoch   3 |   500/ 1782 batches | accuracy    0.915
    | epoch   3 |  1000/ 1782 batches | accuracy    0.916
    | epoch   3 |  1500/ 1782 batches | accuracy    0.915
    -----------------------------------------------------------
    | end of epoch   3 | time: 36.85s | valid accuracy    0.899 
    -----------------------------------------------------------
    | epoch   4 |   500/ 1782 batches | accuracy    0.925
    | epoch   4 |  1000/ 1782 batches | accuracy    0.925
    | epoch   4 |  1500/ 1782 batches | accuracy    0.922
    -----------------------------------------------------------
    | end of epoch   4 | time: 20.15s | valid accuracy    0.897 
    -----------------------------------------------------------
    | epoch   5 |   500/ 1782 batches | accuracy    0.937
    | epoch   5 |  1000/ 1782 batches | accuracy    0.938
    | epoch   5 |  1500/ 1782 batches | accuracy    0.936
    -----------------------------------------------------------
    | end of epoch   5 | time: 28.52s | valid accuracy    0.905 
    -----------------------------------------------------------
    | epoch   6 |   500/ 1782 batches | accuracy    0.939
    | epoch   6 |  1000/ 1782 batches | accuracy    0.938
    | epoch   6 |  1500/ 1782 batches | accuracy    0.941
    -----------------------------------------------------------
    | end of epoch   6 | time: 33.47s | valid accuracy    0.905 
    -----------------------------------------------------------
    | epoch   7 |   500/ 1782 batches | accuracy    0.940
    | epoch   7 |  1000/ 1782 batches | accuracy    0.941
    | epoch   7 |  1500/ 1782 batches | accuracy    0.939
    -----------------------------------------------------------
    | end of epoch   7 | time: 20.75s | valid accuracy    0.904 
    -----------------------------------------------------------
    | epoch   8 |   500/ 1782 batches | accuracy    0.941
    | epoch   8 |  1000/ 1782 batches | accuracy    0.941
    | epoch   8 |  1500/ 1782 batches | accuracy    0.940
    -----------------------------------------------------------
    | end of epoch   8 | time: 27.11s | valid accuracy    0.906 
    -----------------------------------------------------------
    | epoch   9 |   500/ 1782 batches | accuracy    0.942
    | epoch   9 |  1000/ 1782 batches | accuracy    0.942
    | epoch   9 |  1500/ 1782 batches | accuracy    0.942
    -----------------------------------------------------------
    | end of epoch   9 | time: 34.83s | valid accuracy    0.906 
    -----------------------------------------------------------
    | epoch  10 |   500/ 1782 batches | accuracy    0.942
    | epoch  10 |  1000/ 1782 batches | accuracy    0.942
    | epoch  10 |  1500/ 1782 batches | accuracy    0.940
    -----------------------------------------------------------
    | end of epoch  10 | time: 22.78s | valid accuracy    0.906 
    -----------------------------------------------------------
    Checking the results of test dataset.
    test accuracy    0.906
    This is a Sports news
    
    Process finished with exit code 0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    参考链接

    https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html

  • 相关阅读:
    docker服务无法关停的原因
    【无标题】
    Unity直接调用java代码(不打jar包)
    用了那么久的 Java For 循环,你知道哪种方式效率最高吗?
    超结MOS/低压MOS在微型逆变器上的应用-REASUNOS瑞森半导体
    刷题记录(NC16645 [NOIP2007]矩阵取数游戏,NC207781 迁徙过程中的河流,NC235953 最大m个子段和)
    goproxy实现windows的mysql的内网穿透
    go goroutine
    Spring中注入的使用
    day19学习总结
  • 原文地址:https://blog.csdn.net/weixin_43820352/article/details/125881673