doc: update docs/pytorch.md #649 a1a1bd60f8

This commit is contained in:
jaywcjlove
2024-05-13 09:08:21 +00:00
parent 9910c6d135
commit 4cd8d87f5d
4 changed files with 77 additions and 49 deletions

View File

@ -35,13 +35,14 @@
备忘清单为您提供了 <a href="https://pytorch.org/">Pytorch</a> 基本语法和初步应用参考</p>
</div></header><div class="menu-tocs"><div class="menu-btn"><svg aria-hidden="true" fill="currentColor" height="1em" width="1em" viewBox="0 0 16 16" version="1.1" data-view-component="true">
<path fill-rule="evenodd" d="M2 4a1 1 0 100-2 1 1 0 000 2zm3.75-1.5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zm0 5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zm0 5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zM3 8a1 1 0 11-2 0 1 1 0 012 0zm-1 6a1 1 0 100-2 1 1 0 000 2z"></path>
</svg></div><div class="menu-modal"><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#入门">入门</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#介绍">介绍</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#认识-pytorch">认识 Pytorch</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#创建一个全零矩阵">创建一个全零矩阵</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#数据创建张量">数据创建张量</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#pytorch-的基本语法">Pytorch 的基本语法</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作1">加法操作(1)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作2">加法操作(2)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作3">加法操作(3)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作4">加法操作(4)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量操作">张量操作</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量形状">张量形状</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#取张量元素">取张量元素</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-和-numpy-array互换">Torch Tensor 和 Numpy array互换</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-转换为-numpy-array">Torch Tensor 转换为 Numpy array</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#numpy-array转换为torch-tensor">Numpy array转换为Torch Tensor</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#squeeze函数">squeeze函数</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#unsqueeze函数">unsqueeze函数</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#cuda-相关">Cuda 相关</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#检查-cuda-是否可用">检查 Cuda 是否可用</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#列出-gpu-设备">列出 GPU 设备</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#将模型张量等数据在-gpu-和内存之间进行搬运">将模型、张量等数据在 GPU 和内存之间进行搬运</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#导入-imports">导入 Imports</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#一般">一般</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#神经网络-api">神经网络 API</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torchscript-和-jit">Torchscript 和 JIT</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#onnx">ONNX</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#vision">Vision</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#分布式训练">分布式训练</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#另见">另见</a></div></div><div class="h1wrap-body"><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="入门"><a aria-hidden="true" tabindex="-1" href="#入门"><span class="icon icon-link"></span></a>入门</h2><div class="wrap-body">
</svg></div><div class="menu-modal"><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#入门">入门</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#介绍">介绍</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#认识-pytorch">认识 Pytorch</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#创建一个全零矩阵">创建一个全零矩阵</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#数据创建张量">数据创建张量</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#pytorch-的基本语法">Pytorch 的基本语法</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作1">加法操作(1)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作2">加法操作(2)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作3">加法操作(3)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作4">加法操作(4)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量操作">张量操作</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量形状">张量形状</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#取张量元素">取张量元素</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-和-numpy-array互换">Torch Tensor 和 Numpy array互换</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-转换为-numpy-array">Torch Tensor 转换为 Numpy array</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#numpy-array转换为torch-tensor">Numpy array转换为Torch Tensor</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#squeeze函数">squeeze函数</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#unsqueeze函数">unsqueeze函数</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#cuda-相关">Cuda 相关</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#检查-cuda-是否可用">检查 Cuda 是否可用</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#列出-gpu-设备">列出 GPU 设备</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#将模型张量等数据在-gpu-和内存之间进行搬运">将模型、张量等数据在 GPU 和内存之间进行搬运</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#导入-imports">导入 Imports</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#一般">一般</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#神经网络-api">神经网络 API</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#onnx">ONNX</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torchscript-和-jit">Torchscript 和 JIT</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#vision">Vision</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#分布式训练">分布式训练</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#另见">另见</a></div></div><div class="h1wrap-body"><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="入门"><a aria-hidden="true" tabindex="-1" href="#入门"><span class="icon icon-link"></span></a>入门</h2><div class="wrap-body">
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="介绍"><a aria-hidden="true" tabindex="-1" href="#介绍"><span class="icon icon-link"></span></a>介绍</h3><div class="wrap-body">
<ul>
<li><a href="https://pytorch.org/">Pytorch 官网</a> <em>(pytorch.org)</em></li>
<li><a href="https://pytorch.org/tutorials/beginner/ptcheat.html">Pytorch 官方备忘清单</a> <em>(pytorch.org)</em></li>
</ul>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="认识-pytorch"><a aria-hidden="true" tabindex="-1" href="#认识-pytorch"><span class="icon icon-link"></span></a>认识 Pytorch</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="认识-pytorch"><a aria-hidden="true" tabindex="-1" href="#认识-pytorch"><span class="icon icon-link"></span></a>认识 Pytorch</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> __future__ <span class="token keyword">import</span> print_function
</span><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>empty<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
@ -56,7 +57,8 @@
</span></code></pre>
<!--rehype:className=wrap-text-->
<p>Tensors 张量: 张量的概念类似于Numpy中的ndarray数据结构, 最大的区别在于Tensor可以利用GPU的加速功能.</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="创建一个全零矩阵"><a aria-hidden="true" tabindex="-1" href="#创建一个全零矩阵"><span class="icon icon-link"></span></a>创建一个全零矩阵</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="创建一个全零矩阵"><a aria-hidden="true" tabindex="-1" href="#创建一个全零矩阵"><span class="icon icon-link"></span></a>创建一个全零矩阵</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
@ -101,7 +103,8 @@
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作4"><a aria-hidden="true" tabindex="-1" href="#加法操作4"><span class="icon icon-link"></span></a>加法操作(4)</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="加法操作4"><a aria-hidden="true" tabindex="-1" href="#加法操作4"><span class="icon icon-link"></span></a>加法操作(4)</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">y<span class="token punctuation">.</span>add_<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
@ -117,7 +120,8 @@
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">2.0902</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.4489</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.1441</span><span class="token punctuation">,</span> <span class="token number">0.8035</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.8341</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="张量形状"><a aria-hidden="true" tabindex="-1" href="#张量形状"><span class="icon icon-link"></span></a>张量形状</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="张量形状"><a aria-hidden="true" tabindex="-1" href="#张量形状"><span class="icon icon-link"></span></a>张量形状</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># tensor.view()操作需要保证数据元素的总数量不变</span>
</span><span class="code-line">y <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token number">16</span><span class="token punctuation">)</span>
@ -159,20 +163,31 @@
<p>注意: 所有在CPU上的Tensors, 除了CharTensor, 都可以转换为Numpy array并可以反向转换.</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="squeeze函数"><a aria-hidden="true" tabindex="-1" href="#squeeze函数"><span class="icon icon-link"></span></a>squeeze函数</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape <span class="token comment"># squeeze不加参数默认去除所有为1的维度</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># squeeze不加参数默认去除所有为1的维度</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape <span class="token comment"># squeeze加参数去除指定为1的维度</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># squeeze加参数去除指定为1的维度</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape <span class="token comment"># squeeze加参数如果不为1则不变</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># squeeze加参数如果不为1则不变</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape <span class="token comment"># 既可以是函数,也可以是方法</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># 既可以是函数,也可以是方法</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="unsqueeze函数"><a aria-hidden="true" tabindex="-1" href="#unsqueeze函数"><span class="icon icon-link"></span></a>unsqueeze函数</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape <span class="token comment"># unsqueeze必须加参数 _ 2 _ 28 _</span>
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># 参数代表在哪里添加维度 0 1 2</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape <span class="token comment"># 既可以是函数,也可以是方法</span>
</span><span class="code-line"><span class="token comment"># unsqueeze必须加参数 _ 2 _ 28 _</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line"><span class="token comment"># 参数代表在哪里添加维度 0 1 2</span>
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 既可以是函数,也可以是方法</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="cuda-相关"><a aria-hidden="true" tabindex="-1" href="#cuda-相关"><span class="icon icon-link"></span></a>Cuda 相关</h2><div class="wrap-body">
@ -181,50 +196,76 @@
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token boolean">True</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="列出-gpu-设备"><a aria-hidden="true" tabindex="-1" href="#列出-gpu-设备"><span class="icon icon-link"></span></a>列出 GPU 设备</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist col-span-2 row-span-2"><div class="wrap-header h3wrap"><h3 id="列出-gpu-设备"><a aria-hidden="true" tabindex="-1" href="#列出-gpu-设备"><span class="icon icon-link"></span></a>列出 GPU 设备</h3><div class="wrap-body">
<!--rehype:wrap-class=col-span-2 row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line">
</span><span class="code-line">device_count <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>device_count<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"CUDA 设备"</span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>device_count<span class="token punctuation">)</span><span class="token punctuation">:</span>
</span><span class="code-line"> device_name <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>get_device_name<span class="token punctuation">(</span>i<span class="token punctuation">)</span>
</span><span class="code-line"> total_memory <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>get_device_properties<span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">.</span>total_memory <span class="token operator">/</span> <span class="token punctuation">(</span><span class="token number">1024</span> <span class="token operator">**</span> <span class="token number">3</span><span class="token punctuation">)</span>
</span><span class="code-line"> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"├── 设备 </span><span class="token interpolation"><span class="token punctuation">{</span>i<span class="token punctuation">}</span></span><span class="token string">: </span><span class="token interpolation"><span class="token punctuation">{</span>device_name<span class="token punctuation">}</span></span><span class="token string">, 容量: </span><span class="token interpolation"><span class="token punctuation">{</span>total_memory<span class="token punctuation">:</span><span class="token format-spec">.2f</span><span class="token punctuation">}</span></span><span class="token string"> GiB"</span></span><span class="token punctuation">)</span>
</span><span class="code-line">
</span><span class="code-line"><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"└── (结束)"</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="将模型张量等数据在-gpu-和内存之间进行搬运"><a aria-hidden="true" tabindex="-1" href="#将模型张量等数据在-gpu-和内存之间进行搬运"><span class="icon icon-link"></span></a>将模型、张量等数据在 GPU 和内存之间进行搬运</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line"><span class="token comment"># Replace 0 to your GPU device index. or use "cuda" directly.</span>
</span><span class="code-line"><span class="token comment"># 将 0 替换为您的 GPU 设备索引或者直接使用 "cuda"</span>
</span><span class="code-line">device <span class="token operator">=</span> <span class="token string-interpolation"><span class="token string">f"cuda:0"</span></span>
</span><span class="code-line"><span class="token comment"># Move to GPU</span>
</span><span class="code-line"><span class="token comment"># 移动到GPU</span>
</span><span class="code-line">tensor_m <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</span><span class="code-line">tensor_g <span class="token operator">=</span> tensor_m<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
</span><span class="code-line">model_m <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
</span><span class="code-line">model_g <span class="token operator">=</span> model_m<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># Move back.</span>
</span><span class="code-line"><span class="token comment"># 向后移动</span>
</span><span class="code-line">tensor_m <span class="token operator">=</span> tensor_g<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span><span class="code-line">model_m <span class="token operator">=</span> model_g<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
</span></code></pre>
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="导入-imports"><a aria-hidden="true" tabindex="-1" href="#导入-imports"><span class="icon icon-link"></span></a>导入 Imports</h2><div class="wrap-body">
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="一般"><a aria-hidden="true" tabindex="-1" href="#一般"><span class="icon icon-link"></span></a>一般</h3><div class="wrap-body">
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 根包</span>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 根包</span>
</span><span class="code-line"><span class="token keyword">import</span> torch
</span><span class="code-line"><span class="token comment"># 数据集表示和加载</span>
</span><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> Dataset<span class="token punctuation">,</span> DataLoader
</span></code></pre>
<p>数据集表示和加载</p>
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> Dataset<span class="token punctuation">,</span> DataLoader
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="神经网络-api"><a aria-hidden="true" tabindex="-1" href="#神经网络-api"><span class="icon icon-link"></span></a>神经网络 API</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="神经网络-api"><a aria-hidden="true" tabindex="-1" href="#神经网络-api"><span class="icon icon-link"></span></a>神经网络 API</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 计算图</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>autograd <span class="token keyword">as</span> autograd
</span><span class="code-line"><span class="token comment"># 计算图中的张量节点</span>
</span><span class="code-line"><span class="token keyword">from</span> torch <span class="token keyword">import</span> Tensor
</span><span class="code-line"><span class="token comment"># 神经网络</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn
</span></code></pre>
<p>神经网络</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn
</span><span class="code-line">
</span><span class="code-line"><span class="token comment"># 层、激活等</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>functional <span class="token keyword">as</span> F
</span><span class="code-line"><span class="token comment"># 优化器,例如 梯度下降、ADAM等</span>
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>optim <span class="token keyword">as</span> optim
</span><span class="code-line"><span class="token comment"># 混合前端装饰器和跟踪 jit</span>
</span><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>jit <span class="token keyword">import</span> script<span class="token punctuation">,</span> trace
</span></code></pre>
<p>混合前端装饰器和跟踪 jit</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>jit <span class="token keyword">import</span> script<span class="token punctuation">,</span> trace
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="onnx"><a aria-hidden="true" tabindex="-1" href="#onnx"><span class="icon icon-link"></span></a>ONNX</h3><div class="wrap-body">
<!--rehype:wrap-class=row-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>onnx<span class="token punctuation">.</span>export<span class="token punctuation">(</span>model<span class="token punctuation">,</span> dummy data<span class="token punctuation">,</span> xxxx<span class="token punctuation">.</span>proto<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 导出 ONNX 格式</span>
</span><span class="code-line"><span class="token comment"># 使用经过训练的模型模型dummy</span>
</span><span class="code-line"><span class="token comment"># 数据和所需的文件名</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
<p>加载 ONNX 模型</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">model <span class="token operator">=</span> onnx<span class="token punctuation">.</span>load<span class="token punctuation">(</span><span class="token string">"alexnet.proto"</span><span class="token punctuation">)</span>
</span></code></pre>
<p>检查模型IT 是否结构良好</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">onnx<span class="token punctuation">.</span>checker<span class="token punctuation">.</span>check_model<span class="token punctuation">(</span>model<span class="token punctuation">)</span>
</span></code></pre>
<p>打印一个人类可读的,图的表示</p>
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">onnx<span class="token punctuation">.</span>helper<span class="token punctuation">.</span>printable_graph<span class="token punctuation">(</span>model<span class="token punctuation">.</span>graph<span class="token punctuation">)</span>
</span></code></pre>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torchscript-和-jit"><a aria-hidden="true" tabindex="-1" href="#torchscript-和-jit"><span class="icon icon-link"></span></a>Torchscript 和 JIT</h3><div class="wrap-body">
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>jit<span class="token punctuation">.</span>trace<span class="token punctuation">(</span><span class="token punctuation">)</span>
@ -233,22 +274,8 @@
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token decorator annotation punctuation">@script</span>
</span></code></pre>
<p>装饰器用于指示被跟踪代码中的数据相关控制流</p>
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="onnx"><a aria-hidden="true" tabindex="-1" href="#onnx"><span class="icon icon-link"></span></a>ONNX</h3><div class="wrap-body">
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>onnx<span class="token punctuation">.</span>export<span class="token punctuation">(</span>model<span class="token punctuation">,</span> dummy data<span class="token punctuation">,</span> xxxx<span class="token punctuation">.</span>proto<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 导出 ONNX 格式</span>
</span><span class="code-line"><span class="token comment"># 使用经过训练的模型模型dummy</span>
</span><span class="code-line"><span class="token comment"># 数据和所需的文件名</span>
</span><span class="code-line">
</span><span class="code-line">model <span class="token operator">=</span> onnx<span class="token punctuation">.</span>load<span class="token punctuation">(</span><span class="token string">"alexnet.proto"</span><span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 加载 ONNX 模型</span>
</span><span class="code-line">onnx<span class="token punctuation">.</span>checker<span class="token punctuation">.</span>check_model<span class="token punctuation">(</span>model<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 检查模型IT 是否结构良好</span>
</span><span class="code-line">
</span><span class="code-line">onnx<span class="token punctuation">.</span>helper<span class="token punctuation">.</span>printable_graph<span class="token punctuation">(</span>model<span class="token punctuation">.</span>graph<span class="token punctuation">)</span>
</span><span class="code-line"><span class="token comment"># 打印一个人类可读的,图的表示</span>
</span></code></pre>
<!--rehype:className=wrap-text-->
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="vision"><a aria-hidden="true" tabindex="-1" href="#vision"><span class="icon icon-link"></span></a>Vision</h3><div class="wrap-body">
</div></div></div><div class="wrap h3body-not-exist col-span-2"><div class="wrap-header h3wrap"><h3 id="vision"><a aria-hidden="true" tabindex="-1" href="#vision"><span class="icon icon-link"></span></a>Vision</h3><div class="wrap-body">
<!--rehype:wrap-class=col-span-2-->
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 视觉数据集,架构 &#x26; 变换</span>
</span><span class="code-line"><span class="token keyword">from</span> torchvision <span class="token keyword">import</span> datasets<span class="token punctuation">,</span> models<span class="token punctuation">,</span> transforms
</span><span class="code-line"><span class="token comment"># 组合转换</span>