• 【机器学习】Tensorflow.js:我在浏览器中实现了迁移学习


    ⭐️ 本文首发自 前端修罗场(点击加入),是一个由资深开发者独立运行的专业技术社区,我专注 Web 技术、答疑解惑、面试辅导以及职业发展现在加入,私聊我即可获取一次免费的模拟面试机会,帮你评估知识点的掌握程度,获得更全面的学习指导意见,交个朋友,不走弯路,少吃亏!

    最近公司在研发分布式高性能的云计算平台,其中涉及到了 AI 方面的处理。所以我也在自学 Machine Learning。不过在 AI 方面的知识却是需要花功夫花时间学习的。在学习的过程中我发现了一个不错的学习教程(点击查看),推荐给大伙😋,我个人觉得这个教程讲解的通俗易懂,帮我省去了自己苦苦专研的时间,能够得到快速的进步。下一阶段,我也会在这里和大家分享我的学习笔记。


    迁移学习是将预训练模型与自定义训练数据相结合的能力。 这意味着你可以利用模型的功能并添加自己的样本,而无需从头开始创建所有内容。

    例如,一种算法已经用数千张图像进行了训练以创建图像分类模型,而不是创建自己的图像分类模型,迁移学习允许你将新的自定义图像样本与预先训练的模型相结合以创建新的图像分类器。 这个特性使得拥有一个更加定制化的分类器变得非常快速和容易。

    为了提供代码中的示例,让我们重新利用之前的示例并对其进行修改,以便我们可以对新图像进行分类。
    请添加图片描述
    以下是此设置最重要部分的一些代码示例,但如果你需要查看整个代码,可以在本文的最后找到它。

    我们仍然需要从导入 Tensorflow.js 和 MobileNet 开始,但是这次我们还需要添加一个 KNN(k-nearest neighbor)分类器:

    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <!-- Load MobileNet -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
    <!-- Load KNN Classifier -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    我们需要分类器的原因是(不仅仅是使用 MobileNet 模块)我们正在添加以前从未见过的自定义样本,因此 KNN 分类器将允许我们将所有内容组合在一起并对组合的数据进行预测

    然后,我们可以用视频标签替换猫的图像,以使用来自摄像头的图像。

    <video autoplay id="webcam" width="227" height="227"></video>
    
    • 1

    最后,我们需要在页面上添加一些按钮,我们将用作标签来记录一些视频样本并开始预测。

    <section>
      <button class="button">Left</button>
    
      <button class="button">Right</button>
    
      <button class="test-predictions">Test</button>
    </section>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    现在,让我们转到 JavaScript 文件,我们将从设置几个重要变量开始:

    //要分类的数量
    const NUM_CLASSES = 2;
    // 分类标签
    const classes = ["Left", "Right"];
    // Webcam Image size. Must be 227.
    const IMAGE_SIZE = 227;
    // KNN 的 K 值
    const TOPK = 10;
    
    const video = document.getElementById("webcam");
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    在这个特定的示例中,我们希望能够在我们的头部向左或向右倾斜之间对网络摄像头输入进行分类,因此我们需要两个标记为 leftright 的类。

    设置为 227 的图像大小是视频元素的大小(以像素为单位)。 根据 Tensorflow.js 示例,该值需要设置为 227 以匹配用于训练 MobileNet 模型的数据格式。 为了能够对我们的新数据进行分类,后者需要适应相同的格式。

    如果你真的需要它更大,这是可能的,但你必须在将数据提供给 KNN 分类器之前转换和调整数据大小。

    然后,我们将 K 的值设置为 10。KNN 算法中的 K 值很重要,因为它代表了我们在确定新输入的类别时考虑的实例数。

    在这种情况下,10 意味着,在预测一些新数据的标签时,我们将查看训练数据中的 10 个最近邻,以确定如何对新输入进行分类。

    最后,我们得到了视频元素。

    对于逻辑,让我们从加载模型和分类器开始:

    async load() {
        const knn = knnClassifier.create();
        const mobilenetModule = await mobilenet.load();
        console.log("model loaded");
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5

    然后,让我们访问视频源:

    navigator.mediaDevices
      .getUserMedia({ video: true, audio: false })
      .then(stream => {
        video.srcObject = stream;
        video.width = IMAGE_SIZE;
        video.height = IMAGE_SIZE;
      });
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    接下来,让我们设置一些按钮事件来记录我们的示例数据:

    setupButtonEvents() {
        for (let i = 0; i < NUM_CLASSES; i++) {
          let button = document.getElementsByClassName("button")[i];
    
          button.onmousedown = () => {
            this.training = i;
            this.recordSamples = true;
          };
          button.onmouseup = () => (this.training = -1);
        }
      }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    让我们编写我们的函数,它将获取网络摄像头图像样本,重新格式化它们并将它们与 MobileNet 模块结合起来:

    // 从视频元素中获取图像数据
    const image = tf.browser.fromPixels(video);
    
    let logits;
    // 'conv_preds' 是 MobileNet 的 logits 激活。
    const infer = () => this.mobilenetModule.infer(image, "conv_preds");
    
    // 如果按住其中一个按钮,则进行训练
    if (this.training != -1) {
      logits = infer();
    
      // 将当前图像添加到分类器
      this.knn.addExample(logits, this.training);
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    最后,一旦我们收集了一些网络摄像头图像,我们就可以使用以下代码测试我们的预测:

    logits = infer();
    const res = await this.knn.predictClass(logits, TOPK);
    const prediction = classes[res.classIndex];
    
    • 1
    • 2
    • 3

    最后,您可以处理我们不再需要的网络摄像头数据:

    // 完成后处理图像
    image.dispose();
    if (logits != null) {
      logits.dispose();
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5

    完整代码下载

    这里我们提供完整代码的下载,你可以通过 CSDN 专属的链接获取资源


    如果你觉得这篇文章还不错,请点击下方小红心 👍🏻 ❤️,鼓励一下!我会继续为你带来优质的内容~我是前端修罗场,一个独立运行的专业技术社区,感谢你关注我!

  • 相关阅读:
    LVS负载均衡集群
    小程序笔记2
    【Python百日进阶-Web开发-Peewee】Day278 - SQLite 扩展(三)
    Servlet(一):实现一个Servlet程序和使用Smart Tomcat部署Servlet程序
    【华为机试真题 JAVA】一种字符串压缩表示的解压-100
    【Linux操作系统教程】用户管理与权限管理你真的懂了吗(三)
    metrics.accuracy_score 和metrics.roc_auc_score的计算
    协程的创建
    Linux:无法接收组播数据
    【自用】VUE 获取登录用户名 显示在其他页面上
  • 原文地址:https://blog.csdn.net/ImagineCode/article/details/125631718