13. PyTorch Pre-Flight#
from IPython.display import Image as IPythonImage
%matplotlib inline
Performance Challenge
Specification |
Intel Core i9-13900K |
NVIDIA RTX 4090 |
---|---|---|
Architecture |
Hybrid design (8 P‑cores + 16 E‑cores) |
Ada Lovelace (16,384 CUDA cores) |
Cores |
24 total (8 Performance cores, 16 Efficient cores, 32 Threads) |
16,384 CUDA cores |
Base Clock Frequency |
~3.0 GHz (P‑cores); ~2.2 GHz (E‑cores) |
~2.23 GHz (base), ~2.52 GHz (boost) |
Memory Bandwidth |
System DDR5 (~50–60+ GB/s) |
1,008 GB/s (24 GB GDDR6X, 384-bit bus) |
Floating Point Performance |
~500 GFLOPS (FP32) |
~82.6 TFLOPS (FP32) |
Cost |
$589–$600 USD |
$1,599 USD (MSRP) |
13.1. Tensors#
In data science, a tensor is essentially a multi-dimensional array that generalizes the concepts of scalars (0-dimensional), vectors (1-dimensional), and matrices (2-dimensional) to higher dimensions (higher-dimensional tensors extend this idea to 3D, 4D, and beyond).
13.2. Installing PyTorch#
#--- check CUDA Driver
!nvidia-smi
/bin/bash: line 1: nvidia-smi: command not found
# choose between "CPU" or "GPU"
install_type = "CPU"
cuda_version = "cu124" # see above
#--- you can install the most recent release of PyTorch
#--- for CPU installation
if install_type=="CPU":
%pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
#--- for GPU installation (the URL below hosts PyTorch wheels compiled with CUDA 12.4.)
elif install_type=="GPU":
%pip install torch==2.6.0+cu124 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu124
Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torch==2.6.0
Downloading https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl.metadata (26 kB)
Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.20.1+cu124)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (3.17.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (4.12.2)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (3.1.5)
Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (2024.10.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch==2.6.0) (1.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.
Collecting torchvision
Downloading https://download.pytorch.org/whl/cpu/torchvision-0.21.0%2Bcpu-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.1.0)
INFO: pip is looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.
Collecting torchaudio
Downloading https://download.pytorch.org/whl/cpu/torchaudio-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.6.0) (3.0.2)
Downloading https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl (178.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 178.7/178.7 MB 7.0 MB/s eta 0:00:00
?25hDownloading https://download.pytorch.org/whl/cpu/torchvision-0.21.0%2Bcpu-cp311-cp311-linux_x86_64.whl (1.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 49.2 MB/s eta 0:00:00
?25hDownloading https://download.pytorch.org/whl/cpu/torchaudio-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl (1.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 48.2 MB/s eta 0:00:00
?25hInstalling collected packages: torch, torchvision, torchaudio
Attempting uninstall: torch
Found existing installation: torch 2.5.1+cu124
Uninstalling torch-2.5.1+cu124:
Successfully uninstalled torch-2.5.1+cu124
Attempting uninstall: torchvision
Found existing installation: torchvision 0.20.1+cu124
Uninstalling torchvision-0.20.1+cu124:
Successfully uninstalled torchvision-0.20.1+cu124
Attempting uninstall: torchaudio
Found existing installation: torchaudio 2.5.1+cu124
Uninstalling torchaudio-2.5.1+cu124:
Successfully uninstalled torchaudio-2.5.1+cu124
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastai 2.7.18 requires torch<2.6,>=1.10, but you have torch 2.6.0+cpu which is incompatible.
Successfully installed torch-2.6.0+cpu torchaudio-2.6.0+cpu torchvision-0.21.0+cpu
%pip list
Package Version
---------------------------------- -------------------
absl-py 1.4.0
accelerate 1.3.0
aiohappyeyeballs 2.4.6
aiohttp 3.11.12
aiosignal 1.3.2
alabaster 1.0.0
albucore 0.0.23
albumentations 2.0.4
ale-py 0.10.1
altair 5.5.0
annotated-types 0.7.0
anyio 3.7.1
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
array_record 0.6.0
arviz 0.20.0
astropy 7.0.1
astropy-iers-data 0.2025.2.10.0.33.26
astunparse 1.6.3
atpublic 4.1.0
attrs 25.1.0
audioread 3.0.1
autograd 1.7.0
babel 2.17.0
backcall 0.2.0
beautifulsoup4 4.13.3
betterproto 2.0.0b6
bigframes 1.36.0
bigquery-magics 0.5.0
bleach 6.2.0
blinker 1.9.0
blis 0.7.11
blosc2 3.0.0
bokeh 3.6.3
Bottleneck 1.4.2
bqplot 0.12.44
branca 0.8.1
CacheControl 0.14.2
cachetools 5.5.1
catalogue 2.0.10
certifi 2025.1.31
cffi 1.17.1
chardet 5.2.0
charset-normalizer 3.4.1
chex 0.1.88
clarabel 0.10.0
click 8.1.8
cloudpathlib 0.20.0
cloudpickle 3.1.1
cmake 3.31.4
cmdstanpy 1.2.5
colorcet 3.1.0
colorlover 0.3.0
colour 0.1.5
community 1.0.0b1
confection 0.1.5
cons 0.4.6
contourpy 1.3.1
cramjam 2.9.1
cryptography 43.0.3
cuda-python 12.6.0
cudf-cu12 24.12.0
cufflinks 0.17.3
cupy-cuda12x 13.3.0
cvxopt 1.3.2
cvxpy 1.6.0
cycler 0.12.1
cyipopt 1.5.0
cymem 2.0.11
Cython 3.0.12
dask 2024.10.0
datascience 0.17.6
db-dtypes 1.4.1
dbus-python 1.2.18
debugpy 1.8.0
decorator 4.4.2
defusedxml 0.7.1
Deprecated 1.2.18
diffusers 0.32.2
distro 1.9.0
dlib 19.24.2
dm-tree 0.1.9
dnspython 2.7.0
docker-pycreds 0.4.0
docstring_parser 0.16
docutils 0.21.2
dopamine_rl 4.1.2
duckdb 1.1.3
earthengine-api 1.5.2
easydict 1.13
editdistance 0.8.1
eerepr 0.1.0
einops 0.8.1
email_validator 2.2.0
en-core-web-sm 3.7.1
entrypoints 0.4
et_xmlfile 2.0.0
etils 1.12.0
etuples 0.3.9
Farama-Notifications 0.0.4
fastai 2.7.18
fastcore 1.7.29
fastdownload 0.0.7
fastjsonschema 2.21.1
fastprogress 1.0.3
fastrlock 0.8.3
filelock 3.17.0
firebase-admin 6.6.0
Flask 3.1.0
flatbuffers 25.2.10
flax 0.10.3
folium 0.19.4
fonttools 4.56.0
frozendict 2.4.6
frozenlist 1.5.0
fsspec 2024.10.0
future 1.0.0
gast 0.6.0
gcsfs 2024.10.0
GDAL 3.6.4
gdown 5.2.0
geemap 0.35.1
gensim 4.3.3
geocoder 1.38.1
geographiclib 2.0
geopandas 1.0.1
geopy 2.4.1
gin-config 0.5.0
gitdb 4.0.12
GitPython 3.1.44
glob2 0.7
google 2.0.3
google-ai-generativelanguage 0.6.15
google-api-core 2.19.2
google-api-python-client 2.160.0
google-auth 2.27.0
google-auth-httplib2 0.2.0
google-auth-oauthlib 1.2.1
google-cloud-aiplatform 1.79.0
google-cloud-bigquery 3.25.0
google-cloud-bigquery-connection 1.17.0
google-cloud-bigquery-storage 2.28.0
google-cloud-bigtable 2.28.1
google-cloud-core 2.4.1
google-cloud-dataproc 5.16.0
google-cloud-datastore 2.20.2
google-cloud-firestore 2.20.0
google-cloud-functions 1.19.0
google-cloud-iam 2.17.0
google-cloud-language 2.16.0
google-cloud-pubsub 2.25.0
google-cloud-resource-manager 1.14.0
google-cloud-spanner 3.51.0
google-cloud-storage 2.19.0
google-cloud-translate 3.19.0
google-colab 1.0.0
google-crc32c 1.6.0
google-genai 0.8.0
google-generativeai 0.8.4
google-pasta 0.2.0
google-resumable-media 2.7.2
google-spark-connect 0.5.2
googleapis-common-protos 1.66.0
googledrivedownloader 1.1.0
graphviz 0.20.3
greenlet 3.1.1
grpc-google-iam-v1 0.14.0
grpc-interceptor 0.15.4
grpcio 1.70.0
grpcio-status 1.62.3
grpclib 0.4.7
gspread 6.1.4
gspread-dataframe 4.0.0
gym 0.25.2
gym-notices 0.0.8
gymnasium 1.0.0
h11 0.14.0
h2 4.2.0
h5netcdf 1.5.0
h5py 3.12.1
highspy 1.9.0
holidays 0.66
holoviews 1.20.0
hpack 4.1.0
html5lib 1.1
httpcore 1.0.7
httpimport 1.4.0
httplib2 0.22.0
httpx 0.28.1
huggingface-hub 0.28.1
humanize 4.11.0
hyperframe 6.1.0
hyperopt 0.2.7
ibis-framework 9.2.0
id 1.5.0
idna 3.10
imageio 2.37.0
imageio-ffmpeg 0.6.0
imagesize 1.4.1
imbalanced-learn 0.13.0
imgaug 0.4.0
immutabledict 4.2.1
importlib_metadata 8.6.1
importlib_resources 6.5.2
imutils 0.5.4
in-toto-attestation 0.9.3
inflect 7.5.0
iniconfig 2.0.0
intel-cmplr-lib-ur 2025.0.4
intel-openmp 2025.0.4
ipyevents 2.0.2
ipyfilechooser 0.6.0
ipykernel 5.5.6
ipyleaflet 0.19.2
ipyparallel 8.8.0
ipython 7.34.0
ipython-genutils 0.2.0
ipython-sql 0.5.0
ipytree 0.2.2
ipywidgets 7.7.1
itsdangerous 2.2.0
jax 0.4.33
jax-cuda12-pjrt 0.4.33
jax-cuda12-plugin 0.4.33
jaxlib 0.4.33
jeepney 0.7.1
jellyfish 1.1.0
jieba 0.42.1
Jinja2 3.1.5
jiter 0.8.2
joblib 1.4.2
jsonpatch 1.33
jsonpickle 4.0.1
jsonpointer 3.0.0
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
jupyter-client 6.1.12
jupyter-console 6.1.0
jupyter_core 5.7.2
jupyter-leaflet 0.19.2
jupyter-server 1.24.0
jupyterlab_pygments 0.3.0
jupyterlab_widgets 3.0.13
kaggle 1.6.17
kagglehub 0.3.7
keras 3.8.0
keras-hub 0.18.1
keras-nlp 0.18.1
keyring 23.5.0
kiwisolver 1.4.8
langchain 0.3.18
langchain-core 0.3.35
langchain-text-splitters 0.3.6
langcodes 3.5.0
langsmith 0.3.8
language_data 1.3.0
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
lazy_loader 0.4
libclang 18.1.1
libcudf-cu12 24.12.0
libkvikio-cu12 24.12.1
librosa 0.10.2.post1
lightgbm 4.5.0
linkify-it-py 2.0.3
llvmlite 0.44.0
locket 1.0.0
logical-unification 0.4.6
lxml 5.3.1
marisa-trie 1.2.1
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.10.0
matplotlib-inline 0.1.7
matplotlib-venn 1.1.1
mdit-py-plugins 0.4.2
mdurl 0.1.2
miniKanren 1.0.3
missingno 0.5.2
mistune 3.1.1
mizani 0.13.1
mkl 2025.0.1
ml-dtypes 0.4.1
mlxtend 0.23.4
model-signing 0.2.0
more-itertools 10.6.0
moviepy 1.0.3
mpmath 1.3.0
msgpack 1.1.0
multidict 6.1.0
multipledispatch 1.0.0
multitasking 0.0.11
murmurhash 1.0.12
music21 9.3.0
namex 0.0.8
narwhals 1.26.0
natsort 8.4.0
nbclassic 1.2.0
nbclient 0.10.2
nbconvert 7.16.6
nbformat 5.10.4
ndindex 1.9.2
nest-asyncio 1.6.0
networkx 3.4.2
nibabel 5.3.2
nltk 3.9.1
notebook 6.5.5
notebook_shim 0.2.4
numba 0.61.0
numba-cuda 0.0.17.1
numexpr 2.10.2
numpy 1.26.4
nvidia-cublas-cu12 12.5.3.2
nvidia-cuda-cupti-cu12 12.5.82
nvidia-cuda-nvcc-cu12 12.5.82
nvidia-cuda-nvrtc-cu12 12.5.82
nvidia-cuda-runtime-cu12 12.5.82
nvidia-cudnn-cu12 9.3.0.75
nvidia-cufft-cu12 11.2.3.61
nvidia-curand-cu12 10.3.6.82
nvidia-cusolver-cu12 11.6.3.83
nvidia-cusparse-cu12 12.5.1.3
nvidia-nccl-cu12 2.21.5
nvidia-nvcomp-cu12 4.1.0.6
nvidia-nvjitlink-cu12 12.5.82
nvidia-nvtx-cu12 12.4.127
nvtx 0.2.10
nx-cugraph-cu12 24.12.0
oauth2client 4.1.3
oauthlib 3.2.2
openai 1.61.1
opencv-contrib-python 4.11.0.86
opencv-python 4.11.0.86
opencv-python-headless 4.11.0.86
openpyxl 3.1.5
opentelemetry-api 1.16.0
opentelemetry-sdk 1.16.0
opentelemetry-semantic-conventions 0.37b0
opt_einsum 3.4.0
optax 0.2.4
optree 0.14.0
orbax-checkpoint 0.6.4
orjson 3.10.15
osqp 0.6.7.post3
packaging 24.2
pandas 2.2.2
pandas-datareader 0.10.0
pandas-gbq 0.26.1
pandas-stubs 2.2.2.240909
pandocfilters 1.5.1
panel 1.6.0
param 2.2.0
parso 0.8.4
parsy 2.1
partd 1.4.2
pathlib 1.0.1
patsy 1.0.1
peewee 3.17.9
peft 0.14.0
pexpect 4.9.0
pickleshare 0.7.5
pillow 11.1.0
pip 24.1.2
platformdirs 4.3.6
plotly 5.24.1
plotnine 0.14.5
pluggy 1.5.0
ply 3.11
polars 1.9.0
pooch 1.8.2
portpicker 1.5.2
preshed 3.0.9
prettytable 3.14.0
proglog 0.1.10
progressbar2 4.5.0
prometheus_client 0.21.1
promise 2.3
prompt_toolkit 3.0.50
propcache 0.2.1
prophet 1.1.6
proto-plus 1.26.0
protobuf 4.25.6
psutil 5.9.5
psycopg2 2.9.10
ptyprocess 0.7.0
py-cpuinfo 9.0.0
py4j 0.10.9.7
pyarrow 17.0.0
pyasn1 0.6.1
pyasn1_modules 0.4.1
pycocotools 2.0.8
pycparser 2.22
pydantic 2.10.6
pydantic_core 2.27.2
pydata-google-auth 1.9.1
pydot 3.0.4
pydotplus 2.0.2
PyDrive 1.3.1
PyDrive2 1.21.3
pyerfa 2.0.1.5
pygame 2.6.1
pygit2 1.17.0
Pygments 2.18.0
PyGObject 3.42.1
PyJWT 2.10.1
pylibcudf-cu12 24.12.0
pylibcugraph-cu12 24.12.0
pylibraft-cu12 24.12.0
pymc 5.20.1
pymystem3 0.2.0
pynvjitlink-cu12 0.5.0
pyogrio 0.10.0
Pyomo 6.8.2
PyOpenGL 3.1.9
pyOpenSSL 24.2.1
pyparsing 3.2.1
pyperclip 1.9.0
pyproj 3.7.0
pyshp 2.3.1
PySocks 1.7.1
pyspark 3.5.4
pytensor 2.27.1
pytest 8.3.4
python-apt 0.0.0
python-box 7.3.2
python-dateutil 2.8.2
python-louvain 0.16
python-slugify 8.0.4
python-snappy 0.7.3
python-utils 3.9.1
pytz 2025.1
pyviz_comms 3.0.4
PyYAML 6.0.2
pyzmq 24.0.1
qdldl 0.1.7.post5
ratelim 0.1.6
referencing 0.36.2
regex 2024.11.6
requests 2.32.3
requests-oauthlib 2.0.0
requests-toolbelt 1.0.0
requirements-parser 0.9.0
rfc3161-client 0.1.2
rfc8785 0.1.4
rich 13.9.4
rmm-cu12 24.12.1
rpds-py 0.22.3
rpy2 3.4.2
rsa 4.9
safetensors 0.5.2
scikit-image 0.25.1
scikit-learn 1.6.1
scipy 1.13.1
scooby 0.10.0
scs 3.2.7.post2
seaborn 0.13.2
SecretStorage 3.3.1
securesystemslib 1.2.0
Send2Trash 1.8.3
sentence-transformers 3.4.1
sentencepiece 0.2.0
sentry-sdk 2.21.0
setproctitle 1.3.4
setuptools 75.1.0
shap 0.46.0
shapely 2.0.7
shellingham 1.5.4
sigstore 3.6.1
sigstore-protobuf-specs 0.3.2
sigstore-rekor-types 0.0.18
simple-parsing 0.1.7
simsimd 6.2.1
six 1.17.0
sklearn-compat 0.1.3
sklearn-pandas 2.2.0
slicer 0.0.8
smart-open 7.1.0
smmap 5.0.2
sniffio 1.3.1
snowballstemmer 2.2.0
soundfile 0.13.1
soupsieve 2.6
soxr 0.5.0.post1
spacy 3.7.5
spacy-legacy 3.0.12
spacy-loggers 1.0.5
spanner-graph-notebook 1.0.9
Sphinx 8.1.3
sphinxcontrib-applehelp 2.0.0
sphinxcontrib-devhelp 2.0.0
sphinxcontrib-htmlhelp 2.1.0
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 2.0.0
sphinxcontrib-serializinghtml 2.0.0
SQLAlchemy 2.0.38
sqlglot 25.6.1
sqlparse 0.5.3
srsly 2.5.1
stanio 0.5.1
statsmodels 0.14.4
stringzilla 3.11.3
sympy 1.13.1
tables 3.10.2
tabulate 0.9.0
tbb 2022.0.0
tcmlib 1.2.0
tenacity 9.0.0
tensorboard 2.18.0
tensorboard-data-server 0.7.2
tensorflow 2.18.0
tensorflow-datasets 4.9.7
tensorflow-hub 0.16.1
tensorflow-io-gcs-filesystem 0.37.1
tensorflow-metadata 1.16.1
tensorflow-probability 0.25.0
tensorflow-text 2.18.1
tensorstore 0.1.71
termcolor 2.5.0
terminado 0.18.1
text-unidecode 1.3
textblob 0.19.0
tf_keras 2.18.0
tf-slim 1.1.0
thinc 8.2.5
threadpoolctl 3.5.0
tifffile 2025.1.10
timm 1.0.14
tinycss2 1.4.0
tokenizers 0.21.0
toml 0.10.2
toolz 0.12.1
torch 2.6.0+cpu
torchaudio 2.6.0+cpu
torchsummary 1.5.1
torchvision 0.21.0+cpu
tornado 6.4.2
tqdm 4.67.1
traitlets 5.7.1
traittypes 0.2.1
transformers 4.48.3
treescope 0.1.8
triton 3.1.0
tuf 5.1.0
tweepy 4.15.0
typeguard 4.4.1
typer 0.15.1
types-pytz 2025.1.0.20250204
types-setuptools 75.8.0.20250210
typing_extensions 4.12.2
tzdata 2025.1
tzlocal 5.2
uc-micro-py 1.0.3
umf 0.9.1
uritemplate 4.1.1
urllib3 2.3.0
vega-datasets 0.9.0
wadllib 1.3.6
wandb 0.19.6
wasabi 1.1.3
wcwidth 0.2.13
weasel 0.4.1
webcolors 24.11.1
webencodings 0.5.1
websocket-client 1.8.0
websockets 14.2
Werkzeug 3.1.3
wheel 0.45.1
widgetsnbextension 3.6.10
wordcloud 1.9.4
wrapt 1.17.2
xarray 2025.1.2
xarray-einstats 0.8.0
xgboost 2.1.4
xlrd 2.0.1
xyzservices 2025.1.0
yarl 1.18.3
yellowbrick 1.5
yfinance 0.2.52
zipp 3.21.0
zstandard 0.23.0
%pip install cuda-python==12.4.0
Collecting cuda-python==12.4.0
Downloading cuda_python-12.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading cuda_python-12.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 25.4/25.4 MB 9.3 MB/s eta 0:00:00
?25hInstalling collected packages: cuda-python
Attempting uninstall: cuda-python
Found existing installation: cuda-python 12.6.0
Uninstalling cuda-python-12.6.0:
Successfully uninstalled cuda-python-12.6.0
Successfully installed cuda-python-12.4.0
#%pip show cuda-python
import torch
import numpy as np
print('PyTorch version:', torch.__version__)
np.set_printoptions(precision=3) # sets the precision of the printed numbers to three decimal places
PyTorch version: 2.6.0+cpu
import sys
print(sys.version)
3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0]
# Check if GPU is available
if torch.cuda.is_available():
print("GPU is available!")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print(torch.version.cuda)
print(torch.cuda.device_count())
else:
print("GPU not available, using CPU.")
GPU not available, using CPU.
13.3. Creating tensors in Pytorch#
a = [1, 2, 3]
b = np.array([4, 5, 6], dtype=np.int32)
t_a = torch.tensor(a)
t_b = torch.from_numpy(b)
print(t_a)
print(t_b)
tensor([1, 2, 3])
tensor([4, 5, 6], dtype=torch.int32)
torch.is_tensor(a), torch.is_tensor(t_a)
(False, True)
print(type(t_a))
print(t_a.dtype)
<class 'torch.Tensor'>
torch.int64
print(type(t_b))
print(t_b.dtype)
<class 'torch.Tensor'>
torch.int32
t_ones = torch.ones(2, 3)
t_ones.shape
torch.Size([2, 3])
print(t_ones)
tensor([[1., 1., 1.],
[1., 1., 1.]])
t_zeros = torch.zeros(2, 3)
print(t_zeros.shape)
print(t_zeros)
torch.Size([2, 3])
tensor([[0., 0., 0.],
[0., 0., 0.]])
rand_tensor = torch.rand(2,3)
print(rand_tensor)
tensor([[0.1406, 0.9446, 0.4206],
[0.0613, 0.3332, 0.6908]])
13.4. Manipulating the data type and shape of tensors#
print(t_b.dtype)
t_b_new = t_b.to(torch.int64)
print(t_b_new.dtype)
torch.int32
torch.int64
import gc #garbage collector
import torch
def get_all_tensors():
# Iterate through all objects tracked by the garbage collector
tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor)]
return tensors
# Print all tensors
for t in get_all_tensors():
print(t)
tensor([4, 5, 6])
tensor([4, 5, 6])
tensor([1, 2, 3])
tensor([4, 5, 6], dtype=torch.int32)
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[0.1406, 0.9446, 0.4206],
[0.0613, 0.3332, 0.6908]])
<ipython-input-25-eb422b9ed847>:6: FutureWarning: `torch.distributed.reduce_op` is deprecated, please use `torch.distributed.ReduceOp` instead
tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor)]
t = torch.rand(3, 5)
t_tr = torch.transpose(t, 1, 0) #dim0, #dim1 (t,0,1) or (t,1,0) return the same
print(t.shape, ' --> ', t_tr.shape)
torch.Size([3, 5]) --> torch.Size([5, 3])
t
tensor([[0.8835, 0.1709, 0.3207, 0.4077, 0.8177],
[0.0383, 0.1107, 0.7277, 0.1610, 0.7233],
[0.6170, 0.0494, 0.4231, 0.8764, 0.3719]])
t_tr
tensor([[0.4224, 0.0221, 0.8720],
[0.3779, 0.1804, 0.3054],
[0.4365, 0.0268, 0.6949],
[0.8605, 0.2145, 0.0506],
[0.2159, 0.6308, 0.8648]])
t = torch.zeros(30)
t_reshape = t.reshape(5, 6)
print(t_reshape.shape)
torch.Size([5, 6])
t = torch.zeros(1, 2, 1, 4, 1)
# The torch.squeeze() function removes dimensions of size 1.
# By specifying 2 as the dimension, you’re telling PyTorch to remove only dimension 2 if it is of size 1.
# where dim0, dim1, dim2, etc...
t_sqz = torch.squeeze(t, 2)
print(t.shape, ' --> ', t_sqz.shape)
print(t,"\n\n")
print(t_sqz)
torch.Size([1, 2, 1, 4, 1]) --> torch.Size([1, 2, 4, 1])
tensor([[[[[0.],
[0.],
[0.],
[0.]]],
[[[0.],
[0.],
[0.],
[0.]]]]])
tensor([[[[0.],
[0.],
[0.],
[0.]],
[[0.],
[0.],
[0.],
[0.]]]])
13.5. Applying mathematical operations to tensors#
torch.manual_seed(1) #sets the seed for PyTorch's random number generator
t1 = 2 * torch.rand(5, 2) - 1.
t2 = torch.normal(mean=0, std=1, size=(5, 2))
print(f"t1={t1}\n\n")
print(t2)
t1=tensor([[ 0.5153, -0.4414],
[-0.1939, 0.4694],
[-0.9414, 0.5997],
[-0.2057, 0.5087],
[ 0.1390, -0.1224]])
tensor([[ 0.8590, 0.7056],
[-0.3406, -1.2720],
[-1.1948, 0.0250],
[-0.7627, 1.3969],
[-0.3245, 0.2879]])
t3 = torch.multiply(t1, t2)
print(t3)
tensor([[ 0.4426, -0.3114],
[ 0.0660, -0.5970],
[ 1.1249, 0.0150],
[ 0.1569, 0.7107],
[-0.0451, -0.0352]])
t4 = torch.mean(t3, axis=0)
print(t4)
tensor([ 0.3491, -0.0436])
t4_b = torch.mean(t3, axis=1)
print(t4_b, t4_b.shape)
tensor([ 0.0656, -0.2655, 0.5699, 0.4338, -0.0402]) torch.Size([5])
Note: These operations are implemented in PyTorch’s own backend (specifically the ATen library) in C/C++ to provide efficient and optimized tensor operations. This native implementation ensures that operations like torch.mean run fast on both CPU and GPU, and they are independent of NumPy’s functions.
t5 = torch.matmul(t1, torch.transpose(t2, 0, 1))
print(t5)
tensor([[ 0.1312, 0.3860, -0.6267, -1.0096, -0.2943],
[ 0.1647, -0.5310, 0.2434, 0.8035, 0.1980],
[-0.3855, -0.4422, 1.1399, 1.5558, 0.4781],
[ 0.1822, -0.5771, 0.2585, 0.8676, 0.2132],
[ 0.0330, 0.1084, -0.1692, -0.2771, -0.0804]])
t5.shape
torch.Size([5, 5])
t6 = torch.matmul(torch.transpose(t1, 0, 1), t2)
print(t6)
tensor([[ 1.7453, 0.3392],
[-1.6038, -0.2180]])
t1
tensor([[ 0.5153, -0.4414],
[-0.1939, 0.4694],
[-0.9414, 0.5997],
[-0.2057, 0.5087],
[ 0.1390, -0.1224]])
# calculates the norm (ord=2 means Euclidean norm, along dim)
norm_t1 = torch.linalg.norm(t1, ord=2, dim=1)
print(norm_t1)
tensor([0.6785, 0.5078, 1.1162, 0.5488, 0.1853])
torch.sqrt(t1[0][0]**2+t1[0][1]**2)
tensor(0.6785)
torch.sqrt(t1[0][0]**2+t1[0][1]**2).item()
0.6784621477127075
# to verify the above calculated the norm, we can do
np.sqrt(np.sum(np.square(t1.numpy()), axis=1))
array([0.678, 0.508, 1.116, 0.549, 0.185], dtype=float32)
13.6. Split, stack, and concatenate tensors#
torch.manual_seed(1)
t = torch.rand(6)
print(t)
t_splits = torch.chunk(t, 3) #divides in 3 chunks
[item.numpy() for item in t_splits]
tensor([0.7576, 0.2793, 0.4031, 0.7347, 0.0293, 0.7999])
[array([0.758, 0.279], dtype=float32),
array([0.403, 0.735], dtype=float32),
array([0.029, 0.8 ], dtype=float32)]
# what happens if you divide by 4?
# Number of chunks with ceil(6/4)=2; remaining chunk will be floor(6/4)=1. Therefore a total of 3 chunks again.
t_splits = torch.chunk(t, 4) #divides in 3 chunks
[item.numpy() for item in t_splits]
[array([0.758, 0.279], dtype=float32),
array([0.403, 0.735], dtype=float32),
array([0.029, 0.8 ], dtype=float32)]
See the above. Remember we did np.set_printoptions(precision=3) to set the precision of the printed numbers to three decimal places with numpy
Notice that the function torch.set_printoptions(precision=4) is used by default by PyTorch
Also notice that the numbers appear rounded for clarity and brevity during printing, but the actual computations and stored values maintain their full precision.
torch.manual_seed(1)
t = torch.rand(5)
print(t)
t_splits = torch.split(t, split_size_or_sections=[3, 2])
[item.numpy() for item in t_splits]
tensor([0.7576, 0.2793, 0.4031, 0.7347, 0.0293])
[array([0.758, 0.279, 0.403], dtype=float32),
array([0.735, 0.029], dtype=float32)]
A = torch.ones(2,2)
B = torch.zeros(3,2)
C = torch.cat([A, B], axis=0)
print(C)
print(C.shape)
tensor([[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.],
[0., 0.]])
torch.Size([5, 2])
A = torch.ones(3)
B = torch.zeros(3)
S = torch.stack([A, B], axis=0)
print(S)
print(S.shape)
tensor([[1., 1., 1.],
[0., 0., 0.]])
torch.Size([2, 3])
A = A.reshape(1,3)
B = B.reshape(1,3)
C = torch.cat([A, B], axis=0)
print(C)
print(C.shape)
tensor([[1., 1., 1.],
[0., 0., 0.]])
torch.Size([2, 3])
EXERCISE
Create a Tensor:
Create a tensorA
of shape[3, 4]
filled with random values.Transpose:
Compute the transpose ofA
and store it inB
.Matrix Multiplication:
MultiplyA
by its transposeB
using matrix multiplication. This will result in a new tensorC
.Compute Means:
Compute the mean ofC
along:Dimension 0 (i.e., compute the mean of each column)
Dimension 1 (i.e., compute the mean of each row)
Split:
Split the tensorC
into 2 equal chunks along dimension 0.
Hint: If the number of rows is not evenly divisible, usetorch.chunk
which will distribute the rows as evenly as possible.Print:
Print the following:Tensor
A
and its shape.Tensor
B
and its shape.Tensor
C
and its shape.The means computed in step 4.
The two chunks obtained in step 5.
13.7. Building input pipelines in PyTorch#
Creating a PyTorch DataLoader from existing tensors
from torch.utils.data import DataLoader
t = torch.arange(6, dtype=torch.float32)
print(t.shape)
data_loader = DataLoader(t)
torch.Size([6])
for item in data_loader:
print(item, item.shape)
tensor([0.]) torch.Size([1])
tensor([1.]) torch.Size([1])
tensor([2.]) torch.Size([1])
tensor([3.]) torch.Size([1])
tensor([4.]) torch.Size([1])
tensor([5.]) torch.Size([1])
data_loader = DataLoader(t, batch_size=3, drop_last=False) # The final batch will be included even if it has fewer than batch_size elements.
for i, batch in enumerate(data_loader, 1):
print(f'batch {i}:', batch)
batch 1: tensor([0., 1., 2.])
batch 2: tensor([3., 4., 5.])
Combining two tensors into a joint dataset
from torch.utils.data import Dataset
class JointDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
# "magic" methods: e.g., automatically involked when using indexing operator
torch.manual_seed(1)
t_x = torch.rand([4, 3], dtype=torch.float32)
t_y = torch.arange(4)
#joint_dataset = JointDataset(t_x, t_y)
# Or use TensorDataset directly
from torch.utils.data import TensorDataset
joint_dataset = TensorDataset(t_x, t_y) # Returns a tuple containing the i-th sample from each tensor (in this case, a feature tensor and the corresponding label)
for example in joint_dataset:
print(' x: ', example[0],
' y: ', example[1])
x: tensor([0.7576, 0.2793, 0.4031]) y: tensor(0)
x: tensor([0.7347, 0.0293, 0.7999]) y: tensor(1)
x: tensor([0.3971, 0.7544, 0.5695]) y: tensor(2)
x: tensor([0.4388, 0.6387, 0.5247]) y: tensor(3)
Shuffle, batch, and repeat
torch.manual_seed(1)
data_loader = DataLoader(dataset=joint_dataset, batch_size=2, shuffle=True)
for i, batch in enumerate(data_loader, 1):
print(f'batch {i}:', 'x:', batch[0],
'\n y:', batch[1])
print("\n\n-------------\n\n")
for epoch in range(2):
print(f'***epoch #{epoch+1}')
for i, batch in enumerate(data_loader, 1):
print(f'batch {i}:', 'x:', batch[0],
'\n y:', batch[1])
batch 1: x: tensor([[0.3971, 0.7544, 0.5695],
[0.7576, 0.2793, 0.4031]])
y: tensor([2, 0])
batch 2: x: tensor([[0.7347, 0.0293, 0.7999],
[0.4388, 0.6387, 0.5247]])
y: tensor([1, 3])
-------------
***epoch #1
batch 1: x: tensor([[0.7576, 0.2793, 0.4031],
[0.3971, 0.7544, 0.5695]])
y: tensor([0, 2])
batch 2: x: tensor([[0.7347, 0.0293, 0.7999],
[0.4388, 0.6387, 0.5247]])
y: tensor([1, 3])
***epoch #2
batch 1: x: tensor([[0.4388, 0.6387, 0.5247],
[0.3971, 0.7544, 0.5695]])
y: tensor([3, 2])
batch 2: x: tensor([[0.7576, 0.2793, 0.4031],
[0.7347, 0.0293, 0.7999]])
y: tensor([0, 1])
Creating a dataset from files on your local storage disk or from a repository
#from google.colab import drive
#drive.mount('/content/drive')
#import pathlib
#imgdir_path = pathlib.Path('/content/drive/My Drive/W&M/Teaching/DATA621/cat_dog_images')
#file_list = sorted([str(path) for path in imgdir_path.glob('*.jpg')])
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
# Define any preprocessing transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Download (if not already downloaded) and load the dataset.
# Setting download=True will automatically download the dataset to the specified root.
dataset = OxfordIIITPet(root='./data', download=True, transform=transform)
# Create a DataLoader for iterating through the dataset
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Print out some details
print(f"Total samples: {len(dataset)}")
100%|██████████| 792M/792M [00:38<00:00, 20.4MB/s]
100%|██████████| 19.2M/19.2M [00:01<00:00, 10.8MB/s]
Total samples: 3680
list_images, list_labels = map(list, zip(*dataset))
# there are 37-class labels (pet breeds)
np.unique(list_labels)
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36])
first6_images = []
first6_labels = []
for images, labels in data_loader:
# images is typically a batch of images, so iterate over them
for img, idx in zip(images, labels):
first6_images.append(img)
first6_labels.append(idx)
if len(first6_images) == 6:
break
if len(first6_images) == 6:
break
# Now, first6_images is a list of 6 image tensors.
#print(first6_images)
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10, 5))
for i, img_tensor in enumerate(first6_images):
# Convert the tensor to a NumPy array and change from (C, H, W) to (H, W, C)
npimg = img_tensor.permute(1, 2, 0).cpu().numpy()
print('Image shape:', npimg.shape)
print(first6_labels[i])
img_tensor.shape
ax = fig.add_subplot(2, 3, i+1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(npimg)
ax.set_title(f"Image {i+1}; label: {first6_labels[i]} ", size=15)
plt.tight_layout()
plt.show()
Image shape: (224, 224, 3)
tensor(7)
Image shape: (224, 224, 3)
tensor(34)
Image shape: (224, 224, 3)
tensor(0)
Image shape: (224, 224, 3)
tensor(35)
Image shape: (224, 224, 3)
tensor(1)
Image shape: (224, 224, 3)
tensor(35)

class ImageDataset(Dataset):
def __init__(self, file_list, labels):
self.file_list = file_list
self.labels = labels
def __getitem__(self, index):
file = self.file_list[index]
label = self.labels[index]
return file, label
def __len__(self):
return len(self.labels)
image_dataset = ImageDataset(first6_images, first6_labels)
for file, label in image_dataset:
print(file, label)
tensor([[[0.1176, 0.1137, 0.1176, ..., 0.9098, 0.9020, 0.8941],
[0.1176, 0.1176, 0.1216, ..., 0.9176, 0.9137, 0.9020],
[0.1137, 0.1216, 0.1216, ..., 0.9255, 0.9176, 0.9059],
...,
[0.7843, 0.7686, 0.8353, ..., 0.7529, 0.8039, 0.8314],
[0.6863, 0.6706, 0.7529, ..., 0.7059, 0.7373, 0.7961],
[0.6784, 0.7373, 0.8431, ..., 0.7137, 0.5608, 0.6196]],
[[0.1333, 0.1294, 0.1373, ..., 0.8824, 0.8706, 0.8588],
[0.1373, 0.1333, 0.1373, ..., 0.8941, 0.8863, 0.8706],
[0.1333, 0.1333, 0.1294, ..., 0.9059, 0.8902, 0.8824],
...,
[0.7529, 0.7490, 0.8196, ..., 0.6980, 0.7529, 0.7765],
[0.6314, 0.6157, 0.7098, ..., 0.6627, 0.7059, 0.7686],
[0.6118, 0.6745, 0.7961, ..., 0.7020, 0.5412, 0.5922]],
[[0.1176, 0.1294, 0.1137, ..., 0.8196, 0.8039, 0.7961],
[0.1216, 0.1216, 0.1098, ..., 0.8235, 0.8118, 0.8000],
[0.1255, 0.1255, 0.1176, ..., 0.8353, 0.8314, 0.8235],
...,
[0.6863, 0.6784, 0.7569, ..., 0.6275, 0.6863, 0.7333],
[0.5412, 0.5216, 0.6235, ..., 0.5843, 0.6392, 0.7255],
[0.5255, 0.5961, 0.7216, ..., 0.6431, 0.4471, 0.5176]]]) tensor(7)
tensor([[[0.1804, 0.1804, 0.1961, ..., 0.8549, 0.3882, 0.1412],
[0.1765, 0.1843, 0.2000, ..., 0.8863, 0.4431, 0.1608],
[0.1804, 0.1843, 0.2039, ..., 0.9059, 0.4863, 0.1843],
...,
[0.6863, 0.6353, 0.7333, ..., 0.3412, 0.5451, 0.7255],
[0.7804, 0.8000, 0.6980, ..., 0.4784, 0.4000, 0.5882],
[0.6784, 0.8275, 0.7373, ..., 0.6706, 0.4275, 0.6314]],
[[0.2118, 0.2235, 0.2392, ..., 0.6745, 0.3216, 0.1804],
[0.2157, 0.2275, 0.2431, ..., 0.7020, 0.3569, 0.1804],
[0.2196, 0.2314, 0.2471, ..., 0.7216, 0.3843, 0.1882],
...,
[0.6510, 0.6667, 0.7882, ..., 0.3725, 0.5059, 0.7412],
[0.8000, 0.7882, 0.6667, ..., 0.4784, 0.3765, 0.6000],
[0.6667, 0.7373, 0.6549, ..., 0.6549, 0.4314, 0.6510]],
[[0.1765, 0.1961, 0.2039, ..., 0.5882, 0.2941, 0.2078],
[0.1804, 0.2000, 0.2078, ..., 0.6157, 0.3216, 0.2000],
[0.1882, 0.2039, 0.2118, ..., 0.6314, 0.3412, 0.2000],
...,
[0.4745, 0.3843, 0.5412, ..., 0.1725, 0.3020, 0.5686],
[0.5882, 0.5882, 0.4863, ..., 0.2706, 0.1765, 0.4078],
[0.4510, 0.5961, 0.4902, ..., 0.4039, 0.1686, 0.3725]]]) tensor(34)
tensor([[[0.3686, 0.3725, 0.3725, ..., 0.8706, 0.8863, 0.8980],
[0.3725, 0.3765, 0.3804, ..., 0.8784, 0.8784, 0.8824],
[0.3804, 0.3804, 0.3843, ..., 0.8824, 0.8784, 0.8706],
...,
[0.4353, 0.4392, 0.4431, ..., 0.4863, 0.4902, 0.4863],
[0.4353, 0.4392, 0.4431, ..., 0.4863, 0.4902, 0.4863],
[0.4353, 0.4392, 0.4431, ..., 0.4863, 0.4902, 0.4863]],
[[0.3333, 0.3373, 0.3373, ..., 0.9882, 0.9882, 0.9922],
[0.3333, 0.3373, 0.3412, ..., 0.9922, 0.9804, 0.9804],
[0.3373, 0.3373, 0.3412, ..., 0.9882, 0.9843, 0.9843],
...,
[0.3765, 0.3804, 0.3765, ..., 0.4314, 0.4275, 0.4235],
[0.3765, 0.3804, 0.3765, ..., 0.4314, 0.4275, 0.4235],
[0.3765, 0.3804, 0.3765, ..., 0.4314, 0.4275, 0.4235]],
[[0.2667, 0.2706, 0.2706, ..., 0.8196, 0.8431, 0.8627],
[0.2627, 0.2667, 0.2706, ..., 0.8078, 0.8157, 0.8157],
[0.2667, 0.2667, 0.2706, ..., 0.8000, 0.7882, 0.7725],
...,
[0.2941, 0.2980, 0.3059, ..., 0.3294, 0.3255, 0.3216],
[0.2941, 0.2980, 0.3059, ..., 0.3294, 0.3255, 0.3216],
[0.2941, 0.2980, 0.3059, ..., 0.3255, 0.3255, 0.3216]]]) tensor(0)
tensor([[[0.4118, 0.4000, 0.5961, ..., 0.8745, 0.8667, 0.8706],
[0.3804, 0.3333, 0.5216, ..., 0.8824, 0.8784, 0.8784],
[0.5333, 0.4980, 0.5333, ..., 0.8863, 0.8863, 0.8863],
...,
[0.0627, 0.0902, 0.0941, ..., 0.1529, 0.1686, 0.1098],
[0.0745, 0.0784, 0.1020, ..., 0.1765, 0.1059, 0.0980],
[0.0902, 0.0588, 0.0902, ..., 0.1333, 0.1059, 0.1098]],
[[0.4275, 0.4314, 0.6431, ..., 0.9765, 0.9804, 0.9804],
[0.4157, 0.3804, 0.5765, ..., 0.9804, 0.9765, 0.9765],
[0.5882, 0.5490, 0.5922, ..., 0.9765, 0.9765, 0.9765],
...,
[0.1137, 0.1412, 0.1451, ..., 0.2588, 0.2667, 0.2039],
[0.1255, 0.1255, 0.1529, ..., 0.2667, 0.2000, 0.1843],
[0.1412, 0.1059, 0.1451, ..., 0.2118, 0.1961, 0.1804]],
[[0.2902, 0.3373, 0.5961, ..., 0.9922, 0.9961, 0.9961],
[0.2745, 0.2314, 0.4706, ..., 1.0000, 0.9961, 0.9961],
[0.4392, 0.3922, 0.5137, ..., 1.0000, 1.0000, 1.0000],
...,
[0.0902, 0.1137, 0.1176, ..., 0.1255, 0.1490, 0.0902],
[0.1137, 0.1020, 0.1294, ..., 0.1529, 0.0902, 0.0784],
[0.1412, 0.0863, 0.1216, ..., 0.1020, 0.0863, 0.0784]]]) tensor(35)
tensor([[[0.2784, 0.2863, 0.2745, ..., 0.2235, 0.2078, 0.2000],
[0.2824, 0.2824, 0.2745, ..., 0.2118, 0.1882, 0.1882],
[0.2824, 0.2863, 0.2784, ..., 0.1922, 0.1804, 0.1843],
...,
[0.1608, 0.1647, 0.1569, ..., 0.6902, 0.7020, 0.6941],
[0.1608, 0.1608, 0.1608, ..., 0.6824, 0.6863, 0.6745],
[0.1569, 0.1608, 0.1608, ..., 0.6667, 0.6745, 0.6667]],
[[0.1804, 0.1804, 0.1804, ..., 0.1294, 0.1176, 0.1176],
[0.1804, 0.1765, 0.1765, ..., 0.1255, 0.1098, 0.1098],
[0.1804, 0.1765, 0.1765, ..., 0.1098, 0.1020, 0.1059],
...,
[0.0941, 0.0980, 0.0941, ..., 0.6980, 0.7137, 0.7059],
[0.0980, 0.0980, 0.0980, ..., 0.6980, 0.7020, 0.6863],
[0.0941, 0.0980, 0.0980, ..., 0.6824, 0.6902, 0.6824]],
[[0.0235, 0.0235, 0.0157, ..., 0.1020, 0.0941, 0.0902],
[0.0235, 0.0196, 0.0196, ..., 0.0980, 0.0863, 0.0863],
[0.0235, 0.0235, 0.0196, ..., 0.0863, 0.0784, 0.0824],
...,
[0.0392, 0.0392, 0.0392, ..., 0.7412, 0.7569, 0.7490],
[0.0471, 0.0471, 0.0431, ..., 0.7412, 0.7490, 0.7294],
[0.0431, 0.0471, 0.0510, ..., 0.7255, 0.7333, 0.7255]]]) tensor(1)
tensor([[[0.5765, 0.5843, 0.5804, ..., 0.5725, 0.6039, 0.6157],
[0.5725, 0.5922, 0.5922, ..., 0.5843, 0.6039, 0.6235],
[0.5725, 0.5922, 0.5961, ..., 0.5961, 0.6196, 0.6235],
...,
[0.2039, 0.1647, 0.1608, ..., 0.7451, 0.7490, 0.7490],
[0.2118, 0.2000, 0.1804, ..., 0.7294, 0.7333, 0.7255],
[0.1882, 0.2157, 0.2039, ..., 0.7098, 0.7059, 0.7059]],
[[0.4549, 0.4549, 0.4471, ..., 0.4588, 0.4745, 0.4902],
[0.4510, 0.4588, 0.4549, ..., 0.4706, 0.4745, 0.4980],
[0.4392, 0.4549, 0.4588, ..., 0.4824, 0.4863, 0.4941],
...,
[0.2275, 0.0784, 0.0627, ..., 0.0980, 0.1020, 0.1020],
[0.3020, 0.1647, 0.0667, ..., 0.1020, 0.1020, 0.0941],
[0.3373, 0.2275, 0.0667, ..., 0.1020, 0.0941, 0.0980]],
[[0.3020, 0.2980, 0.2745, ..., 0.3373, 0.3412, 0.3451],
[0.3020, 0.3098, 0.3020, ..., 0.3333, 0.3333, 0.3529],
[0.3020, 0.3098, 0.3059, ..., 0.3529, 0.3569, 0.3686],
...,
[0.2275, 0.0902, 0.0588, ..., 0.1647, 0.1647, 0.1647],
[0.2980, 0.1647, 0.0627, ..., 0.1647, 0.1647, 0.1569],
[0.3216, 0.2196, 0.0706, ..., 0.1608, 0.1569, 0.1569]]]) tensor(35)
fig = plt.figure(figsize=(10, 6))
for i, example in enumerate(image_dataset):
ax = fig.add_subplot(2, 3, i+1)
ax.set_xticks([]); ax.set_yticks([])
#imshow expect a numpy as argument
#The original shape (C, H, W) is transformed to (H, W, C)
# imshow expectes the channel dimension to be the last axis.
ax.imshow(example[0].numpy().transpose((1, 2,0)))
ax.set_title(f'{example[1]}', size=15)
plt.tight_layout()
plt.show()
