import torch
import torch.nn as nn
import torch.nn.functional as F
from .layer_factory import CRPBlock, batchnorm, conv1x1, conv3x3, sepconv_bn
from ..misc.utils import make_list
[docs]class DLv3plus(nn.Module):
"""DeepLab-v3+ for Semantic Image Segmentation.
ASPP with decoder. Allows to have multiple skip-connections.
More information about the model: https://arxiv.org/abs/1802.02611
Args:
input_sizes (int, or list): number of channels for each input.
Last value represents the input to ASPP,
other values are for skip-connections.
num_classes (int): number of output channels.
skip_size (int): common filter size for skip-connections.
agg_size (int): common filter size.
rates (list of ints): dilation rates in the ASPP module.
"""
def __init__(
self,
input_sizes,
num_classes,
skip_size=48,
agg_size=256,
rates=(6, 12, 18),
**kwargs,
):
super(DLv3plus, self).__init__()
skip_convs = nn.ModuleList()
aspp = nn.ModuleList()
input_sizes = make_list(input_sizes)
for size in input_sizes[:-1]:
skip_convs.append(
nn.Sequential(
conv1x1(size, skip_size, bias=False),
batchnorm(skip_size),
nn.ReLU(inplace=False),
)
)
# ASPP
aspp.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(1),
conv1x1(input_sizes[-1], agg_size, bias=False),
batchnorm(agg_size),
nn.ReLU(inplace=False),
)
)
aspp.append(
nn.Sequential(
conv1x1(input_sizes[-1], agg_size, bias=False),
batchnorm(agg_size),
nn.ReLU(inplace=False),
)
)
for rate in rates:
aspp.append(
sepconv_bn(input_sizes[-1], agg_size, rate=rate, depth_activation=True)
)
aspp.append(
nn.Sequential(
conv1x1(agg_size * 5, agg_size, bias=False),
batchnorm(agg_size),
nn.ReLU(inplace=False),
nn.Dropout(p=0.1),
)
)
self.skip_convs = skip_convs
self.aspp = aspp
self.dec = nn.Sequential(
sepconv_bn(
agg_size + len(skip_convs) * skip_size, agg_size, depth_activation=True
),
sepconv_bn(agg_size, agg_size, depth_activation=True),
)
self.clf = conv1x1(agg_size, num_classes, bias=True)
[docs] def forward(self, xs):
xs = make_list(xs)
skips = [conv(x) for conv, x in zip(self.skip_convs, xs[:-1])]
aspp = [branch(xs[-1]) for branch in self.aspp[:-1]]
# Upsample GAP
aspp[0] = F.interpolate(
aspp[0], size=xs[-1].size()[2:], mode="bilinear", align_corners=True
)
aspp = torch.cat(aspp, dim=1)
# Apply last conv in ASPP
aspp = self.aspp[-1](aspp)
# Connect with skip-connections
if skips:
dec = [skips[0]]
for x in skips[1:] + [aspp]:
dec.append(
F.interpolate(
x, size=dec[0].size()[2:], mode="bilinear", align_corners=True
)
)
else:
dec = [aspp]
dec = torch.cat(dec, dim=1)
dec = self.dec(dec)
out = self.clf(dec)
return out
[docs]class MTDLv3plus(nn.Module):
"""Multi-Task DeepLab-v3+ for Semantic Image Segmentation.
ASPP with decoder. Allows to have multiple skip-connections.
More information about the model: https://arxiv.org/abs/1802.02611
Args:
input_sizes (int, or list): number of channels for each input.
Last value represents the input to ASPP,
other values are for skip-connections.
num_classes (int): number of output channels.
skip_size (int): common filter size for skip-connections.
agg_size (int): common filter size.
rates (list of ints): dilation rates in the ASPP module.
"""
def __init__(
self,
input_sizes,
num_classes,
skip_size=48,
agg_size=256,
rates=(6, 12, 18),
**kwargs,
):
super(MTDLv3plus, self).__init__()
skip_convs = nn.ModuleList()
aspp = nn.ModuleList()
input_sizes = make_list(input_sizes)
num_classes = make_list(num_classes)
for size in input_sizes[:-1]:
skip_convs.append(
nn.Sequential(
conv1x1(size, skip_size, bias=False),
batchnorm(skip_size),
nn.ReLU(inplace=False),
)
)
# ASPP
aspp.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(1),
conv1x1(input_sizes[-1], agg_size, bias=False),
batchnorm(agg_size),
nn.ReLU(inplace=False),
)
)
aspp.append(
nn.Sequential(
conv1x1(input_sizes[-1], agg_size, bias=False),
batchnorm(agg_size),
nn.ReLU(inplace=False),
)
)
for rate in rates:
aspp.append(
sepconv_bn(input_sizes[-1], agg_size, rate=rate, depth_activation=True)
)
aspp.append(
nn.Sequential(
conv1x1(agg_size * 5, agg_size, bias=False),
batchnorm(agg_size),
nn.ReLU(inplace=False),
nn.Dropout(p=0.1),
)
)
self.skip_convs = skip_convs
self.aspp = aspp
self.dec = nn.Sequential(
sepconv_bn(
agg_size + len(skip_convs) * skip_size, agg_size, depth_activation=True
),
sepconv_bn(agg_size, agg_size, depth_activation=True),
)
self.clfs = nn.ModuleList(
[
conv1x1(agg_size, output_classes, bias=True)
for output_classes in num_classes
]
)
[docs] def forward(self, xs):
xs = make_list(xs)
skips = [conv(x) for conv, x in zip(self.skip_convs, xs[:-1])]
aspp = [branch(xs[-1]) for branch in self.aspp[:-1]]
# Upsample GAP
aspp[0] = F.interpolate(
aspp[0], size=xs[-1].size()[2:], mode="bilinear", align_corners=True
)
aspp = torch.cat(aspp, dim=1)
# Apply last conv in ASPP
aspp = self.aspp[-1](aspp)
# Connect with skip-connections
if skips:
dec = [skips[0]]
for x in skips[1:] + [aspp]:
dec.append(
F.interpolate(
x, size=dec[0].size()[2:], mode="bilinear", align_corners=True
)
)
else:
dec = [aspp]
dec = torch.cat(dec, dim=1)
dec = self.dec(dec)
return [clf(dec) for clf in self.clfs]
[docs]class LWRefineNet(nn.Module):
"""Light-Weight RefineNet for Semantic Image Segmentation.
More information about the model: https://arxiv.org/abs/1810.03272
Args:
input_sizes (int, or list): number of channels for each input.
combine_layers (list): which input layers should be united together
(via element-wise summation) before CRP.
num_classes (int): number of output channels.
agg_size (int): common filter size.
n_crp (int): number of CRP layers in a single CRP block.
"""
def __init__(self, input_sizes, combine_layers, num_classes, agg_size=256, n_crp=4):
super(LWRefineNet, self).__init__()
stem_convs = nn.ModuleList()
crp_blocks = nn.ModuleList()
adapt_convs = nn.ModuleList()
input_sizes = make_list(input_sizes)
# Reverse since we recover information from the end
input_sizes = list(reversed(input_sizes))
# No reverse for collapse indices
self.combine_layers = make_list(combine_layers)
for size in input_sizes:
stem_convs.append(conv1x1(size, agg_size, bias=False))
for _ in range(len(self.combine_layers)):
crp_blocks.append(self._make_crp(agg_size, agg_size, n_crp))
adapt_convs.append(conv1x1(agg_size, agg_size, bias=False))
self.stem_convs = stem_convs
self.crp_blocks = crp_blocks
self.adapt_convs = adapt_convs[:-1]
self.segm = conv3x3(agg_size, num_classes, bias=True)
self.relu = nn.ReLU6(inplace=True)
[docs] def forward(self, xs):
xs = make_list(xs)
xs = list(reversed(xs))
for idx, (conv, x) in enumerate(zip(self.stem_convs, xs)):
xs[idx] = conv(x)
# Collapse layers
c_xs = [
sum([xs[idx] for idx in make_list(c_idx)]) for c_idx in self.combine_layers
]
for idx, (crp, x) in enumerate(zip(self.crp_blocks, c_xs)):
if idx == 0:
y = self.relu(x)
else:
y = self.relu(x + y)
y = crp(y)
if idx < (len(c_xs) - 1):
y = self.adapt_convs[idx](y)
y = F.interpolate(
y,
size=c_xs[idx + 1].size()[2:],
mode="bilinear",
align_corners=True,
)
out_segm = self.segm(y)
return out_segm
@staticmethod
def _make_crp(in_planes, out_planes, stages):
"""Creating Light-Weight Chained Residual Pooling (CRP) block.
Args:
in_planes (int): number of input channels.
out_planes (int): number of output channels.
stages (int): number of times the design is repeated
(with new weights)
Returns:
`nn.Sequential` of CRP layers.
"""
layers = [CRPBlock(in_planes, out_planes, stages)]
return nn.Sequential(*layers)
[docs]class MTLWRefineNet(nn.Module):
"""Multi-Task Light-Weight RefineNet for Dense per-pixel tasks.
More information about the model: https://arxiv.org/abs/1809.04766
Args:
input_sizes (int, or list): number of channels for each input.
combine_layers (list): which input layers should be united together
(via element-wise summation) before CRP.
num_classes (int or list): number of output channels per each head.
agg_size (int): common filter size.
n_crp (int): number of CRP layers in a single CRP block.
"""
def __init__(
self, input_sizes, combine_layers, num_classes, agg_size=256, n_crp=4, **kwargs
):
super(MTLWRefineNet, self).__init__()
stem_convs = nn.ModuleList()
crp_blocks = nn.ModuleList()
adapt_convs = nn.ModuleList()
heads = nn.ModuleList()
input_sizes = make_list(input_sizes)
# Reverse since we recover information from the end
input_sizes = list(reversed(input_sizes))
# No reverse for collapse indices is needed
self.combine_layers = make_list(combine_layers)
groups = [False] * len(self.combine_layers)
groups[-1] = True
for size in input_sizes:
stem_convs.append(conv1x1(size, agg_size, bias=False))
for group in groups:
crp_blocks.append(self._make_crp(agg_size, agg_size, n_crp, group))
adapt_convs.append(conv1x1(agg_size, agg_size, bias=False))
self.stem_convs = stem_convs
self.crp_blocks = crp_blocks
self.adapt_convs = adapt_convs[:-1]
num_classes = make_list(num_classes)
for n_out in num_classes:
heads.append(
nn.Sequential(
conv1x1(agg_size, agg_size, groups=agg_size, bias=False),
nn.ReLU6(inplace=False),
conv3x3(agg_size, n_out, bias=True),
)
)
self.heads = heads
self.relu = nn.ReLU6(inplace=True)
[docs] def forward(self, xs):
xs = make_list(xs)
xs = list(reversed(xs))
for idx, (conv, x) in enumerate(zip(self.stem_convs, xs)):
xs[idx] = conv(x)
# Collapse layers
c_xs = [
sum([xs[idx] for idx in make_list(c_idx)]) for c_idx in self.combine_layers
]
for idx, (crp, x) in enumerate(zip(self.crp_blocks, c_xs)):
if idx == 0:
y = self.relu(x)
else:
y = self.relu(x + y)
y = crp(y)
if idx < (len(c_xs) - 1):
y = self.adapt_convs[idx](y)
y = F.interpolate(
y,
size=c_xs[idx + 1].size()[2:],
mode="bilinear",
align_corners=True,
)
outs = []
for head in self.heads:
outs.append(head(y))
return outs
@staticmethod
def _make_crp(in_planes, out_planes, stages, groups):
"""Creating Light-Weight Chained Residual Pooling (CRP) block.
Args:
in_planes (int): number of input channels.
out_planes (int): number of output channels.
stages (int): number of times the design is repeated
(with new weights)
groups (bool): whether to do groupwise convolution inside CRP.
Returns:
`nn.Sequential` of CRP layers.
"""
layers = [CRPBlock(in_planes, out_planes, stages, groups)]
return nn.Sequential(*layers)