本文主要介绍 torchtext.transforms 的一些常用API。
VocabTransform 主要用于将输入的词元映射成它们在词表中的索引。输入的词元可以是 List[str] 或 List[List[str]] 型。
请看下面的例子
from collections import OrderedDict
from torchtext.transforms import VocabTransform
ordered_dict = OrderedDict([('a', 1), ('b', 1), ('c', 1), ('d', 1)])
vocab_transform = VocabTransform(vocab(ordered_dict))
print(vocab_transform(['b', 'd']))
# [1, 3]
print(vocab_transform([['b', 'd'], ['a', 'c', 'b']]))
# [[1, 3], [0, 2, 1]]
Truncate 用于将输入进行截断,输入的类型可以是 List[str] 或 List[List[str]]。
from torchtext.transforms import Truncate
sentences = [
['a', 'b', 'c', 'd', 'e'],
['a', 'b', 'c'],
['a', 'b', 'c', 'd'],
]
print(Truncate(4)(sentences)) # 每个句子长度不超过4
# [['a', 'b', 'c', 'd'], ['a', 'b', 'c'], ['a', 'b', 'c', 'd']]
print(Truncate(3)(sentences)) # 每个句子长度不超过3
# [['a', 'b', 'c'], ['a', 'b', 'c'], ['a', 'b', 'c']]
AddToken 用于在输入序列的起始/末尾处添加词元,例如
from torchtext.transforms import AddToken
sentences = [
['a', 'b', 'c', 'd', 'e'],
['a', 'b', 'c'],
['a', 'b', 'c', 'd'],
]
print(AddToken('' , begin=True)(sentences))
# [['', 'a', 'b', 'c', 'd', 'e'], ['', 'a', 'b', 'c'], ['', 'a', 'b', 'c', 'd']]
print(AddToken('' , begin=False)(sentences))
# [['a', 'b', 'c', 'd', 'e', ''], ['a', 'b', 'c', ''], ['a', 'b', 'c', 'd', '']]
类似于 torch.nn.Sequential,这里不作过多介绍。
ToTensor 通常将输入的一系列句子填充到最长句子的长度。
from torchtext.transforms import ToTensor
sentences = [
[1, 2, 3, 4, 5],
[1, 2, 3],
[1, 2, 3, 4],
]
t = ToTensor(padding_value=0)
print(t(sentences))
# tensor([[1, 2, 3, 4, 5],
# [1, 2, 3, 0, 0],
# [1, 2, 3, 4, 0]])
需要注意的是,ToTensor 的填充长度无法人为自定义,它是由输入句子的最大长度来决定的。如果需要自定义填充长度,可以使用下面的 PadTransform。
我们仍使用上面的例子,假设每个句子都要填充到长度8,则
import torchtext.transforms as T
sentences = [
[1, 2, 3, 4, 5],
[1, 2, 3],
[1, 2, 3, 4],
]
text_transform = T.Sequential(
T.ToTensor(padding_value=0),
T.PadTransform(max_length=8, pad_value=0)
)
print(text_transform(sentences))
# tensor([[1, 2, 3, 4, 5, 0, 0, 0],
# [1, 2, 3, 0, 0, 0, 0, 0],
# [1, 2, 3, 4, 0, 0, 0, 0]])