13. PyTorch Pre-Flight#

from IPython.display import Image as IPythonImage
%matplotlib inline

Performance Challenge

Specification

Intel Core i9-13900K

NVIDIA RTX 4090

Architecture

Hybrid design (8 P‑cores + 16 E‑cores)

Ada Lovelace (16,384 CUDA cores)

Cores

24 total (8 Performance cores, 16 Efficient cores, 32 Threads)

16,384 CUDA cores

Base Clock Frequency

~3.0 GHz (P‑cores); ~2.2 GHz (E‑cores)

~2.23 GHz (base), ~2.52 GHz (boost)

Memory Bandwidth

System DDR5 (~50–60+ GB/s)

1,008 GB/s (24 GB GDDR6X, 384-bit bus)

Floating Point Performance

~500 GFLOPS (FP32)

~82.6 TFLOPS (FP32)

Cost

$589–$600 USD

$1,599 USD (MSRP)

13.1. Tensors#

In data science, a tensor is essentially a multi-dimensional array that generalizes the concepts of scalars (0-dimensional), vectors (1-dimensional), and matrices (2-dimensional) to higher dimensions (higher-dimensional tensors extend this idea to 3D, 4D, and beyond).

image.png

13.2. Installing PyTorch#

#--- check CUDA Driver

!nvidia-smi
/bin/bash: line 1: nvidia-smi: command not found
# choose between "CPU" or "GPU"

install_type = "CPU"
cuda_version = "cu124" # see above
#--- you can install the most recent release of PyTorch

#--- for CPU installation

if install_type=="CPU":

  %pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

#--- for GPU installation (the URL below hosts PyTorch wheels compiled with CUDA 12.4.)

elif install_type=="GPU":

  %pip install torch==2.6.0+cu124 torchvision torchaudio \
    --index-url https://download.pytorch.org/whl/cu124

Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torch==2.6.0
  Downloading https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl.metadata (26 kB)
Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.20.1+cu124)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (3.17.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (4.12.2)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (3.1.5)
Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (2024.10.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch==2.6.0) (1.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.21.0%2Bcpu-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.1.0)
INFO: pip is looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cpu/torchaudio-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.6.0) (3.0.2)
Downloading https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl (178.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 178.7/178.7 MB 7.0 MB/s eta 0:00:00
?25hDownloading https://download.pytorch.org/whl/cpu/torchvision-0.21.0%2Bcpu-cp311-cp311-linux_x86_64.whl (1.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 49.2 MB/s eta 0:00:00
?25hDownloading https://download.pytorch.org/whl/cpu/torchaudio-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl (1.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 48.2 MB/s eta 0:00:00
?25hInstalling collected packages: torch, torchvision, torchaudio
  Attempting uninstall: torch
    Found existing installation: torch 2.5.1+cu124
    Uninstalling torch-2.5.1+cu124:
      Successfully uninstalled torch-2.5.1+cu124
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.20.1+cu124
    Uninstalling torchvision-0.20.1+cu124:
      Successfully uninstalled torchvision-0.20.1+cu124
  Attempting uninstall: torchaudio
    Found existing installation: torchaudio 2.5.1+cu124
    Uninstalling torchaudio-2.5.1+cu124:
      Successfully uninstalled torchaudio-2.5.1+cu124
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastai 2.7.18 requires torch<2.6,>=1.10, but you have torch 2.6.0+cpu which is incompatible.
Successfully installed torch-2.6.0+cpu torchaudio-2.6.0+cpu torchvision-0.21.0+cpu
%pip list
Package                            Version
---------------------------------- -------------------
absl-py                            1.4.0
accelerate                         1.3.0
aiohappyeyeballs                   2.4.6
aiohttp                            3.11.12
aiosignal                          1.3.2
alabaster                          1.0.0
albucore                           0.0.23
albumentations                     2.0.4
ale-py                             0.10.1
altair                             5.5.0
annotated-types                    0.7.0
anyio                              3.7.1
argon2-cffi                        23.1.0
argon2-cffi-bindings               21.2.0
array_record                       0.6.0
arviz                              0.20.0
astropy                            7.0.1
astropy-iers-data                  0.2025.2.10.0.33.26
astunparse                         1.6.3
atpublic                           4.1.0
attrs                              25.1.0
audioread                          3.0.1
autograd                           1.7.0
babel                              2.17.0
backcall                           0.2.0
beautifulsoup4                     4.13.3
betterproto                        2.0.0b6
bigframes                          1.36.0
bigquery-magics                    0.5.0
bleach                             6.2.0
blinker                            1.9.0
blis                               0.7.11
blosc2                             3.0.0
bokeh                              3.6.3
Bottleneck                         1.4.2
bqplot                             0.12.44
branca                             0.8.1
CacheControl                       0.14.2
cachetools                         5.5.1
catalogue                          2.0.10
certifi                            2025.1.31
cffi                               1.17.1
chardet                            5.2.0
charset-normalizer                 3.4.1
chex                               0.1.88
clarabel                           0.10.0
click                              8.1.8
cloudpathlib                       0.20.0
cloudpickle                        3.1.1
cmake                              3.31.4
cmdstanpy                          1.2.5
colorcet                           3.1.0
colorlover                         0.3.0
colour                             0.1.5
community                          1.0.0b1
confection                         0.1.5
cons                               0.4.6
contourpy                          1.3.1
cramjam                            2.9.1
cryptography                       43.0.3
cuda-python                        12.6.0
cudf-cu12                          24.12.0
cufflinks                          0.17.3
cupy-cuda12x                       13.3.0
cvxopt                             1.3.2
cvxpy                              1.6.0
cycler                             0.12.1
cyipopt                            1.5.0
cymem                              2.0.11
Cython                             3.0.12
dask                               2024.10.0
datascience                        0.17.6
db-dtypes                          1.4.1
dbus-python                        1.2.18
debugpy                            1.8.0
decorator                          4.4.2
defusedxml                         0.7.1
Deprecated                         1.2.18
diffusers                          0.32.2
distro                             1.9.0
dlib                               19.24.2
dm-tree                            0.1.9
dnspython                          2.7.0
docker-pycreds                     0.4.0
docstring_parser                   0.16
docutils                           0.21.2
dopamine_rl                        4.1.2
duckdb                             1.1.3
earthengine-api                    1.5.2
easydict                           1.13
editdistance                       0.8.1
eerepr                             0.1.0
einops                             0.8.1
email_validator                    2.2.0
en-core-web-sm                     3.7.1
entrypoints                        0.4
et_xmlfile                         2.0.0
etils                              1.12.0
etuples                            0.3.9
Farama-Notifications               0.0.4
fastai                             2.7.18
fastcore                           1.7.29
fastdownload                       0.0.7
fastjsonschema                     2.21.1
fastprogress                       1.0.3
fastrlock                          0.8.3
filelock                           3.17.0
firebase-admin                     6.6.0
Flask                              3.1.0
flatbuffers                        25.2.10
flax                               0.10.3
folium                             0.19.4
fonttools                          4.56.0
frozendict                         2.4.6
frozenlist                         1.5.0
fsspec                             2024.10.0
future                             1.0.0
gast                               0.6.0
gcsfs                              2024.10.0
GDAL                               3.6.4
gdown                              5.2.0
geemap                             0.35.1
gensim                             4.3.3
geocoder                           1.38.1
geographiclib                      2.0
geopandas                          1.0.1
geopy                              2.4.1
gin-config                         0.5.0
gitdb                              4.0.12
GitPython                          3.1.44
glob2                              0.7
google                             2.0.3
google-ai-generativelanguage       0.6.15
google-api-core                    2.19.2
google-api-python-client           2.160.0
google-auth                        2.27.0
google-auth-httplib2               0.2.0
google-auth-oauthlib               1.2.1
google-cloud-aiplatform            1.79.0
google-cloud-bigquery              3.25.0
google-cloud-bigquery-connection   1.17.0
google-cloud-bigquery-storage      2.28.0
google-cloud-bigtable              2.28.1
google-cloud-core                  2.4.1
google-cloud-dataproc              5.16.0
google-cloud-datastore             2.20.2
google-cloud-firestore             2.20.0
google-cloud-functions             1.19.0
google-cloud-iam                   2.17.0
google-cloud-language              2.16.0
google-cloud-pubsub                2.25.0
google-cloud-resource-manager      1.14.0
google-cloud-spanner               3.51.0
google-cloud-storage               2.19.0
google-cloud-translate             3.19.0
google-colab                       1.0.0
google-crc32c                      1.6.0
google-genai                       0.8.0
google-generativeai                0.8.4
google-pasta                       0.2.0
google-resumable-media             2.7.2
google-spark-connect               0.5.2
googleapis-common-protos           1.66.0
googledrivedownloader              1.1.0
graphviz                           0.20.3
greenlet                           3.1.1
grpc-google-iam-v1                 0.14.0
grpc-interceptor                   0.15.4
grpcio                             1.70.0
grpcio-status                      1.62.3
grpclib                            0.4.7
gspread                            6.1.4
gspread-dataframe                  4.0.0
gym                                0.25.2
gym-notices                        0.0.8
gymnasium                          1.0.0
h11                                0.14.0
h2                                 4.2.0
h5netcdf                           1.5.0
h5py                               3.12.1
highspy                            1.9.0
holidays                           0.66
holoviews                          1.20.0
hpack                              4.1.0
html5lib                           1.1
httpcore                           1.0.7
httpimport                         1.4.0
httplib2                           0.22.0
httpx                              0.28.1
huggingface-hub                    0.28.1
humanize                           4.11.0
hyperframe                         6.1.0
hyperopt                           0.2.7
ibis-framework                     9.2.0
id                                 1.5.0
idna                               3.10
imageio                            2.37.0
imageio-ffmpeg                     0.6.0
imagesize                          1.4.1
imbalanced-learn                   0.13.0
imgaug                             0.4.0
immutabledict                      4.2.1
importlib_metadata                 8.6.1
importlib_resources                6.5.2
imutils                            0.5.4
in-toto-attestation                0.9.3
inflect                            7.5.0
iniconfig                          2.0.0
intel-cmplr-lib-ur                 2025.0.4
intel-openmp                       2025.0.4
ipyevents                          2.0.2
ipyfilechooser                     0.6.0
ipykernel                          5.5.6
ipyleaflet                         0.19.2
ipyparallel                        8.8.0
ipython                            7.34.0
ipython-genutils                   0.2.0
ipython-sql                        0.5.0
ipytree                            0.2.2
ipywidgets                         7.7.1
itsdangerous                       2.2.0
jax                                0.4.33
jax-cuda12-pjrt                    0.4.33
jax-cuda12-plugin                  0.4.33
jaxlib                             0.4.33
jeepney                            0.7.1
jellyfish                          1.1.0
jieba                              0.42.1
Jinja2                             3.1.5
jiter                              0.8.2
joblib                             1.4.2
jsonpatch                          1.33
jsonpickle                         4.0.1
jsonpointer                        3.0.0
jsonschema                         4.23.0
jsonschema-specifications          2024.10.1
jupyter-client                     6.1.12
jupyter-console                    6.1.0
jupyter_core                       5.7.2
jupyter-leaflet                    0.19.2
jupyter-server                     1.24.0
jupyterlab_pygments                0.3.0
jupyterlab_widgets                 3.0.13
kaggle                             1.6.17
kagglehub                          0.3.7
keras                              3.8.0
keras-hub                          0.18.1
keras-nlp                          0.18.1
keyring                            23.5.0
kiwisolver                         1.4.8
langchain                          0.3.18
langchain-core                     0.3.35
langchain-text-splitters           0.3.6
langcodes                          3.5.0
langsmith                          0.3.8
language_data                      1.3.0
launchpadlib                       1.10.16
lazr.restfulclient                 0.14.4
lazr.uri                           1.0.6
lazy_loader                        0.4
libclang                           18.1.1
libcudf-cu12                       24.12.0
libkvikio-cu12                     24.12.1
librosa                            0.10.2.post1
lightgbm                           4.5.0
linkify-it-py                      2.0.3
llvmlite                           0.44.0
locket                             1.0.0
logical-unification                0.4.6
lxml                               5.3.1
marisa-trie                        1.2.1
Markdown                           3.7
markdown-it-py                     3.0.0
MarkupSafe                         3.0.2
matplotlib                         3.10.0
matplotlib-inline                  0.1.7
matplotlib-venn                    1.1.1
mdit-py-plugins                    0.4.2
mdurl                              0.1.2
miniKanren                         1.0.3
missingno                          0.5.2
mistune                            3.1.1
mizani                             0.13.1
mkl                                2025.0.1
ml-dtypes                          0.4.1
mlxtend                            0.23.4
model-signing                      0.2.0
more-itertools                     10.6.0
moviepy                            1.0.3
mpmath                             1.3.0
msgpack                            1.1.0
multidict                          6.1.0
multipledispatch                   1.0.0
multitasking                       0.0.11
murmurhash                         1.0.12
music21                            9.3.0
namex                              0.0.8
narwhals                           1.26.0
natsort                            8.4.0
nbclassic                          1.2.0
nbclient                           0.10.2
nbconvert                          7.16.6
nbformat                           5.10.4
ndindex                            1.9.2
nest-asyncio                       1.6.0
networkx                           3.4.2
nibabel                            5.3.2
nltk                               3.9.1
notebook                           6.5.5
notebook_shim                      0.2.4
numba                              0.61.0
numba-cuda                         0.0.17.1
numexpr                            2.10.2
numpy                              1.26.4
nvidia-cublas-cu12                 12.5.3.2
nvidia-cuda-cupti-cu12             12.5.82
nvidia-cuda-nvcc-cu12              12.5.82
nvidia-cuda-nvrtc-cu12             12.5.82
nvidia-cuda-runtime-cu12           12.5.82
nvidia-cudnn-cu12                  9.3.0.75
nvidia-cufft-cu12                  11.2.3.61
nvidia-curand-cu12                 10.3.6.82
nvidia-cusolver-cu12               11.6.3.83
nvidia-cusparse-cu12               12.5.1.3
nvidia-nccl-cu12                   2.21.5
nvidia-nvcomp-cu12                 4.1.0.6
nvidia-nvjitlink-cu12              12.5.82
nvidia-nvtx-cu12                   12.4.127
nvtx                               0.2.10
nx-cugraph-cu12                    24.12.0
oauth2client                       4.1.3
oauthlib                           3.2.2
openai                             1.61.1
opencv-contrib-python              4.11.0.86
opencv-python                      4.11.0.86
opencv-python-headless             4.11.0.86
openpyxl                           3.1.5
opentelemetry-api                  1.16.0
opentelemetry-sdk                  1.16.0
opentelemetry-semantic-conventions 0.37b0
opt_einsum                         3.4.0
optax                              0.2.4
optree                             0.14.0
orbax-checkpoint                   0.6.4
orjson                             3.10.15
osqp                               0.6.7.post3
packaging                          24.2
pandas                             2.2.2
pandas-datareader                  0.10.0
pandas-gbq                         0.26.1
pandas-stubs                       2.2.2.240909
pandocfilters                      1.5.1
panel                              1.6.0
param                              2.2.0
parso                              0.8.4
parsy                              2.1
partd                              1.4.2
pathlib                            1.0.1
patsy                              1.0.1
peewee                             3.17.9
peft                               0.14.0
pexpect                            4.9.0
pickleshare                        0.7.5
pillow                             11.1.0
pip                                24.1.2
platformdirs                       4.3.6
plotly                             5.24.1
plotnine                           0.14.5
pluggy                             1.5.0
ply                                3.11
polars                             1.9.0
pooch                              1.8.2
portpicker                         1.5.2
preshed                            3.0.9
prettytable                        3.14.0
proglog                            0.1.10
progressbar2                       4.5.0
prometheus_client                  0.21.1
promise                            2.3
prompt_toolkit                     3.0.50
propcache                          0.2.1
prophet                            1.1.6
proto-plus                         1.26.0
protobuf                           4.25.6
psutil                             5.9.5
psycopg2                           2.9.10
ptyprocess                         0.7.0
py-cpuinfo                         9.0.0
py4j                               0.10.9.7
pyarrow                            17.0.0
pyasn1                             0.6.1
pyasn1_modules                     0.4.1
pycocotools                        2.0.8
pycparser                          2.22
pydantic                           2.10.6
pydantic_core                      2.27.2
pydata-google-auth                 1.9.1
pydot                              3.0.4
pydotplus                          2.0.2
PyDrive                            1.3.1
PyDrive2                           1.21.3
pyerfa                             2.0.1.5
pygame                             2.6.1
pygit2                             1.17.0
Pygments                           2.18.0
PyGObject                          3.42.1
PyJWT                              2.10.1
pylibcudf-cu12                     24.12.0
pylibcugraph-cu12                  24.12.0
pylibraft-cu12                     24.12.0
pymc                               5.20.1
pymystem3                          0.2.0
pynvjitlink-cu12                   0.5.0
pyogrio                            0.10.0
Pyomo                              6.8.2
PyOpenGL                           3.1.9
pyOpenSSL                          24.2.1
pyparsing                          3.2.1
pyperclip                          1.9.0
pyproj                             3.7.0
pyshp                              2.3.1
PySocks                            1.7.1
pyspark                            3.5.4
pytensor                           2.27.1
pytest                             8.3.4
python-apt                         0.0.0
python-box                         7.3.2
python-dateutil                    2.8.2
python-louvain                     0.16
python-slugify                     8.0.4
python-snappy                      0.7.3
python-utils                       3.9.1
pytz                               2025.1
pyviz_comms                        3.0.4
PyYAML                             6.0.2
pyzmq                              24.0.1
qdldl                              0.1.7.post5
ratelim                            0.1.6
referencing                        0.36.2
regex                              2024.11.6
requests                           2.32.3
requests-oauthlib                  2.0.0
requests-toolbelt                  1.0.0
requirements-parser                0.9.0
rfc3161-client                     0.1.2
rfc8785                            0.1.4
rich                               13.9.4
rmm-cu12                           24.12.1
rpds-py                            0.22.3
rpy2                               3.4.2
rsa                                4.9
safetensors                        0.5.2
scikit-image                       0.25.1
scikit-learn                       1.6.1
scipy                              1.13.1
scooby                             0.10.0
scs                                3.2.7.post2
seaborn                            0.13.2
SecretStorage                      3.3.1
securesystemslib                   1.2.0
Send2Trash                         1.8.3
sentence-transformers              3.4.1
sentencepiece                      0.2.0
sentry-sdk                         2.21.0
setproctitle                       1.3.4
setuptools                         75.1.0
shap                               0.46.0
shapely                            2.0.7
shellingham                        1.5.4
sigstore                           3.6.1
sigstore-protobuf-specs            0.3.2
sigstore-rekor-types               0.0.18
simple-parsing                     0.1.7
simsimd                            6.2.1
six                                1.17.0
sklearn-compat                     0.1.3
sklearn-pandas                     2.2.0
slicer                             0.0.8
smart-open                         7.1.0
smmap                              5.0.2
sniffio                            1.3.1
snowballstemmer                    2.2.0
soundfile                          0.13.1
soupsieve                          2.6
soxr                               0.5.0.post1
spacy                              3.7.5
spacy-legacy                       3.0.12
spacy-loggers                      1.0.5
spanner-graph-notebook             1.0.9
Sphinx                             8.1.3
sphinxcontrib-applehelp            2.0.0
sphinxcontrib-devhelp              2.0.0
sphinxcontrib-htmlhelp             2.1.0
sphinxcontrib-jsmath               1.0.1
sphinxcontrib-qthelp               2.0.0
sphinxcontrib-serializinghtml      2.0.0
SQLAlchemy                         2.0.38
sqlglot                            25.6.1
sqlparse                           0.5.3
srsly                              2.5.1
stanio                             0.5.1
statsmodels                        0.14.4
stringzilla                        3.11.3
sympy                              1.13.1
tables                             3.10.2
tabulate                           0.9.0
tbb                                2022.0.0
tcmlib                             1.2.0
tenacity                           9.0.0
tensorboard                        2.18.0
tensorboard-data-server            0.7.2
tensorflow                         2.18.0
tensorflow-datasets                4.9.7
tensorflow-hub                     0.16.1
tensorflow-io-gcs-filesystem       0.37.1
tensorflow-metadata                1.16.1
tensorflow-probability             0.25.0
tensorflow-text                    2.18.1
tensorstore                        0.1.71
termcolor                          2.5.0
terminado                          0.18.1
text-unidecode                     1.3
textblob                           0.19.0
tf_keras                           2.18.0
tf-slim                            1.1.0
thinc                              8.2.5
threadpoolctl                      3.5.0
tifffile                           2025.1.10
timm                               1.0.14
tinycss2                           1.4.0
tokenizers                         0.21.0
toml                               0.10.2
toolz                              0.12.1
torch                              2.6.0+cpu
torchaudio                         2.6.0+cpu
torchsummary                       1.5.1
torchvision                        0.21.0+cpu
tornado                            6.4.2
tqdm                               4.67.1
traitlets                          5.7.1
traittypes                         0.2.1
transformers                       4.48.3
treescope                          0.1.8
triton                             3.1.0
tuf                                5.1.0
tweepy                             4.15.0
typeguard                          4.4.1
typer                              0.15.1
types-pytz                         2025.1.0.20250204
types-setuptools                   75.8.0.20250210
typing_extensions                  4.12.2
tzdata                             2025.1
tzlocal                            5.2
uc-micro-py                        1.0.3
umf                                0.9.1
uritemplate                        4.1.1
urllib3                            2.3.0
vega-datasets                      0.9.0
wadllib                            1.3.6
wandb                              0.19.6
wasabi                             1.1.3
wcwidth                            0.2.13
weasel                             0.4.1
webcolors                          24.11.1
webencodings                       0.5.1
websocket-client                   1.8.0
websockets                         14.2
Werkzeug                           3.1.3
wheel                              0.45.1
widgetsnbextension                 3.6.10
wordcloud                          1.9.4
wrapt                              1.17.2
xarray                             2025.1.2
xarray-einstats                    0.8.0
xgboost                            2.1.4
xlrd                               2.0.1
xyzservices                        2025.1.0
yarl                               1.18.3
yellowbrick                        1.5
yfinance                           0.2.52
zipp                               3.21.0
zstandard                          0.23.0
%pip install cuda-python==12.4.0
Collecting cuda-python==12.4.0
  Downloading cuda_python-12.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading cuda_python-12.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 25.4/25.4 MB 9.3 MB/s eta 0:00:00
?25hInstalling collected packages: cuda-python
  Attempting uninstall: cuda-python
    Found existing installation: cuda-python 12.6.0
    Uninstalling cuda-python-12.6.0:
      Successfully uninstalled cuda-python-12.6.0
Successfully installed cuda-python-12.4.0
#%pip show cuda-python
import torch
import numpy as np

print('PyTorch version:', torch.__version__)

np.set_printoptions(precision=3) # sets the precision of the printed numbers to three decimal places
PyTorch version: 2.6.0+cpu
import sys
print(sys.version)
3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
# Check if GPU is available
if torch.cuda.is_available():
    print("GPU is available!")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    print(torch.version.cuda)
    print(torch.cuda.device_count())

else:
    print("GPU not available, using CPU.")
GPU not available, using CPU.

13.3. Creating tensors in Pytorch#

a = [1, 2, 3]
b = np.array([4, 5, 6], dtype=np.int32)

t_a = torch.tensor(a)
t_b = torch.from_numpy(b)

print(t_a)
print(t_b)
tensor([1, 2, 3])
tensor([4, 5, 6], dtype=torch.int32)
torch.is_tensor(a), torch.is_tensor(t_a)
(False, True)
print(type(t_a))
print(t_a.dtype)
<class 'torch.Tensor'>
torch.int64
print(type(t_b))
print(t_b.dtype)
<class 'torch.Tensor'>
torch.int32
t_ones = torch.ones(2, 3)

t_ones.shape
torch.Size([2, 3])
print(t_ones)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
t_zeros = torch.zeros(2, 3)

print(t_zeros.shape)
print(t_zeros)
torch.Size([2, 3])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
rand_tensor = torch.rand(2,3)

print(rand_tensor)
tensor([[0.1406, 0.9446, 0.4206],
        [0.0613, 0.3332, 0.6908]])

13.4. Manipulating the data type and shape of tensors#

print(t_b.dtype)

t_b_new = t_b.to(torch.int64)

print(t_b_new.dtype)
torch.int32
torch.int64
import gc #garbage collector
import torch

def get_all_tensors():
    # Iterate through all objects tracked by the garbage collector
    tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor)]
    return tensors

# Print all tensors
for t in get_all_tensors():
    print(t)
tensor([4, 5, 6])
tensor([4, 5, 6])
tensor([1, 2, 3])
tensor([4, 5, 6], dtype=torch.int32)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0.1406, 0.9446, 0.4206],
        [0.0613, 0.3332, 0.6908]])
<ipython-input-25-eb422b9ed847>:6: FutureWarning: `torch.distributed.reduce_op` is deprecated, please use `torch.distributed.ReduceOp` instead
  tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor)]
t = torch.rand(3, 5)

t_tr = torch.transpose(t, 1, 0) #dim0, #dim1 (t,0,1) or (t,1,0) return the same
print(t.shape, ' --> ', t_tr.shape)
torch.Size([3, 5])  -->  torch.Size([5, 3])
t
tensor([[0.8835, 0.1709, 0.3207, 0.4077, 0.8177],
        [0.0383, 0.1107, 0.7277, 0.1610, 0.7233],
        [0.6170, 0.0494, 0.4231, 0.8764, 0.3719]])
t_tr
tensor([[0.4224, 0.0221, 0.8720],
        [0.3779, 0.1804, 0.3054],
        [0.4365, 0.0268, 0.6949],
        [0.8605, 0.2145, 0.0506],
        [0.2159, 0.6308, 0.8648]])
t = torch.zeros(30)

t_reshape = t.reshape(5, 6)

print(t_reshape.shape)
torch.Size([5, 6])
t = torch.zeros(1, 2, 1, 4, 1)

# The torch.squeeze() function removes dimensions of size 1.
# By specifying 2 as the dimension, you’re telling PyTorch to remove only dimension 2 if it is of size 1.
# where dim0, dim1, dim2, etc...
t_sqz = torch.squeeze(t, 2)

print(t.shape, ' --> ', t_sqz.shape)

print(t,"\n\n")

print(t_sqz)
torch.Size([1, 2, 1, 4, 1])  -->  torch.Size([1, 2, 4, 1])
tensor([[[[[0.],
           [0.],
           [0.],
           [0.]]],


         [[[0.],
           [0.],
           [0.],
           [0.]]]]]) 


tensor([[[[0.],
          [0.],
          [0.],
          [0.]],

         [[0.],
          [0.],
          [0.],
          [0.]]]])

13.5. Applying mathematical operations to tensors#

torch.manual_seed(1) #sets the seed for PyTorch's random number generator

t1 = 2 * torch.rand(5, 2) - 1.
t2 = torch.normal(mean=0, std=1, size=(5, 2))
print(f"t1={t1}\n\n")
print(t2)
t1=tensor([[ 0.5153, -0.4414],
        [-0.1939,  0.4694],
        [-0.9414,  0.5997],
        [-0.2057,  0.5087],
        [ 0.1390, -0.1224]])


tensor([[ 0.8590,  0.7056],
        [-0.3406, -1.2720],
        [-1.1948,  0.0250],
        [-0.7627,  1.3969],
        [-0.3245,  0.2879]])
t3 = torch.multiply(t1, t2)
print(t3)
tensor([[ 0.4426, -0.3114],
        [ 0.0660, -0.5970],
        [ 1.1249,  0.0150],
        [ 0.1569,  0.7107],
        [-0.0451, -0.0352]])
t4 = torch.mean(t3, axis=0)
print(t4)
tensor([ 0.3491, -0.0436])
t4_b = torch.mean(t3, axis=1)
print(t4_b, t4_b.shape)
tensor([ 0.0656, -0.2655,  0.5699,  0.4338, -0.0402]) torch.Size([5])

Note: These operations are implemented in PyTorch’s own backend (specifically the ATen library) in C/C++ to provide efficient and optimized tensor operations. This native implementation ensures that operations like torch.mean run fast on both CPU and GPU, and they are independent of NumPy’s functions.

t5 = torch.matmul(t1, torch.transpose(t2, 0, 1))

print(t5)
tensor([[ 0.1312,  0.3860, -0.6267, -1.0096, -0.2943],
        [ 0.1647, -0.5310,  0.2434,  0.8035,  0.1980],
        [-0.3855, -0.4422,  1.1399,  1.5558,  0.4781],
        [ 0.1822, -0.5771,  0.2585,  0.8676,  0.2132],
        [ 0.0330,  0.1084, -0.1692, -0.2771, -0.0804]])
t5.shape
torch.Size([5, 5])
t6 = torch.matmul(torch.transpose(t1, 0, 1), t2)

print(t6)
tensor([[ 1.7453,  0.3392],
        [-1.6038, -0.2180]])
t1
tensor([[ 0.5153, -0.4414],
        [-0.1939,  0.4694],
        [-0.9414,  0.5997],
        [-0.2057,  0.5087],
        [ 0.1390, -0.1224]])
# calculates the norm (ord=2 means Euclidean norm, along dim)
norm_t1 = torch.linalg.norm(t1, ord=2, dim=1)

print(norm_t1)
tensor([0.6785, 0.5078, 1.1162, 0.5488, 0.1853])
torch.sqrt(t1[0][0]**2+t1[0][1]**2)
tensor(0.6785)
torch.sqrt(t1[0][0]**2+t1[0][1]**2).item()
0.6784621477127075
# to verify the above calculated the norm, we can do
np.sqrt(np.sum(np.square(t1.numpy()), axis=1))
array([0.678, 0.508, 1.116, 0.549, 0.185], dtype=float32)

13.6. Split, stack, and concatenate tensors#

torch.manual_seed(1)

t = torch.rand(6)

print(t)

t_splits = torch.chunk(t, 3) #divides in 3 chunks

[item.numpy() for item in t_splits]
tensor([0.7576, 0.2793, 0.4031, 0.7347, 0.0293, 0.7999])
[array([0.758, 0.279], dtype=float32),
 array([0.403, 0.735], dtype=float32),
 array([0.029, 0.8  ], dtype=float32)]
# what happens if you divide by 4?
# Number of chunks with ceil(6/4)=2; remaining chunk will be floor(6/4)=1. Therefore a total of 3 chunks again.

t_splits = torch.chunk(t, 4) #divides in 3 chunks

[item.numpy() for item in t_splits]
[array([0.758, 0.279], dtype=float32),
 array([0.403, 0.735], dtype=float32),
 array([0.029, 0.8  ], dtype=float32)]

  • See the above. Remember we did np.set_printoptions(precision=3) to set the precision of the printed numbers to three decimal places with numpy

  • Notice that the function torch.set_printoptions(precision=4) is used by default by PyTorch

  • Also notice that the numbers appear rounded for clarity and brevity during printing, but the actual computations and stored values maintain their full precision.


torch.manual_seed(1)
t = torch.rand(5)

print(t)

t_splits = torch.split(t, split_size_or_sections=[3, 2])

[item.numpy() for item in t_splits]
tensor([0.7576, 0.2793, 0.4031, 0.7347, 0.0293])
[array([0.758, 0.279, 0.403], dtype=float32),
 array([0.735, 0.029], dtype=float32)]
A = torch.ones(2,2)
B = torch.zeros(3,2)

C = torch.cat([A, B], axis=0)
print(C)
print(C.shape)
tensor([[1., 1.],
        [1., 1.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
torch.Size([5, 2])
A = torch.ones(3)
B = torch.zeros(3)

S = torch.stack([A, B], axis=0)
print(S)
print(S.shape)
tensor([[1., 1., 1.],
        [0., 0., 0.]])
torch.Size([2, 3])
A = A.reshape(1,3)
B = B.reshape(1,3)
C = torch.cat([A, B], axis=0)
print(C)
print(C.shape)
tensor([[1., 1., 1.],
        [0., 0., 0.]])
torch.Size([2, 3])

EXERCISE

  1. Create a Tensor:
    Create a tensor A of shape [3, 4] filled with random values.

  2. Transpose:
    Compute the transpose of A and store it in B.

  3. Matrix Multiplication:
    Multiply A by its transpose B using matrix multiplication. This will result in a new tensor C.

  4. Compute Means:
    Compute the mean of C along:

    • Dimension 0 (i.e., compute the mean of each column)

    • Dimension 1 (i.e., compute the mean of each row)

  5. Split:
    Split the tensor C into 2 equal chunks along dimension 0.
    Hint: If the number of rows is not evenly divisible, use torch.chunk which will distribute the rows as evenly as possible.

  6. Print:
    Print the following:

    • Tensor A and its shape.

    • Tensor B and its shape.

    • Tensor C and its shape.

    • The means computed in step 4.

    • The two chunks obtained in step 5.


13.7. Building input pipelines in PyTorch#

Creating a PyTorch DataLoader from existing tensors

from torch.utils.data import DataLoader

t = torch.arange(6, dtype=torch.float32)
print(t.shape)
data_loader = DataLoader(t)
torch.Size([6])
for item in data_loader:
    print(item, item.shape)
tensor([0.]) torch.Size([1])
tensor([1.]) torch.Size([1])
tensor([2.]) torch.Size([1])
tensor([3.]) torch.Size([1])
tensor([4.]) torch.Size([1])
tensor([5.]) torch.Size([1])
data_loader = DataLoader(t, batch_size=3, drop_last=False) # The final batch will be included even if it has fewer than batch_size elements.

for i, batch in enumerate(data_loader, 1):
    print(f'batch {i}:', batch)
batch 1: tensor([0., 1., 2.])
batch 2: tensor([3., 4., 5.])

Combining two tensors into a joint dataset

from torch.utils.data import Dataset

class JointDataset(Dataset):

    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    # "magic" methods: e.g., automatically involked when using indexing operator
torch.manual_seed(1)

t_x = torch.rand([4, 3], dtype=torch.float32)
t_y = torch.arange(4)
#joint_dataset = JointDataset(t_x, t_y)

# Or use TensorDataset directly
from torch.utils.data import TensorDataset
joint_dataset = TensorDataset(t_x, t_y) # Returns a tuple containing the i-th sample from each tensor (in this case, a feature tensor and the corresponding label)

for example in joint_dataset:
    print('  x: ', example[0],
          '  y: ', example[1])
  x:  tensor([0.7576, 0.2793, 0.4031])   y:  tensor(0)
  x:  tensor([0.7347, 0.0293, 0.7999])   y:  tensor(1)
  x:  tensor([0.3971, 0.7544, 0.5695])   y:  tensor(2)
  x:  tensor([0.4388, 0.6387, 0.5247])   y:  tensor(3)

Shuffle, batch, and repeat

torch.manual_seed(1)
data_loader = DataLoader(dataset=joint_dataset, batch_size=2, shuffle=True)

for i, batch in enumerate(data_loader, 1):
        print(f'batch {i}:', 'x:', batch[0],
              '\n         y:', batch[1])

print("\n\n-------------\n\n")

for epoch in range(2):
    print(f'***epoch #{epoch+1}')
    for i, batch in enumerate(data_loader, 1):
        print(f'batch {i}:', 'x:', batch[0],
              '\n         y:', batch[1])
batch 1: x: tensor([[0.3971, 0.7544, 0.5695],
        [0.7576, 0.2793, 0.4031]]) 
         y: tensor([2, 0])
batch 2: x: tensor([[0.7347, 0.0293, 0.7999],
        [0.4388, 0.6387, 0.5247]]) 
         y: tensor([1, 3])


-------------


***epoch #1
batch 1: x: tensor([[0.7576, 0.2793, 0.4031],
        [0.3971, 0.7544, 0.5695]]) 
         y: tensor([0, 2])
batch 2: x: tensor([[0.7347, 0.0293, 0.7999],
        [0.4388, 0.6387, 0.5247]]) 
         y: tensor([1, 3])
***epoch #2
batch 1: x: tensor([[0.4388, 0.6387, 0.5247],
        [0.3971, 0.7544, 0.5695]]) 
         y: tensor([3, 2])
batch 2: x: tensor([[0.7576, 0.2793, 0.4031],
        [0.7347, 0.0293, 0.7999]]) 
         y: tensor([0, 1])

Creating a dataset from files on your local storage disk or from a repository

#from google.colab import drive
#drive.mount('/content/drive')


#import pathlib

#imgdir_path = pathlib.Path('/content/drive/My Drive/W&M/Teaching/DATA621/cat_dog_images')

#file_list = sorted([str(path) for path in imgdir_path.glob('*.jpg')])
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader

# Define any preprocessing transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
# Download (if not already downloaded) and load the dataset.
# Setting download=True will automatically download the dataset to the specified root.
dataset = OxfordIIITPet(root='./data', download=True, transform=transform)

# Create a DataLoader for iterating through the dataset
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Print out some details
print(f"Total samples: {len(dataset)}")
100%|██████████| 792M/792M [00:38<00:00, 20.4MB/s]
100%|██████████| 19.2M/19.2M [00:01<00:00, 10.8MB/s]
Total samples: 3680
list_images, list_labels = map(list, zip(*dataset))
# there are 37-class labels (pet breeds)
np.unique(list_labels)
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36])
first6_images = []
first6_labels = []
for images, labels in data_loader:
    # images is typically a batch of images, so iterate over them
    for img, idx in zip(images, labels):
        first6_images.append(img)
        first6_labels.append(idx)
        if len(first6_images) == 6:
            break
    if len(first6_images) == 6:
        break

# Now, first6_images is a list of 6 image tensors.
#print(first6_images)
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(10, 5))
for i, img_tensor in enumerate(first6_images):
    # Convert the tensor to a NumPy array and change from (C, H, W) to (H, W, C)
    npimg = img_tensor.permute(1, 2, 0).cpu().numpy()
    print('Image shape:', npimg.shape)

    print(first6_labels[i])


    img_tensor.shape

    ax = fig.add_subplot(2, 3, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(npimg)
    ax.set_title(f"Image {i+1}; label: {first6_labels[i]} ", size=15)

plt.tight_layout()
plt.show()
Image shape: (224, 224, 3)
tensor(7)
Image shape: (224, 224, 3)
tensor(34)
Image shape: (224, 224, 3)
tensor(0)
Image shape: (224, 224, 3)
tensor(35)
Image shape: (224, 224, 3)
tensor(1)
Image shape: (224, 224, 3)
tensor(35)
_images/7bd431514269dd2bee318e3f0448316c2a68644020c2987a94da67a28516ab3f.png
class ImageDataset(Dataset):
    def __init__(self, file_list, labels):
        self.file_list = file_list
        self.labels = labels

    def __getitem__(self, index):
        file = self.file_list[index]
        label = self.labels[index]
        return file, label

    def __len__(self):
        return len(self.labels)

image_dataset = ImageDataset(first6_images, first6_labels)
for file, label in image_dataset:
    print(file, label)
tensor([[[0.1176, 0.1137, 0.1176,  ..., 0.9098, 0.9020, 0.8941],
         [0.1176, 0.1176, 0.1216,  ..., 0.9176, 0.9137, 0.9020],
         [0.1137, 0.1216, 0.1216,  ..., 0.9255, 0.9176, 0.9059],
         ...,
         [0.7843, 0.7686, 0.8353,  ..., 0.7529, 0.8039, 0.8314],
         [0.6863, 0.6706, 0.7529,  ..., 0.7059, 0.7373, 0.7961],
         [0.6784, 0.7373, 0.8431,  ..., 0.7137, 0.5608, 0.6196]],

        [[0.1333, 0.1294, 0.1373,  ..., 0.8824, 0.8706, 0.8588],
         [0.1373, 0.1333, 0.1373,  ..., 0.8941, 0.8863, 0.8706],
         [0.1333, 0.1333, 0.1294,  ..., 0.9059, 0.8902, 0.8824],
         ...,
         [0.7529, 0.7490, 0.8196,  ..., 0.6980, 0.7529, 0.7765],
         [0.6314, 0.6157, 0.7098,  ..., 0.6627, 0.7059, 0.7686],
         [0.6118, 0.6745, 0.7961,  ..., 0.7020, 0.5412, 0.5922]],

        [[0.1176, 0.1294, 0.1137,  ..., 0.8196, 0.8039, 0.7961],
         [0.1216, 0.1216, 0.1098,  ..., 0.8235, 0.8118, 0.8000],
         [0.1255, 0.1255, 0.1176,  ..., 0.8353, 0.8314, 0.8235],
         ...,
         [0.6863, 0.6784, 0.7569,  ..., 0.6275, 0.6863, 0.7333],
         [0.5412, 0.5216, 0.6235,  ..., 0.5843, 0.6392, 0.7255],
         [0.5255, 0.5961, 0.7216,  ..., 0.6431, 0.4471, 0.5176]]]) tensor(7)
tensor([[[0.1804, 0.1804, 0.1961,  ..., 0.8549, 0.3882, 0.1412],
         [0.1765, 0.1843, 0.2000,  ..., 0.8863, 0.4431, 0.1608],
         [0.1804, 0.1843, 0.2039,  ..., 0.9059, 0.4863, 0.1843],
         ...,
         [0.6863, 0.6353, 0.7333,  ..., 0.3412, 0.5451, 0.7255],
         [0.7804, 0.8000, 0.6980,  ..., 0.4784, 0.4000, 0.5882],
         [0.6784, 0.8275, 0.7373,  ..., 0.6706, 0.4275, 0.6314]],

        [[0.2118, 0.2235, 0.2392,  ..., 0.6745, 0.3216, 0.1804],
         [0.2157, 0.2275, 0.2431,  ..., 0.7020, 0.3569, 0.1804],
         [0.2196, 0.2314, 0.2471,  ..., 0.7216, 0.3843, 0.1882],
         ...,
         [0.6510, 0.6667, 0.7882,  ..., 0.3725, 0.5059, 0.7412],
         [0.8000, 0.7882, 0.6667,  ..., 0.4784, 0.3765, 0.6000],
         [0.6667, 0.7373, 0.6549,  ..., 0.6549, 0.4314, 0.6510]],

        [[0.1765, 0.1961, 0.2039,  ..., 0.5882, 0.2941, 0.2078],
         [0.1804, 0.2000, 0.2078,  ..., 0.6157, 0.3216, 0.2000],
         [0.1882, 0.2039, 0.2118,  ..., 0.6314, 0.3412, 0.2000],
         ...,
         [0.4745, 0.3843, 0.5412,  ..., 0.1725, 0.3020, 0.5686],
         [0.5882, 0.5882, 0.4863,  ..., 0.2706, 0.1765, 0.4078],
         [0.4510, 0.5961, 0.4902,  ..., 0.4039, 0.1686, 0.3725]]]) tensor(34)
tensor([[[0.3686, 0.3725, 0.3725,  ..., 0.8706, 0.8863, 0.8980],
         [0.3725, 0.3765, 0.3804,  ..., 0.8784, 0.8784, 0.8824],
         [0.3804, 0.3804, 0.3843,  ..., 0.8824, 0.8784, 0.8706],
         ...,
         [0.4353, 0.4392, 0.4431,  ..., 0.4863, 0.4902, 0.4863],
         [0.4353, 0.4392, 0.4431,  ..., 0.4863, 0.4902, 0.4863],
         [0.4353, 0.4392, 0.4431,  ..., 0.4863, 0.4902, 0.4863]],

        [[0.3333, 0.3373, 0.3373,  ..., 0.9882, 0.9882, 0.9922],
         [0.3333, 0.3373, 0.3412,  ..., 0.9922, 0.9804, 0.9804],
         [0.3373, 0.3373, 0.3412,  ..., 0.9882, 0.9843, 0.9843],
         ...,
         [0.3765, 0.3804, 0.3765,  ..., 0.4314, 0.4275, 0.4235],
         [0.3765, 0.3804, 0.3765,  ..., 0.4314, 0.4275, 0.4235],
         [0.3765, 0.3804, 0.3765,  ..., 0.4314, 0.4275, 0.4235]],

        [[0.2667, 0.2706, 0.2706,  ..., 0.8196, 0.8431, 0.8627],
         [0.2627, 0.2667, 0.2706,  ..., 0.8078, 0.8157, 0.8157],
         [0.2667, 0.2667, 0.2706,  ..., 0.8000, 0.7882, 0.7725],
         ...,
         [0.2941, 0.2980, 0.3059,  ..., 0.3294, 0.3255, 0.3216],
         [0.2941, 0.2980, 0.3059,  ..., 0.3294, 0.3255, 0.3216],
         [0.2941, 0.2980, 0.3059,  ..., 0.3255, 0.3255, 0.3216]]]) tensor(0)
tensor([[[0.4118, 0.4000, 0.5961,  ..., 0.8745, 0.8667, 0.8706],
         [0.3804, 0.3333, 0.5216,  ..., 0.8824, 0.8784, 0.8784],
         [0.5333, 0.4980, 0.5333,  ..., 0.8863, 0.8863, 0.8863],
         ...,
         [0.0627, 0.0902, 0.0941,  ..., 0.1529, 0.1686, 0.1098],
         [0.0745, 0.0784, 0.1020,  ..., 0.1765, 0.1059, 0.0980],
         [0.0902, 0.0588, 0.0902,  ..., 0.1333, 0.1059, 0.1098]],

        [[0.4275, 0.4314, 0.6431,  ..., 0.9765, 0.9804, 0.9804],
         [0.4157, 0.3804, 0.5765,  ..., 0.9804, 0.9765, 0.9765],
         [0.5882, 0.5490, 0.5922,  ..., 0.9765, 0.9765, 0.9765],
         ...,
         [0.1137, 0.1412, 0.1451,  ..., 0.2588, 0.2667, 0.2039],
         [0.1255, 0.1255, 0.1529,  ..., 0.2667, 0.2000, 0.1843],
         [0.1412, 0.1059, 0.1451,  ..., 0.2118, 0.1961, 0.1804]],

        [[0.2902, 0.3373, 0.5961,  ..., 0.9922, 0.9961, 0.9961],
         [0.2745, 0.2314, 0.4706,  ..., 1.0000, 0.9961, 0.9961],
         [0.4392, 0.3922, 0.5137,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.0902, 0.1137, 0.1176,  ..., 0.1255, 0.1490, 0.0902],
         [0.1137, 0.1020, 0.1294,  ..., 0.1529, 0.0902, 0.0784],
         [0.1412, 0.0863, 0.1216,  ..., 0.1020, 0.0863, 0.0784]]]) tensor(35)
tensor([[[0.2784, 0.2863, 0.2745,  ..., 0.2235, 0.2078, 0.2000],
         [0.2824, 0.2824, 0.2745,  ..., 0.2118, 0.1882, 0.1882],
         [0.2824, 0.2863, 0.2784,  ..., 0.1922, 0.1804, 0.1843],
         ...,
         [0.1608, 0.1647, 0.1569,  ..., 0.6902, 0.7020, 0.6941],
         [0.1608, 0.1608, 0.1608,  ..., 0.6824, 0.6863, 0.6745],
         [0.1569, 0.1608, 0.1608,  ..., 0.6667, 0.6745, 0.6667]],

        [[0.1804, 0.1804, 0.1804,  ..., 0.1294, 0.1176, 0.1176],
         [0.1804, 0.1765, 0.1765,  ..., 0.1255, 0.1098, 0.1098],
         [0.1804, 0.1765, 0.1765,  ..., 0.1098, 0.1020, 0.1059],
         ...,
         [0.0941, 0.0980, 0.0941,  ..., 0.6980, 0.7137, 0.7059],
         [0.0980, 0.0980, 0.0980,  ..., 0.6980, 0.7020, 0.6863],
         [0.0941, 0.0980, 0.0980,  ..., 0.6824, 0.6902, 0.6824]],

        [[0.0235, 0.0235, 0.0157,  ..., 0.1020, 0.0941, 0.0902],
         [0.0235, 0.0196, 0.0196,  ..., 0.0980, 0.0863, 0.0863],
         [0.0235, 0.0235, 0.0196,  ..., 0.0863, 0.0784, 0.0824],
         ...,
         [0.0392, 0.0392, 0.0392,  ..., 0.7412, 0.7569, 0.7490],
         [0.0471, 0.0471, 0.0431,  ..., 0.7412, 0.7490, 0.7294],
         [0.0431, 0.0471, 0.0510,  ..., 0.7255, 0.7333, 0.7255]]]) tensor(1)
tensor([[[0.5765, 0.5843, 0.5804,  ..., 0.5725, 0.6039, 0.6157],
         [0.5725, 0.5922, 0.5922,  ..., 0.5843, 0.6039, 0.6235],
         [0.5725, 0.5922, 0.5961,  ..., 0.5961, 0.6196, 0.6235],
         ...,
         [0.2039, 0.1647, 0.1608,  ..., 0.7451, 0.7490, 0.7490],
         [0.2118, 0.2000, 0.1804,  ..., 0.7294, 0.7333, 0.7255],
         [0.1882, 0.2157, 0.2039,  ..., 0.7098, 0.7059, 0.7059]],

        [[0.4549, 0.4549, 0.4471,  ..., 0.4588, 0.4745, 0.4902],
         [0.4510, 0.4588, 0.4549,  ..., 0.4706, 0.4745, 0.4980],
         [0.4392, 0.4549, 0.4588,  ..., 0.4824, 0.4863, 0.4941],
         ...,
         [0.2275, 0.0784, 0.0627,  ..., 0.0980, 0.1020, 0.1020],
         [0.3020, 0.1647, 0.0667,  ..., 0.1020, 0.1020, 0.0941],
         [0.3373, 0.2275, 0.0667,  ..., 0.1020, 0.0941, 0.0980]],

        [[0.3020, 0.2980, 0.2745,  ..., 0.3373, 0.3412, 0.3451],
         [0.3020, 0.3098, 0.3020,  ..., 0.3333, 0.3333, 0.3529],
         [0.3020, 0.3098, 0.3059,  ..., 0.3529, 0.3569, 0.3686],
         ...,
         [0.2275, 0.0902, 0.0588,  ..., 0.1647, 0.1647, 0.1647],
         [0.2980, 0.1647, 0.0627,  ..., 0.1647, 0.1647, 0.1569],
         [0.3216, 0.2196, 0.0706,  ..., 0.1608, 0.1569, 0.1569]]]) tensor(35)
fig = plt.figure(figsize=(10, 6))
for i, example in enumerate(image_dataset):
    ax = fig.add_subplot(2, 3, i+1)
    ax.set_xticks([]); ax.set_yticks([])

    #imshow expect a numpy as argument
    #The original shape (C, H, W) is transformed to (H, W, C)
    # imshow expectes the channel dimension to be the last axis.

    ax.imshow(example[0].numpy().transpose((1, 2,0)))

    ax.set_title(f'{example[1]}', size=15)

plt.tight_layout()
plt.show()
_images/9c6709d02766ee2aa2f5954fe26465b0a63651c771610496da32d96328e95671.png