Credits: https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation

18. U-Net Convolutional Networks for Image Segmentation#

18.1. U-Net#

U-Net is a widely used deep learning architecture that was first introduced in the “U-Net: Convolutional Networks for Biomedical Image Segmentation” paper. The primary purpose of this architecture was to address the challenge of limited annotated data in the medical field. This network was designed to effectively leverage a smaller amount of data while maintaining speed and accuracy.

The architecture of U-Net is unique in that it consists of a contracting path and an expansive path. The contracting path contains encoder layers that capture contextual information and reduce the spatial resolution of the input, while the expansive path contains decoder layers that decode the encoded data and use the information from the contracting path via skip connections to generate a segmentation map.

The contracting path in U-Net is responsible for identifying the relevant features in the input image. The encoder layers perform convolutional operations that reduce the spatial resolution of the feature maps while increasing their depth, thereby capturing increasingly abstract representations of the input. This contracting path is similar to the feedforward layers in other convolutional neural networks. On the other hand, the expansive path works on decoding the encoded data and locating the features while maintaining the spatial resolution of the input. The decoder layers in the expansive path upsample the feature maps, while also performing convolutional operations. The skip connections from the contracting path help to preserve the spatial information lost in the contracting path, which helps the decoder layers to locate the features more accurately.

alt text

Source: https://arxiv.org/pdf/1505.04597

18.2. Explanation#

Credits: see blog by A. Ito Armendia

The Contracting Path of U-Net

alt text

Block 1

An input image with dimensions 572² is fed into the U-Net. This input image consists of only 1 channel, likely a grayscale channel. Two 3x3 convolution layers (unpadded) are then applied to the input image, each followed by a ReLU layer. At the same time the number of channels are increased to 64 in order to capture higher level features. A 2x2 max pooling layer with a stride of 2 is then applied. This downsamples the feature map to half its size, 284².

Block 2

Just like in block 1, two 3x3 convolution layers (unpadded) are applied to the output of block 1, each followed again by a ReLU layer. At each new block the number of feature channels are doubled, now to 128. Next a 2x2 max pooling layer is again applied to the resulting feature map reducing the spatial dimensions by half to 140².

Block 3

The procedure used in block 1 and 2 is the same as in block 3, so will not be repeated.

Block 4

Same as block 3.

Block 5

In the final block of the contracting path, the number of feature channels reach 1024 after being doubled at each block. This block also contains two 3x3 convolution layers (unpadded), which are each followed by a ReLU layer. However, for symmetry purposes, I have only included one layer and included the second layer in the expanding path. After complex features and patterns have been extracted, the feature map moves on to the expanding path.

The Expanding Path

alt text

Block 5

  1. Continuing on from the contracting path, a second 3x3 convolution (unpadded) is applied with a ReLU layer after it.

  2. Then a 2x2 convolution (up-convolution) layer is applied, upsampling the spatial dimensions twofold and also halving the number of channels to 512.

Block 4

  1. Using skip connections, the corresponding feature map from the contracting path is then concatenated, doubling the feature channels to 1024. Note that this concatenation must be cropped to match the expanding path’s dimensions.

  2. Two 3x3 convolution layers (unpadded) are applied, each with a ReLU layer following, reducing the channels to 512.

  3. After, a 2x2 convolution (up-convolution) layer is applied, upsampling the spatial dimensions twofold and also halving the number of channels to 256.

Block 3

The procedure used in block 5 and 4 is the same as in block 3, so will not be repeated.

Block 2

Same as block 3.

Block 1

  1. In the final block of the expanding path, there are 128 channels after concatenating the skip connection.

  2. Next, two 3x3 convolution layers (unpadded) are applied on the feature map, with ReLU layers inbetween reducing the number of feature channels to 64.

  3. Finally, a 1x1 convolution layer, followed by an activation layer (sigmoid for binary classification) is used to reduce the number of channels to the desired number of classes. In this case, 2 classes, as binary classification is often used in medical imaging.

After upsampling the feature map in the expanding path, a segmentation map should be generated, with each pixel classified individually.


N.b.

i. Skipping connections include cropping.

ii. The up-conv reduce the number of channels.

iii. The final ouput has slightly lower resolution, i.e., reduced dimensionality, compared to the input, as a result of the conv operations and cropping. The final number of channels is 2, in that for the problem at hand we have two classes.


!pip install numpy matplotlib torch torchvision Pillow
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1)
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.19.1+cu121)
Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (10.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.54.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (24.1)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.4)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)
!wget https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train.zip&inline=false -O train.zip
/bin/bash: line 1: -O: command not found
--2024-10-08 17:12:31--  https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train.zip
Resolving git.lcsr.jhu.edu (git.lcsr.jhu.edu)... 128.220.253.212
Connecting to git.lcsr.jhu.edu (git.lcsr.jhu.edu)|128.220.253.212|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 43273606 (41M) [application/octet-stream]
Saving to: ‘train.zip.1’

train.zip.1         100%[===================>]  41.27M  9.59MB/s    in 4.7s    

2024-10-08 17:12:37 (8.73 MB/s) - ‘train.zip.1’ saved [43273606/43273606]
!ls
sample_data  train.zip
!file train.zip
train.zip: Zip archive data, at least v1.0 to extract, compression method=store
!(cd /content & unzip train.zip)
Archive:  train.zip
replace train/eefc0d8c94f0_08.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 
!wget https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train_masks.zip&inline=false -O train_masks.zip
/bin/bash: line 1: -O: command not found
--2024-10-08 16:01:53--  https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train_masks.zip
Resolving git.lcsr.jhu.edu (git.lcsr.jhu.edu)... 128.220.253.212
Connecting to git.lcsr.jhu.edu (git.lcsr.jhu.edu)|128.220.253.212|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3119864 (3.0M) [application/octet-stream]
Saving to: ‘train_masks.zip’

train_masks.zip     100%[===================>]   2.97M  1.72MB/s    in 1.7s    

2024-10-08 16:01:56 (1.72 MB/s) - ‘train_masks.zip’ saved [3119864/3119864]
!(cd /content & unzip train_masks.zip)
Archive:  train_masks.zip
   creating: train_masks/
 extracting: train_masks/efaef69e148d_13_mask.gif  
 extracting: train_masks/f707d6fbc0cd_09_mask.gif  
 extracting: train_masks/efaef69e148d_12_mask.gif  
 extracting: train_masks/f707d6fbc0cd_08_mask.gif  
 extracting: train_masks/fecea3036c59_14_mask.gif  
 extracting: train_masks/fecea3036c59_15_mask.gif  
 extracting: train_masks/eeb7eeca738e_16_mask.gif  
 extracting: train_masks/ef5567efd904_14_mask.gif  
 extracting: train_masks/ef5567efd904_15_mask.gif  
 extracting: train_masks/eb91b1c659a0_07_mask.gif  
 extracting: train_masks/eb91b1c659a0_06_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_14_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_15_mask.gif  
 extracting: train_masks/feaf59172a01_14_mask.gif  
  inflating: train_masks/feaf59172a01_15_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_16_mask.gif  
  inflating: train_masks/eb07e3f63ad2_01_mask.gif  
  inflating: train_masks/f4cd1286d5f4_13_mask.gif  
 extracting: train_masks/f4cd1286d5f4_12_mask.gif  
 extracting: train_masks/fdc2c87853ce_05_mask.gif  
 extracting: train_masks/fdc2c87853ce_04_mask.gif  
 extracting: train_masks/f00905abd3d7_07_mask.gif  
 extracting: train_masks/f00905abd3d7_06_mask.gif  
 extracting: train_masks/f8b6f4c39204_09_mask.gif  
 extracting: train_masks/f8b6f4c39204_08_mask.gif  
 extracting: train_masks/fa613ac8eac5_10_mask.gif  
 extracting: train_masks/fa613ac8eac5_11_mask.gif  
 extracting: train_masks/f70052627830_03_mask.gif  
 extracting: train_masks/f70052627830_02_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_07_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_06_mask.gif  
 extracting: train_masks/fc237174b128_01_mask.gif  
 extracting: train_masks/f3eee6348205_01_mask.gif  
  inflating: train_masks/f7ad86e13ed7_05_mask.gif  
 extracting: train_masks/f7ad86e13ed7_04_mask.gif  
 extracting: train_masks/f1eb080c7182_01_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_04_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_05_mask.gif  
 extracting: train_masks/ed8472086df8_05_mask.gif  
 extracting: train_masks/ed8472086df8_04_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_12_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_13_mask.gif  
  inflating: train_masks/fff9b3a5373f_01_mask.gif  
 extracting: train_masks/f707d6fbc0cd_03_mask.gif  
 extracting: train_masks/f707d6fbc0cd_02_mask.gif  
 extracting: train_masks/f70052627830_09_mask.gif  
 extracting: train_masks/f70052627830_08_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_01_mask.gif  
 extracting: train_masks/f591b4f2e006_10_mask.gif  
  inflating: train_masks/f591b4f2e006_11_mask.gif  
 extracting: train_masks/f3b482e091c0_05_mask.gif  
 extracting: train_masks/f8b6f4c39204_03_mask.gif  
 extracting: train_masks/f3b482e091c0_04_mask.gif  
 extracting: train_masks/f8b6f4c39204_02_mask.gif  
 extracting: train_masks/fb1b923dd978_07_mask.gif  
  inflating: train_masks/eefc0d8c94f0_11_mask.gif  
 extracting: train_masks/fb1b923dd978_06_mask.gif  
  inflating: train_masks/eefc0d8c94f0_10_mask.gif  
 extracting: train_masks/fa006be8b6d9_14_mask.gif  
 extracting: train_masks/eaf9eb0b2293_14_mask.gif  
 extracting: train_masks/fa006be8b6d9_15_mask.gif  
 extracting: train_masks/eaf9eb0b2293_15_mask.gif  
  inflating: train_masks/feaf59172a01_03_mask.gif  
 extracting: train_masks/feaf59172a01_02_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_03_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_02_mask.gif  
 extracting: train_masks/ef5567efd904_03_mask.gif  
 extracting: train_masks/ef5567efd904_02_mask.gif  
 extracting: train_masks/eb91b1c659a0_10_mask.gif  
 extracting: train_masks/eb91b1c659a0_11_mask.gif  
 extracting: train_masks/fecea3036c59_03_mask.gif  
 extracting: train_masks/fecea3036c59_02_mask.gif  
 extracting: train_masks/eeb7eeca738e_01_mask.gif  
 extracting: train_masks/efaef69e148d_04_mask.gif  
 extracting: train_masks/efaef69e148d_05_mask.gif  
 extracting: train_masks/f70052627830_14_mask.gif  
 extracting: train_masks/f70052627830_15_mask.gif  
 extracting: train_masks/fa613ac8eac5_07_mask.gif  
 extracting: train_masks/fa613ac8eac5_06_mask.gif  
 extracting: train_masks/f00905abd3d7_10_mask.gif  
 extracting: train_masks/f00905abd3d7_11_mask.gif  
 extracting: train_masks/fa006be8b6d9_09_mask.gif  
 extracting: train_masks/fa006be8b6d9_08_mask.gif  
  inflating: train_masks/eb07e3f63ad2_16_mask.gif  
 extracting: train_masks/f4cd1286d5f4_04_mask.gif  
 extracting: train_masks/f4cd1286d5f4_05_mask.gif  
 extracting: train_masks/fdc2c87853ce_12_mask.gif  
 extracting: train_masks/fdc2c87853ce_13_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_01_mask.gif  
 extracting: train_masks/fff9b3a5373f_16_mask.gif  
 extracting: train_masks/f707d6fbc0cd_14_mask.gif  
 extracting: train_masks/f707d6fbc0cd_15_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_13_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_12_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_05_mask.gif  
 extracting: train_masks/fecea3036c59_09_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_04_mask.gif  
 extracting: train_masks/fecea3036c59_08_mask.gif  
 extracting: train_masks/ed8472086df8_12_mask.gif  
 extracting: train_masks/ed8472086df8_13_mask.gif  
 extracting: train_masks/f7ad86e13ed7_12_mask.gif  
 extracting: train_masks/f7ad86e13ed7_13_mask.gif  
 extracting: train_masks/f3eee6348205_16_mask.gif  
 extracting: train_masks/ef5567efd904_09_mask.gif  
 extracting: train_masks/ef5567efd904_08_mask.gif  
 extracting: train_masks/f1eb080c7182_16_mask.gif  
 extracting: train_masks/fc237174b128_16_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_10_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_11_mask.gif  
 extracting: train_masks/feaf59172a01_09_mask.gif  
 extracting: train_masks/feaf59172a01_08_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_09_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_08_mask.gif  
  inflating: train_masks/eefc0d8c94f0_06_mask.gif  
 extracting: train_masks/fb1b923dd978_10_mask.gif  
  inflating: train_masks/eefc0d8c94f0_07_mask.gif  
  inflating: train_masks/fb1b923dd978_11_mask.gif  
 extracting: train_masks/fa006be8b6d9_03_mask.gif  
 extracting: train_masks/fa006be8b6d9_02_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_16_mask.gif  
  inflating: train_masks/f591b4f2e006_07_mask.gif  
 extracting: train_masks/f591b4f2e006_06_mask.gif  
 extracting: train_masks/f3b482e091c0_12_mask.gif  
 extracting: train_masks/f8b6f4c39204_14_mask.gif  
 extracting: train_masks/f3b482e091c0_13_mask.gif  
 extracting: train_masks/f8b6f4c39204_15_mask.gif  
 extracting: train_masks/ef5567efd904_04_mask.gif  
 extracting: train_masks/ef5567efd904_05_mask.gif  
 extracting: train_masks/eb91b1c659a0_16_mask.gif  
  inflating: train_masks/feaf59172a01_04_mask.gif  
 extracting: train_masks/feaf59172a01_05_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_04_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_05_mask.gif  
 extracting: train_masks/efaef69e148d_03_mask.gif  
 extracting: train_masks/efaef69e148d_02_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_08_mask.gif  
 extracting: train_masks/fecea3036c59_04_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_09_mask.gif  
 extracting: train_masks/fecea3036c59_05_mask.gif  
 extracting: train_masks/eeb7eeca738e_06_mask.gif  
 extracting: train_masks/eeb7eeca738e_07_mask.gif  
 extracting: train_masks/f00905abd3d7_16_mask.gif  
  inflating: train_masks/f70052627830_13_mask.gif  
  inflating: train_masks/f70052627830_12_mask.gif  
 extracting: train_masks/fa613ac8eac5_01_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_06_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_07_mask.gif  
  inflating: train_masks/eb07e3f63ad2_11_mask.gif  
 extracting: train_masks/f4cd1286d5f4_03_mask.gif  
  inflating: train_masks/eb07e3f63ad2_10_mask.gif  
 extracting: train_masks/f4cd1286d5f4_02_mask.gif  
 extracting: train_masks/fdc2c87853ce_15_mask.gif  
 extracting: train_masks/fdc2c87853ce_14_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_14_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_15_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_02_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_03_mask.gif  
 extracting: train_masks/ed8472086df8_15_mask.gif  
 extracting: train_masks/ed8472086df8_14_mask.gif  
 extracting: train_masks/fff9b3a5373f_11_mask.gif  
 extracting: train_masks/fff9b3a5373f_10_mask.gif  
 extracting: train_masks/efaef69e148d_09_mask.gif  
 extracting: train_masks/f707d6fbc0cd_13_mask.gif  
 extracting: train_masks/efaef69e148d_08_mask.gif  
 extracting: train_masks/f707d6fbc0cd_12_mask.gif  
 extracting: train_masks/fc237174b128_11_mask.gif  
 extracting: train_masks/fc237174b128_10_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_16_mask.gif  
 extracting: train_masks/f7ad86e13ed7_15_mask.gif  
 extracting: train_masks/f7ad86e13ed7_14_mask.gif  
 extracting: train_masks/f3eee6348205_11_mask.gif  
 extracting: train_masks/f3eee6348205_10_mask.gif  
 extracting: train_masks/f1eb080c7182_10_mask.gif  
 extracting: train_masks/f1eb080c7182_11_mask.gif  
  inflating: train_masks/eefc0d8c94f0_01_mask.gif  
 extracting: train_masks/fb1b923dd978_16_mask.gif  
 extracting: train_masks/fa006be8b6d9_04_mask.gif  
  inflating: train_masks/fa006be8b6d9_05_mask.gif  
 extracting: train_masks/f4cd1286d5f4_09_mask.gif  
 extracting: train_masks/f4cd1286d5f4_08_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_11_mask.gif  
  inflating: train_masks/f98dbe8a5ee2_10_mask.gif  
 extracting: train_masks/f591b4f2e006_01_mask.gif  
 extracting: train_masks/f3b482e091c0_15_mask.gif  
 extracting: train_masks/f8b6f4c39204_13_mask.gif  
 extracting: train_masks/f3b482e091c0_14_mask.gif  
  inflating: train_masks/f8b6f4c39204_12_mask.gif  
 extracting: train_masks/ed8472086df8_08_mask.gif  
 extracting: train_masks/ed8472086df8_09_mask.gif  
 extracting: train_masks/fecea3036c59_13_mask.gif  
 extracting: train_masks/fecea3036c59_12_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_09_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_08_mask.gif  
  inflating: train_masks/eeb7eeca738e_11_mask.gif  
 extracting: train_masks/eeb7eeca738e_10_mask.gif  
 extracting: train_masks/efaef69e148d_14_mask.gif  
 extracting: train_masks/efaef69e148d_15_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_13_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_12_mask.gif  
 extracting: train_masks/feaf59172a01_13_mask.gif  
 extracting: train_masks/feaf59172a01_12_mask.gif  
 extracting: train_masks/ef5567efd904_13_mask.gif  
 extracting: train_masks/ef5567efd904_12_mask.gif  
 extracting: train_masks/f7ad86e13ed7_08_mask.gif  
 extracting: train_masks/eb91b1c659a0_01_mask.gif  
 extracting: train_masks/f7ad86e13ed7_09_mask.gif  
  inflating: train_masks/eb07e3f63ad2_06_mask.gif  
 extracting: train_masks/f4cd1286d5f4_14_mask.gif  
  inflating: train_masks/eb07e3f63ad2_07_mask.gif  
 extracting: train_masks/f4cd1286d5f4_15_mask.gif  
 extracting: train_masks/fdc2c87853ce_02_mask.gif  
 extracting: train_masks/fdc2c87853ce_03_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_11_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_10_mask.gif  
 extracting: train_masks/fa613ac8eac5_16_mask.gif  
 extracting: train_masks/f70052627830_04_mask.gif  
  inflating: train_masks/f70052627830_05_mask.gif  
 extracting: train_masks/f00905abd3d7_01_mask.gif  
 extracting: train_masks/f3b482e091c0_08_mask.gif  
 extracting: train_masks/f3b482e091c0_09_mask.gif  
 extracting: train_masks/f3eee6348205_06_mask.gif  
 extracting: train_masks/f3eee6348205_07_mask.gif  
 extracting: train_masks/f7ad86e13ed7_02_mask.gif  
  inflating: train_masks/f7ad86e13ed7_03_mask.gif  
 extracting: train_masks/f1eb080c7182_07_mask.gif  
 extracting: train_masks/f1eb080c7182_06_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_01_mask.gif  
 extracting: train_masks/fc237174b128_06_mask.gif  
  inflating: train_masks/fc237174b128_07_mask.gif  
 extracting: train_masks/fff9b3a5373f_06_mask.gif  
 extracting: train_masks/fff9b3a5373f_07_mask.gif  
 extracting: train_masks/f707d6fbc0cd_04_mask.gif  
 extracting: train_masks/f707d6fbc0cd_05_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_03_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_02_mask.gif  
 extracting: train_masks/ed8472086df8_02_mask.gif  
 extracting: train_masks/ed8472086df8_03_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_15_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_14_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_06_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_07_mask.gif  
 extracting: train_masks/f591b4f2e006_16_mask.gif  
 extracting: train_masks/f3b482e091c0_02_mask.gif  
 extracting: train_masks/f8b6f4c39204_04_mask.gif  
 extracting: train_masks/f3b482e091c0_03_mask.gif  
  inflating: train_masks/f8b6f4c39204_05_mask.gif  
  inflating: train_masks/eefc0d8c94f0_16_mask.gif  
 extracting: train_masks/fb1b923dd978_01_mask.gif  
 extracting: train_masks/fdc2c87853ce_08_mask.gif  
 extracting: train_masks/fdc2c87853ce_09_mask.gif  
 extracting: train_masks/eaf9eb0b2293_13_mask.gif  
  inflating: train_masks/fa006be8b6d9_13_mask.gif  
  inflating: train_masks/fa006be8b6d9_12_mask.gif  
 extracting: train_masks/f707d6fbc0cd_16_mask.gif  
 extracting: train_masks/fff9b3a5373f_15_mask.gif  
 extracting: train_masks/fff9b3a5373f_14_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_06_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_07_mask.gif  
 extracting: train_masks/ed8472086df8_11_mask.gif  
 extracting: train_masks/ed8472086df8_10_mask.gif  
 extracting: train_masks/eeb7eeca738e_08_mask.gif  
 extracting: train_masks/eeb7eeca738e_09_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_10_mask.gif  
  inflating: train_masks/fce0ba5b8ed7_11_mask.gif  
 extracting: train_masks/f1eb080c7182_14_mask.gif  
 extracting: train_masks/f1eb080c7182_15_mask.gif  
 extracting: train_masks/f7ad86e13ed7_11_mask.gif  
 extracting: train_masks/f7ad86e13ed7_10_mask.gif  
 extracting: train_masks/f3eee6348205_15_mask.gif  
 extracting: train_masks/f3eee6348205_14_mask.gif  
 extracting: train_masks/fc237174b128_15_mask.gif  
 extracting: train_masks/fc237174b128_14_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_13_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_12_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_08_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_09_mask.gif  
 extracting: train_masks/fa006be8b6d9_01_mask.gif  
  inflating: train_masks/eefc0d8c94f0_05_mask.gif  
 extracting: train_masks/fb1b923dd978_13_mask.gif  
  inflating: train_masks/eefc0d8c94f0_04_mask.gif  
 extracting: train_masks/fb1b923dd978_12_mask.gif  
 extracting: train_masks/f3b482e091c0_11_mask.gif  
 extracting: train_masks/f8b6f4c39204_16_mask.gif  
 extracting: train_masks/f3b482e091c0_10_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_15_mask.gif  
 extracting: train_masks/f591b4f2e006_04_mask.gif  
 extracting: train_masks/f591b4f2e006_05_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_14_mask.gif  
 extracting: train_masks/feaf59172a01_01_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_01_mask.gif  
 extracting: train_masks/eb91b1c659a0_13_mask.gif  
 extracting: train_masks/eb91b1c659a0_12_mask.gif  
 extracting: train_masks/ef5567efd904_01_mask.gif  
 extracting: train_masks/eeb7eeca738e_02_mask.gif  
 extracting: train_masks/eeb7eeca738e_03_mask.gif  
 extracting: train_masks/fecea3036c59_01_mask.gif  
 extracting: train_masks/efaef69e148d_07_mask.gif  
 extracting: train_masks/efaef69e148d_06_mask.gif  
 extracting: train_masks/f70052627830_16_mask.gif  
 extracting: train_masks/fa613ac8eac5_04_mask.gif  
 extracting: train_masks/fa613ac8eac5_05_mask.gif  
 extracting: train_masks/f00905abd3d7_13_mask.gif  
 extracting: train_masks/f00905abd3d7_12_mask.gif  
 extracting: train_masks/fdc2c87853ce_11_mask.gif  
 extracting: train_masks/fdc2c87853ce_10_mask.gif  
 extracting: train_masks/f4cd1286d5f4_07_mask.gif  
  inflating: train_masks/eb07e3f63ad2_15_mask.gif  
 extracting: train_masks/f4cd1286d5f4_06_mask.gif  
  inflating: train_masks/eb07e3f63ad2_14_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_02_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_03_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_04_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_05_mask.gif  
  inflating: train_masks/fc237174b128_02_mask.gif  
 extracting: train_masks/fc237174b128_03_mask.gif  
 extracting: train_masks/f1eb080c7182_03_mask.gif  
 extracting: train_masks/f1eb080c7182_02_mask.gif  
 extracting: train_masks/f3eee6348205_02_mask.gif  
 extracting: train_masks/f3eee6348205_03_mask.gif  
 extracting: train_masks/f7ad86e13ed7_06_mask.gif  
 extracting: train_masks/f7ad86e13ed7_07_mask.gif  
 extracting: train_masks/ed8472086df8_06_mask.gif  
 extracting: train_masks/ed8472086df8_07_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_11_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_10_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_07_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_06_mask.gif  
 extracting: train_masks/f707d6fbc0cd_01_mask.gif  
 extracting: train_masks/fff9b3a5373f_02_mask.gif  
 extracting: train_masks/fff9b3a5373f_03_mask.gif  
 extracting: train_masks/f3b482e091c0_06_mask.gif  
  inflating: train_masks/f8b6f4c39204_01_mask.gif  
 extracting: train_masks/f3b482e091c0_07_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_02_mask.gif  
  inflating: train_masks/f591b4f2e006_13_mask.gif  
 extracting: train_masks/f591b4f2e006_12_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_03_mask.gif  
  inflating: train_masks/eb07e3f63ad2_08_mask.gif  
  inflating: train_masks/eb07e3f63ad2_09_mask.gif  
 extracting: train_masks/fa006be8b6d9_16_mask.gif  
 extracting: train_masks/eaf9eb0b2293_16_mask.gif  
 extracting: train_masks/fb1b923dd978_04_mask.gif  
  inflating: train_masks/eefc0d8c94f0_12_mask.gif  
 extracting: train_masks/fb1b923dd978_05_mask.gif  
  inflating: train_masks/eefc0d8c94f0_13_mask.gif  
 extracting: train_masks/fff9b3a5373f_08_mask.gif  
 extracting: train_masks/fff9b3a5373f_09_mask.gif  
 extracting: train_masks/efaef69e148d_10_mask.gif  
 extracting: train_masks/efaef69e148d_11_mask.gif  
 extracting: train_masks/eeb7eeca738e_15_mask.gif  
 extracting: train_masks/eeb7eeca738e_14_mask.gif  
 extracting: train_masks/fecea3036c59_16_mask.gif  
 extracting: train_masks/f3eee6348205_08_mask.gif  
 extracting: train_masks/f3eee6348205_09_mask.gif  
 extracting: train_masks/eb91b1c659a0_04_mask.gif  
 extracting: train_masks/eb91b1c659a0_05_mask.gif  
 extracting: train_masks/f1eb080c7182_09_mask.gif  
 extracting: train_masks/f1eb080c7182_08_mask.gif  
 extracting: train_masks/ef5567efd904_16_mask.gif  
 extracting: train_masks/fc237174b128_08_mask.gif  
 extracting: train_masks/fc237174b128_09_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_16_mask.gif  
 extracting: train_masks/feaf59172a01_16_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_15_mask.gif  
  inflating: train_masks/fc5f1a3a66cf_14_mask.gif  
 extracting: train_masks/fdc2c87853ce_06_mask.gif  
 extracting: train_masks/fdc2c87853ce_07_mask.gif  
 extracting: train_masks/f4cd1286d5f4_10_mask.gif  
  inflating: train_masks/eb07e3f63ad2_02_mask.gif  
 extracting: train_masks/f4cd1286d5f4_11_mask.gif  
  inflating: train_masks/eb07e3f63ad2_03_mask.gif  
  inflating: train_masks/f98dbe8a5ee2_08_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_09_mask.gif  
 extracting: train_masks/f00905abd3d7_04_mask.gif  
 extracting: train_masks/f00905abd3d7_05_mask.gif  
 extracting: train_masks/fa613ac8eac5_13_mask.gif  
 extracting: train_masks/fa613ac8eac5_12_mask.gif  
 extracting: train_masks/f70052627830_01_mask.gif  
 extracting: train_masks/f1eb080c7182_04_mask.gif  
 extracting: train_masks/f1eb080c7182_05_mask.gif  
 extracting: train_masks/f3eee6348205_05_mask.gif  
 extracting: train_masks/f3eee6348205_04_mask.gif  
 extracting: train_masks/eb91b1c659a0_09_mask.gif  
 extracting: train_masks/f7ad86e13ed7_01_mask.gif  
 extracting: train_masks/eb91b1c659a0_08_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_03_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_02_mask.gif  
  inflating: train_masks/fc237174b128_05_mask.gif  
 extracting: train_masks/fc237174b128_04_mask.gif  
 extracting: train_masks/f707d6fbc0cd_07_mask.gif  
 extracting: train_masks/f707d6fbc0cd_06_mask.gif  
 extracting: train_masks/fff9b3a5373f_05_mask.gif  
 extracting: train_masks/fff9b3a5373f_04_mask.gif  
 extracting: train_masks/ed8472086df8_01_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_16_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_01_mask.gif  
 extracting: train_masks/f00905abd3d7_09_mask.gif  
 extracting: train_masks/f00905abd3d7_08_mask.gif  
  inflating: train_masks/f8b6f4c39204_07_mask.gif  
 extracting: train_masks/f3b482e091c0_01_mask.gif  
 extracting: train_masks/f8b6f4c39204_06_mask.gif  
 extracting: train_masks/f591b4f2e006_14_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_05_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_04_mask.gif  
 extracting: train_masks/f591b4f2e006_15_mask.gif  
 extracting: train_masks/fa006be8b6d9_10_mask.gif  
 extracting: train_masks/fa006be8b6d9_11_mask.gif  
  inflating: train_masks/eefc0d8c94f0_15_mask.gif  
 extracting: train_masks/fb1b923dd978_03_mask.gif  
  inflating: train_masks/eefc0d8c94f0_14_mask.gif  
 extracting: train_masks/fb1b923dd978_02_mask.gif  
 extracting: train_masks/eeb7eeca738e_12_mask.gif  
 extracting: train_masks/eeb7eeca738e_13_mask.gif  
 extracting: train_masks/fecea3036c59_10_mask.gif  
 extracting: train_masks/fecea3036c59_11_mask.gif  
 extracting: train_masks/efaef69e148d_16_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_09_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_08_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_10_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_11_mask.gif  
 extracting: train_masks/feaf59172a01_10_mask.gif  
 extracting: train_masks/feaf59172a01_11_mask.gif  
 extracting: train_masks/eb91b1c659a0_03_mask.gif  
 extracting: train_masks/eb91b1c659a0_02_mask.gif  
 extracting: train_masks/ef5567efd904_10_mask.gif  
 extracting: train_masks/ef5567efd904_11_mask.gif  
 extracting: train_masks/fb1b923dd978_09_mask.gif  
 extracting: train_masks/fb1b923dd978_08_mask.gif  
 extracting: train_masks/fdc2c87853ce_01_mask.gif  
  inflating: train_masks/eb07e3f63ad2_05_mask.gif  
 extracting: train_masks/f4cd1286d5f4_16_mask.gif  
  inflating: train_masks/eb07e3f63ad2_04_mask.gif  
  inflating: train_masks/fc5f1a3a66cf_12_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_13_mask.gif  
 extracting: train_masks/fa613ac8eac5_14_mask.gif  
 extracting: train_masks/fa613ac8eac5_15_mask.gif  
 extracting: train_masks/f70052627830_07_mask.gif  
 extracting: train_masks/f70052627830_06_mask.gif  
 extracting: train_masks/f00905abd3d7_03_mask.gif  
 extracting: train_masks/f00905abd3d7_02_mask.gif  
  inflating: train_masks/ebfdf6ec7ede_01_mask.gif  
 extracting: train_masks/ed8472086df8_16_mask.gif  
 extracting: train_masks/fce0ba5b8ed7_16_mask.gif  
 extracting: train_masks/f707d6fbc0cd_10_mask.gif  
 extracting: train_masks/f707d6fbc0cd_11_mask.gif  
 extracting: train_masks/fff9b3a5373f_12_mask.gif  
 extracting: train_masks/fff9b3a5373f_13_mask.gif  
 extracting: train_masks/fc237174b128_12_mask.gif  
  inflating: train_masks/fc237174b128_13_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_14_mask.gif  
 extracting: train_masks/fd9da5d0bb6f_15_mask.gif  
 extracting: train_masks/f1eb080c7182_13_mask.gif  
 extracting: train_masks/f1eb080c7182_12_mask.gif  
 extracting: train_masks/f7ad86e13ed7_16_mask.gif  
 extracting: train_masks/f3eee6348205_12_mask.gif  
 extracting: train_masks/f3eee6348205_13_mask.gif  
 extracting: train_masks/fa006be8b6d9_07_mask.gif  
 extracting: train_masks/fa006be8b6d9_06_mask.gif  
 extracting: train_masks/fb1b923dd978_14_mask.gif  
  inflating: train_masks/eefc0d8c94f0_02_mask.gif  
 extracting: train_masks/fb1b923dd978_15_mask.gif  
  inflating: train_masks/eefc0d8c94f0_03_mask.gif  
 extracting: train_masks/fa613ac8eac5_09_mask.gif  
 extracting: train_masks/fa613ac8eac5_08_mask.gif  
 extracting: train_masks/f8b6f4c39204_10_mask.gif  
 extracting: train_masks/f3b482e091c0_16_mask.gif  
 extracting: train_masks/f8b6f4c39204_11_mask.gif  
 extracting: train_masks/f591b4f2e006_03_mask.gif  
  inflating: train_masks/f98dbe8a5ee2_12_mask.gif  
 extracting: train_masks/f98dbe8a5ee2_13_mask.gif  
 extracting: train_masks/f591b4f2e006_02_mask.gif  
 extracting: train_masks/eb91b1c659a0_14_mask.gif  
 extracting: train_masks/eb91b1c659a0_15_mask.gif  
 extracting: train_masks/ef5567efd904_07_mask.gif  
 extracting: train_masks/ef5567efd904_06_mask.gif  
  inflating: train_masks/feaf59172a01_07_mask.gif  
 extracting: train_masks/feaf59172a01_06_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_07_mask.gif  
 extracting: train_masks/ed13cbcdd5d8_06_mask.gif  
 extracting: train_masks/efaef69e148d_01_mask.gif  
 extracting: train_masks/eeb7eeca738e_05_mask.gif  
 extracting: train_masks/eeb7eeca738e_04_mask.gif  
 extracting: train_masks/fecea3036c59_07_mask.gif  
 extracting: train_masks/fecea3036c59_06_mask.gif  
 extracting: train_masks/f591b4f2e006_09_mask.gif  
 extracting: train_masks/f591b4f2e006_08_mask.gif  
 extracting: train_masks/f00905abd3d7_14_mask.gif  
 extracting: train_masks/f00905abd3d7_15_mask.gif  
 extracting: train_masks/f70052627830_10_mask.gif  
 extracting: train_masks/f70052627830_11_mask.gif  
 extracting: train_masks/fa613ac8eac5_03_mask.gif  
 extracting: train_masks/fa613ac8eac5_02_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_05_mask.gif  
 extracting: train_masks/fc5f1a3a66cf_04_mask.gif  
 extracting: train_masks/fdc2c87853ce_16_mask.gif  
  inflating: train_masks/eefc0d8c94f0_08_mask.gif  
  inflating: train_masks/eefc0d8c94f0_09_mask.gif  
  inflating: train_masks/eb07e3f63ad2_12_mask.gif  
 extracting: train_masks/f4cd1286d5f4_01_mask.gif  
  inflating: train_masks/eb07e3f63ad2_13_mask.gif  
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets,transforms, models

Set Hyperparameters

batch_size = 16
epochs = 5

#
# GPU CPU - nice way to setup the device as it works on any machine
#
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f'Device is {device}')

if device == 'cuda':
    print(f'CUDA device {torch.cuda.device(0)}')
    print(f'Number of devices: {torch.cuda.device_count()}')
    print(f'Device name: {torch.cuda.get_device_name(0)}')
Device is cuda
CUDA device <torch.cuda.device object at 0x7e18ce907640>
Number of devices: 1
Device name: Tesla T4

18.3. Load the Data#

train_transform = transforms.Compose([
                                transforms.Resize(64),
                                transforms.CenterCrop(64),

                                # converts 0-255 to 0-1 and rowxcolxchan to chanxrowxcol
                                transforms.ToTensor(),
])

Dataset

class CarDataset(Dataset):

    def __init__(self, filename_cars, filename_masks, transform):
        """
        Initialized
        """
        super().__init__()

        # Store variables we are interested in...

        self._filename_cars = filename_cars
        self._filename_masks = filename_masks

        self._transforms = transform

    def __getitem__(self, index):
        """
        Get a single image / label pair.
        """

        #
        # Read in the image
        #
        name = self._filename_cars[index+1] # double-check
        image = Image.open(name)

        #
        # Read in the mask
        #
        name = name.replace('train/', 'train_masks/').replace('.jpg', '_mask.gif')
        mask = Image.open(name)

        #
        #  Can do further processing here or anything else
        #

        # image = clahe(image)

        #
        # Do transformations on it (typicalyl data augmentation)
        #
        if self._transforms is not None:
            image = self._transforms(image)
            mask = self._transforms(mask)

        #
        # Return the image mask pair
        #
        return image, mask[0]>0 #extract channel 0; if pixel >0 set to True

    def __len__(self):
        """
        Return length of the dataset
        """
        return len(self._filename_cars)-1  # double-check

18.4. Instantiate the Dataset and DataLoader#

filenames_train = glob.glob('/content/train/*.jpg')
filenames_train_mask = glob.glob('/content/train_masks/*.gif')

train_dataset = CarDataset(filenames_train, filenames_train_mask, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8)
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(

Show an Example

len(train_dataset)
499
image, mask = train_dataset[2]

plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(image).transpose((1,2,0)))
plt.subplot(1,2,2)
plt.imshow(np.array(mask).squeeze())
plt.show()
_images/1f705d91cc23a2fa2997075ed15181d841f254a5154a68abc7857a96e3ab2c05.png
# Convert to numpy arrays
image_np = np.array(image)  # Image should already be in a 3D array (C, H, W)
mask_np = np.array(mask)    # Mask should already be in a 2D array (H, W)

print(np.shape(image_np), np.shape(mask_np))
(3, 64, 64) (64, 64)

18.5. The U-Net Network#

class contracting(nn.Module):
    def __init__(self):
        super().__init__()
        # Conv2d n_channels, out_channels, kernel_size
        # In U-Net (and many other CNN architectures), it’s common to use two consecutive convolutional layers before downsampling (with max pooling)
        # This contributes to increase feature representation
        self.layer1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU()) # input: (3, 64, 64); output: 64x64x64

        self.layer2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU()) # input 64x32x32; output: 128x32x32

        self.layer3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU()) # input 128x16x16; output: 256x16x16

        self.layer4 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU()) # input 256x8x8; output: 512x8x8

        self.layer5 = nn.Sequential(nn.Conv2d(512, 1024, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(1024, 1024, 3, stride=1, padding=1), nn.ReLU()) # input 512x4x4; output: 1024x4x4

        self.down_sample = nn.MaxPool2d(2, stride=2)


    def forward(self, X):
        X1 = self.layer1(X)
        X2 = self.layer2(self.down_sample(X1))
        X3 = self.layer3(self.down_sample(X2))
        X4 = self.layer4(self.down_sample(X3))
        X5 = self.layer5(self.down_sample(X4))
        return X5, X4, X3, X2, X1


class expansive(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer1 = nn.Conv2d(64, 2, 3, stride=1, padding=1)

        self.layer2 = nn.Sequential(nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU())

        self.layer3 = nn.Sequential(nn.Conv2d(256, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())

        self.layer4 = nn.Sequential(nn.Conv2d(512, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())

        self.layer5 = nn.Sequential(nn.Conv2d(1024, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())

        # N.b.: for ConvTranspose2d:
        # 1. ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0)
        # 2. output_dim=(n−1)×s−2p+m+output_padding

        self.up_sample_54 = nn.ConvTranspose2d(1024, 512, 2, stride=2) # input: 1024x4x4; output: 512x8x8

        self.up_sample_43 = nn.ConvTranspose2d(512, 256, 2, stride=2)

        self.up_sample_32 = nn.ConvTranspose2d(256, 128, 2, stride=2)

        self.up_sample_21 = nn.ConvTranspose2d(128, 64, 2, stride=2)


    def forward(self, X5, X4, X3, X2, X1):
        X = self.up_sample_54(X5) # input: 1024x4x4; output: 512x8x8
        X4 = torch.cat([X, X4], dim=1) # Concatenate 512x8x8 with 512x8x8 to give 1024x8x8
        X4 = self.layer5(X4)   # Reduces the channels to 512x8x8

        X = self.up_sample_43(X4)
        X3 = torch.cat([X, X3], dim=1)
        X3 = self.layer4(X3)

        X = self.up_sample_32(X3)
        X2 = torch.cat([X, X2], dim=1)
        X2 = self.layer3(X2)

        X = self.up_sample_21(X2)
        X1 = torch.cat([X, X1], dim=1)
        X1 = self.layer2(X1)

        X = self.layer1(X1)         # final output should be 2x64x64)

        return X


class unet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.down = contracting()

        # Decoder
        self.up = expansive()

    def forward(self, X):
        # Encoder
        X5, X4, X3, X2, X1 = self.down(X)

        # Decoder
        X = self.up(X5, X4, X3, X2, X1)
        return X
# check
model = unet()
tmpx = torch.ones((1,3, 64, 64))
model(tmpx).shape
torch.Size([1, 2, 64, 64])

18.6. Train the Network#

# Create network optimizer and loss function

model = unet()
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss = torch.nn.CrossEntropyLoss()
epochs = 10

epoch_loss = []

for epoch in range(epochs):
    print('='*30)
    print('Epoch {} / {}'.format(epoch, epochs))

    # Set variables
    correct = 0
    overlap = 0
    union = 0
    _len = 0
    l = 0
    count = 0

    # Loop over the batches
    for index, (X, Y) in enumerate(train_dataloader):
        print(f'\tBatch {index}')

        if device is not None:
            X = X.to(device)
            Y = Y.to(device)

        # Call the model (image to mask)
        R = model(X)

        # Compute the loss
        L = loss(R, Y.long())

        # Do PyTorch stuff
        optimizer.zero_grad()
        L.backward()
        optimizer.step()

        # Compute Stats
        pred = R.data.max(1)[1]
        pred_sum, label_sum, overlap_sum = (pred==1).sum(), (Y==1).sum(), (pred*Y==1).sum()
        print(f'\t label_sum {label_sum}  pred_sum {pred_sum}  overlap_sum {overlap_sum}')

#         plt.figure(1)
#         plt.subplot(1,2,1)
#         plt.imshow(Y[0].cpu())
#         plt.clim((0,1))
#         plt.subplot(1,2,2)
#         plt.imshow(pred[0].cpu())
#         plt.clim((0,1))
#         plt.show()

        union_sum = pred_sum+label_sum-overlap_sum

        # IoU for accuracy
        overlap = overlap+overlap_sum.data.item()
        union = union+union_sum.data.item()
        l = l+L.data.item()
        count = count+1

    _loss = l/count
    _accuracy = overlap/union
    string = "epoch: {}, accuracy: {}, loss: {}".format(epoch, _accuracy, _loss)
    print(string)

    epoch_loss.append(_loss)
==============================
Epoch 0 / 10
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
	Batch 0
	 label_sum 19901  pred_sum 16  overlap_sum 0
	Batch 1
	 label_sum 19453  pred_sum 16  overlap_sum 0
	Batch 2
	 label_sum 19286  pred_sum 0  overlap_sum 0
	Batch 3
	 label_sum 20708  pred_sum 0  overlap_sum 0
	Batch 4
	 label_sum 21073  pred_sum 0  overlap_sum 0
	Batch 5
	 label_sum 19582  pred_sum 0  overlap_sum 0
	Batch 6
	 label_sum 19161  pred_sum 0  overlap_sum 0
	Batch 7
	 label_sum 20005  pred_sum 0  overlap_sum 0
	Batch 8
	 label_sum 18707  pred_sum 0  overlap_sum 0
	Batch 9
	 label_sum 19435  pred_sum 0  overlap_sum 0
	Batch 10
	 label_sum 18486  pred_sum 0  overlap_sum 0
	Batch 11
	 label_sum 19524  pred_sum 0  overlap_sum 0
	Batch 12
	 label_sum 19655  pred_sum 0  overlap_sum 0
	Batch 13
	 label_sum 20346  pred_sum 0  overlap_sum 0
	Batch 14
	 label_sum 19963  pred_sum 0  overlap_sum 0
	Batch 15
	 label_sum 20673  pred_sum 0  overlap_sum 0
	Batch 16
	 label_sum 18831  pred_sum 0  overlap_sum 0
	Batch 17
	 label_sum 18553  pred_sum 0  overlap_sum 0
	Batch 18
	 label_sum 21036  pred_sum 0  overlap_sum 0
	Batch 19
	 label_sum 19219  pred_sum 0  overlap_sum 0
	Batch 20
	 label_sum 18653  pred_sum 0  overlap_sum 0
	Batch 21
	 label_sum 19450  pred_sum 0  overlap_sum 0
	Batch 22
	 label_sum 20742  pred_sum 0  overlap_sum 0
	Batch 23
	 label_sum 18806  pred_sum 266  overlap_sum 265
	Batch 24
	 label_sum 19907  pred_sum 1526  overlap_sum 1403
	Batch 25
	 label_sum 18737  pred_sum 7302  overlap_sum 6142
	Batch 26
	 label_sum 20334  pred_sum 32548  overlap_sum 18777
	Batch 27
	 label_sum 20472  pred_sum 9054  overlap_sum 7211
	Batch 28
	 label_sum 19728  pred_sum 3238  overlap_sum 2783
	Batch 29
	 label_sum 20615  pred_sum 2773  overlap_sum 2392
	Batch 30
	 label_sum 19567  pred_sum 4261  overlap_sum 3626
	Batch 31
	 label_sum 4216  pred_sum 3831  overlap_sum 3257
epoch: 0, accuracy: 0.07235101349165902, loss: 0.5556896096095443
==============================
Epoch 1 / 10
	Batch 0
	 label_sum 19901  pred_sum 29146  overlap_sum 18353
	Batch 1
	 label_sum 19453  pred_sum 21029  overlap_sum 16304
	Batch 2
	 label_sum 19286  pred_sum 10758  overlap_sum 9505
	Batch 3
	 label_sum 20708  pred_sum 11144  overlap_sum 9861
	Batch 4
	 label_sum 21073  pred_sum 12725  overlap_sum 10707
	Batch 5
	 label_sum 19582  pred_sum 28609  overlap_sum 18142
	Batch 6
	 label_sum 19161  pred_sum 21538  overlap_sum 15481
	Batch 7
	 label_sum 20005  pred_sum 10629  overlap_sum 8982
	Batch 8
	 label_sum 18707  pred_sum 15723  overlap_sum 13386
	Batch 9
	 label_sum 19435  pred_sum 20889  overlap_sum 16155
	Batch 10
	 label_sum 18486  pred_sum 22068  overlap_sum 16139
	Batch 11
	 label_sum 19524  pred_sum 15868  overlap_sum 13676
	Batch 12
	 label_sum 19655  pred_sum 16774  overlap_sum 14319
	Batch 13
	 label_sum 20346  pred_sum 23159  overlap_sum 17565
	Batch 14
	 label_sum 19963  pred_sum 20095  overlap_sum 15965
	Batch 15
	 label_sum 20673  pred_sum 16624  overlap_sum 14877
	Batch 16
	 label_sum 18831  pred_sum 18962  overlap_sum 14278
	Batch 17
	 label_sum 18553  pred_sum 21839  overlap_sum 16579
	Batch 18
	 label_sum 21036  pred_sum 16080  overlap_sum 13778
	Batch 19
	 label_sum 19219  pred_sum 19611  overlap_sum 16002
	Batch 20
	 label_sum 18653  pred_sum 21800  overlap_sum 16548
	Batch 21
	 label_sum 19450  pred_sum 18788  overlap_sum 15974
	Batch 22
	 label_sum 20742  pred_sum 15188  overlap_sum 13781
	Batch 23
	 label_sum 18806  pred_sum 18781  overlap_sum 15668
	Batch 24
	 label_sum 19907  pred_sum 22217  overlap_sum 17898
	Batch 25
	 label_sum 18737  pred_sum 21804  overlap_sum 16960
	Batch 26
	 label_sum 20334  pred_sum 18253  overlap_sum 15771
	Batch 27
	 label_sum 20472  pred_sum 18250  overlap_sum 15439
	Batch 28
	 label_sum 19728  pred_sum 21485  overlap_sum 16982
	Batch 29
	 label_sum 20615  pred_sum 22933  overlap_sum 17895
	Batch 30
	 label_sum 19567  pred_sum 20096  overlap_sum 16175
	Batch 31
	 label_sum 4216  pred_sum 4026  overlap_sum 3520
epoch: 1, accuracy: 0.6395575400852446, loss: 0.3033404001034796
==============================
Epoch 2 / 10
	Batch 0
	 label_sum 19901  pred_sum 17079  overlap_sum 15078
	Batch 1
	 label_sum 19453  pred_sum 19785  overlap_sum 16534
	Batch 2
	 label_sum 19286  pred_sum 22282  overlap_sum 17354
	Batch 3
	 label_sum 20708  pred_sum 21810  overlap_sum 18179
	Batch 4
	 label_sum 21073  pred_sum 18735  overlap_sum 16269
	Batch 5
	 label_sum 19582  pred_sum 19617  overlap_sum 16536
	Batch 6
	 label_sum 19161  pred_sum 20054  overlap_sum 15789
	Batch 7
	 label_sum 20005  pred_sum 20432  overlap_sum 17099
	Batch 8
	 label_sum 18707  pred_sum 20846  overlap_sum 16289
	Batch 9
	 label_sum 19435  pred_sum 18201  overlap_sum 15664
	Batch 10
	 label_sum 18486  pred_sum 16435  overlap_sum 14199
	Batch 11
	 label_sum 19524  pred_sum 18662  overlap_sum 15963
	Batch 12
	 label_sum 19655  pred_sum 21832  overlap_sum 17395
	Batch 13
	 label_sum 20346  pred_sum 21396  overlap_sum 17679
	Batch 14
	 label_sum 19963  pred_sum 19312  overlap_sum 16187
	Batch 15
	 label_sum 20673  pred_sum 19247  overlap_sum 17042
	Batch 16
	 label_sum 18831  pred_sum 18496  overlap_sum 14390
	Batch 17
	 label_sum 18553  pred_sum 20289  overlap_sum 16466
	Batch 18
	 label_sum 21036  pred_sum 19743  overlap_sum 17089
	Batch 19
	 label_sum 19219  pred_sum 20019  overlap_sum 16693
	Batch 20
	 label_sum 18653  pred_sum 19754  overlap_sum 16478
	Batch 21
	 label_sum 19450  pred_sum 19337  overlap_sum 16460
	Batch 22
	 label_sum 20742  pred_sum 19033  overlap_sum 16950
	Batch 23
	 label_sum 18806  pred_sum 19532  overlap_sum 16441
	Batch 24
	 label_sum 19907  pred_sum 20577  overlap_sum 17690
	Batch 25
	 label_sum 18737  pred_sum 20237  overlap_sum 16665
	Batch 26
	 label_sum 20334  pred_sum 18704  overlap_sum 16287
	Batch 27
	 label_sum 20472  pred_sum 20029  overlap_sum 16860
	Batch 28
	 label_sum 19728  pred_sum 21178  overlap_sum 17303
	Batch 29
	 label_sum 20615  pred_sum 21401  overlap_sum 17912
	Batch 30
	 label_sum 19567  pred_sum 18793  overlap_sum 16020
	Batch 31
	 label_sum 4216  pred_sum 4415  overlap_sum 3784
epoch: 2, accuracy: 0.7223733542836853, loss: 0.22777789225801826
==============================
Epoch 3 / 10
	Batch 0
	 label_sum 19901  pred_sum 19658  overlap_sum 16884
	Batch 1
	 label_sum 19453  pred_sum 20159  overlap_sum 17205
	Batch 2
	 label_sum 19286  pred_sum 20232  overlap_sum 17062
	Batch 3
	 label_sum 20708  pred_sum 20966  overlap_sum 18202
	Batch 4
	 label_sum 21073  pred_sum 19088  overlap_sum 16816
	Batch 5
	 label_sum 19582  pred_sum 21167  overlap_sum 17476
	Batch 6
	 label_sum 19161  pred_sum 20445  overlap_sum 16352
	Batch 7
	 label_sum 20005  pred_sum 18055  overlap_sum 15942
	Batch 8
	 label_sum 18707  pred_sum 20661  overlap_sum 16640
	Batch 9
	 label_sum 19435  pred_sum 19786  overlap_sum 16880
	Batch 10
	 label_sum 18486  pred_sum 17775  overlap_sum 15277
	Batch 11
	 label_sum 19524  pred_sum 19027  overlap_sum 16574
	Batch 12
	 label_sum 19655  pred_sum 21407  overlap_sum 17631
	Batch 13
	 label_sum 20346  pred_sum 20860  overlap_sum 17867
	Batch 14
	 label_sum 19963  pred_sum 19984  overlap_sum 16959
	Batch 15
	 label_sum 20673  pred_sum 20839  overlap_sum 18217
	Batch 16
	 label_sum 18831  pred_sum 18840  overlap_sum 14951
	Batch 17
	 label_sum 18553  pred_sum 19587  overlap_sum 16526
	Batch 18
	 label_sum 21036  pred_sum 19829  overlap_sum 17501
	Batch 19
	 label_sum 19219  pred_sum 20964  overlap_sum 17390
	Batch 20
	 label_sum 18653  pred_sum 19992  overlap_sum 16942
	Batch 21
	 label_sum 19450  pred_sum 18968  overlap_sum 16541
	Batch 22
	 label_sum 20742  pred_sum 19494  overlap_sum 17546
	Batch 23
	 label_sum 18806  pred_sum 20196  overlap_sum 17085
	Batch 24
	 label_sum 19907  pred_sum 21018  overlap_sum 18205
	Batch 25
	 label_sum 18737  pred_sum 20535  overlap_sum 17190
	Batch 26
	 label_sum 20334  pred_sum 18884  overlap_sum 16763
	Batch 27
	 label_sum 20472  pred_sum 20133  overlap_sum 17302
	Batch 28
	 label_sum 19728  pred_sum 21958  overlap_sum 17898
	Batch 29
	 label_sum 20615  pred_sum 20901  overlap_sum 17992
	Batch 30
	 label_sum 19567  pred_sum 18310  overlap_sum 16143
	Batch 31
	 label_sum 4216  pred_sum 4609  overlap_sum 3921
epoch: 3, accuracy: 0.7520172607105339, loss: 0.2053607814013958
==============================
Epoch 4 / 10
	Batch 0
	 label_sum 19901  pred_sum 21267  overlap_sum 18045
	Batch 1
	 label_sum 19453  pred_sum 19796  overlap_sum 17338
	Batch 2
	 label_sum 19286  pred_sum 19555  overlap_sum 17228
	Batch 3
	 label_sum 20708  pred_sum 21759  overlap_sum 18937
	Batch 4
	 label_sum 21073  pred_sum 20789  overlap_sum 18232
	Batch 5
	 label_sum 19582  pred_sum 20443  overlap_sum 17542
	Batch 6
	 label_sum 19161  pred_sum 20040  overlap_sum 16533
	Batch 7
	 label_sum 20005  pred_sum 19059  overlap_sum 17108
	Batch 8
	 label_sum 18707  pred_sum 20821  overlap_sum 17025
	Batch 9
	 label_sum 19435  pred_sum 20006  overlap_sum 17390
	Batch 10
	 label_sum 18486  pred_sum 17045  overlap_sum 15275
	Batch 11
	 label_sum 19524  pred_sum 19553  overlap_sum 17370
	Batch 12
	 label_sum 19655  pred_sum 22955  overlap_sum 18670
	Batch 13
	 label_sum 20346  pred_sum 20382  overlap_sum 18131
	Batch 14
	 label_sum 19963  pred_sum 18563  overlap_sum 16785
	Batch 15
	 label_sum 20673  pred_sum 21113  overlap_sum 18679
	Batch 16
	 label_sum 18831  pred_sum 20848  overlap_sum 16520
	Batch 17
	 label_sum 18553  pred_sum 19243  overlap_sum 16872
	Batch 18
	 label_sum 21036  pred_sum 18654  overlap_sum 17072
	Batch 19
	 label_sum 19219  pred_sum 20977  overlap_sum 17796
	Batch 20
	 label_sum 18653  pred_sum 20530  overlap_sum 17587
	Batch 21
	 label_sum 19450  pred_sum 20105  overlap_sum 17750
	Batch 22
	 label_sum 20742  pred_sum 20503  overlap_sum 18577
	Batch 23
	 label_sum 18806  pred_sum 19156  overlap_sum 17086
	Batch 24
	 label_sum 19907  pred_sum 20356  overlap_sum 18306
	Batch 25
	 label_sum 18737  pred_sum 20730  overlap_sum 17702
	Batch 26
	 label_sum 20334  pred_sum 20507  overlap_sum 18241
	Batch 27
	 label_sum 20472  pred_sum 19908  overlap_sum 17886
	Batch 28
	 label_sum 19728  pred_sum 20593  overlap_sum 17921
	Batch 29
	 label_sum 20615  pred_sum 20807  overlap_sum 18601
	Batch 30
	 label_sum 19567  pred_sum 20288  overlap_sum 17851
	Batch 31
	 label_sum 4216  pred_sum 4503  overlap_sum 3966
epoch: 4, accuracy: 0.7906522764124797, loss: 0.17321861116215587
==============================
Epoch 5 / 10
	Batch 0
	 label_sum 19901  pred_sum 20541  overlap_sum 18227
	Batch 1
	 label_sum 19453  pred_sum 18644  overlap_sum 17199
	Batch 2
	 label_sum 19286  pred_sum 21961  overlap_sum 18382
	Batch 3
	 label_sum 20708  pred_sum 21000  overlap_sum 19091
	Batch 4
	 label_sum 21073  pred_sum 17842  overlap_sum 16775
	Batch 5
	 label_sum 19582  pred_sum 21278  overlap_sum 18383
	Batch 6
	 label_sum 19161  pred_sum 24410  overlap_sum 18417
	Batch 7
	 label_sum 20005  pred_sum 17109  overlap_sum 16149
	Batch 8
	 label_sum 18707  pred_sum 16233  overlap_sum 15175
	Batch 9
	 label_sum 19435  pred_sum 19775  overlap_sum 17921
	Batch 10
	 label_sum 18486  pred_sum 20807  overlap_sum 17781
	Batch 11
	 label_sum 19524  pred_sum 21682  overlap_sum 18788
	Batch 12
	 label_sum 19655  pred_sum 22209  overlap_sum 18915
	Batch 13
	 label_sum 20346  pred_sum 19113  overlap_sum 17930
	Batch 14
	 label_sum 19963  pred_sum 17738  overlap_sum 16822
	Batch 15
	 label_sum 20673  pred_sum 19800  overlap_sum 18444
	Batch 16
	 label_sum 18831  pred_sum 19921  overlap_sum 17140
	Batch 17
	 label_sum 18553  pred_sum 21102  overlap_sum 17942
	Batch 18
	 label_sum 21036  pred_sum 21441  overlap_sum 19325
	Batch 19
	 label_sum 19219  pred_sum 19879  overlap_sum 17759
	Batch 20
	 label_sum 18653  pred_sum 17803  overlap_sum 16719
	Batch 21
	 label_sum 19450  pred_sum 19189  overlap_sum 17730
	Batch 22
	 label_sum 20742  pred_sum 21738  overlap_sum 19630
	Batch 23
	 label_sum 18806  pred_sum 20152  overlap_sum 17839
	Batch 24
	 label_sum 19907  pred_sum 20370  overlap_sum 18711
	Batch 25
	 label_sum 18737  pred_sum 18866  overlap_sum 17270
	Batch 26
	 label_sum 20334  pred_sum 20139  overlap_sum 18565
	Batch 27
	 label_sum 20472  pred_sum 21946  overlap_sum 19499
	Batch 28
	 label_sum 19728  pred_sum 20825  overlap_sum 18444
	Batch 29
	 label_sum 20615  pred_sum 19757  overlap_sum 18493
	Batch 30
	 label_sum 19567  pred_sum 18810  overlap_sum 17611
	Batch 31
	 label_sum 4216  pred_sum 4624  overlap_sum 4072
epoch: 5, accuracy: 0.8247567535788823, loss: 0.14396312390454113
==============================
Epoch 6 / 10
	Batch 0
	 label_sum 19901  pred_sum 22198  overlap_sum 19054
	Batch 1
	 label_sum 19453  pred_sum 18375  overlap_sum 17325
	Batch 2
	 label_sum 19286  pred_sum 18180  overlap_sum 17183
	Batch 3
	 label_sum 20708  pred_sum 21474  overlap_sum 19498
	Batch 4
	 label_sum 21073  pred_sum 22868  overlap_sum 20010
	Batch 5
	 label_sum 19582  pred_sum 20096  overlap_sum 18379
	Batch 6
	 label_sum 19161  pred_sum 18730  overlap_sum 17079
	Batch 7
	 label_sum 20005  pred_sum 19474  overlap_sum 18161
	Batch 8
	 label_sum 18707  pred_sum 19737  overlap_sum 17634
	Batch 9
	 label_sum 19435  pred_sum 20626  overlap_sum 18469
	Batch 10
	 label_sum 18486  pred_sum 18633  overlap_sum 17066
	Batch 11
	 label_sum 19524  pred_sum 18075  overlap_sum 17377
	Batch 12
	 label_sum 19655  pred_sum 19949  overlap_sum 18237
	Batch 13
	 label_sum 20346  pred_sum 19536  overlap_sum 18449
	Batch 14
	 label_sum 19963  pred_sum 19629  overlap_sum 18404
	Batch 15
	 label_sum 20673  pred_sum 21540  overlap_sum 19641
	Batch 16
	 label_sum 18831  pred_sum 19335  overlap_sum 17164
	Batch 17
	 label_sum 18553  pred_sum 19795  overlap_sum 17636
	Batch 18
	 label_sum 21036  pred_sum 21121  overlap_sum 19379
	Batch 19
	 label_sum 19219  pred_sum 20063  overlap_sum 18079
	Batch 20
	 label_sum 18653  pred_sum 18473  overlap_sum 17272
	Batch 21
	 label_sum 19450  pred_sum 19347  overlap_sum 17950
	Batch 22
	 label_sum 20742  pred_sum 21029  overlap_sum 19441
	Batch 23
	 label_sum 18806  pred_sum 18759  overlap_sum 17476
	Batch 24
	 label_sum 19907  pred_sum 19390  overlap_sum 18420
	Batch 25
	 label_sum 18737  pred_sum 18485  overlap_sum 17290
	Batch 26
	 label_sum 20334  pred_sum 20620  overlap_sum 19128
	Batch 27
	 label_sum 20472  pred_sum 22065  overlap_sum 19786
	Batch 28
	 label_sum 19728  pred_sum 19594  overlap_sum 18187
	Batch 29
	 label_sum 20615  pred_sum 18937  overlap_sum 18111
	Batch 30
	 label_sum 19567  pred_sum 19889  overlap_sum 18298
	Batch 31
	 label_sum 4216  pred_sum 4636  overlap_sum 4104
epoch: 6, accuracy: 0.8556492614092926, loss: 0.11515915510244668
==============================
Epoch 7 / 10
	Batch 0
	 label_sum 19901  pred_sum 21851  overlap_sum 19191
	Batch 1
	 label_sum 19453  pred_sum 18204  overlap_sum 17409
	Batch 2
	 label_sum 19286  pred_sum 17647  overlap_sum 17022
	Batch 3
	 label_sum 20708  pred_sum 21249  overlap_sum 19512
	Batch 4
	 label_sum 21073  pred_sum 23439  overlap_sum 20375
	Batch 5
	 label_sum 19582  pred_sum 19451  overlap_sum 18198
	Batch 6
	 label_sum 19161  pred_sum 17541  overlap_sum 16599
	Batch 7
	 label_sum 20005  pred_sum 18957  overlap_sum 18146
	Batch 8
	 label_sum 18707  pred_sum 20275  overlap_sum 18045
	Batch 9
	 label_sum 19435  pred_sum 21007  overlap_sum 18733
	Batch 10
	 label_sum 18486  pred_sum 18695  overlap_sum 17244
	Batch 11
	 label_sum 19524  pred_sum 17811  overlap_sum 17155
	Batch 12
	 label_sum 19655  pred_sum 20042  overlap_sum 18522
	Batch 13
	 label_sum 20346  pred_sum 19989  overlap_sum 18825
	Batch 14
	 label_sum 19963  pred_sum 19745  overlap_sum 18618
	Batch 15
	 label_sum 20673  pred_sum 21015  overlap_sum 19566
	Batch 16
	 label_sum 18831  pred_sum 18697  overlap_sum 17196
	Batch 17
	 label_sum 18553  pred_sum 18738  overlap_sum 17337
	Batch 18
	 label_sum 21036  pred_sum 20951  overlap_sum 19510
	Batch 19
	 label_sum 19219  pred_sum 20031  overlap_sum 18234
	Batch 20
	 label_sum 18653  pred_sum 18876  overlap_sum 17526
	Batch 21
	 label_sum 19450  pred_sum 19334  overlap_sum 17994
	Batch 22
	 label_sum 20742  pred_sum 20829  overlap_sum 19499
	Batch 23
	 label_sum 18806  pred_sum 18532  overlap_sum 17452
	Batch 24
	 label_sum 19907  pred_sum 19506  overlap_sum 18579
	Batch 25
	 label_sum 18737  pred_sum 18597  overlap_sum 17496
	Batch 26
	 label_sum 20334  pred_sum 20316  overlap_sum 19098
	Batch 27
	 label_sum 20472  pred_sum 21437  overlap_sum 19702
	Batch 28
	 label_sum 19728  pred_sum 19691  overlap_sum 18303
	Batch 29
	 label_sum 20615  pred_sum 19248  overlap_sum 18494
	Batch 30
	 label_sum 19567  pred_sum 19729  overlap_sum 18364
	Batch 31
	 label_sum 4216  pred_sum 4546  overlap_sum 4096
epoch: 7, accuracy: 0.8683587345922642, loss: 0.10284013347700238
==============================
Epoch 8 / 10
	Batch 0
	 label_sum 19901  pred_sum 20897  overlap_sum 18988
	Batch 1
	 label_sum 19453  pred_sum 18544  overlap_sum 17731
	Batch 2
	 label_sum 19286  pred_sum 18005  overlap_sum 17292
	Batch 3
	 label_sum 20708  pred_sum 21193  overlap_sum 19598
	Batch 4
	 label_sum 21073  pred_sum 22296  overlap_sum 20147
	Batch 5
	 label_sum 19582  pred_sum 18795  overlap_sum 17916
	Batch 6
	 label_sum 19161  pred_sum 17842  overlap_sum 16982
	Batch 7
	 label_sum 20005  pred_sum 19569  overlap_sum 18557
	Batch 8
	 label_sum 18707  pred_sum 20205  overlap_sum 18127
	Batch 9
	 label_sum 19435  pred_sum 20220  overlap_sum 18658
	Batch 10
	 label_sum 18486  pred_sum 18128  overlap_sum 17153
	Batch 11
	 label_sum 19524  pred_sum 18059  overlap_sum 17460
	Batch 12
	 label_sum 19655  pred_sum 20620  overlap_sum 18819
	Batch 13
	 label_sum 20346  pred_sum 20315  overlap_sum 19048
	Batch 14
	 label_sum 19963  pred_sum 19213  overlap_sum 18506
	Batch 15
	 label_sum 20673  pred_sum 20262  overlap_sum 19315
	Batch 16
	 label_sum 18831  pred_sum 18630  overlap_sum 17512
	Batch 17
	 label_sum 18553  pred_sum 19412  overlap_sum 17680
	Batch 18
	 label_sum 21036  pred_sum 21001  overlap_sum 19698
	Batch 19
	 label_sum 19219  pred_sum 19319  overlap_sum 17989
	Batch 20
	 label_sum 18653  pred_sum 18490  overlap_sum 17471
	Batch 21
	 label_sum 19450  pred_sum 19471  overlap_sum 18187
	Batch 22
	 label_sum 20742  pred_sum 21030  overlap_sum 19691
	Batch 23
	 label_sum 18806  pred_sum 18681  overlap_sum 17666
	Batch 24
	 label_sum 19907  pred_sum 19603  overlap_sum 18756
	Batch 25
	 label_sum 18737  pred_sum 18351  overlap_sum 17433
	Batch 26
	 label_sum 20334  pred_sum 20257  overlap_sum 19173
	Batch 27
	 label_sum 20472  pred_sum 21517  overlap_sum 19790
	Batch 28
	 label_sum 19728  pred_sum 19348  overlap_sum 18226
	Batch 29
	 label_sum 20615  pred_sum 18927  overlap_sum 18362
	Batch 30
	 label_sum 19567  pred_sum 20109  overlap_sum 18554
	Batch 31
	 label_sum 4216  pred_sum 4484  overlap_sum 4093
epoch: 8, accuracy: 0.8798525049805601, loss: 0.09189422405324876
==============================
Epoch 9 / 10
	Batch 0
	 label_sum 19901  pred_sum 20457  overlap_sum 19008
	Batch 1
	 label_sum 19453  pred_sum 19078  overlap_sum 18156
	Batch 2
	 label_sum 19286  pred_sum 18491  overlap_sum 17731
	Batch 3
	 label_sum 20708  pred_sum 21211  overlap_sum 19710
	Batch 4
	 label_sum 21073  pred_sum 21993  overlap_sum 20155
	Batch 5
	 label_sum 19582  pred_sum 18494  overlap_sum 17808
	Batch 6
	 label_sum 19161  pred_sum 18121  overlap_sum 17421
	Batch 7
	 label_sum 20005  pred_sum 19430  overlap_sum 18581
	Batch 8
	 label_sum 18707  pred_sum 19942  overlap_sum 18124
	Batch 9
	 label_sum 19435  pred_sum 20090  overlap_sum 18722
	Batch 10
	 label_sum 18486  pred_sum 18147  overlap_sum 17255
	Batch 11
	 label_sum 19524  pred_sum 18357  overlap_sum 17806
	Batch 12
	 label_sum 19655  pred_sum 20748  overlap_sum 18975
	Batch 13
	 label_sum 20346  pred_sum 20347  overlap_sum 19210
	Batch 14
	 label_sum 19963  pred_sum 19189  overlap_sum 18535
	Batch 15
	 label_sum 20673  pred_sum 20267  overlap_sum 19374
	Batch 16
	 label_sum 18831  pred_sum 18856  overlap_sum 17784
	Batch 17
	 label_sum 18553  pred_sum 19130  overlap_sum 17713
	Batch 18
	 label_sum 21036  pred_sum 20854  overlap_sum 19700
	Batch 19
	 label_sum 19219  pred_sum 19128  overlap_sum 17992
	Batch 20
	 label_sum 18653  pred_sum 18563  overlap_sum 17594
	Batch 21
	 label_sum 19450  pred_sum 19442  overlap_sum 18242
	Batch 22
	 label_sum 20742  pred_sum 20969  overlap_sum 19743
	Batch 23
	 label_sum 18806  pred_sum 18503  overlap_sum 17661
	Batch 24
	 label_sum 19907  pred_sum 19577  overlap_sum 18844
	Batch 25
	 label_sum 18737  pred_sum 18557  overlap_sum 17636
	Batch 26
	 label_sum 20334  pred_sum 20558  overlap_sum 19424
	Batch 27
	 label_sum 20472  pred_sum 21335  overlap_sum 19812
	Batch 28
	 label_sum 19728  pred_sum 19086  overlap_sum 18199
	Batch 29
	 label_sum 20615  pred_sum 19231  overlap_sum 18733
	Batch 30
	 label_sum 19567  pred_sum 20383  overlap_sum 18650
	Batch 31
	 label_sum 4216  pred_sum 4361  overlap_sum 4057
epoch: 9, accuracy: 0.8906483882691372, loss: 0.0831722195725888

18.7. Plot the loss curve#

plt.figure()
plt.plot(epoch_loss)
plt.title('Epoch Loss')
Text(0.5, 1.0, 'Epoch Loss')
_images/f6b79520a8a44a4475d1184584a7f8326bc12d63790c401cb9c3e36e5e61f8c7.png

18.8. Plot an Example#

plt.figure(1)

# Plot the actual label
plt.subplot(1,2,1)
plt.imshow(Y[0].cpu())
plt.clim((0,1))
plt.title('True Mask')

# Show the predicted label
plt.subplot(1,2,2)
plt.imshow(pred[0].cpu())
plt.clim((0,1))
plt.title('Predicted Mask')
plt.show()
_images/16b9d799c29fcae7bf335ec609d27e9e1fc8ffdeb8c038fdc6cb30b366f433fc.png

18.9. Exercise#

  • Complete this notebook by dividing your dataset in training, validation and test datasets.

  • Show learning curves.

  • Retrain your network and do segmentation on test dataset.

  • How do you assess the quality of your results?