论文:https://arxiv.org/pdf/1606.02147.pdf
ENet算法 的目标是语义分割的快速实现,在考虑分割精确度的同时,还要考虑分割的实时性能。语义分割的基本网络结构为编码-解码结构,即通过下采样实现像素级的分类、上采样实现图像目标的定位。要想提高算法的实时性,必须在上采样阶段减少计算量,提高采样速度。
在下采样层中的滤波器操作会产生一个更大的图片感受野,允许网络收集更多的目标上下文信息,然而在图像语义分割中,对图像进行下采样操作主要有 两个缺点 :(1)特征图分辨率的降低暗示着空间信息的损失;(2)对于整张图像的像素分割,要求输入和输出具有相同的分辨率,在进行下采样操作之后同样需要上采样与之匹配。这增大了模型尺寸,同时增加了计算量。
针对第一个问题,FCN算法融合不同编码层产生的特征图,但这会增加网络参数量,不利于语义分割的实时性,针对第二个问题,SegNet算法通过在最大池化层中保存特征元素索引,在解码器中进行搜索,使得解码特征产生稀疏上采样图。但是下采样仍然损害了目标空间的信息精度。
ENet在不同数据集上的效果
ENet算法选择在初始化模块中设计一个池化操作与一个步长为2的卷积操作并行,并合并结果特征图,在网络早期使用了更小尺寸和更少数量的特征图,大大减少了网络参数,提高了网络的运行速度,ENet算法在下采样过程使用了扩张卷积,可以很好的平衡图像分辨率和图像感受野,实现在不降低特征图分辨率的同时扩大图像目标的感受野。不同于SegNet算法是一个非常对称的结构,编码器和解码器结构大小相同,ENet算法包含了一个大的编码器和一个小的解码器,有助于降低ENet网络的参数量。
ENet 网络结构
import torch.nn as nn
import torch
class InitialBlock(nn.Module):
"""The initial block is composed of two branches:
1. a main branch which performs a regular convolution with stride 2;
2. an extension branch which performs max-pooling.
Doing both operations in parallel and concatenating their results
allows for efficient downsampling and expansion. The main branch
outputs 13 feature maps while the extension branch outputs 3, for a
total of 16 feature maps after concatenation.
Keyword arguments:
- in\_channels (int): the number of input channels.
- out\_channels (int): the number output channels.
- kernel\_size (int, optional): the kernel size of the filters used in
the convolution layer. Default: 3.
- padding (int, optional): zero-padding added to both sides of the
input. Default: 0.
- bias (bool, optional): Adds a learnable bias to the output if
``True``. Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
bias=False,
relu=True):
super().__init__()
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - As stated above the number of output channels for this
# branch is the total minus 3, since the remaining channels come from
# the extension branch
self.main_branch = nn.Conv2d(
in_channels,
out_channels - 3,
kernel_size=3,
stride=2,
padding=1,
bias=bias)
# Extension branch
self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1)
# Initialize batch normalization to be used after concatenation
self.batch_norm = nn.BatchNorm2d(out_channels)
# PReLU layer to apply after concatenating the branches
self.out_activation = activation()
def forward(self, x):
main = self.main_branch(x)
ext = self.ext_branch(x)
# Concatenate branches
out = torch.cat((main, ext), 1)
# Apply batch normalization
out = self.batch_norm(out)
return self.out_activation(out)
class RegularBottleneck(nn.Module):
"""Regular bottlenecks are the main building block of ENet.
Main branch:
1. Shortcut connection.
Extension branch:
1. 1x1 convolution which decreases the number of channels by
``internal\_ratio``, also called a projection;
2. regular, dilated or asymmetric convolution;
3. 1x1 convolution which increases the number of channels back to
``channels``, also called an expansion;
4. dropout as a regularizer.
Keyword arguments:
- channels (int): the number of input and output channels.
- internal\_ratio (int, optional): a scale factor applied to
``channels`` used to compute the number of
channels after the projection. eg. given ``channels`` equal to 128 and
internal\_ratio equal to 2 the number of channels after the projection
is 64. Default: 4.
- kernel\_size (int, optional): the kernel size of the filters used in
the convolution layer described above in item 2 of the extension
branch. Default: 3.
- padding (int, optional): zero-padding added to both sides of the
input. Default: 0.
- dilation (int, optional): spacing between kernel elements for the
convolution described in item 2 of the extension branch. Default: 1.
asymmetric (bool, optional): flags if the convolution described in
item 2 of the extension branch is asymmetric or not. Default: False.
- dropout\_prob (float, optional): probability of an element to be
zeroed. Default: 0 (no dropout).
- bias (bool, optional): Adds a learnable bias to the output if
``True``. Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
channels,
internal_ratio=4,
kernel_size=3,
padding=0,
dilation=1,
asymmetric=False,
dropout_prob=0,
bias=False,
relu=True):
super().__init__()
# Check in the internal\_scale parameter is within the expected range
# [1, channels]
if internal_ratio <= 1 or internal_ratio > channels:
raise RuntimeError("Value out of range. Expected value in the "
"interval [1, {0}], got internal\_scale={1}."
.format(channels, internal_ratio))
internal_channels = channels // internal_ratio
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - shortcut connection
# Extension branch - 1x1 convolution, followed by a regular, dilated or
# asymmetric convolution, followed by another 1x1 convolution, and,
# finally, a regularizer (spatial dropout). Number of channels is constant.
# 1x1 projection convolution
self.ext_conv1 = nn.Sequential(
nn.Conv2d(
channels,
internal_channels,
kernel_size=1,
stride=1,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# If the convolution is asymmetric we split the main convolution in
# two. Eg. for a 5x5 asymmetric convolution we have two convolution:
# the first is 5x1 and the second is 1x5.
if asymmetric:
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=(kernel_size, 1),
stride=1,
padding=(padding, 0),
dilation=dilation,
bias=bias), nn.BatchNorm2d(internal_channels), activation(),
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=(1, kernel_size),
stride=1,
padding=(0, padding),
dilation=dilation,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
else:
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# 1x1 expansion convolution
self.ext_conv3 = nn.Sequential(
nn.Conv2d(
internal_channels,
channels,
kernel_size=1,
stride=1,
bias=bias), nn.BatchNorm2d(channels), activation())
self.ext_regul = nn.Dropout2d(p=dropout_prob)
# PReLU layer to apply after adding the branches
self.out_activation = activation()
def forward(self, x):
# Main branch shortcut
main = x
# Extension branch
ext = self.ext_conv1(x)
ext = self.ext_conv2(ext)
ext = self.ext_conv3(ext)
ext = self.ext_regul(ext)
# Add main and extension branches
out = main + ext
return self.out_activation(out)
class DownsamplingBottleneck(nn.Module):
"""Downsampling bottlenecks further downsample the feature map size.
Main branch:
1. max pooling with stride 2; indices are saved to be used for
unpooling later.
Extension branch:
1. 2x2 convolution with stride 2 that decreases the number of channels
by ``internal\_ratio``, also called a projection;
2. regular convolution (by default, 3x3);
3. 1x1 convolution which increases the number of channels to
``out\_channels``, also called an expansion;
4. dropout as a regularizer.
Keyword arguments:
- in\_channels (int): the number of input channels.
- out\_channels (int): the number of output channels.
- internal\_ratio (int, optional): a scale factor applied to ``channels``
used to compute the number of channels after the projection. eg. given
``channels`` equal to 128 and internal\_ratio equal to 2 the number of
channels after the projection is 64. Default: 4.
- return\_indices (bool, optional): if ``True``, will return the max
indices along with the outputs. Useful when unpooling later.
- dropout\_prob (float, optional): probability of an element to be
zeroed. Default: 0 (no dropout).
- bias (bool, optional): Adds a learnable bias to the output if
``True``. Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
internal_ratio=4,
return_indices=False,
dropout_prob=0,
bias=False,
relu=True):
super().__init__()
# Store parameters that are needed later
self.return_indices = return_indices
# Check in the internal\_scale parameter is within the expected range
# [1, channels]
if internal_ratio <= 1 or internal_ratio > in_channels:
raise RuntimeError("Value out of range. Expected value in the "
"interval [1, {0}], got internal\_scale={1}. "
.format(in_channels, internal_ratio))
internal_channels = in_channels // internal_ratio
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - max pooling followed by feature map (channels) padding
self.main_max1 = nn.MaxPool2d(
2,
stride=2,
return_indices=return_indices)
# Extension branch - 2x2 convolution, followed by a regular, dilated or
# asymmetric convolution, followed by another 1x1 convolution. Number
# of channels is doubled.
# 2x2 projection convolution with stride 2
self.ext_conv1 = nn.Sequential(
nn.Conv2d(
in_channels,
internal_channels,
kernel_size=2,
stride=2,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# Convolution
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# 1x1 expansion convolution
self.ext_conv3 = nn.Sequential(
nn.Conv2d(
internal_channels,
out_channels,
kernel_size=1,
stride=1,
bias=bias), nn.BatchNorm2d(out_channels), activation())
self.ext_regul = nn.Dropout2d(p=dropout_prob)
# PReLU layer to apply after concatenating the branches
self.out_activation = activation()
def forward(self, x):
# Main branch shortcut
if self.return_indices:
main, max_indices = self.main_max1(x)
else:
main = self.main_max1(x)
# Extension branch
ext = self.ext_conv1(x)
ext = self.ext_conv2(ext)
ext = self.ext_conv3(ext)
ext = self.ext_regul(ext)
# Main branch channel padding
n, ch_ext, h, w = ext.size()
ch_main = main.size()[1]
padding = torch.zeros(n, ch_ext - ch_main, h, w)
# Before concatenating, check if main is on the CPU or GPU and
# convert padding accordingly
if main.is_cuda:
padding = padding.cuda()
# Concatenate
main = torch.cat((main, padding), 1)
# Add main and extension branches
out = main + ext
return self.out_activation(out), max_indices
class UpsamplingBottleneck(nn.Module):
"""The upsampling bottlenecks upsample the feature map resolution using max
pooling indices stored from the corresponding downsampling bottleneck.
Main branch:
1. 1x1 convolution with stride 1 that decreases the number of channels by
``internal\_ratio``, also called a projection;
2. max unpool layer using the max pool indices from the corresponding
downsampling max pool layer.
Extension branch:
1. 1x1 convolution with stride 1 that decreases the number of channels by
``internal\_ratio``, also called a projection;
2. transposed convolution (by default, 3x3);
3. 1x1 convolution which increases the number of channels to
``out\_channels``, also called an expansion;
4. dropout as a regularizer.
Keyword arguments:
- in\_channels (int): the number of input channels.
- out\_channels (int): the number of output channels.
- internal\_ratio (int, optional): a scale factor applied to ``in\_channels``
used to compute the number of channels after the projection. eg. given
``in\_channels`` equal to 128 and ``internal\_ratio`` equal to 2 the number
of channels after the projection is 64. Default: 4.
- dropout\_prob (float, optional): probability of an element to be zeroed.
Default: 0 (no dropout).
- bias (bool, optional): Adds a learnable bias to the output if ``True``.
Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
internal_ratio=4,
dropout_prob=0,
bias=False,
relu=True):
super().__init__()
# Check in the internal\_scale parameter is within the expected range
# [1, channels]
if internal_ratio <= 1 or internal_ratio > in_channels:
raise RuntimeError("Value out of range. Expected value in the "
"interval [1, {0}], got internal\_scale={1}. "
.format(in_channels, internal_ratio))
internal_channels = in_channels // internal_ratio
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - max pooling followed by feature map (channels) padding
self.main_conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias),
nn.BatchNorm2d(out_channels))
# Remember that the stride is the same as the kernel\_size, just like
# the max pooling layers
self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)
# Extension branch - 1x1 convolution, followed by a regular, dilated or
# asymmetric convolution, followed by another 1x1 convolution. Number
# of channels is doubled.
# 1x1 projection convolution with stride 1
self.ext_conv1 = nn.Sequential(
nn.Conv2d(
in_channels, internal_channels, kernel_size=1, bias=bias),
nn.BatchNorm2d(internal_channels), activation())
# Transposed convolution
self.ext_tconv1 = nn.ConvTranspose2d(
internal_channels,
internal_channels,
kernel_size=2,
stride=2,
bias=bias)
self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels)
self.ext_tconv1_activation = activation()
# 1x1 expansion convolution
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels, out_channels, kernel_size=1, bias=bias),
nn.BatchNorm2d(out_channels))
self.ext_regul = nn.Dropout2d(p=dropout_prob)
# PReLU layer to apply after concatenating the branches
self.out_activation = activation()
def forward(self, x, max_indices, output_size):
# Main branch shortcut
main = self.main_conv1(x)
main = self.main_unpool1(
main, max_indices, output_size=output_size)
# Extension branch
ext = self.ext_conv1(x)
ext = self.ext_tconv1(ext, output_size=output_size)
ext = self.ext_tconv1_bnorm(ext)
ext = self.ext_tconv1_activation(ext)
ext = self.ext_conv2(ext)
ext = self.ext_regul(ext)
# Add main and extension branches
out = main + ext
return self.out_activation(out)
class ENet(nn.Module):
"""Generate the ENet model.
Keyword arguments:
- num\_classes (int): the number of classes to segment.
- encoder\_relu (bool, optional): When ``True`` ReLU is used as the
activation function in the encoder blocks/layers; otherwise, PReLU
is used. Default: False.
- decoder\_relu (bool, optional): When ``True`` ReLU is used as the
activation function in the decoder blocks/layers; otherwise, PReLU
is used. Default: True.
"""
def __init__(self, num_classes, encoder_relu=False, decoder_relu=True):
super().__init__()
self.initial_block = InitialBlock(3, 16, relu=encoder_relu)
# Stage 1 - Encoder
self.downsample1_0 = DownsamplingBottleneck(
16,
64,
return_indices=True,
dropout_prob=0.01,
relu=encoder_relu)
self.regular1_1 = RegularBottleneck(
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
self.regular1_2 = RegularBottleneck(
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
self.regular1_3 = RegularBottleneck(
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
self.regular1_4 = RegularBottleneck(
64, padding=1, dropout_prob=0.01, relu=encoder_relu)
# Stage 2 - Encoder
self.downsample2_0 = DownsamplingBottleneck(
64,
128,
return_indices=True,
dropout_prob=0.1,
relu=encoder_relu)
self.regular2_1 = RegularBottleneck(
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
self.dilated2_2 = RegularBottleneck(
128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
self.asymmetric2_3 = RegularBottleneck(
128,
kernel_size=5,
padding=2,
asymmetric=True,
dropout_prob=0.1,
relu=encoder_relu)
self.dilated2_4 = RegularBottleneck(
128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
self.regular2_5 = RegularBottleneck(
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
self.dilated2_6 = RegularBottleneck(
128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
self.asymmetric2_7 = RegularBottleneck(
128,
kernel_size=5,
asymmetric=True,
padding=2,
dropout_prob=0.1,
relu=encoder_relu)
self.dilated2_8 = RegularBottleneck(
128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
# Stage 3 - Encoder
self.regular3_0 = RegularBottleneck(
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
self.dilated3_1 = RegularBottleneck(
128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
self.asymmetric3_2 = RegularBottleneck(
128,
kernel_size=5,
padding=2,
asymmetric=True,
dropout_prob=0.1,
relu=encoder_relu)
self.dilated3_3 = RegularBottleneck(
128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
self.regular3_4 = RegularBottleneck(
128, padding=1, dropout_prob=0.1, relu=encoder_relu)
self.dilated3_5 = RegularBottleneck(
128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
self.asymmetric3_6 = RegularBottleneck(
128,
kernel_size=5,
asymmetric=True,
padding=2,
dropout_prob=0.1,
relu=encoder_relu)
self.dilated3_7 = RegularBottleneck(
128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
# Stage 4 - Decoder
self.upsample4_0 = UpsamplingBottleneck(
128, 64, dropout_prob=0.1, relu=decoder_relu)
self.regular4_1 = RegularBottleneck(
64, padding=1, dropout_prob=0.1, relu=decoder_relu)
self.regular4_2 = RegularBottleneck(
64, padding=1, dropout_prob=0.1, relu=decoder_relu)
# Stage 5 - Decoder
self.upsample5_0 = UpsamplingBottleneck(
64, 16, dropout_prob=0.1, relu=decoder_relu)
self.regular5_1 = RegularBottleneck(
16, padding=1, dropout_prob=0.1, relu=decoder_relu)
self.transposed_conv = nn.ConvTranspose2d(
16,
num_classes,
kernel_size=3,
stride=2,
padding=1,
bias=False)
def forward(self, x):
# Initial block
input_size = x.size()
x = self.initial_block(x)
# Stage 1 - Encoder
stage1_input_size = x.size()
x, max_indices1_0 = self.downsample1_0(x)
x = self.regular1_1(x)
x = self.regular1_2(x)
x = self.regular1_3(x)
x = self.regular1_4(x)
# Stage 2 - Encoder
stage2_input_size = x.size()
x, max_indices2_0 = self.downsample2_0(x)
x = self.regular2_1(x)
x = self.dilated2_2(x)
x = self.asymmetric2_3(x)
x = self.dilated2_4(x)
x = self.regular2_5(x)
x = self.dilated2_6(x)
x = self.asymmetric2_7(x)
x = self.dilated2_8(x)
# Stage 3 - Encoder
x = self.regular3_0(x)
x = self.dilated3_1(x)
x = self.asymmetric3_2(x)
x = self.dilated3_3(x)
x = self.regular3_4(x)
x = self.dilated3_5(x)
x = self.asymmetric3_6(x)
x = self.dilated3_7(x)
# Stage 4 - Decoder
x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size)
x = self.regular4_1(x)
x = self.regular4_2(x)
# Stage 5 - Decoder
x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size)
x = self.regular5_1(x)
x = self.transposed_conv(x, output_size=input_size)
return x
if __name__ == '\_\_main\_\_':
x = torch.randn(1, 3, 256, 256)
net = ENet(13)(x)
print(net.shape) # torch.Size([1, 13, 256, 256])
由于标准卷积权重具有相当数量的冗余,当将一个滤波器为nxn的卷积层分解成两个连续的具有更小滤波器的卷积层(一个卷积层具有一个nx1滤波器,另外一个卷积层具有1xn滤波器),可以减少冗余信息。该分解卷积也被称为不对称卷积。ENet算法中使用了n=5的不对称卷积,这两个步骤产生的计算量和一个3x3卷积层相似,这有助于增加模型学习函数的多样性,并增加感受野。
为了避免特征图的过度下采样,ENet算法使用扩张卷积替代最小分辨率操作阶段中的几个编码模型中的主要卷积层,从而使得精度显著提升。
大多数像素级分割数据集相当较小,神经网络这样复杂的模型很容易过拟合,导致模型泛化能力下降,正则化参数等价于对参数引入先验分布,调节模型允许存储的信息量,对模型允许存储的信息加以约束,使得模型复杂度变小,有助于减少过拟合。
ENet在CmaVid数据集上训练后的模型下载地址,可以看到模型大小只有4.27M,非常轻量。
公众号“ 笑傲算法江湖 ”后台回复“ ENet ” 限时7天领取模型
import os
import sys
import cv2
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
import transforms as ext_transforms
from models.enet import ENet
#import utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = 'save/ENet\_CamVid/ENet'
class_encoding = OrderedDict([
('sky', (128, 128, 128)),
('building', (128, 0, 0)),
('pole', (192, 192, 128)),
('road', (128, 64, 128)),
('pavement', (60, 40, 222)),
('tree', (128, 128, 0)),
('sign\_symbol', (192, 128, 128)),
('fence', (64, 64, 128)),
('car', (64, 0, 128)),
('pedestrian', (64, 64, 0)),
('bicyclist', (0, 128, 192)),
('unlabeled', (0, 0, 0))
])
label_to_rgb = transforms.Compose([
ext_transforms.LongTensorToRGBPIL(class_encoding),
transforms.ToTensor()
])
def imshow_batch(images, labels):
"""Displays two grids of images. The top grid displays ``images``
and the bottom grid ``labels``
Keyword arguments:
- images (``Tensor``): a 4D mini-batch tensor of shape
(B, C, H, W)
- labels (``Tensor``): a 4D mini-batch tensor of shape
(B, C, H, W)
"""
# Make a grid with the images and labels and convert it to numpy
images = torchvision.utils.make_grid(images).numpy()
labels = torchvision.utils.make_grid(labels).numpy()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
ax1.imshow(np.transpose(images, (1, 2, 0)))
ax1.set_xticks([])
ax1.set_yticks([])
ax2.imshow(np.transpose(labels, (1, 2, 0)))
ax2.set_xticks([])
ax2.set_yticks([])
plt.savefig('inference.png')
plt.show()
def batch_transform(batch, transform):
"""Applies a transform to a batch of samples.
Keyword arguments:
- batch (): a batch os samples
- transform (callable): A function/transform to apply to ``batch``
"""
# Convert the single channel label to RGB in tensor form
# 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of
# all slices along that dimension
# 2. the transform is applied to each slice
transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]
return torch.stack(transf_slices)
# Run only if this module is being run directly
if __name__ == '\_\_main\_\_':
model = ENet(len(class_encoding)).to(device)
model.eval()
# Load the previoulsy saved model state to the ENet model
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state\_dict'])
in_image = cv2.imread('test.png')
in_image = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
in_image = in_image.transpose(2, 0, 1)
in_image = torch.from_numpy(in_image).unsqueeze(0)
in_image = in_image.to(device).float() / 255.
out_image = model(in_image)
# Predictions is one-hot encoded with "num\_classes" channels.
# Convert it to a single int using the indices where the maximum (1) occurs
_, predictions = torch.max(out_image, 1)
color_predictions = batch_transform(predictions.cpu(), label_to_rgb)
imshow_batch(in_image.data.cpu(), color_predictions)