• 【Datawhale】扩散模型学习笔记 第一次打卡


    扩散模型学习笔记

    1. 扩散模型库Diffusers

    1.1 安装

    由于diffusers库更新较快,所以建议时常upgrade

    # pip
    pip install --upgrade diffusers[torch]
    # conda
    conda install -c conda-forge diffusers
    
    • 1
    • 2
    • 3
    • 4

    1.2 使用

    from diffusers import DiffusionPipeline
    
    generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
    generator.to("cuda")
    image = generator("An image of a squirrel in Picasso style").images[0]
    image.save("image_of_squirrel_painting.png")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    2. 从零开始搭建扩散模型

    2.1 数据准备

    在这个示例中,我们将使用经典的MNIST数据集作为示范。MNIST数据集包含28x28像素的手写数字图像,每个像素值的范围从0到1。

    2.2 损坏过程

    我们希望能够控制输入数据的损坏程度,因此引入了一个参数 amount,该参数控制了噪声的程度。你可以使用以下方法来添加噪声:

    noise = torch.rand_like(x)
    noisy_x = (1 - amount) * x + amount * noise
    
    • 1
    • 2

    如果 amount 为0,则输入数据保持不变。如果 amount 为1,输入数据将变为纯粹的噪声。通过混合输入数据和噪声,我们可以确保输出数据的范围仍在0到1之间。

    2.3 模型构建

    我们将使用UNet模型来处理噪声图像。UNet是一种用于图像分割的常见架构,由压缩路径和扩展路径组成。在这个示范中,我们将构建一个简化版本的UNet,它接收单通道图像,并通过卷积层在下行路径(down_layers)和上行路径(up_layers)之间具有残差连接。我们将使用最大池化进行下采样和 nn.Upsample 进行上采样。

    2.4 模型训练

    在模型训练过程中,模型的任务是将损坏的输入 noisy_x 转换为对原始图像 x 的最佳估计。我们使用均方误差(MSE)来比较模型的预测与真实值,然后使用反向传播算法来更新模型的参数。

    2.5 采样

    如果模型在高噪声水平下的预测不够理想,可以进行采样以生成更好的图像。你可以从完全随机的噪声图像开始,然后逐渐接近模型的预测。这意味着你可以检查模型的预测结果,然后只向预测的方向移动一小步,比如向预测值移动20%。这将生成一个具有较少噪声的图像,其中可能包含一些关于输入数据的结构提示。将这个新图像输入模型,希望得到比第一个预测更好的结果。这个过程可以迭代多次,以逐渐减小噪声并生成更好的图像。

    这是一个简化的扩散模型搭建和训练的概述。你可以根据具体的问题和数据进行修改和优化,以获得更好的结果。希望这些步骤能帮助你理解如何搭建扩散模型并训练它。

    from diffusers import DDPMScheduler, UNet2DModel
    from PIL import Image
    import torch
    import numpy as np
    
    scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
    model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
    scheduler.set_timesteps(50)
    
    sample_size = model.config.sample_size
    noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
    input = noise
    
    for t in scheduler.timesteps:
        with torch.no_grad():
            noisy_residual = model(input, t).sample
            prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
            input = prev_noisy_sample
    
    image = (input / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
    image = Image.fromarray((image * 255).round().astype("uint8"))
    image
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    3. webui

    参考我的另一篇博客:https://blog.csdn.net/qq_44824148/article/details/130389357

  • 相关阅读:
    【ARC 自动引用计数 Objective-C语言】
    力扣第 387 场周赛第四题 将元素分配到两个数组中 II 二分查找,离散化,线段树
    想查询自己的职称情况,竟然没办法
    Revit中墙体的连接方式创建,快速改变墙连接状态
    Maven简介
    java 企业工程管理系统软件源码 自主研发 工程行业适用
    MATLAB与Python:优势与挑战
    基于java语言+springboot技术架构开发的 互联网智能3D导诊系统源码支持微信小程序、APP 医院AI智能导诊系统源码
    hive表中导入数据 多种方法详细说明
    Python-Flask入门,静态文件、页面跳转、错误信息、动态网页模板
  • 原文地址:https://blog.csdn.net/qq_44824148/article/details/133954423