【项目实践】多粒度网络MGN-ReID之跨境追踪实践

1、跨境追踪的简介

跨镜追踪(Person Re-Identification,简称ReID)技术是现在计算机视觉研究的热门方向,主要解决跨摄像头跨场景下行人的识别与检索。该技术能够根据行人的穿着、体态、发型等信息认知行人,与人脸识别结合能够适用于更多新的应用场景,将人工智能的认知水平提高到一个新阶段。

ReID是行人智能认知的其中一个研究方向,行人智能认知是人脸识别之后比较重要的一个研究方向,特别是计算机视觉行业里面,我们首先简单介绍ReID里比较热门的几项内容:








1、行人检测。任务是在给定图片中检测出行人位置的矩形框,这个跟之前的人脸检测、汽车检测比较类似,是较为基础的技术,也是很多行人技术的一个前置技术。








2、行人分割以及背景替换。行人分割比行人检测更精准,预估每个行人在图片里的像素概率,把这个像素分割出来是人或是背景,这时用到很多P图的场景,比如背景替换。举一个例子,一些网红在做直播时,可以把直播的背景替换成外景,让体验得到提升。








3、骨架关键点检测及姿态识别。一般识别出人体的几个关键点,比如头部、肩部、手掌、脚掌,用到行人姿态识别的任务中,这些技术可以应用在互动娱乐的场景中,类似于Kinnect人机互动方面,关键点检测技术是非常有价值的。








4、行人跟踪“ MOT ”的技术。主要是研究人在单个摄像头里行进的轨迹,每个人后面拖了一根线,这根线表示这个人在摄像头里行进的轨迹,和 ReID技术结合在一起可以形成跨镜头的细粒度的轨迹跟踪。








5、动作识别。动作识别是基于视频的内容理解做的,技术更加复杂一点,但是它与人类的认知更加接近,应用场景会更多,这个技术目前并不成熟。动作识别可以有非常多的应用,比如闯红灯,还有公共场合突发事件的智能认知,像偷窃、聚众斗殴,摄像头识别出这样的行为之后可以采取智能措施,比如自动报警,这有非常大的社会价值。








6、行人属性结构化。把行人的属性提炼出来,比如他衣服的颜色、裤子的类型、背包的颜色。








7、跨境追踪及行人再识别ReID技术。

2、ReID数据采集的特点

1、必须跨摄像头采集,给数据采集的研发团队和公司提出了比较高的要求;


2、公开数据集的数据规模非常小;


3、影响因素复杂多样;


4、数据一般都是视频的连续截图;


5、同一个人最好有多张全身照片;


6、互联网提供的照片基本无法用在ReID;


7、监控大规模搜集涉及到数据,涉及到用户的隐私问题。

3、ReID面临的困难

1、遮挡问题;


2、无正脸问题;


3、配饰问题;


4、姿态问题;


5、相机拍摄角度差异大;


6、监控图片模糊不清;


7、室内室外环境变化;


8、行人更换服装配饰,如之前穿了一件小外套,过一会儿把外套脱掉了;


9、季节性穿衣风格,冬季、夏季穿衣风格差别非常大,但从行人认知来讲他很可能是同一个人;


10、白天晚上的光线差异等。

4、ReID 实现思路与常见方案

ReID 从完整的过程分三个步骤:

第一步:从摄像头的监控视频获得原始图片;


第二步:基于这些原始图片把行人的位置检测出来;


第三步:基于检测出来的行人图片,用ReID技术计算图片的距离,但是我们现在做研究是基于常用数据集,把前面图像的采集以及行人检测的两个工作做过了,我们ReID的课题主要研究第三个阶段。

picture.image

5、多粒度网络(MGN)

picture.image

5.1、多粒度网络(MGN)设计思路

picture.image

设计思想是这样子的,一开始是全局特征,把整张图片输入,我们提取它的特征,用这种特征比较Loss或比较图片距离。但这时我们发现有一些不显著的细节,还有出现频率比较低的特征会被忽略。比如衣服上有个LOGO,但不是所有衣服上有LOGO,只有部分人衣服上有LOGO。全局特征会做特征均匀化,LOGO的细节被忽略掉了。

5.2、网络结构原理

picture.image

MGN的做法,主要是结合全局特征和局部特征,获得更加具有描述力的特征,整个网络框架主要分为3个支路,论文框架如下图所示:

picture.image

   该框架主干网络是ResNet-50,框架共有3个分支,一个是全局分支,两个是局部分支,在ResNet-50的res\_conv4\_2产生出3个分支。








(1)对于全局分支,在res\_conv5\_1上下采样,步长为2的卷积层。经过 global max-pooling,以及一个用来降维的1*1卷积层,特征从2048维降到256维。2048维特征去做softmax loss,256维特征去做triplet loss。








(2)对于两个局部分支,操作方法相同,只是第一个局部分支分成两块,第二个局部分支分成3块,在res\_conv5\_1为了保留原来的局部特征信息,并没有进行下采样,在水平方向均等的分成几块,后面的处理和全局分支一样,2048维依然去做softmax loss,256维特征去做triplet loss。








(3)在测试阶段,特征是所有256维特征的级联。








(4)损失函数主要是分类的softmax损失和用于度量学习的三元组损失。

算法流程:

首先,输入图的尺寸是384×128,我们用的是Resnet50,如果在不做任何改变的情况下,它的特征图谱输出尺寸,从右下角表格可以看到,global 这个地方就相当于对Resnet 50不做任何的改变,特征图谱输出是12×4。

下面有一个part-2 跟 part-3,这是在Res4\_1的位置,本来是有一个stride 等于2的下采样的操作,我们把2改成 1,没有下采样,这个地方的尺寸就不会缩小2,所以part-2跟part-3global大一倍的尺寸,它的尺寸是24×8。为什么要这么操作?因为我们会强制分配 part-2跟part-3去学习细粒度特征,如果把特征尺寸做得大一点,相当于信息更多一点,更利于网络学到更细节的特征。








网络结构从左到右,先是两个人的图片输入,这边有3个模块。3个模块的意思是表示3个分支共享网络,前三层这三个分支是共享的,到第四层时分成三个支路,第一个支路是global的分支,第二个是 part-2 的分支,第三个是part-3的分支。在global的地方有两块,右边这个方块比左边的方块大概缩小了一倍,因为做了个下采样,下面两个分支没有做下采样,所以第四层和第五层特征图是一样大小的。








接下来我们对part-2跟part-3做一个从上到下的纵向分割,part-2在第五层特征图谱分成两块,part-3对特征图谱从上到下分成三块。在分割完成后,我们做一个 pooling,相当于求一个最值,我们用的是Max-pooling,得到一个2048的向量,这个是长条形的、横向的、黄色区域这个地方。








但是part-2跟part-3的操作跟global是不一样的,part-2有两个pooling,第一个是蓝色的,两个part合在一起做一个global-pooling,我们强制part-2去学习细节的联合信息,part-2有两个细的长条形,就是我们刚才引导它去学细节型的信息。淡蓝色这个地方变成小方体一样,是做降维,从2048维做成256维,这个主要方便特征计算,因为可以降维,更快更有效。我们在测试的时候会在淡蓝色的地方,小方块从上到下应该是8个,我们把这8256维的特征串连一个2048的特征,用这个特征替代前面输入的图片。

5.3、损失函数的设计

picture.image

Loss说简单也简单,说复杂也复杂也复杂,为什么?简单是因为整个模型里只用了两种Loss,是机器学习里最常见的,一个是SoftmaxLoss一个是TripletLoss。复杂是因为分支比较多,包括global的,包括刚才local的分支,而且在各个分支的Loss设计上不是完全均等的。我们当时做了些实验和思考去想Loss的设计。现在这个方案,第一,从实践上证明是比较好的,第二,从理解上也是容易理解的。

Loss = Triplet Loss + Softmax Loss:

picture.image

首先,看一下global分支。上面第一块的Loss设计。这个地方对2048维做了SoftmaxLoss,对256维做了一个TripletLoss,这是对global信息通用的方法。下面两个部分global的处理方式也是一样的,都是对2048做一个SoftmaxLoss,对256维做一个TripletLoss。中间part-2地方有一个全局信息,有global特征,做 SoftmaxLoss+TripletLoss。








但是,下面两个Local特征看不到TripletLoss,只用了SoftmaxLoss,这个在文章里也有讨论,我们当时做了实验,如果对细节当和分支做TripletLoss,效果会变差。为什么效果会变差?








一张图片分成从上到下两部分的时候,最完美的情况当然是上面部分是上半身,下面部分是下半身,但是在实际的图片中,有可能整个人都在上半部分,下半部分全是背景,这种情况用上、下部分来区分,假设下半部分都是背景,把这个背景放到TripletLoss三元损失里去算这个Loss,就会使得这个模型学到莫名其妙的特征。








比如背景图是个树,另外一张图是某个人的下半身,比如一个女生的下半身是一个裙子,你让裙子跟另外图的树去算距离,无论是同类还是不同类,算出来的距离是没有任何物理意义或实际意义的。从模型的角度来讲,它属于污点数据,这个污点数据会引导整个模型崩溃掉或者学到错误信息,使得预测的时候引起错误。所以以后有同学想复现我们方法的时候要注意一下, Part-2、part-3Local特征千万不要加TripletLoss。








对全局特征使用tripletloss,对局部特征使用softmax。此设置的灵感源于粗略到精细的机制,将非归约特征作为粗略信息来学习分类,将归约特征作为具有学习度量的精细信息。








将局部与整体在同一个分支中学习得到的效果并不显著。可能是这个分支共享同一个网络结构。








网络体系结构中的三个分支实际上学习了表示具有不同偏好的信息。具有较大接收场和全局最大池的全局分支从行人图像中捕获了完整但粗糙的特征,并且由第2部分和第3部分分支学到的特征却没有大幅度的卷积和条带的分割部分,但这些特征往往是局部的但很好。具有更多分区的分支将学习行人图像的更好表示。学习不同偏好的分支机构可以将底层的区分信息协作地补充到公共骨干网部分,这就是在任何单个分支机构中提高性能的原因。

6、ReID 的技术展望

第一个,ReID的数据比较难获取,如果用应用无监督学习去提高ReID效果,可以降低数据采集的依赖性,这也是一个研究方向。右边可以看到,GAN生成数据来帮助ReID数据增强,现在也是一个很大的分支,但这只是应用无监督学习的一个方向。








第二个,基于视频的ReID。因为刚才几个数据集是基于对视频切好的单个图片而已,但实际应用场景中还存在着视频的连续帧,连续帧可以获取更多信息,跟实际应用更贴近,很多研究者也在进行基于视频 ReID 的技术。








第三个,跨模态的ReID。刚才讲到白天和黑夜的问题,黑夜时可以用红外的摄像头拍出来的跟白色采样摄像头做匹配。








第四个,跨场景的迁移学习。就是在一个场景比如market1501上学到的ReID,怎样在Duke数据集上提高效果。








第五个,应用系统设计。相当于设计一套系统让ReID这个技术实际应用到行人检索等技术上去。

7、基于MGN-ReID方法项目实践

本项目基于以上说明的论文进行实践,数据集时Market1501数据集。实践的Baseline网络为ResNet50模型。

7.1、数据集处理和Dateloder输出

picture.image

1、M数据集的处理:


              
from data.common import list_pictures
              
from torch.utils.data import dataset
              
from torchvision.datasets.folder import default_loader
              

              

              
class Market1501(dataset.Dataset):
              
    def __init__(self, args, transform, dtype):
              
        self.transform = transform
              
        self.loader = default_loader
              
        data_path = args.datadir
              
        if dtype == 'train':
              
            data_path += '/bounding_box_train'
              
        elif dtype == 'test':
              
            data_path += '/bounding_box_test'
              
        else:
              
            data_path += '/query'
              
        self.imgs = [path for path in list_pictures(data_path) if self.id(path) != -1]
              
        self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)}
              

              
    def __getitem__(self, index):
              
        path = self.imgs[index]
              
        target = self._id2label[self.id(path)]
              
        img = self.loader(path)
              
        if self.transform is not None:
              
            img = self.transform(img)
              
        return img, target
              

              
    def __len__(self):
              
        return len(self.imgs)
              

              
    @staticmethod
              
    def id(file_path):
              
        """
              
        :param file_path: unix style file path
              
        :return: person id
              
        """
              
        return int(file_path.split('/')[-1].split('_')[0])
              

              
    @staticmethod
              
    def camera(file_path):
              
        """
              
        :param file_path: unix style file path
              
        :return: camera id
              
        """
              
        return int(file_path.split('/')[-1].split('_')[1][1])
              

              
    @property
              
    def ids(self):
              
        """
              
        :return: person id list corresponding to dataset image paths
              
        """
              
        return [self.id(path) for path in self.imgs]
              

              
    @property
              
    def unique_ids(self):
              
        """
              
        :return: unique person ids in ascending order
              
        """
              
        return sorted(set(self.ids))
              

              
    @property
              
    def cameras(self):
              
        """
              
        :return: camera id list corresponding to dataset image paths
              
        """
              
        return [self.camera(path) for path in self.imgs]
          

2、DataLoder的制作


              
from importlib import import_module
              
from torchvision import transforms
              
from utils.random_erasing import RandomErasing
              
from data.sampler import RandomSampler
              
from torch.utils.data import dataloader
              

              
class Data:
              
    def __init__(self, args):
              

              
        train_list = [
              
            transforms.Resize((args.height, args.width), interpolation=3),
              
            transforms.RandomHorizontalFlip(),
              
            transforms.ToTensor(),
              
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
              
        ]
              
        if args.random_erasing:
              
            train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))
              
        train_transform = transforms.Compose(train_list)
              
        test_transform = transforms.Compose([
              
            transforms.Resize((args.height, args.width), interpolation=3),
              
            transforms.ToTensor(),
              
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
              
        ])
              
        if not args.test_only:
              
            module_train = import_module('data.' + args.data_train.lower())
              
            self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train')
              
            self.train_loader = dataloader.DataLoader(self.trainset,
              
                            sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),
              
                            #shuffle=True,
              
                            batch_size=args.batchid * args.batchimage,
              
                            num_workers=args.nThread)
              
        else:
              
            self.train_loader = None
              
        if args.data_test in ['Market1501']:
              
            module = import_module('data.' + args.data_train.lower())
              
            self.testset = getattr(module, args.data_test)(args, test_transform, 'test')
              
            self.queryset = getattr(module, args.data_test)(args, test_transform, 'query')
              
        else:
              
            raise Exception()
              
        self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread)
              
        self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)
          

7.2、 数据增强操作

picture.image

1、随机擦除操作——Random Erasing


              
from __future__ import absolute_import
              
from torchvision.transforms import *
              
from PIL import Image
              
import random
              
import math
              
import numpy as np
              
import torch
              

              
class RandomErasing(object):
              
    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
              
        self.probability = probability
              
        self.mean = mean
              
        self.sl = sl
              
        self.sh = sh
              
        self.r1 = r1
              
       
              
    def __call__(self, img):
              
        if random.uniform(0, 1) > self.probability:
              
            return img
              
        for attempt in range(100):
              
            area = img.size()[1] * img.size()[2]
              
            target_area = random.uniform(self.sl, self.sh) * area
              
            aspect_ratio = random.uniform(self.r1, 1/self.r1)
              
            h = int(round(math.sqrt(target_area * aspect_ratio)))
              
            w = int(round(math.sqrt(target_area / aspect_ratio)))
              

              
            if w < img.size()[2] and h < img.size()[1]:
              
                x1 = random.randint(0, img.size()[1] - h)
              
                y1 = random.randint(0, img.size()[2] - w)
              
                if img.size()[0] == 3:
              
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
              
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
              
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
              
                else:
              
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
              
                return img
              
        return img
          

2、其他torch自带数据处理操作


              
    def __init__(self, args):
              

              
        train_list = [
              
            transforms.Resize((args.height, args.width), interpolation=3),
              
            transforms.RandomHorizontalFlip(),
              
            transforms.ToTensor(),
              
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
              
        ]
              
        if args.random_erasing:
              
            train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))
              

              
        train_transform = transforms.Compose(train_list)
              

              
        test_transform = transforms.Compose([
              
            transforms.Resize((args.height, args.width), interpolation=3),
              
            transforms.ToTensor(),
              
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
              
        ])
          

7.3、TripletSemihard Loss与Triplet Loss

picture.image

picture.image


              
#!/usr/bin/env python
              
# -*- coding: utf-8 -*-
              
import torch
              
from torch import nn
              
from torch.nn import functional as F
              

              
class TripletSemihardLoss(nn.Module):
              
    """
              
    Shape:
              
        - Input: :math:`(N, C)` where `C = number of channels`
              
        - Target: :math:`(N)`
              
        - Output: scalar.
              
    """
              

              
    def __init__(self, device, margin=0, size_average=True):
              
        super(TripletSemihardLoss, self).__init__()
              
        self.margin = margin
              
        self.size_average = size_average
              
        self.device = device
              

              
    def forward(self, input, target):
              
        y_true = target.int().unsqueeze(-1)
              
        same_id = torch.eq(y_true, y_true.t()).type_as(input)
              

              
        pos_mask = same_id
              
        neg_mask = 1 - same_id
              

              
        def _mask_max(input_tensor, mask, axis=None, keepdims=False):
              
            input_tensor = input_tensor - 1e6 * (1 - mask)
              
            _max, _idx = torch.max(input_tensor, dim=axis, keepdim=keepdims)
              
            return _max, _idx
              

              
        def _mask_min(input_tensor, mask, axis=None, keepdims=False):
              
            input_tensor = input_tensor + 1e6 * (1 - mask)
              
            _min, _idx = torch.min(input_tensor, dim=axis, keepdim=keepdims)
              
            return _min, _idx
              

              
        # output[i, j] = || feature[i, :] - feature[j, :] ||_2
              
        dist_squared = torch.sum(input ** 2, dim=1, keepdim=True) + \
              
                       torch.sum(input.t() ** 2, dim=0, keepdim=True) - \
              
                       2.0 * torch.matmul(input, input.t())
              
        dist = dist_squared.clamp(min=1e-16).sqrt()
              

              
        pos_max, pos_idx = _mask_max(dist, pos_mask, axis=-1)
              
        neg_min, neg_idx = _mask_min(dist, neg_mask, axis=-1)
              

              
        # loss(x, y) = max(0, -y * (x1 - x2) + margin)
              
        y = torch.ones(same_id.size()[0]).to(self.device)
              
        return F.margin_ranking_loss(neg_min.float(),
              
                                     pos_max.float(),
              
                                     y,
              
                                     self.margin,
              
                                     self.size_average)
              

              
class TripletLoss(nn.Module):
              
    """Triplet loss with hard positive/negative mining.
              

              
    Reference:
              
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
              

              
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
              

              
    Args:
              
        margin (float): margin for triplet.
              
    """
              
    def __init__(self, margin=0.3, mutual_flag = False):
              
        super(TripletLoss, self).__init__()
              
        self.margin = margin
              
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)
              
        self.mutual = mutual_flag
              

              
    def forward(self, inputs, targets):
              
        """
              
        Args:
              
            inputs: feature matrix with shape (batch_size, feat_dim)
              
            targets: ground truth labels with shape (num_classes)
              
        """
              
        n = inputs.size(0)
              
        #inputs = 1. * inputs / (torch.norm(inputs, 2, dim=-1, keepdim=True).expand_as(inputs) + 1e-12)
              
        # Compute pairwise distance, replace by the official when merged
              
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
              
        dist = dist + dist.t()
              
        dist.addmm_(1, -2, inputs, inputs.t())
              
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
              
        # For each anchor, find the hardest positive and negative
              
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
              
        dist_ap, dist_an = [], []
              
        for i in range(n):
              
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
              
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
              
        dist_ap = torch.cat(dist_ap)
              
        dist_an = torch.cat(dist_an)
              
        # Compute ranking hinge loss
              
        y = torch.ones_like(dist_an)
              
        loss = self.ranking_loss(dist_an, dist_ap, y)
              
        if self.mutual:
              
            return loss, dist
              
        return loss
          

7.4、MGN网络模型

picture.image

picture.image


              
import copy
              

              
import torch
              
from torch import nn
              
import torch.nn.functional as F
              

              
from torchvision.models.resnet import resnet50, Bottleneck
              

              
def make_model(args):
              
    return MGN(args)
              

              
class MGN(nn.Module):
              
    def __init__(self, args):
              
        super(MGN, self).__init__()
              
        num_classes = args.num_classes
              

              
        resnet = resnet50(pretrained=True)
              

              
        self.backone = nn.Sequential(
              
            resnet.conv1,
              
            resnet.bn1,
              
            resnet.relu,
              
            resnet.maxpool,
              
            resnet.layer1,
              
            resnet.layer2,
              
            resnet.layer3[0],
              
        )
              

              
        res_conv4 = nn.Sequential(*resnet.layer3[1:])
              

              
        res_g_conv5 = resnet.layer4
              

              
        res_p_conv5 = nn.Sequential(
              
            Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))),
              
            Bottleneck(2048, 512),
              
            Bottleneck(2048, 512))
              
        res_p_conv5.load_state_dict(resnet.layer4.state_dict())
              

              
        self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
              
        self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
              
        self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
              
        
              
        if args.pool == 'max':
              
            pool2d = nn.MaxPool2d
              
        elif args.pool == 'avg':
              
            pool2d = nn.AvgPool2d
              
        else:
              
            raise Exception()
              

              
        self.maxpool_zg_p1 = pool2d(kernel_size=(12, 4))
              
        self.maxpool_zg_p2 = pool2d(kernel_size=(24, 8))
              
        self.maxpool_zg_p3 = pool2d(kernel_size=(24, 8))
              
        self.maxpool_zp2 = pool2d(kernel_size=(12, 8))
              
        self.maxpool_zp3 = pool2d(kernel_size=(8, 8))
              

              
        reduction = nn.Sequential(nn.Conv2d(2048, args.feats, 1, bias=False), nn.BatchNorm2d(args.feats), nn.ReLU())
              

              
        self._init_reduction(reduction)
              
        self.reduction_0 = copy.deepcopy(reduction)
              
        self.reduction_1 = copy.deepcopy(reduction)
              
        self.reduction_2 = copy.deepcopy(reduction)
              
        self.reduction_3 = copy.deepcopy(reduction)
              
        self.reduction_4 = copy.deepcopy(reduction)
              
        self.reduction_5 = copy.deepcopy(reduction)
              
        self.reduction_6 = copy.deepcopy(reduction)
              
        self.reduction_7 = copy.deepcopy(reduction)
              

              
        #self.fc_id_2048_0 = nn.Linear(2048, num_classes)
              
        self.fc_id_2048_0 = nn.Linear(args.feats, num_classes)
              
        self.fc_id_2048_1 = nn.Linear(args.feats, num_classes)
              
        self.fc_id_2048_2 = nn.Linear(args.feats, num_classes)
              

              
        self.fc_id_256_1_0 = nn.Linear(args.feats, num_classes)
              
        self.fc_id_256_1_1 = nn.Linear(args.feats, num_classes)
              
        self.fc_id_256_2_0 = nn.Linear(args.feats, num_classes)
              
        self.fc_id_256_2_1 = nn.Linear(args.feats, num_classes)
              
        self.fc_id_256_2_2 = nn.Linear(args.feats, num_classes)
              

              
        self._init_fc(self.fc_id_2048_0)
              
        self._init_fc(self.fc_id_2048_1)
              
        self._init_fc(self.fc_id_2048_2)
              

              
        self._init_fc(self.fc_id_256_1_0)
              
        self._init_fc(self.fc_id_256_1_1)
              
        self._init_fc(self.fc_id_256_2_0)
              
        self._init_fc(self.fc_id_256_2_1)
              
        self._init_fc(self.fc_id_256_2_2)
              

              
    @staticmethod
              
    def _init_reduction(reduction):
              
        # conv
              
        nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in')
              
        #nn.init.constant_(reduction[0].bias, 0.)
              

              
        # bn
              
        nn.init.normal_(reduction[1].weight, mean=1., std=0.02)
              
        nn.init.constant_(reduction[1].bias, 0.)
              

              
    @staticmethod
              
    def _init_fc(fc):
              
        nn.init.kaiming_normal_(fc.weight, mode='fan_out')
              
        #nn.init.normal_(fc.weight, std=0.001)
              
        nn.init.constant_(fc.bias, 0.)
              

              
    def forward(self, x):
              

              
        x = self.backone(x)
              

              
        p1 = self.p1(x)
              
        p2 = self.p2(x)
              
        p3 = self.p3(x)
              

              
        zg_p1 = self.maxpool_zg_p1(p1)
              
        zg_p2 = self.maxpool_zg_p2(p2)
              
        zg_p3 = self.maxpool_zg_p3(p3)
              

              
        zp2 = self.maxpool_zp2(p2)
              
        z0_p2 = zp2[:, :, 0:1, :]
              
        z1_p2 = zp2[:, :, 1:2, :]
              

              
        zp3 = self.maxpool_zp3(p3)
              
        z0_p3 = zp3[:, :, 0:1, :]
              
        z1_p3 = zp3[:, :, 1:2, :]
              
        z2_p3 = zp3[:, :, 2:3, :]
              
        
              
        fg_p1 = self.reduction_0(zg_p1).squeeze(dim=3).squeeze(dim=2)
              
        fg_p2 = self.reduction_1(zg_p2).squeeze(dim=3).squeeze(dim=2)
              
        fg_p3 = self.reduction_2(zg_p3).squeeze(dim=3).squeeze(dim=2)
              
        f0_p2 = self.reduction_3(z0_p2).squeeze(dim=3).squeeze(dim=2)
              
        f1_p2 = self.reduction_4(z1_p2).squeeze(dim=3).squeeze(dim=2)
              
        f0_p3 = self.reduction_5(z0_p3).squeeze(dim=3).squeeze(dim=2)
              
        f1_p3 = self.reduction_6(z1_p3).squeeze(dim=3).squeeze(dim=2)
              
        f2_p3 = self.reduction_7(z2_p3).squeeze(dim=3).squeeze(dim=2)
              

              
        '''
              
        l_p1 = self.fc_id_2048_0(zg_p1.squeeze(dim=3).squeeze(dim=2))
              
        l_p2 = self.fc_id_2048_1(zg_p2.squeeze(dim=3).squeeze(dim=2))
              
        l_p3 = self.fc_id_2048_2(zg_p3.squeeze(dim=3).squeeze(dim=2))
              
        '''
              
        l_p1 = self.fc_id_2048_0(fg_p1)
              
        l_p2 = self.fc_id_2048_1(fg_p2)
              
        l_p3 = self.fc_id_2048_2(fg_p3)
              
        
              
        l0_p2 = self.fc_id_256_1_0(f0_p2)
              
        l1_p2 = self.fc_id_256_1_1(f1_p2)
              
        l0_p3 = self.fc_id_256_2_0(f0_p3)
              
        l1_p3 = self.fc_id_256_2_1(f1_p3)
              
        l2_p3 = self.fc_id_256_2_2(f2_p3)
              
        predict = torch.cat([fg_p1, fg_p2, fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
              
        return predict, fg_p1, fg_p2, fg_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
          

主函数:


              
import data
              
import loss
              
import torch
              
import model
              
from trainer import Trainer
              

              
from option import args
              
import utils.utility as utility
              

              
ckpt = utility.checkpoint(args)
              

              
loader = data.Data(args)
              
model = model.Model(args, ckpt)
              
loss = loss.Loss(args, ckpt) if not args.test_only else None
              
trainer = Trainer(args, model, loss, loader, ckpt)
              

              
n = 0
              

              
if __name__ == '__main__':
              
  while not trainer.terminate():
              
    n += 1
              
    trainer.train()
              
    if args.test_every!=0 and n%args.test_every==0:
              
      trainer.test()
          

7.5、MGN-ReID模型的训练过程与测试结果

picture.image

picture.image

picture.image

picture.image

参考:

https://blog.csdn.net/guleileo/article/details/80837332

https://zhuanlan.zhihu.com/p/58238545

关注公众号,回复【MGN 】即可获得完整的项目代码以及文档说明。

一个人能走多远,在于与谁同行

希望您可以扫描下方二维码关注公众号与我们一同前行

非常期待您的打赏

声明:转载请说明出处

下方为小生公众号,还望包容接纳和关注,非常期待与您的美好相遇,让我们以梦为马,砥砺前行。

希望技术与灵魂可以一路同行

长按识别二维码关注一下

更多精彩内容可回复关键词

每篇文章的主题即可

picture.image

picture.image

0
0
0
0
评论
未登录
暂无评论