【文档智能】轻量级级表格识别算法模型-SLANet

大模型关系型数据库图像处理

前言

前面文档介绍了文档智能上多种思路及核心技术实现《【文档智能 & RAG】RAG增强之路:增强PDF解析并结构化技术路线方案及思路》,

picture.image

表格识别 作为文档智能的重要组成部分,面临着复杂结构和多样化格式的挑战。本文介绍的轻量级的表格识别算法模型——SLANet ,旨在在保证准确率的同时提升推理速度,方便生产落地。SLANet综合了PP-LCNet作为基础网络,采用CSP-PAN进行特征融合,并引入Attention机制以实现结构与位置信息的精确解码。通过这一框架,SLANet不仅有效减少了计算资源的消耗,还增强了模型在实际应用场景中的适用性与灵活性。

PP-LCNet

PP-LCNet是一种一种轻量级的CPU卷积神经网络,在图像分类的任务上表现良好,具有很高的落地意义。PP-LCNet的准确度显著优于具有相同推理时间的先前网络结构。

picture.image

模型细节

picture.image 网络架构

  • DepthSepConv块 : 使用MobileNetV1中的DepthSepConv作为基本块,该块没有快捷操作,减少了额外的拼接或逐元素相加操作,从而提高了推理速度。
  • 更好的激活函数 : 将BaseNet中的ReLU激活函数替换为H-Swish ,提升了网络性能,同时推理时间几乎没有变化。
  • SE模块的适当位置 : 在网络的尾部添加SE模块,以提高特征权重,从而实现更好的准确性和速度平衡。SE 模块是 SENet 提出的一种通道注意力机制,可以有效提升模型的精度。但是在 Intel CPU 端,该模块同样会带来较大的延时,如何平衡精度和速度是我们要解决的一个问题。虽然在 MobileNetV3 等基于 NAS 搜索的网络中对 SE 模块的位置进行了搜索,但是并没有得出一般的结论,我们通过实验发现,SE 模块越靠近网络的尾部对模型精度的提升越大。

picture.image PP-LCNet 中的 SE 模块的位置选用了表格中第三行的方案。

  • 更大的卷积核 : 在网络的尾部使用5x5卷积核替代3x3卷积核 ,以在低延迟和高准确性之间取得平衡。

picture.image实验表明,更大的卷积核放在网络的中后部即可达到放在所有位置的精度,与此同时,获得更快的推理速度。PP-LCNet 最终选用了表格中第三行的方案。

  • 更大的1x1卷积层 : 在全局平均池化(GAP)层后添加一个1280维的1x1卷积层,以增强模型的拟合能力,同时推理时间增加不多。在 GoogLeNet 之后,GAP(Global-Average-Pooling)后往往直接接分类层,但是在轻量级网络中,这样会导致 GAP 后提取的特征没有得到进一步的融合和加工。如果在此后使用一个更大的 1x1 卷积层(等同于 FC 层),GAP 后的特征便不会直接经过分类层,而是先进行了融合,并将融合的特征进行分类。这样可以在不影响模型推理速度的同时大大提升准确率。

picture.image

PP-LCNet系列效果

picture.image 图像分类

picture.image 与其他轻量级网络的性能对比

picture.image 目标检测

CSP-PAN

picture.image PP-PicoDet

PAN结构图 :相比于原始的FPN多了自下而上的特征金字塔。

picture.image PAN

CSPNet是一种处理的思想,可以和ResNet、ResNeXt和DenseNet结合。用 CSP 网络进行相邻 feature maps 之间的特征连接和融合。picture.image

CSP-PAN的引入主要有下面三个目的:

  1. 增强CNN的学习能力
  2. 减少计算量
  3. 降低内存占用

SLANet

picture.image SLANet结构

原理:

从上图看,SLANet主要由PP-LCNet + CSP-PAN + Attention组合得到。

  • PP-LCNet:CPU 友好型轻量级骨干网络
  • CSP-PAN:轻量级高低层特征融合模块
  • SLAHead:结构与位置信息对齐的特征解码模块,模型预测两个值,一是structure_pobs,表格结构的html代码 ,二是loc_preds,回归单元格四个点坐标

核心代码实现


        
          
import torch  
from torch import nn  
from torch.nn import functional as F  
  
  
class SLAHead(nn.Module):  
    def \_\_init\_\_(self, in\_channels=96, is\_train=False) -> None:  
        super().__init__()  
        self.max_text_length = 500  
        self.hidden_size = 256  
        self.loc_reg_num = 4  
        self.out_channels = 30  
        self.num_embeddings = self.out_channels  
        self.is_train = is_train  
  
        self.structure_attention_cell = AttentionGRUCell(in_channels,  
                                                         self.hidden_size,  
                                                         self.num_embeddings)  
  
        self.structure_generator = nn.Sequential(  
            nn.Linear(self.hidden_size, self.hidden_size),  
            nn.Linear(self.hidden_size, self.out_channels)  
        )  
  
        self.loc_generator = nn.Sequential(  
            nn.Linear(self.hidden_size, self.hidden_size),  
            nn.Linear(self.hidden_size, self.loc_reg_num)  
        )  
  
    def forward(self, fea):  
        batch_size = fea.shape[0]  
  
        # 1 x 96 x 16 x 16 → 1 x 96 x 256  
        fea = torch.reshape(fea, [fea.shape[0], fea.shape[1], -1])  
  
        # 1 x 256 x 96  
        fea = fea.permute(0, 2, 1)  
  
        # infer 1 x 501 x 30  
        structure_preds = torch.zeros(batch_size, self.max_text_length + 1,  
                                      self.num_embeddings)  
        # 1 x 501 x 4  
        loc_preds = torch.zeros(batch_size, self.max_text_length + 1,  
                                self.loc_reg_num)  
  
        hidden = torch.zeros(batch_size, self.hidden_size)  
        pre_chars = torch.zeros(batch_size, dtype=torch.int64)  
  
        loc_step, structure_step = None, None  
        for i in range(self.max_text_length + 1):  
            hidden, structure_step, loc_step = self._decode(pre_chars,  
                                                            fea, hidden)  
            pre_chars = structure_step.argmax(dim=1)  
            structure_preds[:, i, :] = structure_step  
            loc_preds[:, i, :] = loc_step  
  
        if not self.is_train:  
            structure_preds = F.softmax(structure_preds, dim=-1)  
        # structure\_preds: 1 x 501 x 30  
        # loc\_preds: 1 x 501 x 4  
        return structure_preds, loc_preds  
  
    def \_decode(self, pre\_chars, features, hidden):  
        emb_features = F.one_hot(pre_chars, num_classes=self.num_embeddings)  
        (output, hidden), alpha = self.structure_attention_cell(hidden,  
                                                                features,  
                                                                emb_features)  
        structure_step = self.structure_generator(output)  
        loc_step = self.loc_generator(output)  
        return hidden, structure_step, loc_step  
  
  
class AttentionGRUCell(nn.Module):  
    def \_\_init\_\_(self, input\_size, hidden\_size, num\_embedding) -> None:  
        super().__init__()  
  
        self.i2h = nn.Linear(input_size, hidden_size, bias=False)  
        self.h2h = nn.Linear(hidden_size, hidden_size)  
        self.score = nn.Linear(hidden_size, 1, bias=False)  
  
        self.gru = nn.GRU(input_size=input_size + num_embedding,  
                          hidden_size=hidden_size,)  
        self.hidden_size = hidden_size  
  
    def forward(self, prev\_hidden, batch\_H, char\_onehots):  
        # 这里实现参考论文https://arxiv.org/pdf/1704.03549.pdf  
        batch_H_proj = self.i2h(batch_H)  
        prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)  
  
        res = torch.add(batch_H_proj, prev_hidden_proj)  
        res = F.tanh(res)  
        e = self.score(res)  
  
        alpha = F.softmax(e, dim=1)  
        alpha = alpha.permute(0, 2, 1)  
        context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)  
        concat_context = torch.concat([context, char_onehots], 1)  
  
        cur_hidden = self.gru(concat_context, prev_hidden)  
        return cur_hidden, alpha  
  
  
class SLALoss(nn.Module):  
    def \_\_init\_\_(self) -> None:  
        super().__init__()  
        self.loss_func = nn.CrossEntropyLoss()  
        self.structure_weight = 1.0  
        self.loc_weight = 2.0  
        self.eps = 1e-12  
  
    def forward(self, pred):  
        structure_probs = pred[0]  
        structure_probs = structure_probs.permute(0, 2, 1)  
        # 1 x 30 x 501  
  
        # 1 x 501  
        structure_target = torch.empty(1, 501, dtype=torch.long).random_(30)  
        structure_loss = self.loss_func(structure_probs, structure_target)  
        structure_loss = structure_loss * self.structure_weight  
  
        loc_preds = pred[1]  # 1 x 501 x 4  
        loc_targets = torch.randn(1, 501, 4)  
        loc_target_mask = torch.randn(1, 501, 1)  
  
        loc_loss = F.smooth_l1_loss(loc_preds * loc_target_mask,  
                                    loc_targets * loc_target_mask,  
                                    reduction='mean')  
        loc_loss *= self.loc_weight  
        loc_loss = loc_loss / (loc_target_mask.sum() + self.eps)  
  
        total_loss = structure_loss + loc_loss  
        return total_loss  
  

      

参考文献

往期相关

【文档智能 & RAG】RAG增强之路:增强PDF解析并结构化技术路线方案及思路

【文档智能 & LLM】LayoutLLM:一种多模态文档布局模型和大模型结合的框架

【文档智能】再谈基于Transformer架构的文档智能理解方法论和相关数据集

【文档智能】多模态预训练模型及相关数据集汇总

【文档智能】:GeoLayoutLM:一种用于视觉信息提取(VIE)的多模态预训练模型

文档智能:ERNIE-Layout

【文档智能】符合人类阅读顺序的文档模型-LayoutReader及非官方权重开源

【文档智能】实践:基于Yolo三行代码极简的训练一个版式分析模型

【文档智能 & RAG】RAG增强之路-智能文档解析关键技术难点及PDF解析工具PDFlux

【文档智能】DLAFormer:端到端的解决版式分析、阅读顺序方法

【文档智能】LACE:帮你自动生成文档布局的方法浅尝

【文档智能 & RAG】浅看开源的同质化的文档解析框架-Docling

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论