数据集:手写字体识别MNIST
模型:LeNet
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
# LeNet 模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training) # 防止过拟合,实现时必须标明training的状态为self.training
x = self.fc2(x)
return F.log_softmax(x, dim=1)
test_loader = torch.utils.data.DataLoader(#导入数据
datasets.MNIST('data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=1, shuffle=True)
model = Net().to(device)
pretrained_model = "lenet_mnist_model.pth"
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
model.eval()
def fgsm_attack(image, epsilon, data_grad): # 此函数的功能是进行fgsm攻击,需要输入三个变量,干净的图片,扰动量和输入图片梯度
sign_data_grad = data_grad.sign() # 梯度符号
# print(sign_data_grad)
perturbed_image = image+epsilon*sign_data_grad # 公式
perturbed_image = torch.clamp(perturbed_image, 0, 1) # 为了保持图像的原始范围,将受干扰的图像裁剪到一定的范围【0,1】
return perturbed_image
epsilons = [0, .05, .1, .15, .2, .25, .3]
def test(model, device, test_loader, epsilon):
correct = 0
adv_examples = []
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data.requires_grad = True
output = model(data)
init_pred = output.max(1, keepdim=True)[1] # 选取最大的类别概率
loss = F.nll_loss(output, target)
model.zero_grad()
loss.backward()
data_grad = data.grad.data
perturbed_data = fgsm_attack(data, epsilon, data_grad)
output = model(perturbed_data)
final_pred = output.max(1, keepdim=True)[1]
if final_pred.item() == target.item(): # 判断类别是否相等
correct += 1
if len(adv_examples) < 6:
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
final_acc = correct / float(len(test_loader)) # 算正确率
print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
return final_acc, adv_examples
accuracies = []
examples = []
# Run test for each epsilon
for eps in epsilons:
acc, ex = test(model, device, test_loader, eps)
accuracies.append(acc)
examples.append(ex)
plt.plot(epsilons, accuracies)
plt.show()
cnt = 0
plt.figure(figsize=(8, 10))
for i in range(len(epsilons)):
for j in range(len(examples[i])):
cnt += 1
plt.subplot(len(epsilons), len(examples[0]), cnt)
plt.xticks([], [])
plt.yticks([], [])
if j == 0:
plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
orig, adv, ex = examples[i][j]
plt.title("{} -> {}".format(orig, adv))
plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106

