Commit c6f52aa1 authored by Vaanathi Sundaresan's avatar Vaanathi Sundaresan
Browse files

Update truenet_tumseg_model_utils.py

parent 92490297
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
#=========================================================================================
# Truenet model utility functions
# Vaanathi Sundaresan
# 09-03-2021, Oxford
#=========================================================================================
class SingleConv(nn.Module):
"""(convolution => [BN] => ReLU)"""
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, kernelsize, name, mid_channels=None):
def __init__(self, in_channels, out_channels, kernelsize, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.single_conv = nn.Sequential(
OrderedDict([(
name+"conv", nn.Conv2d(in_channels, mid_channels, kernel_size=kernelsize, padding=1)),
(name+"bn", nn.BatchNorm2d(mid_channels)),
(name+"relu", nn.ReLU(inplace=True)),])
nn.Conv2d(in_channels, mid_channels, kernel_size=kernelsize, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
......@@ -34,19 +18,18 @@ class SingleConv(nn.Module):
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, kernelsize, name, mid_channels=None):
def __init__(self, in_channels, out_channels, kernelsize, mid_channels=None):
super().__init__()
pad = (kernelsize - 1)//2
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
OrderedDict([(
name+"conv1", nn.Conv2d(in_channels, mid_channels, kernel_size=kernelsize, padding=pad)),
(name+"bn1", nn.BatchNorm2d(mid_channels)),
(name+"relu1", nn.ReLU(inplace=True)),
(name+"conv2", nn.Conv2d(mid_channels, out_channels, kernel_size=kernelsize, padding=pad)),
(name+"bn2", nn.BatchNorm2d(out_channels)),
(name+"relu2", nn.ReLU(inplace=True)),])
nn.Conv2d(in_channels, mid_channels, kernel_size=kernelsize, padding=pad),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=kernelsize, padding=pad),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
......@@ -56,19 +39,11 @@ class DoubleConv(nn.Module):
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, kernel_size, name):
def __init__(self, in_channels, out_channels, kernel_size):
super().__init__()
pad = (kernel_size - 1)//2
mid_channels = out_channels
self.maxpool_conv = nn.Sequential(
OrderedDict([
(name+"maxpool", nn.MaxPool2d(2)),
(name+"conv1", nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=pad)),
(name+"bn1", nn.BatchNorm2d(mid_channels)),
(name+"relu1", nn.ReLU(inplace=True)),
(name+"conv2", nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=pad)),
(name+"bn2", nn.BatchNorm2d(out_channels)),
(name+"relu2", nn.ReLU(inplace=True)),])
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels, kernel_size)
)
def forward(self, x):
......@@ -78,16 +53,16 @@ class Down(nn.Module):
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, kernel_size, name, bilinear=True):
def __init__(self, in_channels, out_channels, kernel_size, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, name)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=kernel_size, stride=2)
self.conv = DoubleConv(in_channels, out_channels, 3, name)
self.conv = DoubleConv(in_channels, out_channels, 3)
def forward(self, x1, x2):
......@@ -106,21 +81,14 @@ class Up(nn.Module):
class OutConv(nn.Module):
"""convolution"""
def __init__(self, in_channels, out_channels, name):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Sequential(
OrderedDict([(
name+"conv", nn.Conv2d(in_channels, out_channels, kernel_size=1)), ])
)
self.conv1 = nn.Sequential(
OrderedDict([(
name + "conv1", nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=1)), ])
)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=1)
def forward(self, x):
try:
return self.conv(x)
except:
return self.conv1(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment