当一个前端学了很久的神经网络...

前端

前言

picture.image

最近在学习神经网络相关的知识,并做了一个简单的猫狗识别的神经网络,结果如图。

虽然有点绷不住,但这其实是少数情况,整体的猫狗分类正确率已经来到 90% 了。

本篇文章是给大家介绍一下我是如何利用前端如何做神经网络-猫狗训练的。

步骤概览

还是掏出之前那个步骤流程,我们只需要按照这个步骤就可以训练出自己的神经网络

  1. 处理数据集

  2. 定义模型

    1. 神经网络层数
    2. 每层节点数
    3. 每层的激活函数
  3. 编译模型

  4. 训练模型

  5. 使用模型

最终的页面是这样的

picture.image

处理数据集

  1. 首先得找到数据集,本次使用的是这个 www.kaggle.com/datasets/li… 2000 个猫图,2000 个狗图,足够我们使用(其实我只用了其中 500 个,电脑跑太慢了)
  2. 由于这些图片大小不一致,首先我们需要将其处理为大小一致。这一步可以使用 canvas 来做,我统一处理成了 128 * 128 像素大小。

const preprocessImage = (img: HTMLImageElement): HTMLCanvasElement => {
  const canvas = document.createElement("canvas");
  canvas.width = 128;
  canvas.height = 128;

  const ctx = canvas.getContext("2d");
  if (!ctx) return canvas;

  // 保持比例缩放并居中裁剪
  const ratio = Math.min(128 / img.width, 128 / img.height);
  const newWidth = img.width * ratio;
  const newHeight = img.height * ratio;

  ctx.drawImage(
    img,
    (128 - newWidth) / 2,
    (128 - newHeight) / 2,
    newWidth,
    newHeight
  );
  return canvas;
};

这里可能就有同学要问了:imooimoo,你怎么返回了 canvas,不应该返回它 getImageData 的数据点吗。我一开始也是这样想的,结果 ai 告诉我,tfjs 是可以直接读取 canvas 的,牛。

tf.browser.fromPixels() // 可以接受 canvas 作为参数

  1. 将其处理为 tfjs 可用的对象

  // 加载单个图片并处理为 tfjs 对应格式
  const loadImage = async (category: "cat" | "dog", index: number): Promise<ImageData> => {
    const imgPath = `src/pages/cat-dog/image/${category}/${category}.${index}.jpg`;
    const img = new Image();
    img.src = imgPath;

    await new Promise((resolve, reject) => {
      img.onload = () => resolve(img);
      img.onerror = reject;
    });

    return {
      path: imgPath,
      element: img,
      tensor: tf.browser.fromPixels(preprocessImage(img)).div(255), // 归一化
      label: category === "cat" ? 0 : 1,
    };
  };
  
  // 加载全部图片
  const loadDataset = async () => {
    const images: ImageData[] = [];
    
    for (const category of ["cat", "dog"]) {
      for (let i = 1000; i < 1500; i++) { // 这里只使用了后 500 张,电脑跑不动
        try {
          const imgData = await loadImage(category, i);
          images.push(imgData);
        } catch (error) {
          console.error(`加载${category === "cat" ? "猫" : "狗"}图片失败: ${category}.${i}.jpg`, error);
        }
      }
    }
    return images;
  };

<需要看新机会的>
顺便吆喝一句,技术大厂,待遇给的还可以,就是偶尔有加班(放心,加班有加班费)
前、后端/测试,多地有位置,感兴趣的可以来共事~

定义模型 & 编译模型

由于我们的主题是图片识别,图片识别一般会需要用到几个常用的层

  1. 最大池化层:用于缩小图片,节约算力。但也不能太小,否则很糊会提取不出东西。
  2. 卷积层:用于提取图片特征
  3. 展平层:将多维的结果转为一维

有同学可能想问为什么会有多维。首先是三维的颜色,输入就有三维;卷积层的每一个卷积核,都会使结果增加维度,所以后续的维度会很高。这张图比较形象,最后就只会剩下一维,方便机器进行计算。

picture.image


  // 创建卷积神经网络模型
  const createCNNModel = () => {
    const model = tf.sequential({
      layers: [
        // 最大池化层:降低特征图尺寸,增强特征鲁棒性
        tf.layers.maxPooling2d({
          inputShape: [128, 128, 3], // 输入形状 [高度, 宽度, 通道数]
          poolSize: 2, // 池化窗口尺寸 2x2
          strides: 2, // 滑动步长:每次移动 n 像素,使输出尺寸减小到原先的 1/n
        }),

        // 卷积层:用于提取图像局部特征
        tf.layers.conv2d({
          filters: 32, // 卷积核数量,决定输出特征图的深度
          kernelSize: 3, // 卷积核尺寸 3x3
          activation: "relu", // 激活函数:修正线性单元,解决梯度消失问题
          padding: "same", // 边缘填充方式:保持输出尺寸与输入相同
        }),
        
        // 展平层:将多维特征图转换为一维向量
        tf.layers.flatten(),

        // 全连接层(输出层):进行最终分类
        tf.layers.dense({
          units: 2, // 输出单元数:对应猫/狗两个类别
          activation: "softmax", // 激活函数:将输出转换为概率分布
        }),
      ],
    });

    // 编译模型,参数基本写死这几个就对了
    model.compile({
      optimizer: "adam",
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });

    console.log("模型架构:");
    model.summary();

    return model;
  };

这里实际上只需要额外注意两点:

  1. 卷积层的激活函数 activation: "relu",这里理论上是个非线性激活函数就行。但是我个人更喜欢 relu,函数好记,速度和效果又不错。
  2. 输出层的激活函数 activation: "softmax",由于我们做的是分类,最后必须是这个。

训练模型

训练模型可以说的就不多了,也就是提供一下你的模型、训练集就可以开始了。这里有俩参数可以注意下

  • epochs: 训练轮次
  • validationSplit: 验证集比例,用于测算训练好的模型准确程度并优化下一轮的模型

  // 训练模型
  const trainModel = async (
    model: tf.Sequential,
    xData: tf.Tensor4D,
    yData: tf.Tensor2D
  ) => {
    setTrainingLogs([]); // 清空之前的训练日志

    await model.fit(xData, yData, {
      epochs: 10, // 训练轮数
      batchSize: 4,
      validationSplit: 0.4,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          if (!logs) return;
          setTrainingLogs((prev) => [
            ...prev,
            {
              epoch: epoch + 1,
              loss: Number(logs.loss.toFixed(4)),
              accuracy: Number(logs.acc.toFixed(4)),
            },
          ]);
        },
      },
    });
  };

整体页面

基本就是这样了,稍微写一下页面,基本就完工了

picture.image

picture.image

总结

别慌,神经网络没那么可怕,核心步骤就那几步,冲冲冲。

源码:github.com/imoo666/neu…

——转载自作者:imoo

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 EB 级湖仓一体分析服务 LAS 的实践与展望
火山引擎湖仓一体分析服务 LAS 是面向湖仓一体架构的 Serverless 数据处理分析服务,提供一站式的海量数据存储计算和交互分析能力,完全兼容 Spark、Presto、Flink 生态,在字节跳动内部有着广泛的应用。本次演讲将介绍 LAS 在字节跳动内部的发展历程和大规模应用实践,同时介绍 LAS 在火山引擎上的发展规划。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论