Credits: https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation
18. U-Net Convolutional Networks for Image Segmentation#
18.1. U-Net#
U-Net is a widely used deep learning architecture that was first introduced in the “U-Net: Convolutional Networks for Biomedical Image Segmentation” paper. The primary purpose of this architecture was to address the challenge of limited annotated data in the medical field. This network was designed to effectively leverage a smaller amount of data while maintaining speed and accuracy.
The architecture of U-Net is unique in that it consists of a contracting path and an expansive path. The contracting path contains encoder layers that capture contextual information and reduce the spatial resolution of the input, while the expansive path contains decoder layers that decode the encoded data and use the information from the contracting path via skip connections to generate a segmentation map.
The contracting path in U-Net is responsible for identifying the relevant features in the input image. The encoder layers perform convolutional operations that reduce the spatial resolution of the feature maps while increasing their depth, thereby capturing increasingly abstract representations of the input. This contracting path is similar to the feedforward layers in other convolutional neural networks. On the other hand, the expansive path works on decoding the encoded data and locating the features while maintaining the spatial resolution of the input. The decoder layers in the expansive path upsample the feature maps, while also performing convolutional operations. The skip connections from the contracting path help to preserve the spatial information lost in the contracting path, which helps the decoder layers to locate the features more accurately.
Source: https://arxiv.org/pdf/1505.04597
18.2. Explanation#
Credits: see blog by A. Ito Armendia
The Contracting Path of U-Net
Block 1
An input image with dimensions 572² is fed into the U-Net. This input image consists of only 1 channel, likely a grayscale channel. Two 3x3 convolution layers (unpadded) are then applied to the input image, each followed by a ReLU layer. At the same time the number of channels are increased to 64 in order to capture higher level features. A 2x2 max pooling layer with a stride of 2 is then applied. This downsamples the feature map to half its size, 284².
Block 2
Just like in block 1, two 3x3 convolution layers (unpadded) are applied to the output of block 1, each followed again by a ReLU layer. At each new block the number of feature channels are doubled, now to 128. Next a 2x2 max pooling layer is again applied to the resulting feature map reducing the spatial dimensions by half to 140².
Block 3
The procedure used in block 1 and 2 is the same as in block 3, so will not be repeated.
Block 4
Same as block 3.
Block 5
In the final block of the contracting path, the number of feature channels reach 1024 after being doubled at each block. This block also contains two 3x3 convolution layers (unpadded), which are each followed by a ReLU layer. However, for symmetry purposes, I have only included one layer and included the second layer in the expanding path. After complex features and patterns have been extracted, the feature map moves on to the expanding path.
The Expanding Path
Block 5
Continuing on from the contracting path, a second 3x3 convolution (unpadded) is applied with a ReLU layer after it.
Then a 2x2 convolution (up-convolution) layer is applied, upsampling the spatial dimensions twofold and also halving the number of channels to 512.
Block 4
Using skip connections, the corresponding feature map from the contracting path is then concatenated, doubling the feature channels to 1024. Note that this concatenation must be cropped to match the expanding path’s dimensions.
Two 3x3 convolution layers (unpadded) are applied, each with a ReLU layer following, reducing the channels to 512.
After, a 2x2 convolution (up-convolution) layer is applied, upsampling the spatial dimensions twofold and also halving the number of channels to 256.
Block 3
The procedure used in block 5 and 4 is the same as in block 3, so will not be repeated.
Block 2
Same as block 3.
Block 1
In the final block of the expanding path, there are 128 channels after concatenating the skip connection.
Next, two 3x3 convolution layers (unpadded) are applied on the feature map, with ReLU layers inbetween reducing the number of feature channels to 64.
Finally, a 1x1 convolution layer, followed by an activation layer (sigmoid for binary classification) is used to reduce the number of channels to the desired number of classes. In this case, 2 classes, as binary classification is often used in medical imaging.
After upsampling the feature map in the expanding path, a segmentation map should be generated, with each pixel classified individually.
N.b.
i. Skipping connections include cropping.
ii. The up-conv reduce the number of channels.
iii. The final ouput has slightly lower resolution, i.e., reduced dimensionality, compared to the input, as a result of the conv operations and cropping. The final number of channels is 2, in that for the problem at hand we have two classes.
!pip install numpy matplotlib torch torchvision Pillow
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1)
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.19.1+cu121)
Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (10.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.54.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (24.1)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.4)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)
!wget https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train.zip&inline=false -O train.zip
/bin/bash: line 1: -O: command not found
--2024-10-08 17:12:31-- https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train.zip
Resolving git.lcsr.jhu.edu (git.lcsr.jhu.edu)... 128.220.253.212
Connecting to git.lcsr.jhu.edu (git.lcsr.jhu.edu)|128.220.253.212|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 43273606 (41M) [application/octet-stream]
Saving to: ‘train.zip.1’
train.zip.1 100%[===================>] 41.27M 9.59MB/s in 4.7s
2024-10-08 17:12:37 (8.73 MB/s) - ‘train.zip.1’ saved [43273606/43273606]
!ls
sample_data train.zip
!file train.zip
train.zip: Zip archive data, at least v1.0 to extract, compression method=store
!(cd /content & unzip train.zip)
Archive: train.zip
replace train/eefc0d8c94f0_08.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename:
!wget https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train_masks.zip&inline=false -O train_masks.zip
/bin/bash: line 1: -O: command not found
--2024-10-08 16:01:53-- https://git.lcsr.jhu.edu/cjones96/basic-unet-segmentation/-/raw/master/train_masks.zip
Resolving git.lcsr.jhu.edu (git.lcsr.jhu.edu)... 128.220.253.212
Connecting to git.lcsr.jhu.edu (git.lcsr.jhu.edu)|128.220.253.212|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3119864 (3.0M) [application/octet-stream]
Saving to: ‘train_masks.zip’
train_masks.zip 100%[===================>] 2.97M 1.72MB/s in 1.7s
2024-10-08 16:01:56 (1.72 MB/s) - ‘train_masks.zip’ saved [3119864/3119864]
!(cd /content & unzip train_masks.zip)
Archive: train_masks.zip
creating: train_masks/
extracting: train_masks/efaef69e148d_13_mask.gif
extracting: train_masks/f707d6fbc0cd_09_mask.gif
extracting: train_masks/efaef69e148d_12_mask.gif
extracting: train_masks/f707d6fbc0cd_08_mask.gif
extracting: train_masks/fecea3036c59_14_mask.gif
extracting: train_masks/fecea3036c59_15_mask.gif
extracting: train_masks/eeb7eeca738e_16_mask.gif
extracting: train_masks/ef5567efd904_14_mask.gif
extracting: train_masks/ef5567efd904_15_mask.gif
extracting: train_masks/eb91b1c659a0_07_mask.gif
extracting: train_masks/eb91b1c659a0_06_mask.gif
extracting: train_masks/ed13cbcdd5d8_14_mask.gif
extracting: train_masks/ed13cbcdd5d8_15_mask.gif
extracting: train_masks/feaf59172a01_14_mask.gif
inflating: train_masks/feaf59172a01_15_mask.gif
extracting: train_masks/fc5f1a3a66cf_16_mask.gif
inflating: train_masks/eb07e3f63ad2_01_mask.gif
inflating: train_masks/f4cd1286d5f4_13_mask.gif
extracting: train_masks/f4cd1286d5f4_12_mask.gif
extracting: train_masks/fdc2c87853ce_05_mask.gif
extracting: train_masks/fdc2c87853ce_04_mask.gif
extracting: train_masks/f00905abd3d7_07_mask.gif
extracting: train_masks/f00905abd3d7_06_mask.gif
extracting: train_masks/f8b6f4c39204_09_mask.gif
extracting: train_masks/f8b6f4c39204_08_mask.gif
extracting: train_masks/fa613ac8eac5_10_mask.gif
extracting: train_masks/fa613ac8eac5_11_mask.gif
extracting: train_masks/f70052627830_03_mask.gif
extracting: train_masks/f70052627830_02_mask.gif
extracting: train_masks/fd9da5d0bb6f_07_mask.gif
extracting: train_masks/fd9da5d0bb6f_06_mask.gif
extracting: train_masks/fc237174b128_01_mask.gif
extracting: train_masks/f3eee6348205_01_mask.gif
inflating: train_masks/f7ad86e13ed7_05_mask.gif
extracting: train_masks/f7ad86e13ed7_04_mask.gif
extracting: train_masks/f1eb080c7182_01_mask.gif
extracting: train_masks/fce0ba5b8ed7_04_mask.gif
extracting: train_masks/fce0ba5b8ed7_05_mask.gif
extracting: train_masks/ed8472086df8_05_mask.gif
extracting: train_masks/ed8472086df8_04_mask.gif
inflating: train_masks/ebfdf6ec7ede_12_mask.gif
inflating: train_masks/ebfdf6ec7ede_13_mask.gif
inflating: train_masks/fff9b3a5373f_01_mask.gif
extracting: train_masks/f707d6fbc0cd_03_mask.gif
extracting: train_masks/f707d6fbc0cd_02_mask.gif
extracting: train_masks/f70052627830_09_mask.gif
extracting: train_masks/f70052627830_08_mask.gif
extracting: train_masks/f98dbe8a5ee2_01_mask.gif
extracting: train_masks/f591b4f2e006_10_mask.gif
inflating: train_masks/f591b4f2e006_11_mask.gif
extracting: train_masks/f3b482e091c0_05_mask.gif
extracting: train_masks/f8b6f4c39204_03_mask.gif
extracting: train_masks/f3b482e091c0_04_mask.gif
extracting: train_masks/f8b6f4c39204_02_mask.gif
extracting: train_masks/fb1b923dd978_07_mask.gif
inflating: train_masks/eefc0d8c94f0_11_mask.gif
extracting: train_masks/fb1b923dd978_06_mask.gif
inflating: train_masks/eefc0d8c94f0_10_mask.gif
extracting: train_masks/fa006be8b6d9_14_mask.gif
extracting: train_masks/eaf9eb0b2293_14_mask.gif
extracting: train_masks/fa006be8b6d9_15_mask.gif
extracting: train_masks/eaf9eb0b2293_15_mask.gif
inflating: train_masks/feaf59172a01_03_mask.gif
extracting: train_masks/feaf59172a01_02_mask.gif
extracting: train_masks/ed13cbcdd5d8_03_mask.gif
extracting: train_masks/ed13cbcdd5d8_02_mask.gif
extracting: train_masks/ef5567efd904_03_mask.gif
extracting: train_masks/ef5567efd904_02_mask.gif
extracting: train_masks/eb91b1c659a0_10_mask.gif
extracting: train_masks/eb91b1c659a0_11_mask.gif
extracting: train_masks/fecea3036c59_03_mask.gif
extracting: train_masks/fecea3036c59_02_mask.gif
extracting: train_masks/eeb7eeca738e_01_mask.gif
extracting: train_masks/efaef69e148d_04_mask.gif
extracting: train_masks/efaef69e148d_05_mask.gif
extracting: train_masks/f70052627830_14_mask.gif
extracting: train_masks/f70052627830_15_mask.gif
extracting: train_masks/fa613ac8eac5_07_mask.gif
extracting: train_masks/fa613ac8eac5_06_mask.gif
extracting: train_masks/f00905abd3d7_10_mask.gif
extracting: train_masks/f00905abd3d7_11_mask.gif
extracting: train_masks/fa006be8b6d9_09_mask.gif
extracting: train_masks/fa006be8b6d9_08_mask.gif
inflating: train_masks/eb07e3f63ad2_16_mask.gif
extracting: train_masks/f4cd1286d5f4_04_mask.gif
extracting: train_masks/f4cd1286d5f4_05_mask.gif
extracting: train_masks/fdc2c87853ce_12_mask.gif
extracting: train_masks/fdc2c87853ce_13_mask.gif
extracting: train_masks/fc5f1a3a66cf_01_mask.gif
extracting: train_masks/fff9b3a5373f_16_mask.gif
extracting: train_masks/f707d6fbc0cd_14_mask.gif
extracting: train_masks/f707d6fbc0cd_15_mask.gif
extracting: train_masks/fce0ba5b8ed7_13_mask.gif
extracting: train_masks/fce0ba5b8ed7_12_mask.gif
inflating: train_masks/ebfdf6ec7ede_05_mask.gif
extracting: train_masks/fecea3036c59_09_mask.gif
inflating: train_masks/ebfdf6ec7ede_04_mask.gif
extracting: train_masks/fecea3036c59_08_mask.gif
extracting: train_masks/ed8472086df8_12_mask.gif
extracting: train_masks/ed8472086df8_13_mask.gif
extracting: train_masks/f7ad86e13ed7_12_mask.gif
extracting: train_masks/f7ad86e13ed7_13_mask.gif
extracting: train_masks/f3eee6348205_16_mask.gif
extracting: train_masks/ef5567efd904_09_mask.gif
extracting: train_masks/ef5567efd904_08_mask.gif
extracting: train_masks/f1eb080c7182_16_mask.gif
extracting: train_masks/fc237174b128_16_mask.gif
extracting: train_masks/fd9da5d0bb6f_10_mask.gif
extracting: train_masks/fd9da5d0bb6f_11_mask.gif
extracting: train_masks/feaf59172a01_09_mask.gif
extracting: train_masks/feaf59172a01_08_mask.gif
extracting: train_masks/ed13cbcdd5d8_09_mask.gif
extracting: train_masks/ed13cbcdd5d8_08_mask.gif
inflating: train_masks/eefc0d8c94f0_06_mask.gif
extracting: train_masks/fb1b923dd978_10_mask.gif
inflating: train_masks/eefc0d8c94f0_07_mask.gif
inflating: train_masks/fb1b923dd978_11_mask.gif
extracting: train_masks/fa006be8b6d9_03_mask.gif
extracting: train_masks/fa006be8b6d9_02_mask.gif
extracting: train_masks/f98dbe8a5ee2_16_mask.gif
inflating: train_masks/f591b4f2e006_07_mask.gif
extracting: train_masks/f591b4f2e006_06_mask.gif
extracting: train_masks/f3b482e091c0_12_mask.gif
extracting: train_masks/f8b6f4c39204_14_mask.gif
extracting: train_masks/f3b482e091c0_13_mask.gif
extracting: train_masks/f8b6f4c39204_15_mask.gif
extracting: train_masks/ef5567efd904_04_mask.gif
extracting: train_masks/ef5567efd904_05_mask.gif
extracting: train_masks/eb91b1c659a0_16_mask.gif
inflating: train_masks/feaf59172a01_04_mask.gif
extracting: train_masks/feaf59172a01_05_mask.gif
extracting: train_masks/ed13cbcdd5d8_04_mask.gif
extracting: train_masks/ed13cbcdd5d8_05_mask.gif
extracting: train_masks/efaef69e148d_03_mask.gif
extracting: train_masks/efaef69e148d_02_mask.gif
inflating: train_masks/ebfdf6ec7ede_08_mask.gif
extracting: train_masks/fecea3036c59_04_mask.gif
inflating: train_masks/ebfdf6ec7ede_09_mask.gif
extracting: train_masks/fecea3036c59_05_mask.gif
extracting: train_masks/eeb7eeca738e_06_mask.gif
extracting: train_masks/eeb7eeca738e_07_mask.gif
extracting: train_masks/f00905abd3d7_16_mask.gif
inflating: train_masks/f70052627830_13_mask.gif
inflating: train_masks/f70052627830_12_mask.gif
extracting: train_masks/fa613ac8eac5_01_mask.gif
extracting: train_masks/fc5f1a3a66cf_06_mask.gif
extracting: train_masks/fc5f1a3a66cf_07_mask.gif
inflating: train_masks/eb07e3f63ad2_11_mask.gif
extracting: train_masks/f4cd1286d5f4_03_mask.gif
inflating: train_masks/eb07e3f63ad2_10_mask.gif
extracting: train_masks/f4cd1286d5f4_02_mask.gif
extracting: train_masks/fdc2c87853ce_15_mask.gif
extracting: train_masks/fdc2c87853ce_14_mask.gif
extracting: train_masks/fce0ba5b8ed7_14_mask.gif
extracting: train_masks/fce0ba5b8ed7_15_mask.gif
inflating: train_masks/ebfdf6ec7ede_02_mask.gif
inflating: train_masks/ebfdf6ec7ede_03_mask.gif
extracting: train_masks/ed8472086df8_15_mask.gif
extracting: train_masks/ed8472086df8_14_mask.gif
extracting: train_masks/fff9b3a5373f_11_mask.gif
extracting: train_masks/fff9b3a5373f_10_mask.gif
extracting: train_masks/efaef69e148d_09_mask.gif
extracting: train_masks/f707d6fbc0cd_13_mask.gif
extracting: train_masks/efaef69e148d_08_mask.gif
extracting: train_masks/f707d6fbc0cd_12_mask.gif
extracting: train_masks/fc237174b128_11_mask.gif
extracting: train_masks/fc237174b128_10_mask.gif
extracting: train_masks/fd9da5d0bb6f_16_mask.gif
extracting: train_masks/f7ad86e13ed7_15_mask.gif
extracting: train_masks/f7ad86e13ed7_14_mask.gif
extracting: train_masks/f3eee6348205_11_mask.gif
extracting: train_masks/f3eee6348205_10_mask.gif
extracting: train_masks/f1eb080c7182_10_mask.gif
extracting: train_masks/f1eb080c7182_11_mask.gif
inflating: train_masks/eefc0d8c94f0_01_mask.gif
extracting: train_masks/fb1b923dd978_16_mask.gif
extracting: train_masks/fa006be8b6d9_04_mask.gif
inflating: train_masks/fa006be8b6d9_05_mask.gif
extracting: train_masks/f4cd1286d5f4_09_mask.gif
extracting: train_masks/f4cd1286d5f4_08_mask.gif
extracting: train_masks/f98dbe8a5ee2_11_mask.gif
inflating: train_masks/f98dbe8a5ee2_10_mask.gif
extracting: train_masks/f591b4f2e006_01_mask.gif
extracting: train_masks/f3b482e091c0_15_mask.gif
extracting: train_masks/f8b6f4c39204_13_mask.gif
extracting: train_masks/f3b482e091c0_14_mask.gif
inflating: train_masks/f8b6f4c39204_12_mask.gif
extracting: train_masks/ed8472086df8_08_mask.gif
extracting: train_masks/ed8472086df8_09_mask.gif
extracting: train_masks/fecea3036c59_13_mask.gif
extracting: train_masks/fecea3036c59_12_mask.gif
extracting: train_masks/fce0ba5b8ed7_09_mask.gif
extracting: train_masks/fce0ba5b8ed7_08_mask.gif
inflating: train_masks/eeb7eeca738e_11_mask.gif
extracting: train_masks/eeb7eeca738e_10_mask.gif
extracting: train_masks/efaef69e148d_14_mask.gif
extracting: train_masks/efaef69e148d_15_mask.gif
extracting: train_masks/ed13cbcdd5d8_13_mask.gif
extracting: train_masks/ed13cbcdd5d8_12_mask.gif
extracting: train_masks/feaf59172a01_13_mask.gif
extracting: train_masks/feaf59172a01_12_mask.gif
extracting: train_masks/ef5567efd904_13_mask.gif
extracting: train_masks/ef5567efd904_12_mask.gif
extracting: train_masks/f7ad86e13ed7_08_mask.gif
extracting: train_masks/eb91b1c659a0_01_mask.gif
extracting: train_masks/f7ad86e13ed7_09_mask.gif
inflating: train_masks/eb07e3f63ad2_06_mask.gif
extracting: train_masks/f4cd1286d5f4_14_mask.gif
inflating: train_masks/eb07e3f63ad2_07_mask.gif
extracting: train_masks/f4cd1286d5f4_15_mask.gif
extracting: train_masks/fdc2c87853ce_02_mask.gif
extracting: train_masks/fdc2c87853ce_03_mask.gif
extracting: train_masks/fc5f1a3a66cf_11_mask.gif
extracting: train_masks/fc5f1a3a66cf_10_mask.gif
extracting: train_masks/fa613ac8eac5_16_mask.gif
extracting: train_masks/f70052627830_04_mask.gif
inflating: train_masks/f70052627830_05_mask.gif
extracting: train_masks/f00905abd3d7_01_mask.gif
extracting: train_masks/f3b482e091c0_08_mask.gif
extracting: train_masks/f3b482e091c0_09_mask.gif
extracting: train_masks/f3eee6348205_06_mask.gif
extracting: train_masks/f3eee6348205_07_mask.gif
extracting: train_masks/f7ad86e13ed7_02_mask.gif
inflating: train_masks/f7ad86e13ed7_03_mask.gif
extracting: train_masks/f1eb080c7182_07_mask.gif
extracting: train_masks/f1eb080c7182_06_mask.gif
extracting: train_masks/fd9da5d0bb6f_01_mask.gif
extracting: train_masks/fc237174b128_06_mask.gif
inflating: train_masks/fc237174b128_07_mask.gif
extracting: train_masks/fff9b3a5373f_06_mask.gif
extracting: train_masks/fff9b3a5373f_07_mask.gif
extracting: train_masks/f707d6fbc0cd_04_mask.gif
extracting: train_masks/f707d6fbc0cd_05_mask.gif
extracting: train_masks/fce0ba5b8ed7_03_mask.gif
extracting: train_masks/fce0ba5b8ed7_02_mask.gif
extracting: train_masks/ed8472086df8_02_mask.gif
extracting: train_masks/ed8472086df8_03_mask.gif
inflating: train_masks/ebfdf6ec7ede_15_mask.gif
inflating: train_masks/ebfdf6ec7ede_14_mask.gif
extracting: train_masks/f98dbe8a5ee2_06_mask.gif
extracting: train_masks/f98dbe8a5ee2_07_mask.gif
extracting: train_masks/f591b4f2e006_16_mask.gif
extracting: train_masks/f3b482e091c0_02_mask.gif
extracting: train_masks/f8b6f4c39204_04_mask.gif
extracting: train_masks/f3b482e091c0_03_mask.gif
inflating: train_masks/f8b6f4c39204_05_mask.gif
inflating: train_masks/eefc0d8c94f0_16_mask.gif
extracting: train_masks/fb1b923dd978_01_mask.gif
extracting: train_masks/fdc2c87853ce_08_mask.gif
extracting: train_masks/fdc2c87853ce_09_mask.gif
extracting: train_masks/eaf9eb0b2293_13_mask.gif
inflating: train_masks/fa006be8b6d9_13_mask.gif
inflating: train_masks/fa006be8b6d9_12_mask.gif
extracting: train_masks/f707d6fbc0cd_16_mask.gif
extracting: train_masks/fff9b3a5373f_15_mask.gif
extracting: train_masks/fff9b3a5373f_14_mask.gif
inflating: train_masks/ebfdf6ec7ede_06_mask.gif
inflating: train_masks/ebfdf6ec7ede_07_mask.gif
extracting: train_masks/ed8472086df8_11_mask.gif
extracting: train_masks/ed8472086df8_10_mask.gif
extracting: train_masks/eeb7eeca738e_08_mask.gif
extracting: train_masks/eeb7eeca738e_09_mask.gif
extracting: train_masks/fce0ba5b8ed7_10_mask.gif
inflating: train_masks/fce0ba5b8ed7_11_mask.gif
extracting: train_masks/f1eb080c7182_14_mask.gif
extracting: train_masks/f1eb080c7182_15_mask.gif
extracting: train_masks/f7ad86e13ed7_11_mask.gif
extracting: train_masks/f7ad86e13ed7_10_mask.gif
extracting: train_masks/f3eee6348205_15_mask.gif
extracting: train_masks/f3eee6348205_14_mask.gif
extracting: train_masks/fc237174b128_15_mask.gif
extracting: train_masks/fc237174b128_14_mask.gif
extracting: train_masks/fd9da5d0bb6f_13_mask.gif
extracting: train_masks/fd9da5d0bb6f_12_mask.gif
extracting: train_masks/fc5f1a3a66cf_08_mask.gif
extracting: train_masks/fc5f1a3a66cf_09_mask.gif
extracting: train_masks/fa006be8b6d9_01_mask.gif
inflating: train_masks/eefc0d8c94f0_05_mask.gif
extracting: train_masks/fb1b923dd978_13_mask.gif
inflating: train_masks/eefc0d8c94f0_04_mask.gif
extracting: train_masks/fb1b923dd978_12_mask.gif
extracting: train_masks/f3b482e091c0_11_mask.gif
extracting: train_masks/f8b6f4c39204_16_mask.gif
extracting: train_masks/f3b482e091c0_10_mask.gif
extracting: train_masks/f98dbe8a5ee2_15_mask.gif
extracting: train_masks/f591b4f2e006_04_mask.gif
extracting: train_masks/f591b4f2e006_05_mask.gif
extracting: train_masks/f98dbe8a5ee2_14_mask.gif
extracting: train_masks/feaf59172a01_01_mask.gif
extracting: train_masks/ed13cbcdd5d8_01_mask.gif
extracting: train_masks/eb91b1c659a0_13_mask.gif
extracting: train_masks/eb91b1c659a0_12_mask.gif
extracting: train_masks/ef5567efd904_01_mask.gif
extracting: train_masks/eeb7eeca738e_02_mask.gif
extracting: train_masks/eeb7eeca738e_03_mask.gif
extracting: train_masks/fecea3036c59_01_mask.gif
extracting: train_masks/efaef69e148d_07_mask.gif
extracting: train_masks/efaef69e148d_06_mask.gif
extracting: train_masks/f70052627830_16_mask.gif
extracting: train_masks/fa613ac8eac5_04_mask.gif
extracting: train_masks/fa613ac8eac5_05_mask.gif
extracting: train_masks/f00905abd3d7_13_mask.gif
extracting: train_masks/f00905abd3d7_12_mask.gif
extracting: train_masks/fdc2c87853ce_11_mask.gif
extracting: train_masks/fdc2c87853ce_10_mask.gif
extracting: train_masks/f4cd1286d5f4_07_mask.gif
inflating: train_masks/eb07e3f63ad2_15_mask.gif
extracting: train_masks/f4cd1286d5f4_06_mask.gif
inflating: train_masks/eb07e3f63ad2_14_mask.gif
extracting: train_masks/fc5f1a3a66cf_02_mask.gif
extracting: train_masks/fc5f1a3a66cf_03_mask.gif
extracting: train_masks/fd9da5d0bb6f_04_mask.gif
extracting: train_masks/fd9da5d0bb6f_05_mask.gif
inflating: train_masks/fc237174b128_02_mask.gif
extracting: train_masks/fc237174b128_03_mask.gif
extracting: train_masks/f1eb080c7182_03_mask.gif
extracting: train_masks/f1eb080c7182_02_mask.gif
extracting: train_masks/f3eee6348205_02_mask.gif
extracting: train_masks/f3eee6348205_03_mask.gif
extracting: train_masks/f7ad86e13ed7_06_mask.gif
extracting: train_masks/f7ad86e13ed7_07_mask.gif
extracting: train_masks/ed8472086df8_06_mask.gif
extracting: train_masks/ed8472086df8_07_mask.gif
inflating: train_masks/ebfdf6ec7ede_11_mask.gif
inflating: train_masks/ebfdf6ec7ede_10_mask.gif
extracting: train_masks/fce0ba5b8ed7_07_mask.gif
extracting: train_masks/fce0ba5b8ed7_06_mask.gif
extracting: train_masks/f707d6fbc0cd_01_mask.gif
extracting: train_masks/fff9b3a5373f_02_mask.gif
extracting: train_masks/fff9b3a5373f_03_mask.gif
extracting: train_masks/f3b482e091c0_06_mask.gif
inflating: train_masks/f8b6f4c39204_01_mask.gif
extracting: train_masks/f3b482e091c0_07_mask.gif
extracting: train_masks/f98dbe8a5ee2_02_mask.gif
inflating: train_masks/f591b4f2e006_13_mask.gif
extracting: train_masks/f591b4f2e006_12_mask.gif
extracting: train_masks/f98dbe8a5ee2_03_mask.gif
inflating: train_masks/eb07e3f63ad2_08_mask.gif
inflating: train_masks/eb07e3f63ad2_09_mask.gif
extracting: train_masks/fa006be8b6d9_16_mask.gif
extracting: train_masks/eaf9eb0b2293_16_mask.gif
extracting: train_masks/fb1b923dd978_04_mask.gif
inflating: train_masks/eefc0d8c94f0_12_mask.gif
extracting: train_masks/fb1b923dd978_05_mask.gif
inflating: train_masks/eefc0d8c94f0_13_mask.gif
extracting: train_masks/fff9b3a5373f_08_mask.gif
extracting: train_masks/fff9b3a5373f_09_mask.gif
extracting: train_masks/efaef69e148d_10_mask.gif
extracting: train_masks/efaef69e148d_11_mask.gif
extracting: train_masks/eeb7eeca738e_15_mask.gif
extracting: train_masks/eeb7eeca738e_14_mask.gif
extracting: train_masks/fecea3036c59_16_mask.gif
extracting: train_masks/f3eee6348205_08_mask.gif
extracting: train_masks/f3eee6348205_09_mask.gif
extracting: train_masks/eb91b1c659a0_04_mask.gif
extracting: train_masks/eb91b1c659a0_05_mask.gif
extracting: train_masks/f1eb080c7182_09_mask.gif
extracting: train_masks/f1eb080c7182_08_mask.gif
extracting: train_masks/ef5567efd904_16_mask.gif
extracting: train_masks/fc237174b128_08_mask.gif
extracting: train_masks/fc237174b128_09_mask.gif
extracting: train_masks/ed13cbcdd5d8_16_mask.gif
extracting: train_masks/feaf59172a01_16_mask.gif
extracting: train_masks/fc5f1a3a66cf_15_mask.gif
inflating: train_masks/fc5f1a3a66cf_14_mask.gif
extracting: train_masks/fdc2c87853ce_06_mask.gif
extracting: train_masks/fdc2c87853ce_07_mask.gif
extracting: train_masks/f4cd1286d5f4_10_mask.gif
inflating: train_masks/eb07e3f63ad2_02_mask.gif
extracting: train_masks/f4cd1286d5f4_11_mask.gif
inflating: train_masks/eb07e3f63ad2_03_mask.gif
inflating: train_masks/f98dbe8a5ee2_08_mask.gif
extracting: train_masks/f98dbe8a5ee2_09_mask.gif
extracting: train_masks/f00905abd3d7_04_mask.gif
extracting: train_masks/f00905abd3d7_05_mask.gif
extracting: train_masks/fa613ac8eac5_13_mask.gif
extracting: train_masks/fa613ac8eac5_12_mask.gif
extracting: train_masks/f70052627830_01_mask.gif
extracting: train_masks/f1eb080c7182_04_mask.gif
extracting: train_masks/f1eb080c7182_05_mask.gif
extracting: train_masks/f3eee6348205_05_mask.gif
extracting: train_masks/f3eee6348205_04_mask.gif
extracting: train_masks/eb91b1c659a0_09_mask.gif
extracting: train_masks/f7ad86e13ed7_01_mask.gif
extracting: train_masks/eb91b1c659a0_08_mask.gif
extracting: train_masks/fd9da5d0bb6f_03_mask.gif
extracting: train_masks/fd9da5d0bb6f_02_mask.gif
inflating: train_masks/fc237174b128_05_mask.gif
extracting: train_masks/fc237174b128_04_mask.gif
extracting: train_masks/f707d6fbc0cd_07_mask.gif
extracting: train_masks/f707d6fbc0cd_06_mask.gif
extracting: train_masks/fff9b3a5373f_05_mask.gif
extracting: train_masks/fff9b3a5373f_04_mask.gif
extracting: train_masks/ed8472086df8_01_mask.gif
inflating: train_masks/ebfdf6ec7ede_16_mask.gif
extracting: train_masks/fce0ba5b8ed7_01_mask.gif
extracting: train_masks/f00905abd3d7_09_mask.gif
extracting: train_masks/f00905abd3d7_08_mask.gif
inflating: train_masks/f8b6f4c39204_07_mask.gif
extracting: train_masks/f3b482e091c0_01_mask.gif
extracting: train_masks/f8b6f4c39204_06_mask.gif
extracting: train_masks/f591b4f2e006_14_mask.gif
extracting: train_masks/f98dbe8a5ee2_05_mask.gif
extracting: train_masks/f98dbe8a5ee2_04_mask.gif
extracting: train_masks/f591b4f2e006_15_mask.gif
extracting: train_masks/fa006be8b6d9_10_mask.gif
extracting: train_masks/fa006be8b6d9_11_mask.gif
inflating: train_masks/eefc0d8c94f0_15_mask.gif
extracting: train_masks/fb1b923dd978_03_mask.gif
inflating: train_masks/eefc0d8c94f0_14_mask.gif
extracting: train_masks/fb1b923dd978_02_mask.gif
extracting: train_masks/eeb7eeca738e_12_mask.gif
extracting: train_masks/eeb7eeca738e_13_mask.gif
extracting: train_masks/fecea3036c59_10_mask.gif
extracting: train_masks/fecea3036c59_11_mask.gif
extracting: train_masks/efaef69e148d_16_mask.gif
extracting: train_masks/fd9da5d0bb6f_09_mask.gif
extracting: train_masks/fd9da5d0bb6f_08_mask.gif
extracting: train_masks/ed13cbcdd5d8_10_mask.gif
extracting: train_masks/ed13cbcdd5d8_11_mask.gif
extracting: train_masks/feaf59172a01_10_mask.gif
extracting: train_masks/feaf59172a01_11_mask.gif
extracting: train_masks/eb91b1c659a0_03_mask.gif
extracting: train_masks/eb91b1c659a0_02_mask.gif
extracting: train_masks/ef5567efd904_10_mask.gif
extracting: train_masks/ef5567efd904_11_mask.gif
extracting: train_masks/fb1b923dd978_09_mask.gif
extracting: train_masks/fb1b923dd978_08_mask.gif
extracting: train_masks/fdc2c87853ce_01_mask.gif
inflating: train_masks/eb07e3f63ad2_05_mask.gif
extracting: train_masks/f4cd1286d5f4_16_mask.gif
inflating: train_masks/eb07e3f63ad2_04_mask.gif
inflating: train_masks/fc5f1a3a66cf_12_mask.gif
extracting: train_masks/fc5f1a3a66cf_13_mask.gif
extracting: train_masks/fa613ac8eac5_14_mask.gif
extracting: train_masks/fa613ac8eac5_15_mask.gif
extracting: train_masks/f70052627830_07_mask.gif
extracting: train_masks/f70052627830_06_mask.gif
extracting: train_masks/f00905abd3d7_03_mask.gif
extracting: train_masks/f00905abd3d7_02_mask.gif
inflating: train_masks/ebfdf6ec7ede_01_mask.gif
extracting: train_masks/ed8472086df8_16_mask.gif
extracting: train_masks/fce0ba5b8ed7_16_mask.gif
extracting: train_masks/f707d6fbc0cd_10_mask.gif
extracting: train_masks/f707d6fbc0cd_11_mask.gif
extracting: train_masks/fff9b3a5373f_12_mask.gif
extracting: train_masks/fff9b3a5373f_13_mask.gif
extracting: train_masks/fc237174b128_12_mask.gif
inflating: train_masks/fc237174b128_13_mask.gif
extracting: train_masks/fd9da5d0bb6f_14_mask.gif
extracting: train_masks/fd9da5d0bb6f_15_mask.gif
extracting: train_masks/f1eb080c7182_13_mask.gif
extracting: train_masks/f1eb080c7182_12_mask.gif
extracting: train_masks/f7ad86e13ed7_16_mask.gif
extracting: train_masks/f3eee6348205_12_mask.gif
extracting: train_masks/f3eee6348205_13_mask.gif
extracting: train_masks/fa006be8b6d9_07_mask.gif
extracting: train_masks/fa006be8b6d9_06_mask.gif
extracting: train_masks/fb1b923dd978_14_mask.gif
inflating: train_masks/eefc0d8c94f0_02_mask.gif
extracting: train_masks/fb1b923dd978_15_mask.gif
inflating: train_masks/eefc0d8c94f0_03_mask.gif
extracting: train_masks/fa613ac8eac5_09_mask.gif
extracting: train_masks/fa613ac8eac5_08_mask.gif
extracting: train_masks/f8b6f4c39204_10_mask.gif
extracting: train_masks/f3b482e091c0_16_mask.gif
extracting: train_masks/f8b6f4c39204_11_mask.gif
extracting: train_masks/f591b4f2e006_03_mask.gif
inflating: train_masks/f98dbe8a5ee2_12_mask.gif
extracting: train_masks/f98dbe8a5ee2_13_mask.gif
extracting: train_masks/f591b4f2e006_02_mask.gif
extracting: train_masks/eb91b1c659a0_14_mask.gif
extracting: train_masks/eb91b1c659a0_15_mask.gif
extracting: train_masks/ef5567efd904_07_mask.gif
extracting: train_masks/ef5567efd904_06_mask.gif
inflating: train_masks/feaf59172a01_07_mask.gif
extracting: train_masks/feaf59172a01_06_mask.gif
extracting: train_masks/ed13cbcdd5d8_07_mask.gif
extracting: train_masks/ed13cbcdd5d8_06_mask.gif
extracting: train_masks/efaef69e148d_01_mask.gif
extracting: train_masks/eeb7eeca738e_05_mask.gif
extracting: train_masks/eeb7eeca738e_04_mask.gif
extracting: train_masks/fecea3036c59_07_mask.gif
extracting: train_masks/fecea3036c59_06_mask.gif
extracting: train_masks/f591b4f2e006_09_mask.gif
extracting: train_masks/f591b4f2e006_08_mask.gif
extracting: train_masks/f00905abd3d7_14_mask.gif
extracting: train_masks/f00905abd3d7_15_mask.gif
extracting: train_masks/f70052627830_10_mask.gif
extracting: train_masks/f70052627830_11_mask.gif
extracting: train_masks/fa613ac8eac5_03_mask.gif
extracting: train_masks/fa613ac8eac5_02_mask.gif
extracting: train_masks/fc5f1a3a66cf_05_mask.gif
extracting: train_masks/fc5f1a3a66cf_04_mask.gif
extracting: train_masks/fdc2c87853ce_16_mask.gif
inflating: train_masks/eefc0d8c94f0_08_mask.gif
inflating: train_masks/eefc0d8c94f0_09_mask.gif
inflating: train_masks/eb07e3f63ad2_12_mask.gif
extracting: train_masks/f4cd1286d5f4_01_mask.gif
inflating: train_masks/eb07e3f63ad2_13_mask.gif
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets,transforms, models
Set Hyperparameters
batch_size = 16
epochs = 5
#
# GPU CPU - nice way to setup the device as it works on any machine
#
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device is {device}')
if device == 'cuda':
print(f'CUDA device {torch.cuda.device(0)}')
print(f'Number of devices: {torch.cuda.device_count()}')
print(f'Device name: {torch.cuda.get_device_name(0)}')
Device is cuda
CUDA device <torch.cuda.device object at 0x7e18ce907640>
Number of devices: 1
Device name: Tesla T4
18.3. Load the Data#
train_transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
# converts 0-255 to 0-1 and rowxcolxchan to chanxrowxcol
transforms.ToTensor(),
])
Dataset
class CarDataset(Dataset):
def __init__(self, filename_cars, filename_masks, transform):
"""
Initialized
"""
super().__init__()
# Store variables we are interested in...
self._filename_cars = filename_cars
self._filename_masks = filename_masks
self._transforms = transform
def __getitem__(self, index):
"""
Get a single image / label pair.
"""
#
# Read in the image
#
name = self._filename_cars[index+1] # double-check
image = Image.open(name)
#
# Read in the mask
#
name = name.replace('train/', 'train_masks/').replace('.jpg', '_mask.gif')
mask = Image.open(name)
#
# Can do further processing here or anything else
#
# image = clahe(image)
#
# Do transformations on it (typicalyl data augmentation)
#
if self._transforms is not None:
image = self._transforms(image)
mask = self._transforms(mask)
#
# Return the image mask pair
#
return image, mask[0]>0 #extract channel 0; if pixel >0 set to True
def __len__(self):
"""
Return length of the dataset
"""
return len(self._filename_cars)-1 # double-check
18.4. Instantiate the Dataset and DataLoader#
filenames_train = glob.glob('/content/train/*.jpg')
filenames_train_mask = glob.glob('/content/train_masks/*.gif')
train_dataset = CarDataset(filenames_train, filenames_train_mask, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8)
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(_create_warning_msg(
Show an Example
len(train_dataset)
499
image, mask = train_dataset[2]
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(image).transpose((1,2,0)))
plt.subplot(1,2,2)
plt.imshow(np.array(mask).squeeze())
plt.show()
# Convert to numpy arrays
image_np = np.array(image) # Image should already be in a 3D array (C, H, W)
mask_np = np.array(mask) # Mask should already be in a 2D array (H, W)
print(np.shape(image_np), np.shape(mask_np))
(3, 64, 64) (64, 64)
18.5. The U-Net Network#
class contracting(nn.Module):
def __init__(self):
super().__init__()
# Conv2d n_channels, out_channels, kernel_size
# In U-Net (and many other CNN architectures), it’s common to use two consecutive convolutional layers before downsampling (with max pooling)
# This contributes to increase feature representation
self.layer1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU()) # input: (3, 64, 64); output: 64x64x64
self.layer2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU()) # input 64x32x32; output: 128x32x32
self.layer3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU()) # input 128x16x16; output: 256x16x16
self.layer4 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU()) # input 256x8x8; output: 512x8x8
self.layer5 = nn.Sequential(nn.Conv2d(512, 1024, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(1024, 1024, 3, stride=1, padding=1), nn.ReLU()) # input 512x4x4; output: 1024x4x4
self.down_sample = nn.MaxPool2d(2, stride=2)
def forward(self, X):
X1 = self.layer1(X)
X2 = self.layer2(self.down_sample(X1))
X3 = self.layer3(self.down_sample(X2))
X4 = self.layer4(self.down_sample(X3))
X5 = self.layer5(self.down_sample(X4))
return X5, X4, X3, X2, X1
class expansive(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Conv2d(64, 2, 3, stride=1, padding=1)
self.layer2 = nn.Sequential(nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU())
self.layer3 = nn.Sequential(nn.Conv2d(256, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())
self.layer4 = nn.Sequential(nn.Conv2d(512, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())
self.layer5 = nn.Sequential(nn.Conv2d(1024, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())
# N.b.: for ConvTranspose2d:
# 1. ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0)
# 2. output_dim=(n−1)×s−2p+m+output_padding
self.up_sample_54 = nn.ConvTranspose2d(1024, 512, 2, stride=2) # input: 1024x4x4; output: 512x8x8
self.up_sample_43 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.up_sample_32 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.up_sample_21 = nn.ConvTranspose2d(128, 64, 2, stride=2)
def forward(self, X5, X4, X3, X2, X1):
X = self.up_sample_54(X5) # input: 1024x4x4; output: 512x8x8
X4 = torch.cat([X, X4], dim=1) # Concatenate 512x8x8 with 512x8x8 to give 1024x8x8
X4 = self.layer5(X4) # Reduces the channels to 512x8x8
X = self.up_sample_43(X4)
X3 = torch.cat([X, X3], dim=1)
X3 = self.layer4(X3)
X = self.up_sample_32(X3)
X2 = torch.cat([X, X2], dim=1)
X2 = self.layer3(X2)
X = self.up_sample_21(X2)
X1 = torch.cat([X, X1], dim=1)
X1 = self.layer2(X1)
X = self.layer1(X1) # final output should be 2x64x64)
return X
class unet(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.down = contracting()
# Decoder
self.up = expansive()
def forward(self, X):
# Encoder
X5, X4, X3, X2, X1 = self.down(X)
# Decoder
X = self.up(X5, X4, X3, X2, X1)
return X
# check
model = unet()
tmpx = torch.ones((1,3, 64, 64))
model(tmpx).shape
torch.Size([1, 2, 64, 64])
18.6. Train the Network#
# Create network optimizer and loss function
model = unet()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss = torch.nn.CrossEntropyLoss()
epochs = 10
epoch_loss = []
for epoch in range(epochs):
print('='*30)
print('Epoch {} / {}'.format(epoch, epochs))
# Set variables
correct = 0
overlap = 0
union = 0
_len = 0
l = 0
count = 0
# Loop over the batches
for index, (X, Y) in enumerate(train_dataloader):
print(f'\tBatch {index}')
if device is not None:
X = X.to(device)
Y = Y.to(device)
# Call the model (image to mask)
R = model(X)
# Compute the loss
L = loss(R, Y.long())
# Do PyTorch stuff
optimizer.zero_grad()
L.backward()
optimizer.step()
# Compute Stats
pred = R.data.max(1)[1]
pred_sum, label_sum, overlap_sum = (pred==1).sum(), (Y==1).sum(), (pred*Y==1).sum()
print(f'\t label_sum {label_sum} pred_sum {pred_sum} overlap_sum {overlap_sum}')
# plt.figure(1)
# plt.subplot(1,2,1)
# plt.imshow(Y[0].cpu())
# plt.clim((0,1))
# plt.subplot(1,2,2)
# plt.imshow(pred[0].cpu())
# plt.clim((0,1))
# plt.show()
union_sum = pred_sum+label_sum-overlap_sum
# IoU for accuracy
overlap = overlap+overlap_sum.data.item()
union = union+union_sum.data.item()
l = l+L.data.item()
count = count+1
_loss = l/count
_accuracy = overlap/union
string = "epoch: {}, accuracy: {}, loss: {}".format(epoch, _accuracy, _loss)
print(string)
epoch_loss.append(_loss)
==============================
Epoch 0 / 10
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(_create_warning_msg(
Batch 0
label_sum 19901 pred_sum 16 overlap_sum 0
Batch 1
label_sum 19453 pred_sum 16 overlap_sum 0
Batch 2
label_sum 19286 pred_sum 0 overlap_sum 0
Batch 3
label_sum 20708 pred_sum 0 overlap_sum 0
Batch 4
label_sum 21073 pred_sum 0 overlap_sum 0
Batch 5
label_sum 19582 pred_sum 0 overlap_sum 0
Batch 6
label_sum 19161 pred_sum 0 overlap_sum 0
Batch 7
label_sum 20005 pred_sum 0 overlap_sum 0
Batch 8
label_sum 18707 pred_sum 0 overlap_sum 0
Batch 9
label_sum 19435 pred_sum 0 overlap_sum 0
Batch 10
label_sum 18486 pred_sum 0 overlap_sum 0
Batch 11
label_sum 19524 pred_sum 0 overlap_sum 0
Batch 12
label_sum 19655 pred_sum 0 overlap_sum 0
Batch 13
label_sum 20346 pred_sum 0 overlap_sum 0
Batch 14
label_sum 19963 pred_sum 0 overlap_sum 0
Batch 15
label_sum 20673 pred_sum 0 overlap_sum 0
Batch 16
label_sum 18831 pred_sum 0 overlap_sum 0
Batch 17
label_sum 18553 pred_sum 0 overlap_sum 0
Batch 18
label_sum 21036 pred_sum 0 overlap_sum 0
Batch 19
label_sum 19219 pred_sum 0 overlap_sum 0
Batch 20
label_sum 18653 pred_sum 0 overlap_sum 0
Batch 21
label_sum 19450 pred_sum 0 overlap_sum 0
Batch 22
label_sum 20742 pred_sum 0 overlap_sum 0
Batch 23
label_sum 18806 pred_sum 266 overlap_sum 265
Batch 24
label_sum 19907 pred_sum 1526 overlap_sum 1403
Batch 25
label_sum 18737 pred_sum 7302 overlap_sum 6142
Batch 26
label_sum 20334 pred_sum 32548 overlap_sum 18777
Batch 27
label_sum 20472 pred_sum 9054 overlap_sum 7211
Batch 28
label_sum 19728 pred_sum 3238 overlap_sum 2783
Batch 29
label_sum 20615 pred_sum 2773 overlap_sum 2392
Batch 30
label_sum 19567 pred_sum 4261 overlap_sum 3626
Batch 31
label_sum 4216 pred_sum 3831 overlap_sum 3257
epoch: 0, accuracy: 0.07235101349165902, loss: 0.5556896096095443
==============================
Epoch 1 / 10
Batch 0
label_sum 19901 pred_sum 29146 overlap_sum 18353
Batch 1
label_sum 19453 pred_sum 21029 overlap_sum 16304
Batch 2
label_sum 19286 pred_sum 10758 overlap_sum 9505
Batch 3
label_sum 20708 pred_sum 11144 overlap_sum 9861
Batch 4
label_sum 21073 pred_sum 12725 overlap_sum 10707
Batch 5
label_sum 19582 pred_sum 28609 overlap_sum 18142
Batch 6
label_sum 19161 pred_sum 21538 overlap_sum 15481
Batch 7
label_sum 20005 pred_sum 10629 overlap_sum 8982
Batch 8
label_sum 18707 pred_sum 15723 overlap_sum 13386
Batch 9
label_sum 19435 pred_sum 20889 overlap_sum 16155
Batch 10
label_sum 18486 pred_sum 22068 overlap_sum 16139
Batch 11
label_sum 19524 pred_sum 15868 overlap_sum 13676
Batch 12
label_sum 19655 pred_sum 16774 overlap_sum 14319
Batch 13
label_sum 20346 pred_sum 23159 overlap_sum 17565
Batch 14
label_sum 19963 pred_sum 20095 overlap_sum 15965
Batch 15
label_sum 20673 pred_sum 16624 overlap_sum 14877
Batch 16
label_sum 18831 pred_sum 18962 overlap_sum 14278
Batch 17
label_sum 18553 pred_sum 21839 overlap_sum 16579
Batch 18
label_sum 21036 pred_sum 16080 overlap_sum 13778
Batch 19
label_sum 19219 pred_sum 19611 overlap_sum 16002
Batch 20
label_sum 18653 pred_sum 21800 overlap_sum 16548
Batch 21
label_sum 19450 pred_sum 18788 overlap_sum 15974
Batch 22
label_sum 20742 pred_sum 15188 overlap_sum 13781
Batch 23
label_sum 18806 pred_sum 18781 overlap_sum 15668
Batch 24
label_sum 19907 pred_sum 22217 overlap_sum 17898
Batch 25
label_sum 18737 pred_sum 21804 overlap_sum 16960
Batch 26
label_sum 20334 pred_sum 18253 overlap_sum 15771
Batch 27
label_sum 20472 pred_sum 18250 overlap_sum 15439
Batch 28
label_sum 19728 pred_sum 21485 overlap_sum 16982
Batch 29
label_sum 20615 pred_sum 22933 overlap_sum 17895
Batch 30
label_sum 19567 pred_sum 20096 overlap_sum 16175
Batch 31
label_sum 4216 pred_sum 4026 overlap_sum 3520
epoch: 1, accuracy: 0.6395575400852446, loss: 0.3033404001034796
==============================
Epoch 2 / 10
Batch 0
label_sum 19901 pred_sum 17079 overlap_sum 15078
Batch 1
label_sum 19453 pred_sum 19785 overlap_sum 16534
Batch 2
label_sum 19286 pred_sum 22282 overlap_sum 17354
Batch 3
label_sum 20708 pred_sum 21810 overlap_sum 18179
Batch 4
label_sum 21073 pred_sum 18735 overlap_sum 16269
Batch 5
label_sum 19582 pred_sum 19617 overlap_sum 16536
Batch 6
label_sum 19161 pred_sum 20054 overlap_sum 15789
Batch 7
label_sum 20005 pred_sum 20432 overlap_sum 17099
Batch 8
label_sum 18707 pred_sum 20846 overlap_sum 16289
Batch 9
label_sum 19435 pred_sum 18201 overlap_sum 15664
Batch 10
label_sum 18486 pred_sum 16435 overlap_sum 14199
Batch 11
label_sum 19524 pred_sum 18662 overlap_sum 15963
Batch 12
label_sum 19655 pred_sum 21832 overlap_sum 17395
Batch 13
label_sum 20346 pred_sum 21396 overlap_sum 17679
Batch 14
label_sum 19963 pred_sum 19312 overlap_sum 16187
Batch 15
label_sum 20673 pred_sum 19247 overlap_sum 17042
Batch 16
label_sum 18831 pred_sum 18496 overlap_sum 14390
Batch 17
label_sum 18553 pred_sum 20289 overlap_sum 16466
Batch 18
label_sum 21036 pred_sum 19743 overlap_sum 17089
Batch 19
label_sum 19219 pred_sum 20019 overlap_sum 16693
Batch 20
label_sum 18653 pred_sum 19754 overlap_sum 16478
Batch 21
label_sum 19450 pred_sum 19337 overlap_sum 16460
Batch 22
label_sum 20742 pred_sum 19033 overlap_sum 16950
Batch 23
label_sum 18806 pred_sum 19532 overlap_sum 16441
Batch 24
label_sum 19907 pred_sum 20577 overlap_sum 17690
Batch 25
label_sum 18737 pred_sum 20237 overlap_sum 16665
Batch 26
label_sum 20334 pred_sum 18704 overlap_sum 16287
Batch 27
label_sum 20472 pred_sum 20029 overlap_sum 16860
Batch 28
label_sum 19728 pred_sum 21178 overlap_sum 17303
Batch 29
label_sum 20615 pred_sum 21401 overlap_sum 17912
Batch 30
label_sum 19567 pred_sum 18793 overlap_sum 16020
Batch 31
label_sum 4216 pred_sum 4415 overlap_sum 3784
epoch: 2, accuracy: 0.7223733542836853, loss: 0.22777789225801826
==============================
Epoch 3 / 10
Batch 0
label_sum 19901 pred_sum 19658 overlap_sum 16884
Batch 1
label_sum 19453 pred_sum 20159 overlap_sum 17205
Batch 2
label_sum 19286 pred_sum 20232 overlap_sum 17062
Batch 3
label_sum 20708 pred_sum 20966 overlap_sum 18202
Batch 4
label_sum 21073 pred_sum 19088 overlap_sum 16816
Batch 5
label_sum 19582 pred_sum 21167 overlap_sum 17476
Batch 6
label_sum 19161 pred_sum 20445 overlap_sum 16352
Batch 7
label_sum 20005 pred_sum 18055 overlap_sum 15942
Batch 8
label_sum 18707 pred_sum 20661 overlap_sum 16640
Batch 9
label_sum 19435 pred_sum 19786 overlap_sum 16880
Batch 10
label_sum 18486 pred_sum 17775 overlap_sum 15277
Batch 11
label_sum 19524 pred_sum 19027 overlap_sum 16574
Batch 12
label_sum 19655 pred_sum 21407 overlap_sum 17631
Batch 13
label_sum 20346 pred_sum 20860 overlap_sum 17867
Batch 14
label_sum 19963 pred_sum 19984 overlap_sum 16959
Batch 15
label_sum 20673 pred_sum 20839 overlap_sum 18217
Batch 16
label_sum 18831 pred_sum 18840 overlap_sum 14951
Batch 17
label_sum 18553 pred_sum 19587 overlap_sum 16526
Batch 18
label_sum 21036 pred_sum 19829 overlap_sum 17501
Batch 19
label_sum 19219 pred_sum 20964 overlap_sum 17390
Batch 20
label_sum 18653 pred_sum 19992 overlap_sum 16942
Batch 21
label_sum 19450 pred_sum 18968 overlap_sum 16541
Batch 22
label_sum 20742 pred_sum 19494 overlap_sum 17546
Batch 23
label_sum 18806 pred_sum 20196 overlap_sum 17085
Batch 24
label_sum 19907 pred_sum 21018 overlap_sum 18205
Batch 25
label_sum 18737 pred_sum 20535 overlap_sum 17190
Batch 26
label_sum 20334 pred_sum 18884 overlap_sum 16763
Batch 27
label_sum 20472 pred_sum 20133 overlap_sum 17302
Batch 28
label_sum 19728 pred_sum 21958 overlap_sum 17898
Batch 29
label_sum 20615 pred_sum 20901 overlap_sum 17992
Batch 30
label_sum 19567 pred_sum 18310 overlap_sum 16143
Batch 31
label_sum 4216 pred_sum 4609 overlap_sum 3921
epoch: 3, accuracy: 0.7520172607105339, loss: 0.2053607814013958
==============================
Epoch 4 / 10
Batch 0
label_sum 19901 pred_sum 21267 overlap_sum 18045
Batch 1
label_sum 19453 pred_sum 19796 overlap_sum 17338
Batch 2
label_sum 19286 pred_sum 19555 overlap_sum 17228
Batch 3
label_sum 20708 pred_sum 21759 overlap_sum 18937
Batch 4
label_sum 21073 pred_sum 20789 overlap_sum 18232
Batch 5
label_sum 19582 pred_sum 20443 overlap_sum 17542
Batch 6
label_sum 19161 pred_sum 20040 overlap_sum 16533
Batch 7
label_sum 20005 pred_sum 19059 overlap_sum 17108
Batch 8
label_sum 18707 pred_sum 20821 overlap_sum 17025
Batch 9
label_sum 19435 pred_sum 20006 overlap_sum 17390
Batch 10
label_sum 18486 pred_sum 17045 overlap_sum 15275
Batch 11
label_sum 19524 pred_sum 19553 overlap_sum 17370
Batch 12
label_sum 19655 pred_sum 22955 overlap_sum 18670
Batch 13
label_sum 20346 pred_sum 20382 overlap_sum 18131
Batch 14
label_sum 19963 pred_sum 18563 overlap_sum 16785
Batch 15
label_sum 20673 pred_sum 21113 overlap_sum 18679
Batch 16
label_sum 18831 pred_sum 20848 overlap_sum 16520
Batch 17
label_sum 18553 pred_sum 19243 overlap_sum 16872
Batch 18
label_sum 21036 pred_sum 18654 overlap_sum 17072
Batch 19
label_sum 19219 pred_sum 20977 overlap_sum 17796
Batch 20
label_sum 18653 pred_sum 20530 overlap_sum 17587
Batch 21
label_sum 19450 pred_sum 20105 overlap_sum 17750
Batch 22
label_sum 20742 pred_sum 20503 overlap_sum 18577
Batch 23
label_sum 18806 pred_sum 19156 overlap_sum 17086
Batch 24
label_sum 19907 pred_sum 20356 overlap_sum 18306
Batch 25
label_sum 18737 pred_sum 20730 overlap_sum 17702
Batch 26
label_sum 20334 pred_sum 20507 overlap_sum 18241
Batch 27
label_sum 20472 pred_sum 19908 overlap_sum 17886
Batch 28
label_sum 19728 pred_sum 20593 overlap_sum 17921
Batch 29
label_sum 20615 pred_sum 20807 overlap_sum 18601
Batch 30
label_sum 19567 pred_sum 20288 overlap_sum 17851
Batch 31
label_sum 4216 pred_sum 4503 overlap_sum 3966
epoch: 4, accuracy: 0.7906522764124797, loss: 0.17321861116215587
==============================
Epoch 5 / 10
Batch 0
label_sum 19901 pred_sum 20541 overlap_sum 18227
Batch 1
label_sum 19453 pred_sum 18644 overlap_sum 17199
Batch 2
label_sum 19286 pred_sum 21961 overlap_sum 18382
Batch 3
label_sum 20708 pred_sum 21000 overlap_sum 19091
Batch 4
label_sum 21073 pred_sum 17842 overlap_sum 16775
Batch 5
label_sum 19582 pred_sum 21278 overlap_sum 18383
Batch 6
label_sum 19161 pred_sum 24410 overlap_sum 18417
Batch 7
label_sum 20005 pred_sum 17109 overlap_sum 16149
Batch 8
label_sum 18707 pred_sum 16233 overlap_sum 15175
Batch 9
label_sum 19435 pred_sum 19775 overlap_sum 17921
Batch 10
label_sum 18486 pred_sum 20807 overlap_sum 17781
Batch 11
label_sum 19524 pred_sum 21682 overlap_sum 18788
Batch 12
label_sum 19655 pred_sum 22209 overlap_sum 18915
Batch 13
label_sum 20346 pred_sum 19113 overlap_sum 17930
Batch 14
label_sum 19963 pred_sum 17738 overlap_sum 16822
Batch 15
label_sum 20673 pred_sum 19800 overlap_sum 18444
Batch 16
label_sum 18831 pred_sum 19921 overlap_sum 17140
Batch 17
label_sum 18553 pred_sum 21102 overlap_sum 17942
Batch 18
label_sum 21036 pred_sum 21441 overlap_sum 19325
Batch 19
label_sum 19219 pred_sum 19879 overlap_sum 17759
Batch 20
label_sum 18653 pred_sum 17803 overlap_sum 16719
Batch 21
label_sum 19450 pred_sum 19189 overlap_sum 17730
Batch 22
label_sum 20742 pred_sum 21738 overlap_sum 19630
Batch 23
label_sum 18806 pred_sum 20152 overlap_sum 17839
Batch 24
label_sum 19907 pred_sum 20370 overlap_sum 18711
Batch 25
label_sum 18737 pred_sum 18866 overlap_sum 17270
Batch 26
label_sum 20334 pred_sum 20139 overlap_sum 18565
Batch 27
label_sum 20472 pred_sum 21946 overlap_sum 19499
Batch 28
label_sum 19728 pred_sum 20825 overlap_sum 18444
Batch 29
label_sum 20615 pred_sum 19757 overlap_sum 18493
Batch 30
label_sum 19567 pred_sum 18810 overlap_sum 17611
Batch 31
label_sum 4216 pred_sum 4624 overlap_sum 4072
epoch: 5, accuracy: 0.8247567535788823, loss: 0.14396312390454113
==============================
Epoch 6 / 10
Batch 0
label_sum 19901 pred_sum 22198 overlap_sum 19054
Batch 1
label_sum 19453 pred_sum 18375 overlap_sum 17325
Batch 2
label_sum 19286 pred_sum 18180 overlap_sum 17183
Batch 3
label_sum 20708 pred_sum 21474 overlap_sum 19498
Batch 4
label_sum 21073 pred_sum 22868 overlap_sum 20010
Batch 5
label_sum 19582 pred_sum 20096 overlap_sum 18379
Batch 6
label_sum 19161 pred_sum 18730 overlap_sum 17079
Batch 7
label_sum 20005 pred_sum 19474 overlap_sum 18161
Batch 8
label_sum 18707 pred_sum 19737 overlap_sum 17634
Batch 9
label_sum 19435 pred_sum 20626 overlap_sum 18469
Batch 10
label_sum 18486 pred_sum 18633 overlap_sum 17066
Batch 11
label_sum 19524 pred_sum 18075 overlap_sum 17377
Batch 12
label_sum 19655 pred_sum 19949 overlap_sum 18237
Batch 13
label_sum 20346 pred_sum 19536 overlap_sum 18449
Batch 14
label_sum 19963 pred_sum 19629 overlap_sum 18404
Batch 15
label_sum 20673 pred_sum 21540 overlap_sum 19641
Batch 16
label_sum 18831 pred_sum 19335 overlap_sum 17164
Batch 17
label_sum 18553 pred_sum 19795 overlap_sum 17636
Batch 18
label_sum 21036 pred_sum 21121 overlap_sum 19379
Batch 19
label_sum 19219 pred_sum 20063 overlap_sum 18079
Batch 20
label_sum 18653 pred_sum 18473 overlap_sum 17272
Batch 21
label_sum 19450 pred_sum 19347 overlap_sum 17950
Batch 22
label_sum 20742 pred_sum 21029 overlap_sum 19441
Batch 23
label_sum 18806 pred_sum 18759 overlap_sum 17476
Batch 24
label_sum 19907 pred_sum 19390 overlap_sum 18420
Batch 25
label_sum 18737 pred_sum 18485 overlap_sum 17290
Batch 26
label_sum 20334 pred_sum 20620 overlap_sum 19128
Batch 27
label_sum 20472 pred_sum 22065 overlap_sum 19786
Batch 28
label_sum 19728 pred_sum 19594 overlap_sum 18187
Batch 29
label_sum 20615 pred_sum 18937 overlap_sum 18111
Batch 30
label_sum 19567 pred_sum 19889 overlap_sum 18298
Batch 31
label_sum 4216 pred_sum 4636 overlap_sum 4104
epoch: 6, accuracy: 0.8556492614092926, loss: 0.11515915510244668
==============================
Epoch 7 / 10
Batch 0
label_sum 19901 pred_sum 21851 overlap_sum 19191
Batch 1
label_sum 19453 pred_sum 18204 overlap_sum 17409
Batch 2
label_sum 19286 pred_sum 17647 overlap_sum 17022
Batch 3
label_sum 20708 pred_sum 21249 overlap_sum 19512
Batch 4
label_sum 21073 pred_sum 23439 overlap_sum 20375
Batch 5
label_sum 19582 pred_sum 19451 overlap_sum 18198
Batch 6
label_sum 19161 pred_sum 17541 overlap_sum 16599
Batch 7
label_sum 20005 pred_sum 18957 overlap_sum 18146
Batch 8
label_sum 18707 pred_sum 20275 overlap_sum 18045
Batch 9
label_sum 19435 pred_sum 21007 overlap_sum 18733
Batch 10
label_sum 18486 pred_sum 18695 overlap_sum 17244
Batch 11
label_sum 19524 pred_sum 17811 overlap_sum 17155
Batch 12
label_sum 19655 pred_sum 20042 overlap_sum 18522
Batch 13
label_sum 20346 pred_sum 19989 overlap_sum 18825
Batch 14
label_sum 19963 pred_sum 19745 overlap_sum 18618
Batch 15
label_sum 20673 pred_sum 21015 overlap_sum 19566
Batch 16
label_sum 18831 pred_sum 18697 overlap_sum 17196
Batch 17
label_sum 18553 pred_sum 18738 overlap_sum 17337
Batch 18
label_sum 21036 pred_sum 20951 overlap_sum 19510
Batch 19
label_sum 19219 pred_sum 20031 overlap_sum 18234
Batch 20
label_sum 18653 pred_sum 18876 overlap_sum 17526
Batch 21
label_sum 19450 pred_sum 19334 overlap_sum 17994
Batch 22
label_sum 20742 pred_sum 20829 overlap_sum 19499
Batch 23
label_sum 18806 pred_sum 18532 overlap_sum 17452
Batch 24
label_sum 19907 pred_sum 19506 overlap_sum 18579
Batch 25
label_sum 18737 pred_sum 18597 overlap_sum 17496
Batch 26
label_sum 20334 pred_sum 20316 overlap_sum 19098
Batch 27
label_sum 20472 pred_sum 21437 overlap_sum 19702
Batch 28
label_sum 19728 pred_sum 19691 overlap_sum 18303
Batch 29
label_sum 20615 pred_sum 19248 overlap_sum 18494
Batch 30
label_sum 19567 pred_sum 19729 overlap_sum 18364
Batch 31
label_sum 4216 pred_sum 4546 overlap_sum 4096
epoch: 7, accuracy: 0.8683587345922642, loss: 0.10284013347700238
==============================
Epoch 8 / 10
Batch 0
label_sum 19901 pred_sum 20897 overlap_sum 18988
Batch 1
label_sum 19453 pred_sum 18544 overlap_sum 17731
Batch 2
label_sum 19286 pred_sum 18005 overlap_sum 17292
Batch 3
label_sum 20708 pred_sum 21193 overlap_sum 19598
Batch 4
label_sum 21073 pred_sum 22296 overlap_sum 20147
Batch 5
label_sum 19582 pred_sum 18795 overlap_sum 17916
Batch 6
label_sum 19161 pred_sum 17842 overlap_sum 16982
Batch 7
label_sum 20005 pred_sum 19569 overlap_sum 18557
Batch 8
label_sum 18707 pred_sum 20205 overlap_sum 18127
Batch 9
label_sum 19435 pred_sum 20220 overlap_sum 18658
Batch 10
label_sum 18486 pred_sum 18128 overlap_sum 17153
Batch 11
label_sum 19524 pred_sum 18059 overlap_sum 17460
Batch 12
label_sum 19655 pred_sum 20620 overlap_sum 18819
Batch 13
label_sum 20346 pred_sum 20315 overlap_sum 19048
Batch 14
label_sum 19963 pred_sum 19213 overlap_sum 18506
Batch 15
label_sum 20673 pred_sum 20262 overlap_sum 19315
Batch 16
label_sum 18831 pred_sum 18630 overlap_sum 17512
Batch 17
label_sum 18553 pred_sum 19412 overlap_sum 17680
Batch 18
label_sum 21036 pred_sum 21001 overlap_sum 19698
Batch 19
label_sum 19219 pred_sum 19319 overlap_sum 17989
Batch 20
label_sum 18653 pred_sum 18490 overlap_sum 17471
Batch 21
label_sum 19450 pred_sum 19471 overlap_sum 18187
Batch 22
label_sum 20742 pred_sum 21030 overlap_sum 19691
Batch 23
label_sum 18806 pred_sum 18681 overlap_sum 17666
Batch 24
label_sum 19907 pred_sum 19603 overlap_sum 18756
Batch 25
label_sum 18737 pred_sum 18351 overlap_sum 17433
Batch 26
label_sum 20334 pred_sum 20257 overlap_sum 19173
Batch 27
label_sum 20472 pred_sum 21517 overlap_sum 19790
Batch 28
label_sum 19728 pred_sum 19348 overlap_sum 18226
Batch 29
label_sum 20615 pred_sum 18927 overlap_sum 18362
Batch 30
label_sum 19567 pred_sum 20109 overlap_sum 18554
Batch 31
label_sum 4216 pred_sum 4484 overlap_sum 4093
epoch: 8, accuracy: 0.8798525049805601, loss: 0.09189422405324876
==============================
Epoch 9 / 10
Batch 0
label_sum 19901 pred_sum 20457 overlap_sum 19008
Batch 1
label_sum 19453 pred_sum 19078 overlap_sum 18156
Batch 2
label_sum 19286 pred_sum 18491 overlap_sum 17731
Batch 3
label_sum 20708 pred_sum 21211 overlap_sum 19710
Batch 4
label_sum 21073 pred_sum 21993 overlap_sum 20155
Batch 5
label_sum 19582 pred_sum 18494 overlap_sum 17808
Batch 6
label_sum 19161 pred_sum 18121 overlap_sum 17421
Batch 7
label_sum 20005 pred_sum 19430 overlap_sum 18581
Batch 8
label_sum 18707 pred_sum 19942 overlap_sum 18124
Batch 9
label_sum 19435 pred_sum 20090 overlap_sum 18722
Batch 10
label_sum 18486 pred_sum 18147 overlap_sum 17255
Batch 11
label_sum 19524 pred_sum 18357 overlap_sum 17806
Batch 12
label_sum 19655 pred_sum 20748 overlap_sum 18975
Batch 13
label_sum 20346 pred_sum 20347 overlap_sum 19210
Batch 14
label_sum 19963 pred_sum 19189 overlap_sum 18535
Batch 15
label_sum 20673 pred_sum 20267 overlap_sum 19374
Batch 16
label_sum 18831 pred_sum 18856 overlap_sum 17784
Batch 17
label_sum 18553 pred_sum 19130 overlap_sum 17713
Batch 18
label_sum 21036 pred_sum 20854 overlap_sum 19700
Batch 19
label_sum 19219 pred_sum 19128 overlap_sum 17992
Batch 20
label_sum 18653 pred_sum 18563 overlap_sum 17594
Batch 21
label_sum 19450 pred_sum 19442 overlap_sum 18242
Batch 22
label_sum 20742 pred_sum 20969 overlap_sum 19743
Batch 23
label_sum 18806 pred_sum 18503 overlap_sum 17661
Batch 24
label_sum 19907 pred_sum 19577 overlap_sum 18844
Batch 25
label_sum 18737 pred_sum 18557 overlap_sum 17636
Batch 26
label_sum 20334 pred_sum 20558 overlap_sum 19424
Batch 27
label_sum 20472 pred_sum 21335 overlap_sum 19812
Batch 28
label_sum 19728 pred_sum 19086 overlap_sum 18199
Batch 29
label_sum 20615 pred_sum 19231 overlap_sum 18733
Batch 30
label_sum 19567 pred_sum 20383 overlap_sum 18650
Batch 31
label_sum 4216 pred_sum 4361 overlap_sum 4057
epoch: 9, accuracy: 0.8906483882691372, loss: 0.0831722195725888
18.7. Plot the loss curve#
plt.figure()
plt.plot(epoch_loss)
plt.title('Epoch Loss')
Text(0.5, 1.0, 'Epoch Loss')
18.8. Plot an Example#
plt.figure(1)
# Plot the actual label
plt.subplot(1,2,1)
plt.imshow(Y[0].cpu())
plt.clim((0,1))
plt.title('True Mask')
# Show the predicted label
plt.subplot(1,2,2)
plt.imshow(pred[0].cpu())
plt.clim((0,1))
plt.title('Predicted Mask')
plt.show()
18.9. Exercise#
Complete this notebook by dividing your dataset in training, validation and test datasets.
Show learning curves.
Retrain your network and do segmentation on test dataset.
How do you assess the quality of your results?