Files
reference/docs/pytorch.html
2025-05-11 17:34:02 +00:00

309 lines
66 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!doctype html>
<html lang="en" data-color-mode="dark">
<head>
<meta charset="utf-8">
<title>Pytorch 备忘清单
&#x26; pytorch cheatsheet &#x26; Quick Reference</title>
<meta content="width=device-width, initial-scale=1" name="viewport">
<meta description="Pytorch 是一种开源机器学习框架,可加速从研究原型设计到生产部署的过程,备忘单是 官网
备忘清单为您提供了 Pytorch 基本语法和初步应用参考
入门,为开发人员分享快速参考备忘单。">
<meta keywords="pytorch,reference,Quick,Reference,cheatsheet,cheat,sheet">
<meta name="author" content="jaywcjlove">
<meta name="license" content="MIT">
<meta name="funding" content="https://jaywcjlove.github.io/#/sponsor">
<meta rel="apple-touch-icon" href="../icons/touch-icon-iphone.png">
<meta rel="apple-touch-icon" sizes="152x152" href="../icons/touch-icon-ipad.png">
<meta rel="apple-touch-icon" sizes="180x180" href="../icons/touch-icon-iphone.png">
<meta rel="apple-touch-icon" sizes="167x167" href="../icons/touch-icon-ipad-retina.png">
<meta rel="apple-touch-icon" sizes="120x120" href="../icons/touch-icon-iphone-retina.png">
<link rel="icon" href="../icons/favicon.svg" type="image/svg+xml">
<link href="../style/style.css" rel="stylesheet">
<link href="../style/katex.css" rel="stylesheet">
</head>
<body><nav class="header-nav"><div class="max-container"><a href="../index.html" class="logo"><svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" height="1em" width="1em">
<path d="m21.66 10.44-.98 4.18c-.84 3.61-2.5 5.07-5.62 4.77-.5-.04-1.04-.13-1.62-.27l-1.68-.4c-4.17-.99-5.46-3.05-4.48-7.23l.98-4.19c.2-.85.44-1.59.74-2.2 1.17-2.42 3.16-3.07 6.5-2.28l1.67.39c4.19.98 5.47 3.05 4.49 7.23Z" fill="#c9d1d9"></path>
<path d="M15.06 19.39c-.62.42-1.4.77-2.35 1.08l-1.58.52c-3.97 1.28-6.06.21-7.35-3.76L2.5 13.28c-1.28-3.97-.22-6.07 3.75-7.35l1.58-.52c.41-.13.8-.24 1.17-.31-.3.61-.54 1.35-.74 2.2l-.98 4.19c-.98 4.18.31 6.24 4.48 7.23l1.68.4c.58.14 1.12.23 1.62.27Zm2.43-8.88c-.06 0-.12-.01-.19-.02l-4.85-1.23a.75.75 0 0 1 .37-1.45l4.85 1.23a.748.748 0 0 1-.18 1.47Z" fill="#228e6c"></path>
<path d="M14.56 13.89c-.06 0-.12-.01-.19-.02l-2.91-.74a.75.75 0 0 1 .37-1.45l2.91.74c.4.1.64.51.54.91-.08.34-.38.56-.72.56Z" fill="#228e6c"></path>
</svg>
<span class="title">Quick Reference</span></a><div class="menu"><a href="javascript:void(0);" class="searchbtn" id="searchbtn"><svg xmlns="http://www.w3.org/2000/svg" height="1em" width="1em" viewBox="0 0 18 18">
<path fill="currentColor" d="M17.71,16.29 L14.31,12.9 C15.4069846,11.5024547 16.0022094,9.77665502 16,8 C16,3.581722 12.418278,0 8,0 C3.581722,0 0,3.581722 0,8 C0,12.418278 3.581722,16 8,16 C9.77665502,16.0022094 11.5024547,15.4069846 12.9,14.31 L16.29,17.71 C16.4777666,17.8993127 16.7333625,18.0057983 17,18.0057983 C17.2666375,18.0057983 17.5222334,17.8993127 17.71,17.71 C17.8993127,17.5222334 18.0057983,17.2666375 18.0057983,17 C18.0057983,16.7333625 17.8993127,16.4777666 17.71,16.29 Z M2,8 C2,4.6862915 4.6862915,2 8,2 C11.3137085,2 14,4.6862915 14,8 C14,11.3137085 11.3137085,14 8,14 C4.6862915,14 2,11.3137085 2,8 Z"></path>
</svg><span>搜索</span><span>⌘K</span></a><a href="https://github.com/jaywcjlove/reference/blob/main/docs/pytorch.md" class="edit" target="__blank"><svg viewBox="0 0 36 36" fill="currentColor" height="1em" width="1em"><path d="m33 6.4-3.7-3.7a1.71 1.71 0 0 0-2.36 0L23.65 6H6a2 2 0 0 0-2 2v22a2 2 0 0 0 2 2h22a2 2 0 0 0 2-2V11.76l3-3a1.67 1.67 0 0 0 0-2.36ZM18.83 20.13l-4.19.93 1-4.15 9.55-9.57 3.23 3.23ZM29.5 9.43 26.27 6.2l1.85-1.85 3.23 3.23Z"></path><path fill="none" d="M0 0h36v36H0z"></path></svg><span>编辑</span></a><button id="darkMode" type="button"><svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="light" height="1em" width="1em">
<path d="M6.995 12c0 2.761 2.246 5.007 5.007 5.007s5.007-2.246 5.007-5.007-2.246-5.007-5.007-5.007S6.995 9.239 6.995 12zM11 19h2v3h-2zm0-17h2v3h-2zm-9 9h3v2H2zm17 0h3v2h-3zM5.637 19.778l-1.414-1.414 2.121-2.121 1.414 1.414zM16.242 6.344l2.122-2.122 1.414 1.414-2.122 2.122zM6.344 7.759 4.223 5.637l1.415-1.414 2.12 2.122zm13.434 10.605-1.414 1.414-2.122-2.122 1.414-1.414z"></path>
</svg>
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" class="dark" height="1em" width="1em">
<path d="M12 11.807A9.002 9.002 0 0 1 10.049 2a9.942 9.942 0 0 0-5.12 2.735c-3.905 3.905-3.905 10.237 0 14.142 3.906 3.906 10.237 3.905 14.143 0a9.946 9.946 0 0 0 2.735-5.119A9.003 9.003 0 0 1 12 11.807z"></path>
</svg>
</button><script src="../js/dark.js?v=1.8.3"></script><a href="https://github.com/jaywcjlove/reference" class="" target="__blank"><svg viewBox="0 0 16 16" fill="currentColor" height="1em" width="1em"><path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.012 8.012 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path></svg></a></div></div></nav><div class="wrap h1body-exist max-container"><header class="wrap-header h1wrap"><h1 id="pytorch--备忘清单"><svg viewBox="0 0 24 24" fill="currentColor" xmlns="http://www.w3.org/2000/svg" height="1em" width="1em">
<path d="M12.005 0 4.952 7.053a9.865 9.865 0 0 0 0 14.022 9.866 9.866 0 0 0 14.022 0c3.984-3.9 3.986-10.205.085-14.023l-1.744 1.743c2.904 2.905 2.904 7.634 0 10.538s-7.634 2.904-10.538 0-2.904-7.634 0-10.538l4.647-4.646.582-.665zm3.568 3.899a1.327 1.327 0 0 0-1.327 1.327 1.327 1.327 0 0 0 1.327 1.328A1.327 1.327 0 0 0 16.9 5.226 1.327 1.327 0 0 0 15.573 3.9z"></path>
</svg>
<a aria-hidden="true" tabindex="-1" href="#pytorch--备忘清单"><span class="icon icon-link"></span></a>Pytorch 备忘清单</h1><div class="wrap-body">
<p>Pytorch 是一种开源机器学习框架,可加速从研究原型设计到生产部署的过程,备忘单是 官网
备忘清单为您提供了 <a href="https://pytorch.org/">Pytorch</a> 基本语法和初步应用参考</p>
</div></header><div class="menu-tocs"><div class="menu-btn"><svg aria-hidden="true" fill="currentColor" height="1em" width="1em" viewBox="0 0 16 16" version="1.1" data-view-component="true">
<path fill-rule="evenodd" d="M2 4a1 1 0 100-2 1 1 0 000 2zm3.75-1.5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zm0 5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zm0 5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zM3 8a1 1 0 11-2 0 1 1 0 012 0zm-1 6a1 1 0 100-2 1 1 0 000 2z"></path>
</svg></div><div class="menu-modal"><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#入门">入门</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#介绍">介绍</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#认识-pytorch">认识 Pytorch</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#创建一个全零矩阵">创建一个全零矩阵</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#数据创建张量">数据创建张量</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#pytorch-的基本语法">Pytorch 的基本语法</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作1">加法操作(1)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作2">加法操作(2)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作3">加法操作(3)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作4">加法操作(4)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量操作">张量操作</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量形状">张量形状</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#取张量元素">取张量元素</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-和-numpy-array互换">Torch Tensor 和 Numpy array互换</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-转换为-numpy-array">Torch Tensor 转换为 Numpy array</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#numpy-array转换为torch-tensor">Numpy array转换为Torch Tensor</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#squeeze函数">squeeze函数</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#unsqueeze函数">unsqueeze函数</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#cuda-相关">Cuda 相关</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#检查-cuda-是否可用">检查 Cuda 是否可用</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#列出-gpu-设备">列出 GPU 设备</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#将模型张量等数据在-gpu-和内存之间进行搬运">将模型、张量等数据在 GPU 和内存之间进行搬运</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#导入-imports">导入 Imports</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#一般">一般</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#神经网络-api">神经网络 API</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#onnx">ONNX</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torchscript-和-jit">Torchscript 和 JIT</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#vision">Vision</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#分布式训练">分布式训练</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#另见">另见</a></div></div><div class="h1wrap-body"><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="入门"><a aria-hidden="true" tabindex="-1" href="#入门"><span class="icon icon-link"></span></a>入门</h2><div class="wrap-body">
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="介绍"><a aria-hidden="true" tabindex="-1" href="#介绍"><span class="icon icon-link"></span></a>介绍</h3><div class="wrap-body">
<ul>
<li><a href="https://pytorch.org/">Pytorch 官网</a> <em>(pytorch.org)</em></li>
<li><a href="https://pytorch.org/tutorials/beginner/ptcheat.html">Pytorch 官方备忘清单</a> <em>(pytorch.org)</em></li>
</ul>
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="认识-pytorch"><a aria-hidden="true" tabindex="-1" href="#认识-pytorch"><span class="icon icon-link"></span></a>认识 Pytorch</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> __future__ <span class="token keyword">import</span> print_function
</span><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>empty<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">2.4835e+27</span><span class="token punctuation">,</span> <span class="token number">2.5428e+30</span><span class="token punctuation">,</span> <span class="token number">1.0877e-19</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">1.5163e+23</span><span class="token punctuation">,</span> <span class="token number">2.2012e+12</span><span class="token punctuation">,</span> <span class="token number">3.7899e+22</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">5.2480e+05</span><span class="token punctuation">,</span> <span class="token number">1.0175e+31</span><span class="token punctuation">,</span> <span class="token number">9.7056e+24</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">1.6283e+32</span><span class="token punctuation">,</span> <span class="token number">3.7913e+22</span><span class="token punctuation">,</span> <span class="token number">3.9653e+28</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">1.0876e-19</span><span class="token punctuation">,</span> <span class="token number">6.2027e+26</span><span class="token punctuation">,</span> <span class="token number">2.3685e+21</span><span class="token punctuation">]</span>
</span><span class="code-line"><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
<p>Tensors 张量: 张量的概念类似于Numpy中的ndarray数据结构, 最大的区别在于Tensor可以利用GPU的加速功能.</p>
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="创建一个全零矩阵"><a aria-hidden="true" tabindex="-1" href="#创建一个全零矩阵"><span class="icon icon-link"></span></a>创建一个全零矩阵</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<p>创建一个全零矩阵并可指定数据元素的类型为long</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="数据创建张量"><a aria-hidden="true" tabindex="-1" href="#数据创建张量"><span class="icon icon-link"></span></a>数据创建张量</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2.5</span><span class="token punctuation">,</span> <span class="token number">3.5</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2.5000</span><span class="token punctuation">,</span> <span class="token number">3.3000</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="pytorch-的基本语法"><a aria-hidden="true" tabindex="-1" href="#pytorch-的基本语法"><span class="icon icon-link"></span></a>Pytorch 的基本语法</h2><div class="wrap-body">
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作1"><a aria-hidden="true" tabindex="-1" href="#加法操作1"><span class="icon icon-link"></span></a>加法操作(1)</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">y <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x <span class="token operator">+</span> y<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作2"><a aria-hidden="true" tabindex="-1" href="#加法操作2"><span class="icon icon-link"></span></a>加法操作(2)</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>add<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">)</span><span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作3"><a aria-hidden="true" tabindex="-1" href="#加法操作3"><span class="icon icon-link"></span></a>加法操作(3)</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 提前设定一个空的张量</span>
</span><span class="code-line">result <span class="token operator">=</span> torch<span class="token punctuation">.</span>empty<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 将空的张量作为加法的结果存储张量</span>
</span><span class="code-line"> torch<span class="token punctuation">.</span>add<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> out<span class="token operator">=</span>result<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>result<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="加法操作4"><a aria-hidden="true" tabindex="-1" href="#加法操作4"><span class="icon icon-link"></span></a>加法操作(4)</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">y<span class="token punctuation">.</span>add_<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<p>注意: 所有 <code>in-place</code> 的操作函数都有一个下划线的后缀。
比如 <code>x.copy_(y)</code>, <code>x.add_(y)</code>, 都会直接改变x的值</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="张量操作"><a aria-hidden="true" tabindex="-1" href="#张量操作"><span class="icon icon-link"></span></a>张量操作</h3><div class="wrap-body">
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">2.0902</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.4489</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.1441</span><span class="token punctuation">,</span> <span class="token number">0.8035</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.8341</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="张量形状"><a aria-hidden="true" tabindex="-1" href="#张量形状"><span class="icon icon-link"></span></a>张量形状</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># tensor.view()操作需要保证数据元素的总数量不变</span>
</span><span class="code-line">y <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token number">16</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># -1代表自动匹配个数</span>
</span><span class="code-line">z <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> y<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> z<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">16</span><span class="token punctuation">]</span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="取张量元素"><a aria-hidden="true" tabindex="-1" href="#取张量元素"><span class="icon icon-link"></span></a>取张量元素</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.3531</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">-</span><span class="token number">0.3530771732330322</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torch-tensor-和-numpy-array互换"><a aria-hidden="true" tabindex="-1" href="#torch-tensor-和-numpy-array互换"><span class="icon icon-link"></span></a>Torch Tensor 和 Numpy array互换</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">a <span class="token operator">=</span> torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<p>Torch Tensor和Numpy array共享底层的内存空间, 因此改变其中一个的值, 另一个也会随之被改变</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torch-tensor-转换为-numpy-array"><a aria-hidden="true" tabindex="-1" href="#torch-tensor-转换为-numpy-array"><span class="icon icon-link"></span></a>Torch Tensor 转换为 Numpy array</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">b <span class="token operator">=</span> a<span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>b<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token punctuation">[</span><span class="token number">1.</span> <span class="token number">1.</span> <span class="token number">1.</span> <span class="token number">1.</span> <span class="token number">1.</span><span class="token punctuation">]</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="numpy-array转换为torch-tensor"><a aria-hidden="true" tabindex="-1" href="#numpy-array转换为torch-tensor"><span class="icon icon-link"></span></a>Numpy array转换为Torch Tensor</h3><div class="wrap-body">
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> numpy <span class="token keyword">as</span> np
</span><span class="code-line">a <span class="token operator">=</span> np<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">)</span>
</span><span class="code-line">b <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>a<span class="token punctuation">)</span>
</span><span class="code-line">np<span class="token punctuation">.</span>add<span class="token punctuation">(</span>a<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> out<span class="token operator">=</span>a<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>b<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token punctuation">[</span><span class="token number">2.</span> <span class="token number">2.</span> <span class="token number">2.</span> <span class="token number">2.</span> <span class="token number">2.</span><span class="token punctuation">]</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span>float64<span class="token punctuation">)</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
<p>注意: 所有在CPU上的Tensors, 除了CharTensor, 都可以转换为Numpy array并可以反向转换.</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="squeeze函数"><a aria-hidden="true" tabindex="-1" href="#squeeze函数"><span class="icon icon-link"></span></a>squeeze函数</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># squeeze不加参数默认去除所有为1的维度</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># squeeze加参数去除指定为1的维度</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># squeeze加参数如果不为1则不变</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># 既可以是函数,也可以是方法</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="unsqueeze函数"><a aria-hidden="true" tabindex="-1" href="#unsqueeze函数"><span class="icon icon-link"></span></a>unsqueeze函数</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># unsqueeze必须加参数 _ 2 _ 28 _</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line"><span class="token comment"># 参数代表在哪里添加维度 0 1 2</span>
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 既可以是函数,也可以是方法</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="cuda-相关"><a aria-hidden="true" tabindex="-1" href="#cuda-相关"><span class="icon icon-link"></span></a>Cuda 相关</h2><div class="wrap-body">
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="检查-cuda-是否可用"><a aria-hidden="true" tabindex="-1" href="#检查-cuda-是否可用"><span class="icon icon-link"></span></a>检查 Cuda 是否可用</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">import</span> torch<span class="token punctuation">.</span>cuda
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token boolean">True</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist col-span-2 row-span-2"><div class="wrap-header h3wrap"><h3 id="列出-gpu-设备"><a aria-hidden="true" tabindex="-1" href="#列出-gpu-设备"><span class="icon icon-link"></span></a>列出 GPU 设备</h3><div class="wrap-body">
<!--rehype:wrap-class=col-span-2 row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line">
</span><span class="code-line">device_count <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>device_count<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"CUDA 设备"</span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>device_count<span class="token punctuation">)</span><span class="token punctuation">:</span>
</span><span class="code-line"> device_name <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>get_device_name<span class="token punctuation">(</span>i<span class="token punctuation">)</span>
</span><span class="code-line"> total_memory <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>get_device_properties<span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">.</span>total_memory <span class="token operator">/</span> <span class="token punctuation">(</span><span class="token number">1024</span> <span class="token operator">**</span> <span class="token number">3</span><span class="token punctuation">)</span>
</span><span class="code-line"> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"├── 设备 </span><span class="token interpolation"><span class="token punctuation">{</span>i<span class="token punctuation">}</span></span><span class="token string">: </span><span class="token interpolation"><span class="token punctuation">{</span>device_name<span class="token punctuation">}</span></span><span class="token string">, 容量: </span><span class="token interpolation"><span class="token punctuation">{</span>total_memory<span class="token punctuation">:</span><span class="token format-spec">.2f</span><span class="token punctuation">}</span></span><span class="token string"> GiB"</span></span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"└── (结束)"</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="将模型张量等数据在-gpu-和内存之间进行搬运"><a aria-hidden="true" tabindex="-1" href="#将模型张量等数据在-gpu-和内存之间进行搬运"><span class="icon icon-link"></span></a>将模型、张量等数据在 GPU 和内存之间进行搬运</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line"><span class="token comment"># 将 0 替换为您的 GPU 设备索引或者直接使用 "cuda"</span>
</span><span class="code-line">device <span class="token operator">=</span> <span class="token string-interpolation"><span class="token string">f"cuda:0"</span></span>
</span><span class="code-line"><span class="token comment"># 移动到GPU</span>
</span><span class="code-line">tensor_m <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line">tensor_g <span class="token operator">=</span> tensor_m<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
</span><span class="code-line">model_m <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
</span><span class="code-line">model_g <span class="token operator">=</span> model_m<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 向后移动</span>
</span><span class="code-line">tensor_m <span class="token operator">=</span> tensor_g<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line">model_m <span class="token operator">=</span> model_g<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="导入-imports"><a aria-hidden="true" tabindex="-1" href="#导入-imports"><span class="icon icon-link"></span></a>导入 Imports</h2><div class="wrap-body">
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="一般"><a aria-hidden="true" tabindex="-1" href="#一般"><span class="icon icon-link"></span></a>一般</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 根包</span>
</span><span class="code-line"><span class="token keyword">import</span> torch
</span></code></pre>
<p>数据集表示和加载</p>
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> Dataset<span class="token punctuation">,</span> DataLoader
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="神经网络-api"><a aria-hidden="true" tabindex="-1" href="#神经网络-api"><span class="icon icon-link"></span></a>神经网络 API</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 计算图</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>autograd <span class="token keyword">as</span> autograd
</span><span class="code-line"><span class="token comment"># 计算图中的张量节点</span>
</span><span class="code-line"><span class="token keyword">from</span> torch <span class="token keyword">import</span> Tensor
</span></code></pre>
<p>神经网络</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># 层、激活等</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>functional <span class="token keyword">as</span> F
</span><span class="code-line"><span class="token comment"># 优化器,例如 梯度下降、ADAM等</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>optim <span class="token keyword">as</span> optim
</span></code></pre>
<p>混合前端装饰器和跟踪 jit</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>jit <span class="token keyword">import</span> script<span class="token punctuation">,</span> trace
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="onnx"><a aria-hidden="true" tabindex="-1" href="#onnx"><span class="icon icon-link"></span></a>ONNX</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>onnx<span class="token punctuation">.</span>export<span class="token punctuation">(</span>model<span class="token punctuation">,</span> dummy data<span class="token punctuation">,</span> xxxx<span class="token punctuation">.</span>proto<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 导出 ONNX 格式</span>
</span><span class="code-line"><span class="token comment"># 使用经过训练的模型模型dummy</span>
</span><span class="code-line"><span class="token comment"># 数据和所需的文件名</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
<p>加载 ONNX 模型</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">model <span class="token operator">=</span> onnx<span class="token punctuation">.</span>load<span class="token punctuation">(</span><span class="token string">"alexnet.proto"</span><span class="token punctuation">)</span>
</span></code></pre>
<p>检查模型IT 是否结构良好</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">onnx<span class="token punctuation">.</span>checker<span class="token punctuation">.</span>check_model<span class="token punctuation">(</span>model<span class="token punctuation">)</span>
</span></code></pre>
<p>打印一个人类可读的,图的表示</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">onnx<span class="token punctuation">.</span>helper<span class="token punctuation">.</span>printable_graph<span class="token punctuation">(</span>model<span class="token punctuation">.</span>graph<span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torchscript-和-jit"><a aria-hidden="true" tabindex="-1" href="#torchscript-和-jit"><span class="icon icon-link"></span></a>Torchscript 和 JIT</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>jit<span class="token punctuation">.</span>trace<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span></code></pre>
<p>使用你的模块或函数和一个例子,数据输入,并追溯计算步骤,数据在模型中前进时遇到的情况</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token decorator annotation punctuation">@script</span>
</span></code></pre>
<p>装饰器用于指示被跟踪代码中的数据相关控制流</p>
</div></div></div><div class="wrap h3body-not-exist col-span-2"><div class="wrap-header h3wrap"><h3 id="vision"><a aria-hidden="true" tabindex="-1" href="#vision"><span class="icon icon-link"></span></a>Vision</h3><div class="wrap-body">
<!--rehype:wrap-class=col-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 视觉数据集,架构 &#x26; 变换</span>
</span><span class="code-line"><span class="token keyword">from</span> torchvision <span class="token keyword">import</span> datasets<span class="token punctuation">,</span> models<span class="token punctuation">,</span> transforms
</span><span class="code-line"><span class="token comment"># 组合转换</span>
</span><span class="code-line"><span class="token keyword">import</span> torchvision<span class="token punctuation">.</span>transforms <span class="token keyword">as</span> transforms
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="分布式训练"><a aria-hidden="true" tabindex="-1" href="#分布式训练"><span class="icon icon-link"></span></a>分布式训练</h3><div class="wrap-body">
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 分布式通信</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>distributed <span class="token keyword">as</span> dist
</span><span class="code-line"><span class="token comment"># 内存共享进程</span>
</span><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>multiprocessing <span class="token keyword">import</span> Process
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div></div></div><div class="wrap h2body-not-exist"><div class="wrap-header h2wrap"><h2 id="另见"><a aria-hidden="true" tabindex="-1" href="#另见"><span class="icon icon-link"></span></a>另见</h2><div class="wrap-body">
<ul>
<li><a href="https://pytorch.org/">Pytorch 官网</a> <em>(pytorch.org)</em></li>
<li><a href="https://pytorch.org/tutorials/beginner/ptcheat.html">Pytorch 官方备忘清单</a> <em>(pytorch.org)</em></li>
</ul>
</div></div><div class="h2wrap-body"></div></div></div><script src="https://giscus.app/client.js" data-repo="jaywcjlove/reference" data-repo-id="R_kgDOID2-Mw" data-category="Q&#x26;A" data-category-id="DIC_kwDOID2-M84CS5wo" data-mapping="pathname" data-strict="0" data-reactions-enabled="1" data-emit-metadata="0" data-input-position="bottom" data-theme="dark" data-lang="zh-CN" crossorigin="anonymous" async></script><div class="giscus"></div></div><footer class="footer-wrap"><footer class="max-container">© 2022 <a href="https://wangchujiang.com/#/app" target="_blank">Kenny Wang</a>.</footer></footer><script src="../data.js?v=1.8.3" defer></script><script src="../js/fuse.min.js?v=1.8.3" defer></script><script src="../js/main.js?v=1.8.3" defer></script><div id="mysearch"><div class="mysearch-box"><div class="mysearch-input"><div><svg xmlns="http://www.w3.org/2000/svg" height="1em" width="1em" viewBox="0 0 18 18">
<path fill="currentColor" d="M17.71,16.29 L14.31,12.9 C15.4069846,11.5024547 16.0022094,9.77665502 16,8 C16,3.581722 12.418278,0 8,0 C3.581722,0 0,3.581722 0,8 C0,12.418278 3.581722,16 8,16 C9.77665502,16.0022094 11.5024547,15.4069846 12.9,14.31 L16.29,17.71 C16.4777666,17.8993127 16.7333625,18.0057983 17,18.0057983 C17.2666375,18.0057983 17.5222334,17.8993127 17.71,17.71 C17.8993127,17.5222334 18.0057983,17.2666375 18.0057983,17 C18.0057983,16.7333625 17.8993127,16.4777666 17.71,16.29 Z M2,8 C2,4.6862915 4.6862915,2 8,2 C11.3137085,2 14,4.6862915 14,8 C14,11.3137085 11.3137085,14 8,14 C4.6862915,14 2,11.3137085 2,8 Z"></path>
</svg><input id="mysearch-input" type="search" placeholder="搜索" autocomplete="off"><div class="mysearch-clear"></div></div><button id="mysearch-close" type="button">搜索</button></div><div class="mysearch-result"><div id="mysearch-menu"></div><div id="mysearch-content"></div></div></div></div></body>
</html>