mirror of
https://github.com/jaywcjlove/reference.git
synced 2025-06-17 04:31:22 +08:00
309 lines
66 KiB
HTML
309 lines
66 KiB
HTML
<!doctype html>
|
||
<html lang="en" data-color-mode="dark">
|
||
<head>
|
||
<meta charset="utf-8">
|
||
<title>Pytorch 备忘清单
|
||
& pytorch cheatsheet & Quick Reference</title>
|
||
<meta content="width=device-width, initial-scale=1" name="viewport">
|
||
<meta description="Pytorch 是一种开源机器学习框架,可加速从研究原型设计到生产部署的过程,备忘单是 官网
|
||
备忘清单为您提供了 Pytorch 基本语法和初步应用参考
|
||
|
||
入门,为开发人员分享快速参考备忘单。">
|
||
<meta keywords="pytorch,reference,Quick,Reference,cheatsheet,cheat,sheet">
|
||
<meta name="author" content="jaywcjlove">
|
||
<meta name="license" content="MIT">
|
||
<meta name="funding" content="https://jaywcjlove.github.io/#/sponsor">
|
||
<meta rel="apple-touch-icon" href="../icons/touch-icon-iphone.png">
|
||
<meta rel="apple-touch-icon" sizes="152x152" href="../icons/touch-icon-ipad.png">
|
||
<meta rel="apple-touch-icon" sizes="180x180" href="../icons/touch-icon-iphone.png">
|
||
<meta rel="apple-touch-icon" sizes="167x167" href="../icons/touch-icon-ipad-retina.png">
|
||
<meta rel="apple-touch-icon" sizes="120x120" href="../icons/touch-icon-iphone-retina.png">
|
||
<link rel="icon" href="../icons/favicon.svg" type="image/svg+xml">
|
||
<link href="../style/style.css" rel="stylesheet">
|
||
<link href="../style/katex.css" rel="stylesheet">
|
||
</head>
|
||
<body><nav class="header-nav"><div class="max-container"><a href="../index.html" class="logo"><svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" height="1em" width="1em">
|
||
<path d="m21.66 10.44-.98 4.18c-.84 3.61-2.5 5.07-5.62 4.77-.5-.04-1.04-.13-1.62-.27l-1.68-.4c-4.17-.99-5.46-3.05-4.48-7.23l.98-4.19c.2-.85.44-1.59.74-2.2 1.17-2.42 3.16-3.07 6.5-2.28l1.67.39c4.19.98 5.47 3.05 4.49 7.23Z" fill="#c9d1d9"></path>
|
||
<path d="M15.06 19.39c-.62.42-1.4.77-2.35 1.08l-1.58.52c-3.97 1.28-6.06.21-7.35-3.76L2.5 13.28c-1.28-3.97-.22-6.07 3.75-7.35l1.58-.52c.41-.13.8-.24 1.17-.31-.3.61-.54 1.35-.74 2.2l-.98 4.19c-.98 4.18.31 6.24 4.48 7.23l1.68.4c.58.14 1.12.23 1.62.27Zm2.43-8.88c-.06 0-.12-.01-.19-.02l-4.85-1.23a.75.75 0 0 1 .37-1.45l4.85 1.23a.748.748 0 0 1-.18 1.47Z" fill="#228e6c"></path>
|
||
<path d="M14.56 13.89c-.06 0-.12-.01-.19-.02l-2.91-.74a.75.75 0 0 1 .37-1.45l2.91.74c.4.1.64.51.54.91-.08.34-.38.56-.72.56Z" fill="#228e6c"></path>
|
||
</svg>
|
||
<span class="title">Quick Reference</span></a><div class="menu"><a href="javascript:void(0);" class="searchbtn" id="searchbtn"><svg xmlns="http://www.w3.org/2000/svg" height="1em" width="1em" viewBox="0 0 18 18">
|
||
<path fill="currentColor" d="M17.71,16.29 L14.31,12.9 C15.4069846,11.5024547 16.0022094,9.77665502 16,8 C16,3.581722 12.418278,0 8,0 C3.581722,0 0,3.581722 0,8 C0,12.418278 3.581722,16 8,16 C9.77665502,16.0022094 11.5024547,15.4069846 12.9,14.31 L16.29,17.71 C16.4777666,17.8993127 16.7333625,18.0057983 17,18.0057983 C17.2666375,18.0057983 17.5222334,17.8993127 17.71,17.71 C17.8993127,17.5222334 18.0057983,17.2666375 18.0057983,17 C18.0057983,16.7333625 17.8993127,16.4777666 17.71,16.29 Z M2,8 C2,4.6862915 4.6862915,2 8,2 C11.3137085,2 14,4.6862915 14,8 C14,11.3137085 11.3137085,14 8,14 C4.6862915,14 2,11.3137085 2,8 Z"></path>
|
||
</svg><span>搜索</span><span>⌘K</span></a><a href="https://github.com/jaywcjlove/reference/blob/main/docs/pytorch.md" class="edit" target="__blank"><svg viewBox="0 0 36 36" fill="currentColor" height="1em" width="1em"><path d="m33 6.4-3.7-3.7a1.71 1.71 0 0 0-2.36 0L23.65 6H6a2 2 0 0 0-2 2v22a2 2 0 0 0 2 2h22a2 2 0 0 0 2-2V11.76l3-3a1.67 1.67 0 0 0 0-2.36ZM18.83 20.13l-4.19.93 1-4.15 9.55-9.57 3.23 3.23ZM29.5 9.43 26.27 6.2l1.85-1.85 3.23 3.23Z"></path><path fill="none" d="M0 0h36v36H0z"></path></svg><span>编辑</span></a><button id="darkMode" type="button"><svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="light" height="1em" width="1em">
|
||
<path d="M6.995 12c0 2.761 2.246 5.007 5.007 5.007s5.007-2.246 5.007-5.007-2.246-5.007-5.007-5.007S6.995 9.239 6.995 12zM11 19h2v3h-2zm0-17h2v3h-2zm-9 9h3v2H2zm17 0h3v2h-3zM5.637 19.778l-1.414-1.414 2.121-2.121 1.414 1.414zM16.242 6.344l2.122-2.122 1.414 1.414-2.122 2.122zM6.344 7.759 4.223 5.637l1.415-1.414 2.12 2.122zm13.434 10.605-1.414 1.414-2.122-2.122 1.414-1.414z"></path>
|
||
</svg>
|
||
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" class="dark" height="1em" width="1em">
|
||
<path d="M12 11.807A9.002 9.002 0 0 1 10.049 2a9.942 9.942 0 0 0-5.12 2.735c-3.905 3.905-3.905 10.237 0 14.142 3.906 3.906 10.237 3.905 14.143 0a9.946 9.946 0 0 0 2.735-5.119A9.003 9.003 0 0 1 12 11.807z"></path>
|
||
</svg>
|
||
</button><script src="../js/dark.js?v=1.8.3"></script><a href="https://github.com/jaywcjlove/reference" class="" target="__blank"><svg viewBox="0 0 16 16" fill="currentColor" height="1em" width="1em"><path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.012 8.012 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path></svg></a></div></div></nav><div class="wrap h1body-exist max-container"><header class="wrap-header h1wrap"><h1 id="pytorch--备忘清单"><svg viewBox="0 0 24 24" fill="currentColor" xmlns="http://www.w3.org/2000/svg" height="1em" width="1em">
|
||
<path d="M12.005 0 4.952 7.053a9.865 9.865 0 0 0 0 14.022 9.866 9.866 0 0 0 14.022 0c3.984-3.9 3.986-10.205.085-14.023l-1.744 1.743c2.904 2.905 2.904 7.634 0 10.538s-7.634 2.904-10.538 0-2.904-7.634 0-10.538l4.647-4.646.582-.665zm3.568 3.899a1.327 1.327 0 0 0-1.327 1.327 1.327 1.327 0 0 0 1.327 1.328A1.327 1.327 0 0 0 16.9 5.226 1.327 1.327 0 0 0 15.573 3.9z"></path>
|
||
</svg>
|
||
<a aria-hidden="true" tabindex="-1" href="#pytorch--备忘清单"><span class="icon icon-link"></span></a>Pytorch 备忘清单</h1><div class="wrap-body">
|
||
<p>Pytorch 是一种开源机器学习框架,可加速从研究原型设计到生产部署的过程,备忘单是 官网
|
||
备忘清单为您提供了 <a href="https://pytorch.org/">Pytorch</a> 基本语法和初步应用参考</p>
|
||
</div></header><div class="menu-tocs"><div class="menu-btn"><svg aria-hidden="true" fill="currentColor" height="1em" width="1em" viewBox="0 0 16 16" version="1.1" data-view-component="true">
|
||
<path fill-rule="evenodd" d="M2 4a1 1 0 100-2 1 1 0 000 2zm3.75-1.5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zm0 5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zm0 5a.75.75 0 000 1.5h8.5a.75.75 0 000-1.5h-8.5zM3 8a1 1 0 11-2 0 1 1 0 012 0zm-1 6a1 1 0 100-2 1 1 0 000 2z"></path>
|
||
</svg></div><div class="menu-modal"><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#入门">入门</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#介绍">介绍</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#认识-pytorch">认识 Pytorch</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#创建一个全零矩阵">创建一个全零矩阵</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#数据创建张量">数据创建张量</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#pytorch-的基本语法">Pytorch 的基本语法</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作1">加法操作(1)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作2">加法操作(2)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作3">加法操作(3)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#加法操作4">加法操作(4)</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量操作">张量操作</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#张量形状">张量形状</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#取张量元素">取张量元素</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-和-numpy-array互换">Torch Tensor 和 Numpy array互换</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torch-tensor-转换为-numpy-array">Torch Tensor 转换为 Numpy array</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#numpy-array转换为torch-tensor">Numpy array转换为Torch Tensor</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#squeeze函数">squeeze函数</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#unsqueeze函数">unsqueeze函数</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#cuda-相关">Cuda 相关</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#检查-cuda-是否可用">检查 Cuda 是否可用</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#列出-gpu-设备">列出 GPU 设备</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#将模型张量等数据在-gpu-和内存之间进行搬运">将模型、张量等数据在 GPU 和内存之间进行搬运</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#导入-imports">导入 Imports</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#一般">一般</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#神经网络-api">神经网络 API</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#onnx">ONNX</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#torchscript-和-jit">Torchscript 和 JIT</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#vision">Vision</a><a aria-hidden="true" class="leve3 tocs-link" data-num="3" href="#分布式训练">分布式训练</a><a aria-hidden="true" class="leve2 tocs-link" data-num="2" href="#另见">另见</a></div></div><div class="h1wrap-body"><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="入门"><a aria-hidden="true" tabindex="-1" href="#入门"><span class="icon icon-link"></span></a>入门</h2><div class="wrap-body">
|
||
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="介绍"><a aria-hidden="true" tabindex="-1" href="#介绍"><span class="icon icon-link"></span></a>介绍</h3><div class="wrap-body">
|
||
<ul>
|
||
<li><a href="https://pytorch.org/">Pytorch 官网</a> <em>(pytorch.org)</em></li>
|
||
<li><a href="https://pytorch.org/tutorials/beginner/ptcheat.html">Pytorch 官方备忘清单</a> <em>(pytorch.org)</em></li>
|
||
</ul>
|
||
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="认识-pytorch"><a aria-hidden="true" tabindex="-1" href="#认识-pytorch"><span class="icon icon-link"></span></a>认识 Pytorch</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=row-span-2-->
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> __future__ <span class="token keyword">import</span> print_function
|
||
</span><span class="code-line"><span class="token keyword">import</span> torch
|
||
</span><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>empty<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">2.4835e+27</span><span class="token punctuation">,</span> <span class="token number">2.5428e+30</span><span class="token punctuation">,</span> <span class="token number">1.0877e-19</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">1.5163e+23</span><span class="token punctuation">,</span> <span class="token number">2.2012e+12</span><span class="token punctuation">,</span> <span class="token number">3.7899e+22</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">5.2480e+05</span><span class="token punctuation">,</span> <span class="token number">1.0175e+31</span><span class="token punctuation">,</span> <span class="token number">9.7056e+24</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">1.6283e+32</span><span class="token punctuation">,</span> <span class="token number">3.7913e+22</span><span class="token punctuation">,</span> <span class="token number">3.9653e+28</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">1.0876e-19</span><span class="token punctuation">,</span> <span class="token number">6.2027e+26</span><span class="token punctuation">,</span> <span class="token number">2.3685e+21</span><span class="token punctuation">]</span>
|
||
</span><span class="code-line"><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
<p>Tensors 张量: 张量的概念类似于Numpy中的ndarray数据结构, 最大的区别在于Tensor可以利用GPU的加速功能.</p>
|
||
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="创建一个全零矩阵"><a aria-hidden="true" tabindex="-1" href="#创建一个全零矩阵"><span class="icon icon-link"></span></a>创建一个全零矩阵</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=row-span-2-->
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<p>创建一个全零矩阵并可指定数据元素的类型为long</p>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="数据创建张量"><a aria-hidden="true" tabindex="-1" href="#数据创建张量"><span class="icon icon-link"></span></a>数据创建张量</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2.5</span><span class="token punctuation">,</span> <span class="token number">3.5</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2.5000</span><span class="token punctuation">,</span> <span class="token number">3.3000</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="pytorch-的基本语法"><a aria-hidden="true" tabindex="-1" href="#pytorch-的基本语法"><span class="icon icon-link"></span></a>Pytorch 的基本语法</h2><div class="wrap-body">
|
||
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作1"><a aria-hidden="true" tabindex="-1" href="#加法操作1"><span class="icon icon-link"></span></a>加法操作(1)</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">y <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x <span class="token operator">+</span> y<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作2"><a aria-hidden="true" tabindex="-1" href="#加法操作2"><span class="icon icon-link"></span></a>加法操作(2)</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>add<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">)</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="加法操作3"><a aria-hidden="true" tabindex="-1" href="#加法操作3"><span class="icon icon-link"></span></a>加法操作(3)</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 提前设定一个空的张量</span>
|
||
</span><span class="code-line">result <span class="token operator">=</span> torch<span class="token punctuation">.</span>empty<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># 将空的张量作为加法的结果存储张量</span>
|
||
</span><span class="code-line"> torch<span class="token punctuation">.</span>add<span class="token punctuation">(</span>x<span class="token punctuation">,</span> y<span class="token punctuation">,</span> out<span class="token operator">=</span>result<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>result<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="加法操作4"><a aria-hidden="true" tabindex="-1" href="#加法操作4"><span class="icon icon-link"></span></a>加法操作(4)</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=row-span-2-->
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">y<span class="token punctuation">.</span>add_<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span> <span class="token number">1.6978</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1.6979</span><span class="token punctuation">,</span> <span class="token number">0.3093</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.4953</span><span class="token punctuation">,</span> <span class="token number">0.3954</span><span class="token punctuation">,</span> <span class="token number">0.0595</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.9540</span><span class="token punctuation">,</span> <span class="token number">0.3353</span><span class="token punctuation">,</span> <span class="token number">0.1251</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">0.6883</span><span class="token punctuation">,</span> <span class="token number">0.9775</span><span class="token punctuation">,</span> <span class="token number">1.1764</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
|
||
</span><span class="code-line"> <span class="token punctuation">[</span> <span class="token number">2.6784</span><span class="token punctuation">,</span> <span class="token number">0.1209</span><span class="token punctuation">,</span> <span class="token number">1.5542</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<p>注意: 所有 <code>in-place</code> 的操作函数都有一个下划线的后缀。
|
||
比如 <code>x.copy_(y)</code>, <code>x.add_(y)</code>, 都会直接改变x的值</p>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="张量操作"><a aria-hidden="true" tabindex="-1" href="#张量操作"><span class="icon icon-link"></span></a>张量操作</h3><div class="wrap-body">
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">2.0902</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.4489</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.1441</span><span class="token punctuation">,</span> <span class="token number">0.8035</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">0.8341</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="张量形状"><a aria-hidden="true" tabindex="-1" href="#张量形状"><span class="icon icon-link"></span></a>张量形状</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=row-span-2-->
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># tensor.view()操作需要保证数据元素的总数量不变</span>
|
||
</span><span class="code-line">y <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token number">16</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># -1代表自动匹配个数</span>
|
||
</span><span class="code-line">z <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> y<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> z<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">16</span><span class="token punctuation">]</span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="取张量元素"><a aria-hidden="true" tabindex="-1" href="#取张量元素"><span class="icon icon-link"></span></a>取张量元素</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">x <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>x<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">0.3531</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">-</span><span class="token number">0.3530771732330322</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torch-tensor-和-numpy-array互换"><a aria-hidden="true" tabindex="-1" href="#torch-tensor-和-numpy-array互换"><span class="icon icon-link"></span></a>Torch Tensor 和 Numpy array互换</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">a <span class="token operator">=</span> torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">,</span> <span class="token number">1.</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<p>Torch Tensor和Numpy array共享底层的内存空间, 因此改变其中一个的值, 另一个也会随之被改变</p>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torch-tensor-转换为-numpy-array"><a aria-hidden="true" tabindex="-1" href="#torch-tensor-转换为-numpy-array"><span class="icon icon-link"></span></a>Torch Tensor 转换为 Numpy array</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">b <span class="token operator">=</span> a<span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>b<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token punctuation">[</span><span class="token number">1.</span> <span class="token number">1.</span> <span class="token number">1.</span> <span class="token number">1.</span> <span class="token number">1.</span><span class="token punctuation">]</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="numpy-array转换为torch-tensor"><a aria-hidden="true" tabindex="-1" href="#numpy-array转换为torch-tensor"><span class="icon icon-link"></span></a>Numpy array转换为Torch Tensor</h3><div class="wrap-body">
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> numpy <span class="token keyword">as</span> np
|
||
</span><span class="code-line">a <span class="token operator">=</span> np<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">b <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>a<span class="token punctuation">)</span>
|
||
</span><span class="code-line">np<span class="token punctuation">.</span>add<span class="token punctuation">(</span>a<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> out<span class="token operator">=</span>a<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">print</span><span class="token punctuation">(</span>b<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token punctuation">[</span><span class="token number">2.</span> <span class="token number">2.</span> <span class="token number">2.</span> <span class="token number">2.</span> <span class="token number">2.</span><span class="token punctuation">]</span>
|
||
</span><span class="code-line">tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">,</span> <span class="token number">2.</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>torch<span class="token punctuation">.</span>float64<span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
<p>注意: 所有在CPU上的Tensors, 除了CharTensor, 都可以转换为Numpy array并可以反向转换.</p>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="squeeze函数"><a aria-hidden="true" tabindex="-1" href="#squeeze函数"><span class="icon icon-link"></span></a>squeeze函数</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token comment"># squeeze不加参数,默认去除所有为1的维度</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token comment"># squeeze加参数,去除指定为1的维度</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token comment"># squeeze加参数,如果不为1,则不变</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token comment"># 既可以是函数,也可以是方法</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="unsqueeze函数"><a aria-hidden="true" tabindex="-1" href="#unsqueeze函数"><span class="icon icon-link"></span></a>unsqueeze函数</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># unsqueeze必须加参数, _ 2 _ 28 _</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> x<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
|
||
</span><span class="code-line"><span class="token comment"># 参数代表在哪里添加维度 0 1 2</span>
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># 既可以是函数,也可以是方法</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape
|
||
</span><span class="code-line">torch<span class="token punctuation">.</span>Size<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="cuda-相关"><a aria-hidden="true" tabindex="-1" href="#cuda-相关"><span class="icon icon-link"></span></a>Cuda 相关</h2><div class="wrap-body">
|
||
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="检查-cuda-是否可用"><a aria-hidden="true" tabindex="-1" href="#检查-cuda-是否可用"><span class="icon icon-link"></span></a>检查 Cuda 是否可用</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token keyword">import</span> torch<span class="token punctuation">.</span>cuda
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token operator">>></span><span class="token operator">></span> <span class="token boolean">True</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist col-span-2 row-span-2"><div class="wrap-header h3wrap"><h3 id="列出-gpu-设备"><a aria-hidden="true" tabindex="-1" href="#列出-gpu-设备"><span class="icon icon-link"></span></a>列出 GPU 设备</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=col-span-2 row-span-2-->
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch
|
||
</span><span class="code-line">
|
||
</span><span class="code-line">device_count <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>device_count<span class="token punctuation">(</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"CUDA 设备"</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>device_count<span class="token punctuation">)</span><span class="token punctuation">:</span>
|
||
</span><span class="code-line"> device_name <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>get_device_name<span class="token punctuation">(</span>i<span class="token punctuation">)</span>
|
||
</span><span class="code-line"> total_memory <span class="token operator">=</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>get_device_properties<span class="token punctuation">(</span>i<span class="token punctuation">)</span><span class="token punctuation">.</span>total_memory <span class="token operator">/</span> <span class="token punctuation">(</span><span class="token number">1024</span> <span class="token operator">**</span> <span class="token number">3</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line"> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"├── 设备 </span><span class="token interpolation"><span class="token punctuation">{</span>i<span class="token punctuation">}</span></span><span class="token string">: </span><span class="token interpolation"><span class="token punctuation">{</span>device_name<span class="token punctuation">}</span></span><span class="token string">, 容量: </span><span class="token interpolation"><span class="token punctuation">{</span>total_memory<span class="token punctuation">:</span><span class="token format-spec">.2f</span><span class="token punctuation">}</span></span><span class="token string"> GiB"</span></span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"└── (结束)"</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="将模型张量等数据在-gpu-和内存之间进行搬运"><a aria-hidden="true" tabindex="-1" href="#将模型张量等数据在-gpu-和内存之间进行搬运"><span class="icon icon-link"></span></a>将模型、张量等数据在 GPU 和内存之间进行搬运</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch
|
||
</span><span class="code-line"><span class="token comment"># 将 0 替换为您的 GPU 设备索引或者直接使用 "cuda"</span>
|
||
</span><span class="code-line">device <span class="token operator">=</span> <span class="token string-interpolation"><span class="token string">f"cuda:0"</span></span>
|
||
</span><span class="code-line"><span class="token comment"># 移动到GPU</span>
|
||
</span><span class="code-line">tensor_m <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">tensor_g <span class="token operator">=</span> tensor_m<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
|
||
</span><span class="code-line">model_m <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">model_g <span class="token operator">=</span> model_m<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># 向后移动</span>
|
||
</span><span class="code-line">tensor_m <span class="token operator">=</span> tensor_g<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
|
||
</span><span class="code-line">model_m <span class="token operator">=</span> model_g<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div></div></div><div class="wrap h2body-exist"><div class="wrap-header h2wrap"><h2 id="导入-imports"><a aria-hidden="true" tabindex="-1" href="#导入-imports"><span class="icon icon-link"></span></a>导入 Imports</h2><div class="wrap-body">
|
||
</div></div><div class="h2wrap-body"><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="一般"><a aria-hidden="true" tabindex="-1" href="#一般"><span class="icon icon-link"></span></a>一般</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 根包</span>
|
||
</span><span class="code-line"><span class="token keyword">import</span> torch
|
||
</span></code></pre>
|
||
<p>数据集表示和加载</p>
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> Dataset<span class="token punctuation">,</span> DataLoader
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="神经网络-api"><a aria-hidden="true" tabindex="-1" href="#神经网络-api"><span class="icon icon-link"></span></a>神经网络 API</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=row-span-2-->
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 计算图</span>
|
||
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>autograd <span class="token keyword">as</span> autograd
|
||
</span><span class="code-line"><span class="token comment"># 计算图中的张量节点</span>
|
||
</span><span class="code-line"><span class="token keyword">from</span> torch <span class="token keyword">import</span> Tensor
|
||
</span></code></pre>
|
||
<p>神经网络</p>
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn
|
||
</span><span class="code-line">
|
||
</span><span class="code-line"><span class="token comment"># 层、激活等</span>
|
||
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>functional <span class="token keyword">as</span> F
|
||
</span><span class="code-line"><span class="token comment"># 优化器,例如 梯度下降、ADAM等</span>
|
||
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>optim <span class="token keyword">as</span> optim
|
||
</span></code></pre>
|
||
<p>混合前端装饰器和跟踪 jit</p>
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>jit <span class="token keyword">import</span> script<span class="token punctuation">,</span> trace
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist row-span-2"><div class="wrap-header h3wrap"><h3 id="onnx"><a aria-hidden="true" tabindex="-1" href="#onnx"><span class="icon icon-link"></span></a>ONNX</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=row-span-2-->
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>onnx<span class="token punctuation">.</span>export<span class="token punctuation">(</span>model<span class="token punctuation">,</span> dummy data<span class="token punctuation">,</span> xxxx<span class="token punctuation">.</span>proto<span class="token punctuation">)</span>
|
||
</span><span class="code-line"><span class="token comment"># 导出 ONNX 格式</span>
|
||
</span><span class="code-line"><span class="token comment"># 使用经过训练的模型模型,dummy</span>
|
||
</span><span class="code-line"><span class="token comment"># 数据和所需的文件名</span>
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
<p>加载 ONNX 模型</p>
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">model <span class="token operator">=</span> onnx<span class="token punctuation">.</span>load<span class="token punctuation">(</span><span class="token string">"alexnet.proto"</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<p>检查模型,IT 是否结构良好</p>
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">onnx<span class="token punctuation">.</span>checker<span class="token punctuation">.</span>check_model<span class="token punctuation">(</span>model<span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<p>打印一个人类可读的,图的表示</p>
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">onnx<span class="token punctuation">.</span>helper<span class="token punctuation">.</span>printable_graph<span class="token punctuation">(</span>model<span class="token punctuation">.</span>graph<span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="torchscript-和-jit"><a aria-hidden="true" tabindex="-1" href="#torchscript-和-jit"><span class="icon icon-link"></span></a>Torchscript 和 JIT</h3><div class="wrap-body">
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line">torch<span class="token punctuation">.</span>jit<span class="token punctuation">.</span>trace<span class="token punctuation">(</span><span class="token punctuation">)</span>
|
||
</span></code></pre>
|
||
<p>使用你的模块或函数和一个例子,数据输入,并追溯计算步骤,数据在模型中前进时遇到的情况</p>
|
||
<pre class="language-python"><code class="language-python code-highlight"><span class="code-line"><span class="token decorator annotation punctuation">@script</span>
|
||
</span></code></pre>
|
||
<p>装饰器用于指示被跟踪代码中的数据相关控制流</p>
|
||
</div></div></div><div class="wrap h3body-not-exist col-span-2"><div class="wrap-header h3wrap"><h3 id="vision"><a aria-hidden="true" tabindex="-1" href="#vision"><span class="icon icon-link"></span></a>Vision</h3><div class="wrap-body">
|
||
<!--rehype:wrap-class=col-span-2-->
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 视觉数据集,架构 & 变换</span>
|
||
</span><span class="code-line"><span class="token keyword">from</span> torchvision <span class="token keyword">import</span> datasets<span class="token punctuation">,</span> models<span class="token punctuation">,</span> transforms
|
||
</span><span class="code-line"><span class="token comment"># 组合转换</span>
|
||
</span><span class="code-line"><span class="token keyword">import</span> torchvision<span class="token punctuation">.</span>transforms <span class="token keyword">as</span> transforms
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
</div></div></div><div class="wrap h3body-not-exist"><div class="wrap-header h3wrap"><h3 id="分布式训练"><a aria-hidden="true" tabindex="-1" href="#分布式训练"><span class="icon icon-link"></span></a>分布式训练</h3><div class="wrap-body">
|
||
<pre class="wrap-text"><code class="language-python code-highlight"><span class="code-line"><span class="token comment"># 分布式通信</span>
|
||
</span><span class="code-line"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>distributed <span class="token keyword">as</span> dist
|
||
</span><span class="code-line"><span class="token comment"># 内存共享进程</span>
|
||
</span><span class="code-line"><span class="token keyword">from</span> torch<span class="token punctuation">.</span>multiprocessing <span class="token keyword">import</span> Process
|
||
</span></code></pre>
|
||
<!--rehype:className=wrap-text-->
|
||
</div></div></div></div></div><div class="wrap h2body-not-exist"><div class="wrap-header h2wrap"><h2 id="另见"><a aria-hidden="true" tabindex="-1" href="#另见"><span class="icon icon-link"></span></a>另见</h2><div class="wrap-body">
|
||
<ul>
|
||
<li><a href="https://pytorch.org/">Pytorch 官网</a> <em>(pytorch.org)</em></li>
|
||
<li><a href="https://pytorch.org/tutorials/beginner/ptcheat.html">Pytorch 官方备忘清单</a> <em>(pytorch.org)</em></li>
|
||
</ul>
|
||
</div></div><div class="h2wrap-body"></div></div></div><script src="https://giscus.app/client.js" data-repo="jaywcjlove/reference" data-repo-id="R_kgDOID2-Mw" data-category="Q&A" data-category-id="DIC_kwDOID2-M84CS5wo" data-mapping="pathname" data-strict="0" data-reactions-enabled="1" data-emit-metadata="0" data-input-position="bottom" data-theme="dark" data-lang="zh-CN" crossorigin="anonymous" async></script><div class="giscus"></div></div><footer class="footer-wrap"><footer class="max-container">© 2022 <a href="https://wangchujiang.com/#/app" target="_blank">Kenny Wang</a>.</footer></footer><script src="../data.js?v=1.8.3" defer></script><script src="../js/fuse.min.js?v=1.8.3" defer></script><script src="../js/main.js?v=1.8.3" defer></script><div id="mysearch"><div class="mysearch-box"><div class="mysearch-input"><div><svg xmlns="http://www.w3.org/2000/svg" height="1em" width="1em" viewBox="0 0 18 18">
|
||
<path fill="currentColor" d="M17.71,16.29 L14.31,12.9 C15.4069846,11.5024547 16.0022094,9.77665502 16,8 C16,3.581722 12.418278,0 8,0 C3.581722,0 0,3.581722 0,8 C0,12.418278 3.581722,16 8,16 C9.77665502,16.0022094 11.5024547,15.4069846 12.9,14.31 L16.29,17.71 C16.4777666,17.8993127 16.7333625,18.0057983 17,18.0057983 C17.2666375,18.0057983 17.5222334,17.8993127 17.71,17.71 C17.8993127,17.5222334 18.0057983,17.2666375 18.0057983,17 C18.0057983,16.7333625 17.8993127,16.4777666 17.71,16.29 Z M2,8 C2,4.6862915 4.6862915,2 8,2 C11.3137085,2 14,4.6862915 14,8 C14,11.3137085 11.3137085,14 8,14 C4.6862915,14 2,11.3137085 2,8 Z"></path>
|
||
</svg><input id="mysearch-input" type="search" placeholder="搜索" autocomplete="off"><div class="mysearch-clear"></div></div><button id="mysearch-close" type="button">搜索</button></div><div class="mysearch-result"><div id="mysearch-menu"></div><div id="mysearch-content"></div></div></div></div></body>
|
||
</html>
|