<?xml version="1.0" encoding="utf-8"?><?xml-stylesheet type="text/xsl" href="atom.xsl"?>
<feed xmlns="http://www.w3.org/2005/Atom">
    <id>https://kiwinicki.github.io/</id>
    <title>ML blog Blog</title>
    <updated>2025-04-04T00:00:00.000Z</updated>
    <generator>https://github.com/jpmonette/feed</generator>
    <link rel="alternate" href="https://kiwinicki.github.io/"/>
    <subtitle>ML blog Blog</subtitle>
    <entry>
        <title type="html"><![CDATA[Random training tips and tricks]]></title>
        <id>https://kiwinicki.github.io/2025/04/04/training-tricks</id>
        <link href="https://kiwinicki.github.io/2025/04/04/training-tricks"/>
        <updated>2025-04-04T00:00:00.000Z</updated>
        <summary type="html"><![CDATA[List of all sort of things you can use make your training/finetuning go faster.  From one-line tricks to whole architecture changes. Order is random and grouping is somewhat arbitrary. Treat this as a knowledge dump.]]></summary>
        <content type="html"><![CDATA[<p>List of all sort of things you can use make your training/finetuning go faster.  From one-line tricks to whole architecture changes. Order is random and grouping is somewhat arbitrary. Treat this as a knowledge dump.</p>
<ol>
<li>buy bigger GPU.</li>
</ol>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="memory-optimization-techniques">Memory Optimization Techniques<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#memory-optimization-techniques" class="hash-link" aria-label="Direct link to Memory Optimization Techniques" title="Direct link to Memory Optimization Techniques">​</a></h2>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="optimizer-memory-reduction">Optimizer Memory Reduction<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#optimizer-memory-reduction" class="hash-link" aria-label="Direct link to Optimizer Memory Reduction" title="Direct link to Optimizer Memory Reduction">​</a></h3>
<ul>
<li>8-bit Adam -  2*8-bit vs 2*32-bit = ~75% less memory for optimizer states</li>
<li>Not storing activations from forward pass but recomputing them in backward pass (memory-compute tradeoff)</li>
<li>"Stateless" optimizers - SGD, Lion. In contrast to e.g. Adam these don't store additional momentum variables (which takes additional 2x model size of space). The catch is these are not drop-in replacements and Adam was specially made for faster convergence &amp; more stable training.</li>
<li>Optimizer states offloading to CPU, using <a href="https://github.com/pytorch/ao/tree/main/torchao/optim#optimizer-cpu-offload" target="_blank" rel="noopener noreferrer"><code>torchao</code></a>:<!-- -->
<div class="language-python codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#393A34;--prism-background-color:#f6f8fa"><div class="codeBlockContent_biex"><pre tabindex="0" class="prism-code language-python codeBlock_bY9V thin-scrollbar" style="color:#393A34;background-color:#f6f8fa"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#393A34"><span class="token plain">optim </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> CPUOffloadOptimizer</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">model</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">parameters</span><span class="token punctuation" style="color:#393A34">(</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">optim</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">AdamW</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> fused</span><span class="token operator" style="color:#393A34">=</span><span class="token boolean" style="color:#36acaa">True</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">optim</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">load_state_dict</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">ckpt</span><span class="token punctuation" style="color:#393A34">[</span><span class="token string" style="color:#e3116c">"optim"</span><span class="token punctuation" style="color:#393A34">]</span><span class="token punctuation" style="color:#393A34">)</span><br></span></code></pre><div class="buttonGroup__atx"><button type="button" aria-label="Copy code to clipboard" title="Copy" class="clean-btn"><span class="copyButtonIcons_eSgA" aria-hidden="true"><svg viewBox="0 0 24 24" class="copyButtonIcon_y97N"><path fill="currentColor" d="M19,21H8V7H19M19,5H8A2,2 0 0,0 6,7V21A2,2 0 0,0 8,23H19A2,2 0 0,0 21,21V7A2,2 0 0,0 19,5M16,1H4A2,2 0 0,0 2,3V17H4V3H16V1Z"></path></svg><svg viewBox="0 0 24 24" class="copyButtonSuccessIcon_LjdS"><path fill="currentColor" d="M21,7L9,19L3.5,13.5L4.91,12.09L9,16.17L19.59,5.59L21,7Z"></path></svg></span></button></div></div></div>
</li>
</ul>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="parameter-efficient-methods">Parameter-Efficient Methods<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#parameter-efficient-methods" class="hash-link" aria-label="Direct link to Parameter-Efficient Methods" title="Direct link to Parameter-Efficient Methods">​</a></h3>
<ul>
<li>
<p>LoRA (Low-Rank Adaptation) - thin layer over frozen pretrained layers of your model. It uses a trick of decomposing a large matrix into two smaller low-rank (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>n</mi><mo>×</mo><mi>m</mi></mrow><annotation encoding="application/x-tex">n\times m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em"></span><span class="mord mathnormal">n</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal">m</span></span></span></span> where <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>n</mi><mo>&lt;</mo><mo>&lt;</mo><mi>m</mi></mrow><annotation encoding="application/x-tex">n &lt;&lt; m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.5782em;vertical-align:-0.0391em"></span><span class="mord mathnormal">n</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">&lt;&lt;</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal">m</span></span></span></span>) matrices that gives huge memory savings. Slightly more formal:</p>
<p><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>act</mtext><mo stretchy="false">[</mo><mo stretchy="false">(</mo><mi>W</mi><mo>+</mo><mi mathvariant="normal">Δ</mi><mi>W</mi><mo stretchy="false">)</mo><mo>⋅</mo><mi>x</mi><mo>+</mo><mi>b</mi><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">\text{act}[(W + \Delta W)\cdot x + b]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord text"><span class="mord">act</span></span><span class="mopen">[(</span><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord">Δ</span><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal">b</span><span class="mclose">]</span></span></span></span>, where:</p>
<ul>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>W</mi><mo>=</mo><mo stretchy="false">(</mo><mi>d</mi><mo>×</mo><mi>k</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">W=(d \times k)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord mathnormal">d</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="mclose">)</span></span></span></span> pretrained weight tensor</li>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi mathvariant="normal">Δ</mi><mi>W</mi><mo>=</mo><mi>A</mi><mi>B</mi></mrow><annotation encoding="application/x-tex">\Delta W = AB</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord">Δ</span><span class="mord mathnormal" style="margin-right:0.13889em">W</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">A</span><span class="mord mathnormal" style="margin-right:0.05017em">B</span></span></span></span>, (init LoRA tensors) where:<!-- -->
<ul>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>A</mi><mo>=</mo><mo stretchy="false">(</mo><mi>d</mi><mo>×</mo><mtext>rank</mtext><mo stretchy="false">)</mo><mo>=</mo><mi mathvariant="script">N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">A = (d\times \text{rank}) =\mathcal{N}(0, \sigma^2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">A</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord mathnormal">d</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord text"><span class="mord">rank</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1.0641em;vertical-align:-0.25em"></span><span class="mord mathcal" style="margin-right:0.14736em">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8141em"><span style="top:-3.063em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> at init</li>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>B</mi><mo>=</mo><mo stretchy="false">(</mo><mtext>rank</mtext><mo>×</mo><mi>k</mi><mo stretchy="false">)</mo><mo>=</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">B=(\text{rank} \times k)=0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord text"><span class="mord">rank</span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">×</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal" style="margin-right:0.03148em">k</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.6444em"></span><span class="mord">0</span></span></span></span> at init</li>
<li>at the begining <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>A</mi><mi>B</mi></mrow><annotation encoding="application/x-tex">AB</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">A</span><span class="mord mathnormal" style="margin-right:0.05017em">B</span></span></span></span> are no-op but thanks to <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>A</mi></mrow><annotation encoding="application/x-tex">A</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">A</span></span></span></span> being Gaussian there will be symmetry breaking</li>
</ul>
</li>
<li>some good links about it: <a href="https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch" target="_blank" rel="noopener noreferrer">Sebastian Raschka</a>, <a href="https://youtu.be/KEv-F5UkhxU?si=JcCjdEV-7tXPvk2r" target="_blank" rel="noopener noreferrer">AI Coffe Break</a></li>
</ul>
</li>
<li>
<p>QLoRA - It differs in that base model is quantized (usually to 4-bit) but LoRA layers are kept in 16-bit precision</p>
</li>
<li>
<p>After training its good to merge LoRA layers with the base model for less latency and computational overhead (but you can't swap if you have more of them).</p>
</li>
<li>
<p>DoRA - decomposition of weight matrix into magnitude vector <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>m</mi></mrow><annotation encoding="application/x-tex">m</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal">m</span></span></span></span> (euclidean distance) &amp; directional matrix <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>V</mi></mrow><annotation encoding="application/x-tex">V</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.22222em">V</span></span></span></span> (angle) and train them separately.</p>
</li>
<li>
<p>GaLore - LoRA but for gradient matrix. Supposedly works also for pretraining in contrast to LoRA, but I didn't tried it.</p>
</li>
<li>
<p>Prefix Tuning - In place of the prompt you put a random init vector (the so-called prefix) and optimize it until you get the correct answer.</p>
<ul>
<li>✅ tiny amount of parameters to tune</li>
<li>❌ takes context length (but alternatively you would put pre-prompt there)</li>
<li>❌ interpretability - these are not words, but you can decode it to "nearest" words but it often gibberish</li>
</ul>
</li>
</ul>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="mixed-precision-strategies">Mixed Precision Strategies<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#mixed-precision-strategies" class="hash-link" aria-label="Direct link to Mixed Precision Strategies" title="Direct link to Mixed Precision Strategies">​</a></h3>
<p><a href="https://sebastianraschka.com/blog/2023/llm-mixed-precision-copy.html" target="_blank" rel="noopener noreferrer">https://sebastianraschka.com/blog/2023/llm-mixed-precision-copy.html</a></p>
<p><a href="https://medium.com/@jbensnyder/solving-the-limits-of-mixed-precision-training-231019128b4b" target="_blank" rel="noopener noreferrer">https://medium.com/@jbensnyder/solving-the-limits-of-mixed-precision-training-231019128b4b</a></p>
<p>All types other than FP32 are only faster on consumer cards due to Tensor cores which are only available on RTX cards. Tensor cores are used automatically when using mixed precision.</p>
<ul>
<li>
<p><strong>FP32+FP16/BF16</strong> (speed &amp; memory*) - you store an extra copy of the model in 16-bit to be able to do faster calculations (~2x) (memory-compute trade-off). Gradients are also <strong>computed</strong> in 16-bit, but <strong>stored</strong> in FP32. BF16 should be more stable than FP16, as more range is more important for NNs than precision. BF16 stands from "Brain Float" from Google Brain btw.</p>
<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right" columnspacing=""><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><munder><munder><mrow><mo stretchy="false">(</mo><mn>4</mn><mo>+</mo><mn>2</mn><mtext>&nbsp;bytes</mtext><mo stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mtext>model&nbsp;FP32+FP16</mtext></munder><mo>+</mo><munder><munder><mrow><mo stretchy="false">(</mo><mn>2</mn><mtext>&nbsp;bytes</mtext><mo stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mtext>activation&nbsp;FP16</mtext></munder><mo>+</mo><munder><munder><mrow><mo stretchy="false">(</mo><mn>4</mn><mtext>&nbsp;bytes</mtext><mo stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mtext>gradients&nbsp;FP32</mtext></munder><mo>=</mo><mn>12</mn><mtext>&nbsp;bytes&nbsp;(+8&nbsp;bytes&nbsp;for&nbsp;Adam)</mtext></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{align*}
\underbrace{(4+2\ \text{bytes})}_{\text{{model FP32+FP16}}} + 
\underbrace{(2\ \text{bytes})}_{\text{{activation FP16}}} + \underbrace{(4\ \text{bytes})}_{\text{{gradients FP32}}} = 12\ \text{bytes (+8 bytes for Adam)}\end{align*}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:2.8602em;vertical-align:-1.1801em"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6801em"><span style="top:-3.8401em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span style="top:-1.4159em"><span class="pstrut" style="height:3em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight"><span class="mord mtight">model&nbsp;FP32+FP16</span></span></span></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span class="svg-align" style="top:-2.102em"><span class="pstrut" style="height:3em"></span><span class="stretchy" style="height:0.548em;min-width:1.6em"><span class="brace-left" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
 35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
 0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"></path></svg></span><span class="brace-center" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
 53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
 11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"></path></svg></span><span class="brace-right" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
 28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"></path></svg></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mopen">(</span><span class="mord">4</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord">2</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">bytes</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.898em"><span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.6424em"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span style="top:-1.4237em"><span class="pstrut" style="height:3em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight"><span class="mord mtight">activation&nbsp;FP16</span></span></span></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span class="svg-align" style="top:-2.102em"><span class="pstrut" style="height:3em"></span><span class="stretchy" style="height:0.548em;min-width:1.6em"><span class="brace-left" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
 35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
 0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"></path></svg></span><span class="brace-center" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
 53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
 11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"></path></svg></span><span class="brace-right" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
 28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"></path></svg></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mopen">(</span><span class="mord">2</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">bytes</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.898em"><span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.5763em"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span style="top:-1.4159em"><span class="pstrut" style="height:3em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight"><span class="mord mtight">gradients&nbsp;FP32</span></span></span></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span class="svg-align" style="top:-2.102em"><span class="pstrut" style="height:3em"></span><span class="stretchy" style="height:0.548em;min-width:1.6em"><span class="brace-left" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
 35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
 0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"></path></svg></span><span class="brace-center" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
 53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
 11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"></path></svg></span><span class="brace-right" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
 28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"></path></svg></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mopen">(</span><span class="mord">4</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">bytes</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.898em"><span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.7202em"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord">12</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">bytes&nbsp;(+8&nbsp;bytes&nbsp;for&nbsp;Adam)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.1801em"><span></span></span></span></span></span></span></span></span></span></span></span>
<p>*memory savings will be only visible if your batch size is sufficiently large and outweight additional cost of storing copy of weights in 16-bits</p>
</li>
<li>
<p>turn on <strong>TF32</strong> (Tensor Float) (speed) - not all operations are supported in 16-bit mixed-precision and have to be done in FP32. Turning on TF32 replaces FP32 in computation (storage still in FP32) at speeds similar to FP16. TF32 is supported on Ampere arch and newer. Fun fact is that TF32 is 19-bit format but has "32" in name.</p>
<div class="language-python codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#393A34;--prism-background-color:#f6f8fa"><div class="codeBlockContent_biex"><pre tabindex="0" class="prism-code language-python codeBlock_bY9V thin-scrollbar" style="color:#393A34;background-color:#f6f8fa"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#393A34"><span class="token comment" style="color:#999988;font-style:italic"># The flag below controls whether to allow TF32 on matmul.</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token comment" style="color:#999988;font-style:italic"># This flag defaults to False in PyTorch 1.12 and later.</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">backends</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">cuda</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">matmul</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">allow_tf32 </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> </span><span class="token boolean" style="color:#36acaa">True</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain" style="display:inline-block"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token comment" style="color:#999988;font-style:italic"># The flag below controls whether to allow TF32 on cuDNN. </span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token comment" style="color:#999988;font-style:italic"># This flag defaults to True.</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">backends</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">cudnn</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">allow_tf32 </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> </span><span class="token boolean" style="color:#36acaa">True</span><br></span></code></pre><div class="buttonGroup__atx"><button type="button" aria-label="Copy code to clipboard" title="Copy" class="clean-btn"><span class="copyButtonIcons_eSgA" aria-hidden="true"><svg viewBox="0 0 24 24" class="copyButtonIcon_y97N"><path fill="currentColor" d="M19,21H8V7H19M19,5H8A2,2 0 0,0 6,7V21A2,2 0 0,0 8,23H19A2,2 0 0,0 21,21V7A2,2 0 0,0 19,5M16,1H4A2,2 0 0,0 2,3V17H4V3H16V1Z"></path></svg><svg viewBox="0 0 24 24" class="copyButtonSuccessIcon_LjdS"><path fill="currentColor" d="M21,7L9,19L3.5,13.5L4.91,12.09L9,16.17L19.59,5.59L21,7Z"></path></svg></span></button></div></div></div>
</li>
<li>
<p><strong>FP32+FP8</strong> (speed &amp; memory) - there is possibility to do computations in 8-bit with FP32 accumulation but its more complicated to setup than AMP natively supported in PyTorch. Use <a href="https://huggingface.co/docs/accelerate/usage_guides/low_precision_training" target="_blank" rel="noopener noreferrer">HF Accelerate</a> library to do this (from what I understand its a wrapper on 3 other packages [<code>TransformersEngine</code>, <code>MS-AMP</code>, and <code>torchao</code>] but not only this).</p>
<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right" columnspacing=""><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><munder><munder><mrow><mo stretchy="false">(</mo><mn>1</mn><mtext>&nbsp;byte</mtext><mo stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mtext>model&nbsp;FP8&nbsp;E4M3</mtext></munder><mo>+</mo><munder><munder><mrow><mo stretchy="false">(</mo><mn>1</mn><mtext>&nbsp;byte</mtext><mo stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mtext>activation&nbsp;FP8&nbsp;E4M3</mtext></munder><mo>+</mo><munder><munder><mrow><mo stretchy="false">(</mo><mn>1</mn><mtext>&nbsp;byte</mtext><mo stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mtext>gradients&nbsp;FP8&nbsp;E5M2</mtext></munder><mo>=</mo><mn>3</mn><mtext>&nbsp;bytes&nbsp;(+8&nbsp;bytes&nbsp;for&nbsp;Adam)</mtext></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{align*}
\underbrace{(1\ \text{byte})}_{\text{{model FP8 E4M3}}} + 
\underbrace{(1\ \text{byte})}_{\text{{activation FP8 E4M3}}} + 
\underbrace{(1\ \text{byte})}_{\text{{gradients FP8 E5M2}}} = 3\ \text{bytes (+8 bytes for Adam)}\end{align*}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:2.8602em;vertical-align:-1.1801em"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6801em"><span style="top:-3.8401em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span style="top:-1.4159em"><span class="pstrut" style="height:3em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight"><span class="mord mtight">model&nbsp;FP8&nbsp;E4M3</span></span></span></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span class="svg-align" style="top:-2.102em"><span class="pstrut" style="height:3em"></span><span class="stretchy" style="height:0.548em;min-width:1.6em"><span class="brace-left" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
 35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
 0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"></path></svg></span><span class="brace-center" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
 53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
 11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"></path></svg></span><span class="brace-right" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
 28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"></path></svg></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mopen">(</span><span class="mord">1</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">byte</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.898em"><span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.5841em"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span style="top:-1.4237em"><span class="pstrut" style="height:3em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight"><span class="mord mtight">activation&nbsp;FP8&nbsp;E4M3</span></span></span></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span class="svg-align" style="top:-2.102em"><span class="pstrut" style="height:3em"></span><span class="stretchy" style="height:0.548em;min-width:1.6em"><span class="brace-left" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
 35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
 0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"></path></svg></span><span class="brace-center" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
 53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
 11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"></path></svg></span><span class="brace-right" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
 28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"></path></svg></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mopen">(</span><span class="mord">1</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">byte</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.898em"><span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.5763em"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span style="top:-1.4159em"><span class="pstrut" style="height:3em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight"><span class="mord mtight">gradients&nbsp;FP8&nbsp;E5M2</span></span></span></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord munder"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.75em"><span class="svg-align" style="top:-2.102em"><span class="pstrut" style="height:3em"></span><span class="stretchy" style="height:0.548em;min-width:1.6em"><span class="brace-left" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMinYMin slice"><path d="M0 6l6-6h17c12.688 0 19.313.3 20 1 4 4 7.313 8.3 10 13
 35.313 51.3 80.813 93.8 136.5 127.5 55.688 33.7 117.188 55.8 184.5 66.5.688
 0 2 .3 4 1 18.688 2.7 76 4.3 172 5h399450v120H429l-6-1c-124.688-8-235-61.7
-331-161C60.687 138.7 32.312 99.3 7 54L0 41V6z"></path></svg></span><span class="brace-center" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMidYMin slice"><path d="M199572 214
c100.7 8.3 195.3 44 280 108 55.3 42 101.7 93 139 153l9 14c2.7-4 5.7-8.7 9-14
 53.3-86.7 123.7-153 211-199 66.7-36 137.3-56.3 212-62h199568v120H200432c-178.3
 11.7-311.7 78.3-403 201-6 8-9.7 12-11 12-.7.7-6.7 1-18 1s-17.3-.3-18-1c-1.3 0
-5-4-11-12-44.7-59.3-101.3-106.3-170-141s-145.3-54.3-229-60H0V214z"></path></svg></span><span class="brace-right" style="height:0.548em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="0.548em" viewBox="0 0 400000 548" preserveAspectRatio="xMaxYMin slice"><path d="M399994 0l6 6v35l-6 11c-56 104-135.3 181.3-238 232-57.3
 28.7-117 45-179 50H-300V214h399897c43.3-7 81-15 113-26 100.7-33 179.7-91 237
-174 2.7-5 6-9 10-13 .7-1 7.3-1 20-1h17z"></path></svg></span></span></span><span style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mopen">(</span><span class="mord">1</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">byte</span></span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.898em"><span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.7202em"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord">3</span><span class="mspace">&nbsp;</span><span class="mord text"><span class="mord">bytes&nbsp;(+8&nbsp;bytes&nbsp;for&nbsp;Adam)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.1801em"><span></span></span></span></span></span></span></span></span></span></span></span>
</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="computational-efficiency">Computational Efficiency<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#computational-efficiency" class="hash-link" aria-label="Direct link to Computational Efficiency" title="Direct link to Computational Efficiency">​</a></h2>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="forwardbackward-pass-optimization">Forward/Backward Pass Optimization<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#forwardbackward-pass-optimization" class="hash-link" aria-label="Direct link to Forward/Backward Pass Optimization" title="Direct link to Forward/Backward Pass Optimization">​</a></h3>
<ul>
<li>set gradients to <code>None</code> instead of default <code>0</code> (but this can cause some unexpected behaviors - <code>None</code> isn't a number so operations with it produces <code>NaN</code>)</li>
<li>non-global Cross-Entropy calculation reduces memory usage spike at the end (especially beneficial for LLMs). <a href="https://arxiv.org/abs/2411.09009" target="_blank" rel="noopener noreferrer">paper</a></li>
</ul>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="hardware-utilization">Hardware Utilization<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#hardware-utilization" class="hash-link" aria-label="Direct link to Hardware Utilization" title="Direct link to Hardware Utilization">​</a></h3>
<ul>
<li>avoid moving tensors to another device <code>.to(device)</code>, create tensors directly on target device instead. If you don't have any synchronization later in code then you can use <code>.to(non_blocking=True)</code></li>
<li>use <code>torch.compile()</code> if it works, for me it usualy don't.</li>
<li><code>torch.backends.cudnn.benchmark = True</code></li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="training-dynamics">Training Dynamics<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#training-dynamics" class="hash-link" aria-label="Direct link to Training Dynamics" title="Direct link to Training Dynamics">​</a></h2>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="initialization--learning-rate">Initialization &amp; Learning Rate<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#initialization--learning-rate" class="hash-link" aria-label="Direct link to Initialization &amp; Learning Rate" title="Direct link to Initialization &amp; Learning Rate">​</a></h3>
<ul>
<li>init weights properly (relaying on default pytorch init isn't always optimal)</li>
<li>LR scheduler (OneCycleLR, lr warmup, etc.) and lr search. Similar for BS (batch size warmup etc.)</li>
<li>some rule regarding relation of LR &amp; BS - <a href="https://x.com/cloneofsimo/status/1907731069878825400" target="_blank" rel="noopener noreferrer"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mi>R</mi><mo>≈</mo><msqrt><mrow><mi>B</mi><mi>S</mi></mrow></msqrt></mrow><annotation encoding="application/x-tex">LR \approx \sqrt{BS}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">L</span><span class="mord mathnormal" style="margin-right:0.00773em">R</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">≈</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1.04em;vertical-align:-0.1133em"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9267em"><span class="svg-align" style="top:-3em"><span class="pstrut" style="height:3em"></span><span class="mord" style="padding-left:0.833em"><span class="mord mathnormal" style="margin-right:0.05764em">BS</span></span></span><span style="top:-2.8867em"><span class="pstrut" style="height:3em"></span><span class="hide-tail" style="min-width:0.853em;height:1.08em"><svg xmlns="http://www.w3.org/2000/svg" width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"></path></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1133em"><span></span></span></span></span></span></span></span></span></a> (there is also issue of "<a href="https://x.com/SeunghyunSEO7/status/1877188952920125836" target="_blank" rel="noopener noreferrer">critical batch size</a>")</li>
</ul>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="batch-processing">Batch Processing<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#batch-processing" class="hash-link" aria-label="Direct link to Batch Processing" title="Direct link to Batch Processing">​</a></h3>
<ul>
<li>max out batch size to fill whole VRAM (<a href="https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.BatchSizeFinder.html" target="_blank" rel="noopener noreferrer">batch size finder</a>)</li>
<li>if you can't fit enough batch size for stable training (e.g. bsz=1) then use gradient accumulation. In pure pytorch it boils down to calling opt.step() less frequently what results in effectively higher batch size.</li>
</ul>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="stability-techniques">Stability Techniques<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#stability-techniques" class="hash-link" aria-label="Direct link to Stability Techniques" title="Direct link to Stability Techniques">​</a></h3>
<ul>
<li>gradient clipping - prevents exploding gradients by capping their magnitude during backpropagation.</li>
<li>weights clipping - regularization technique, similar to weight decay in some sense but this isn't connected with loss calc.</li>
<li>start from pretrained model (transfer learning) and swap last layer.<!-- -->
<ul>
<li>Freeze whole model except last layer and after few epochs gradually unfreeze rest of the layers (ULMFiT, gradual unfreezing).</li>
</ul>
</li>
<li>constraint latent space of your model to align with some general pretrained model, stabilizes training and reduces overfitting (can be done by feature matching losses).</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="model-specific-optimizations">Model-Specific Optimizations<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#model-specific-optimizations" class="hash-link" aria-label="Direct link to Model-Specific Optimizations" title="Direct link to Model-Specific Optimizations">​</a></h2>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="llm-techniques">LLM Techniques<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#llm-techniques" class="hash-link" aria-label="Direct link to LLM Techniques" title="Direct link to LLM Techniques">​</a></h3>
<ul>
<li>SuperBPE - groups frequent word sequences into single tokens, improving efficiency and performance. Common word combinations get treated as one unit by the tokenizer, which reduces the number of easy-to-predict sequences. This creates a more balanced prediction difficulty across tokens, allowing the model to distribute computational effort more effectively. <a href="https://x.com/alisawuffles/status/1903125390618661068" target="_blank" rel="noopener noreferrer">Author explaination</a></li>
</ul>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="diffusion-model-techniques">Diffusion Model Techniques<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#diffusion-model-techniques" class="hash-link" aria-label="Direct link to Diffusion Model Techniques" title="Direct link to Diffusion Model Techniques">​</a></h3>
<ul>
<li>latent diffusion - diffuse in latent space, not pixel space. VAE encoder <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>→</mo></mrow><annotation encoding="application/x-tex">\rightarrow</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.3669em"></span><span class="mrel">→</span></span></span></span> latent (diffuse <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>N</mi><mo>×</mo></mrow><annotation encoding="application/x-tex">N\times</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7667em;vertical-align:-0.0833em"></span><span class="mord mathnormal" style="margin-right:0.10903em">N</span><span class="mord">×</span></span></span></span>) <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>→</mo></mrow><annotation encoding="application/x-tex">\rightarrow</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.3669em"></span><span class="mrel">→</span></span></span></span> VAE decoder. SD1.5 VAE maybe big but you can use <a href="https://huggingface.co/madebyollin/taesd" target="_blank" rel="noopener noreferrer">TAESD</a> (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mn>3</mn><mo>⋅</mo><mn>512</mn><mo>⋅</mo><mn>512</mn><mo stretchy="false">)</mo><mi mathvariant="normal">/</mi><mo stretchy="false">(</mo><mn>16</mn><mo>⋅</mo><mn>64</mn><mo>⋅</mo><mn>64</mn><mo stretchy="false">)</mo><mo>=</mo><mn>12</mn><mo>×</mo></mrow><annotation encoding="application/x-tex">(3\cdot512\cdot512)/(16\cdot64\cdot64)=12\times</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.6444em"></span><span class="mord">512</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord">512</span><span class="mclose">)</span><span class="mord">/</span><span class="mopen">(</span><span class="mord">16</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.6444em"></span><span class="mord">64</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord">64</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em"></span><span class="mord">12</span><span class="mord">×</span></span></span></span> compression, 5MB for enc/dec each and minimal computational overhead)<!-- -->
<ul>
<li>if you want train on ImageNet: <a href="https://huggingface.co/datasets/fal/cosmos-imagenet" target="_blank" rel="noopener noreferrer">https://huggingface.co/datasets/fal/cosmos-imagenet</a> (compressed to 2.45GB)</li>
</ul>
</li>
<li><a href="https://arxiv.org/abs/2303.09556" target="_blank" rel="noopener noreferrer">Min-SNR</a> - method of adding a weightning to the loss based on the SNR (signal to noise ratio) of the timestep. It prevents conflicting gradients from different deniosing phases (beggining, mid and final refinements)</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="implementation-tricks">Implementation Tricks<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#implementation-tricks" class="hash-link" aria-label="Direct link to Implementation Tricks" title="Direct link to Implementation Tricks">​</a></h2>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="pytorch-data-handling">PyTorch Data Handling<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#pytorch-data-handling" class="hash-link" aria-label="Direct link to PyTorch Data Handling" title="Direct link to PyTorch Data Handling">​</a></h3>
<ul>
<li>use <code>.as_tensor()</code> rather than <code>.tensor()</code>. <code>torch.tensor()</code> always copies data. If you have a numpy array that you want to convert, use <code>torch.as_tensor()</code> or <code>torch.from_numpy()</code> to avoid copying the data.</li>
<li>try "channel last" format for tensors and model (NCHW =&gt; NHWC), sometimes it's faster. <a href="https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html" target="_blank" rel="noopener noreferrer">link</a></li>
<li>slow dataloader optimizations <a href="https://x.com/cloneofsimo/status/1855608988681080910" target="_blank" rel="noopener noreferrer">Simo tweet</a>, <a href="https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19" target="_blank" rel="noopener noreferrer">PyTorch forum</a></li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="other">Other<a href="https://kiwinicki.github.io/2025/04/04/training-tricks#other" class="hash-link" aria-label="Direct link to Other" title="Direct link to Other">​</a></h2>
<ul>
<li>normalization layers (stabilize and speeding up training - you can use higher LR)</li>
<li>turn off bias before BatchNorm (bn already does shift)</li>
<li>MoE - model architecture which selectively activates only part of the model. memory-computation tradeoff. MoE faster achives same loss under the same computational budged compared to dense models.</li>
<li>MoD (Mixture of Depths) - learned skip connection for each transformer block, model learns to not waste compute on easy tokens.</li>
<li>Stochastic Depth - each layer in <strong>deep</strong> ConvNet have probability of not being dropped from 1.0 (for first layer) to 0.5 (for last layer). Simply dropout whole layers (prevents vanishing gradients, faster training,  better performance)</li>
<li>checkpoint averaging - weighted average of previous checkpoints makes loss landscape more smooth and convex which speeds up training + reduces overfitting (applies to pretraining &amp; finetuning)</li>
</ul>]]></content>
    </entry>
    <entry>
        <title type="html"><![CDATA[GANs]]></title>
        <id>https://kiwinicki.github.io/2025/03/04/gans</id>
        <link href="https://kiwinicki.github.io/2025/03/04/gans"/>
        <updated>2025-03-04T00:00:00.000Z</updated>
        <summary type="html"><![CDATA[Notes about GANs from vanilla GAN to CycleGAN and StyleGAN.]]></summary>
        <content type="html"><![CDATA[<p>Notes about GANs from vanilla GAN to CycleGAN and StyleGAN.</p>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="vanilla-gan">(Vanilla) GAN<a href="https://kiwinicki.github.io/2025/03/04/gans#vanilla-gan" class="hash-link" aria-label="Direct link to (Vanilla) GAN" title="Direct link to (Vanilla) GAN">​</a></h2>
<p>In short:</p>
<ul>
<li>Generator <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>G</mi></mrow><annotation encoding="application/x-tex">G</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">G</span></span></span></span> - tries to generate realistic looking images (maximize <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>D</mi><mo stretchy="false">(</mo><mi>G</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">D(G(z))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="mopen">(</span><span class="mord mathnormal">G</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="mclose">))</span></span></span></span>).</li>
<li>Discriminator <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>D</mi></mrow><annotation encoding="application/x-tex">D</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span></span></span> - tries to distinguish between real and generated images (minimize <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>D</mi><mo stretchy="false">(</mo><mi>G</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">D(G(z))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="mopen">(</span><span class="mord mathnormal">G</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="mclose">))</span></span></span></span> and maximize <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>D</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">D(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span>). Discriminator makes for learned loss function for generator.</li>
<li>The original GAN uses binary cross-entropy loss.</li>
<li>Discriminator has softmax at the end and outputs values between <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">[</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">[0, 1]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">[</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em"></span><span class="mord">1</span><span class="mclose">]</span></span></span></span> (0 - fake, 1 - real).</li>
<li>Best for Generator is for the Discriminator to output ~0.5, which means that D can't tell if an image is real or fake (Nash equilibrium).</li>
<li>Loss isn't interpretable because both the generator and discriminator can improve in the same time, so even if G's image quality improves, the loss can stay the same.</li>
</ul>
<div class="theme-admonition theme-admonition-note admonition_xJq3 alert alert--secondary"><div class="admonitionHeading_Gvgb"><span class="admonitionIcon_Rf37"><svg viewBox="0 0 14 16"><path fill-rule="evenodd" d="M6.3 5.69a.942.942 0 0 1-.28-.7c0-.28.09-.52.28-.7.19-.18.42-.28.7-.28.28 0 .52.09.7.28.18.19.28.42.28.7 0 .28-.09.52-.28.7a1 1 0 0 1-.7.3c-.28 0-.52-.11-.7-.3zM8 7.99c-.02-.25-.11-.48-.31-.69-.2-.19-.42-.3-.69-.31H6c-.27.02-.48.13-.69.31-.2.2-.3.44-.31.69h1v3c.02.27.11.5.31.69.2.2.42.31.69.31h1c.27 0 .48-.11.69-.31.2-.19.3-.42.31-.69H8V7.98v.01zM7 2.3c-3.14 0-5.7 2.54-5.7 5.68 0 3.14 2.56 5.7 5.7 5.7s5.7-2.55 5.7-5.7c0-3.15-2.56-5.69-5.7-5.69v.01zM7 .98c3.86 0 7 3.14 7 7s-3.14 7-7 7-7-3.12-7-7 3.14-7 7-7z"></path></svg></span>note</div><div class="admonitionContent_BuS1"><p>The original GAN from 2014 used MLP layers instead of Conv layers.</p></div></div>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="dcgan">DCGAN<a href="https://kiwinicki.github.io/2025/03/04/gans#dcgan" class="hash-link" aria-label="Direct link to DCGAN" title="Direct link to DCGAN">​</a></h2>
<p>The simplest GAN that doesn't use FFN layers but Conv layers, takes a noise vector and gradually increases it through TransposeConv2d.</p>
<ul>
<li>Use conv layers instead of FFN layers.</li>
<li>Use conv with a larger stride instead of pooling.</li>
<li>Use batchnorm.</li>
<li>ReLU in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>G</mi></mrow><annotation encoding="application/x-tex">G</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal">G</span></span></span></span> and LeakyReLU in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>D</mi></mrow><annotation encoding="application/x-tex">D</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span></span></span>.</li>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>β</mi><mn>1</mn></msub><mo>=</mo><mn>0.5</mn></mrow><annotation encoding="application/x-tex">\beta_1=0.5</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05278em">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em"><span style="top:-2.55em;margin-left:-0.0528em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.6444em"></span><span class="mord">0.5</span></span></span></span> (PyTorch default is <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>0.9</mn></mrow><annotation encoding="application/x-tex">0.9</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6444em"></span><span class="mord">0.9</span></span></span></span>)</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="cgan---conditional-gan">cGAN - Conditional GAN<a href="https://kiwinicki.github.io/2025/03/04/gans#cgan---conditional-gan" class="hash-link" aria-label="Direct link to cGAN - Conditional GAN" title="Direct link to cGAN - Conditional GAN">​</a></h2>
<p>Refers to the general idea of adding class information to the GAN.  In the context of the Generator, this usually involves adding labels to the noise (concatenating two vectors). The Discriminator checks if the image is real or fake, having additional information about the class. (The Discriminator returns a scalar for each image).</p>
<p><img decoding="async" loading="lazy" src="https://www.researchgate.net/publication/378656489/figure/fig4/AS:11431281277856805@1726279533229/Conditional-GAN-CGAN-architecture.png" alt="image.png" class="img_ev3q"></p>
<div class="language-python codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#393A34;--prism-background-color:#f6f8fa"><div class="codeBlockContent_biex"><pre tabindex="0" class="prism-code language-python codeBlock_bY9V thin-scrollbar" style="color:#393A34;background-color:#f6f8fa"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#393A34"><span class="token comment" style="color:#999988;font-style:italic"># Adding conditioning in the Generator</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token keyword" style="color:#00009f">def</span><span class="token plain"> </span><span class="token function" style="color:#d73a49">forward</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">x</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> y</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">:</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	cond </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">embedding</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">y</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	x </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">cat</span><span class="token punctuation" style="color:#393A34">(</span><span class="token punctuation" style="color:#393A34">[</span><span class="token plain">x</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> cond</span><span class="token punctuation" style="color:#393A34">]</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> dim</span><span class="token operator" style="color:#393A34">=</span><span class="token number" style="color:#36acaa">1</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"> </span><span class="token comment" style="color:#999988;font-style:italic"># x.shape = [bs, z_dim], cond.shape = [bs, embed_dim]</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	</span><span class="token keyword" style="color:#00009f">return</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">backbone</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">x</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain" style="display:inline-block"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token comment" style="color:#999988;font-style:italic"># Adding conditioning in the Discriminator</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token keyword" style="color:#00009f">def</span><span class="token plain"> </span><span class="token function" style="color:#d73a49">forward</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">x</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> y</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">:</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	out </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">net</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">x</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  cond </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">embed</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">y</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  </span><span class="token keyword" style="color:#00009f">return</span><span class="token plain"> </span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">x</span><span class="token operator" style="color:#393A34">*</span><span class="token plain">y</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">.</span><span class="token builtin">sum</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">dim</span><span class="token operator" style="color:#393A34">=</span><span class="token punctuation" style="color:#393A34">[</span><span class="token number" style="color:#36acaa">1</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> </span><span class="token number" style="color:#36acaa">2</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> </span><span class="token number" style="color:#36acaa">3</span><span class="token punctuation" style="color:#393A34">]</span><span class="token punctuation" style="color:#393A34">)</span><br></span></code></pre><div class="buttonGroup__atx"><button type="button" aria-label="Copy code to clipboard" title="Copy" class="clean-btn"><span class="copyButtonIcons_eSgA" aria-hidden="true"><svg viewBox="0 0 24 24" class="copyButtonIcon_y97N"><path fill="currentColor" d="M19,21H8V7H19M19,5H8A2,2 0 0,0 6,7V21A2,2 0 0,0 8,23H19A2,2 0 0,0 21,21V7A2,2 0 0,0 19,5M16,1H4A2,2 0 0,0 2,3V17H4V3H16V1Z"></path></svg><svg viewBox="0 0 24 24" class="copyButtonSuccessIcon_LjdS"><path fill="currentColor" d="M21,7L9,19L3.5,13.5L4.91,12.09L9,16.17L19.59,5.59L21,7Z"></path></svg></span></button></div></div></div>
<div class="theme-admonition theme-admonition-note admonition_xJq3 alert alert--secondary"><div class="admonitionHeading_Gvgb"><span class="admonitionIcon_Rf37"><svg viewBox="0 0 14 16"><path fill-rule="evenodd" d="M6.3 5.69a.942.942 0 0 1-.28-.7c0-.28.09-.52.28-.7.19-.18.42-.28.7-.28.28 0 .52.09.7.28.18.19.28.42.28.7 0 .28-.09.52-.28.7a1 1 0 0 1-.7.3c-.28 0-.52-.11-.7-.3zM8 7.99c-.02-.25-.11-.48-.31-.69-.2-.19-.42-.3-.69-.31H6c-.27.02-.48.13-.69.31-.2.2-.3.44-.31.69h1v3c.02.27.11.5.31.69.2.2.42.31.69.31h1c.27 0 .48-.11.69-.31.2-.19.3-.42.31-.69H8V7.98v.01zM7 2.3c-3.14 0-5.7 2.54-5.7 5.68 0 3.14 2.56 5.7 5.7 5.7s5.7-2.55 5.7-5.7c0-3.15-2.56-5.69-5.7-5.69v.01zM7 .98c3.86 0 7 3.14 7 7s-3.14 7-7 7-7-3.12-7-7 3.14-7 7-7z"></path></svg></span>note</div><div class="admonitionContent_BuS1"><p>There is also <a href="https://paperswithcode.com/method/projection-discriminator" target="_blank" rel="noopener noreferrer"><strong>Projection Discriminator</strong></a> as a better conditioning method.</p></div></div>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="acgan-auxiliary-classifier-gan">acGAN (Auxiliary Classifier GAN)<a href="https://kiwinicki.github.io/2025/03/04/gans#acgan-auxiliary-classifier-gan" class="hash-link" aria-label="Direct link to acGAN (Auxiliary Classifier GAN)" title="Direct link to acGAN (Auxiliary Classifier GAN)">​</a></h2>
<p>In acGAN, the discriminator has two outputs: one for real/fake (as in cGAN) and one for classifying the class. The Discriminator doesn't receive classes as input because it has to predict them.</p>
<div class="language-python codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#393A34;--prism-background-color:#f6f8fa"><div class="codeBlockContent_biex"><pre tabindex="0" class="prism-code language-python codeBlock_bY9V thin-scrollbar" style="color:#393A34;background-color:#f6f8fa"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#393A34"><span class="token comment" style="color:#999988;font-style:italic"># Discriminator (Generator remains the same as in cGAN)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token keyword" style="color:#00009f">def</span><span class="token plain"> </span><span class="token function" style="color:#d73a49">forward</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">self</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> img</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">:</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  features </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">model</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">img</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  validity </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">real_fake_head</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">features</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  class_logits </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> self</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">class_head</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">features</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  </span><span class="token keyword" style="color:#00009f">return</span><span class="token plain"> validity</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> class_logits</span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  </span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">loss_fn_bce </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> nn</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">BCELoss</span><span class="token punctuation" style="color:#393A34">(</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain">  </span><span class="token comment" style="color:#999988;font-style:italic"># For real/fake</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">loss_fn_ce </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> nn</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">CrossEntropyLoss</span><span class="token punctuation" style="color:#393A34">(</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain">  </span><span class="token comment" style="color:#999988;font-style:italic"># For classification</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain" style="display:inline-block"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain"></span><span class="token keyword" style="color:#00009f">for</span><span class="token plain"> i </span><span class="token keyword" style="color:#00009f">in</span><span class="token plain"> </span><span class="token builtin">range</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">epochs</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">:</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	</span><span class="token comment" style="color:#999988;font-style:italic"># ...</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	</span><span class="token comment" style="color:#999988;font-style:italic"># Discriminator</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	d_loss_real </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> loss_fn_bce</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">real_validity</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">ones_like</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">real_validity</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"> </span><span class="token operator" style="color:#393A34">+</span><span class="token plain"> \\</span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	              loss_fn_ce</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">real_class_logits</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> real_labels</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	d_loss_fake </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> loss_fn_bce</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">fake_validity</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">zeros_like</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">fake_validity</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"> </span><span class="token operator" style="color:#393A34">+</span><span class="token plain"> \\</span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	              loss_fn_ce</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">fake_class_logits</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> labels</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	d_loss </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> d_loss_real </span><span class="token operator" style="color:#393A34">+</span><span class="token plain"> d_loss_fake</span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	</span><span class="token comment" style="color:#999988;font-style:italic"># ...</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	</span><span class="token comment" style="color:#999988;font-style:italic"># Generator</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	g_loss_validity </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> loss_fn_bce</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">fake_validity</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> torch</span><span class="token punctuation" style="color:#393A34">.</span><span class="token plain">ones_like</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">fake_validity</span><span class="token punctuation" style="color:#393A34">)</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	g_loss_class </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> loss_fn_ce</span><span class="token punctuation" style="color:#393A34">(</span><span class="token plain">fake_class_logits</span><span class="token punctuation" style="color:#393A34">,</span><span class="token plain"> labels</span><span class="token punctuation" style="color:#393A34">)</span><span class="token plain"></span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">	g_loss </span><span class="token operator" style="color:#393A34">=</span><span class="token plain"> g_loss_validity </span><span class="token operator" style="color:#393A34">+</span><span class="token plain"> g_loss_class</span><br></span><span class="token-line" style="color:#393A34"><span class="token plain">  </span><span class="token comment" style="color:#999988;font-style:italic"># ...</span><br></span></code></pre><div class="buttonGroup__atx"><button type="button" aria-label="Copy code to clipboard" title="Copy" class="clean-btn"><span class="copyButtonIcons_eSgA" aria-hidden="true"><svg viewBox="0 0 24 24" class="copyButtonIcon_y97N"><path fill="currentColor" d="M19,21H8V7H19M19,5H8A2,2 0 0,0 6,7V21A2,2 0 0,0 8,23H19A2,2 0 0,0 21,21V7A2,2 0 0,0 19,5M16,1H4A2,2 0 0,0 2,3V17H4V3H16V1Z"></path></svg><svg viewBox="0 0 24 24" class="copyButtonSuccessIcon_LjdS"><path fill="currentColor" d="M21,7L9,19L3.5,13.5L4.91,12.09L9,16.17L19.59,5.59L21,7Z"></path></svg></span></button></div></div></div>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="wgan--wgan-gp-wasserstein-gan-with-gradient-penalty">WGAN &amp; WGAN-GP (Wasserstein GAN with Gradient Penalty)<a href="https://kiwinicki.github.io/2025/03/04/gans#wgan--wgan-gp-wasserstein-gan-with-gradient-penalty" class="hash-link" aria-label="Direct link to WGAN &amp; WGAN-GP (Wasserstein GAN with Gradient Penalty)" title="Direct link to WGAN &amp; WGAN-GP (Wasserstein GAN with Gradient Penalty)">​</a></h2>
<p>Change of JS divergence (from Vanilla GAN) to Earth Mover Distance (aka Wasserstein distance).</p>
<ul>
<li>JS divergence works badly if two distributions have small or no overlap (mode collapse, vanishing grads).</li>
<li>Wasserstein distance measures the "cost" of transforming one distribution into another, works even if there is no overlap between distributions, and gives meaningful values (decreasing loss of the discriminator says that Generator improves in quality).</li>
</ul>
<p>In the original WGAN (without GP), <strong>weights</strong> are clipped (not the gradient) to ensure Lipschitz continuity.  The parameter <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>c</mi></mrow><annotation encoding="application/x-tex">c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal">c</span></span></span></span> decides how strong the clipping is.</p>
<p>The purpose of the gradient penalty is to keep the critic in the space of 1-Lipschitz functions, so that the critic provides high-quality gradients and the Wasserstein distance is estimated better. (GP only for the critic).</p>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="lsgan-least-squares-gan">LSGAN (Least Squares GAN)<a href="https://kiwinicki.github.io/2025/03/04/gans#lsgan-least-squares-gan" class="hash-link" aria-label="Direct link to LSGAN (Least Squares GAN)" title="Direct link to LSGAN (Least Squares GAN)">​</a></h2>
<p>Using <strong>L2 loss</strong> instead of <strong>BCE loss</strong> to calculate the loss:</p>
<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left" columnspacing="0em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mi>l</mi><mi>o</mi><mi>s</mi><msub><mi>s</mi><mrow><mi>f</mi><mi>a</mi><mi>k</mi><mi>e</mi></mrow></msub><mo>=</mo><mo stretchy="false">(</mo><mi>D</mi><mo stretchy="false">(</mo><mi>G</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo>−</mo><mn>0</mn><msup><mo stretchy="false">)</mo><mn>2</mn></msup></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mi>l</mi><mi>o</mi><mi>s</mi><msub><mi>s</mi><mrow><mi>r</mi><mi>e</mi><mi>a</mi><mi>l</mi></mrow></msub><mo>=</mo><mo stretchy="false">(</mo><mi>D</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mn>1</mn><msup><mo stretchy="false">)</mo><mn>2</mn></msup></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{align*} &amp; loss_{fake} = (D(G(z)) - 0)^2 \\ &amp; loss_{real} = (D(x) - 1)^2 \end{align*}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:3.0482em;vertical-align:-1.2741em"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.7741em"><span style="top:-3.7741em"><span class="pstrut" style="height:2.8641em"></span><span class="mord"></span></span><span style="top:-2.25em"><span class="pstrut" style="height:2.8641em"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2741em"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.7741em"><span style="top:-3.91em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="mord mathnormal">os</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10764em">f</span><span class="mord mathnormal mtight" style="margin-right:0.03148em">ak</span><span class="mord mathnormal mtight">e</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="mopen">(</span><span class="mord mathnormal">G</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="mclose">))</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord">0</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span><span style="top:-2.3859em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord mathnormal" style="margin-right:0.01968em">l</span><span class="mord mathnormal">os</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">re</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord">1</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2741em"><span></span></span></span></span></span></span></span></span></span></span></span>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="progressive-growing-gan">Progressive Growing GAN<a href="https://kiwinicki.github.io/2025/03/04/gans#progressive-growing-gan" class="hash-link" aria-label="Direct link to Progressive Growing GAN" title="Direct link to Progressive Growing GAN">​</a></h2>
<p>The Generator is gradually trained, first at a low resolution, and a new layer that increases the resolution is added periodically. Advantages:</p>
<ul>
<li>Training stability: Starting with small images causes the model to learn the most important features of the image (shapes, colors) at the beginning, rather than textures and background details. This reduces the risk of training derailment.</li>
<li>Training speed: Training on small images is much cheaper, and the model quickly adapts to a higher resolution.</li>
<li>Better image quality.</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="pix2pix">Pix2Pix<a href="https://kiwinicki.github.io/2025/03/04/gans#pix2pix" class="hash-link" aria-label="Direct link to Pix2Pix" title="Direct link to Pix2Pix">​</a></h2>
<p>An architecture created for img2img translation (e.g., sketch → image). The most important features are the use of <strong>adversarial loss + L1 loss</strong> (they didn't use L2 because of lower outlier resistance) and the introduction of <strong>PatchGAN</strong> (image discrimination with tiles instead of a scalar for the entire image - more detailed/local feedback for G).</p>
<p><strong>Loss</strong></p>
<ul>
<li>L1 for adhering to the outline but blurred (low-freq structure).</li>
<li>Adv. loss is responsible for sharp edges (high-freq structure).</li>
</ul>
<p><strong>Model structure and other details</strong></p>
<ul>
<li>UNet for the generator.</li>
<li>PatchGAN - discrimination with tiles.</li>
<li>Block structure: Conv + BatchNorm + ReLU.</li>
<li>Dropout as a form of result diversity (<strong>enabled</strong> during inference; no noise vector as a source of randomness and diverse outputs).</li>
<li>LR for D is 2x smaller than for G.</li>
<li>Adam with b1=0.5 and b2=0.999.</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="cyclegan">CycleGAN<a href="https://kiwinicki.github.io/2025/03/04/gans#cyclegan" class="hash-link" aria-label="Direct link to CycleGAN" title="Direct link to CycleGAN">​</a></h2>
<p>The CycleGAN architecture was made for style-transfer without needing special data for mapping <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>A</mi><mo>→</mo><mi>B</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(A \rightarrow B)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord mathnormal">A</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="mclose">)</span></span></span></span>.  To achieve this, it uses a combination of 3 losses:</p>
<ol>
<li>Cycle-consistency loss</li>
</ol>
<p>Transforms an image into something and back, and compares the difference between the original and the generated one (the circle/cycle closes).</p>
<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.25em" columnalign="right left" columnspacing="0em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mi>A</mi><mo>→</mo><msub><mi>G</mi><mn>1</mn></msub><mo stretchy="false">(</mo><mi>A</mi><mo stretchy="false">)</mo><mo>→</mo><msup><mi>B</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo>→</mo><msub><mi>G</mi><mn>2</mn></msub><mo stretchy="false">(</mo><msup><mi>B</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo stretchy="false">)</mo><mo>→</mo><msup><mi>A</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><msub><mtext>CycleLoss</mtext><mn>1</mn></msub><mo>=</mo><mi mathvariant="normal">∣</mi><mi mathvariant="normal">∣</mi><mi>A</mi><mo>−</mo><msup><mi>A</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi mathvariant="normal">∣</mi></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><mi>B</mi><mo>→</mo><msub><mi>G</mi><mn>2</mn></msub><mo stretchy="false">(</mo><mi>B</mi><mo stretchy="false">)</mo><mo>→</mo><msup><mi>A</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo>→</mo><msub><mi>G</mi><mn>1</mn></msub><mo stretchy="false">(</mo><msup><mi>A</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo stretchy="false">)</mo><mo>→</mo><msup><mi>B</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><msub><mtext>CycleLoss</mtext><mn>2</mn></msub><mo>=</mo><mi mathvariant="normal">∣</mi><mi mathvariant="normal">∣</mi><mi>B</mi><mo>−</mo><msup><mi>B</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi mathvariant="normal">∣</mi></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow></mrow></mstyle></mtd><mtd><mstyle scriptlevel="0" displaystyle="true"><mrow><mrow></mrow><msub><mtext>CycleLoss</mtext><mrow><mi>t</mi><mi>o</mi><mi>t</mi><mi>a</mi><mi>l</mi></mrow></msub><mo>=</mo><msub><mtext>CycleLoss</mtext><mn>1</mn></msub><mo>+</mo><msub><mtext>CycleLoss</mtext><mn>2</mn></msub></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{align*} &amp;A \to G_1(A) \to B' \to G_2(B') \to A' \\ &amp;\text{CycleLoss}_1 = ||A - A'|| \\ \\ &amp;B \to G_2(B) \to A' \to G_1(A') \to B' \\ &amp;\text{CycleLoss}_2 = ||B - B'|| \\ \\ &amp;\text{CycleLoss}_{total} = \text{CycleLoss}_1 + \text{CycleLoss}_2 \end{align*}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:10.5em;vertical-align:-5em"></span><span class="mord"><span class="mtable"><span class="col-align-r"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.5em"><span style="top:-7.5em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span><span style="top:-6em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span><span style="top:-4.5em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span><span style="top:-3em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span><span style="top:-1.5em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span><span style="top:0em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span><span style="top:1.5em"><span class="pstrut" style="height:2.84em"></span><span class="mord"></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:5em"><span></span></span></span></span></span><span class="col-align-l"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:5.5em"><span style="top:-7.66em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord mathnormal">A</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">A</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span><span style="top:-6.16em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord text"><span class="mord">CycleLoss</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.207em"><span style="top:-2.4559em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2441em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord">∣∣</span><span class="mord mathnormal">A</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord">∣∣</span></span></span><span style="top:-3.16em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3011em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">→</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span></span></span><span style="top:-1.66em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord text"><span class="mord">CycleLoss</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.207em"><span style="top:-2.4559em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2441em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord">∣∣</span><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.05017em">B</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8019em"><span style="top:-3.113em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord">∣∣</span></span></span><span style="top:1.34em"><span class="pstrut" style="height:3em"></span><span class="mord"><span class="mord"></span><span class="mord"><span class="mord text"><span class="mord">CycleLoss</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.242em"><span style="top:-2.4559em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mord mathnormal mtight">o</span><span class="mord mathnormal mtight">t</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2441em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mord"><span class="mord text"><span class="mord">CycleLoss</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.207em"><span style="top:-2.4559em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2441em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord"><span class="mord text"><span class="mord">CycleLoss</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.207em"><span style="top:-2.4559em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2441em"><span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:5em"><span></span></span></span></span></span></span></span></span></span></span></span>
<ol start="2">
<li>Adversarial loss</li>
</ol>
<p>The usual loss from the discriminator (each G has its own D).</p>
<ol start="3">
<li>Identity loss</li>
</ol>
<p>The Generator sometimes receives an image that it should generate, so it shouldn't change anything in it, just return it (a form of regularization).</p>
<p><strong>Features</strong></p>
<ul>
<li>2 generators and 2 discriminators.</li>
<li>Doesn't need a dataset consisting of image pairs (it can be 2 arbitrary datasets from different domains, e.g., photos and paintings).</li>
<li>Generators as <strong>autoencoders, not UNets (unlike Pix2Pix)</strong>.</li>
<li>Uses <a href="https://www.notion.so/GAN-DCGAN-WGAN-gradient-penalty-CycleGAN-14e885efc8fe806998d6f2a6666cb1d2?pvs=21" target="_blank" rel="noopener noreferrer">PatchGAN from Pix2Pix</a> in the generator.</li>
</ul>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="r3gan">R3GAN<a href="https://kiwinicki.github.io/2025/03/04/gans#r3gan" class="hash-link" aria-label="Direct link to R3GAN" title="Direct link to R3GAN">​</a></h2>
<p><strong>Summary</strong></p>
<p>Authors developed a new loss function that has mathematical guarantees of convergence and is stable. With stable training achieved, they modify the architecture (starting from StyleGAN2), remove all the now-superfluous tricks, and modernize the architecture. Ultimately, they arrive at a stable and simpler GAN architecture that rivals previous GANs and diffusion-based SOTA models.</p>
<p><strong>Details</strong></p>
<ol>
<li>
<p><strong>Re-formulated Loss:</strong> Instead of calculating losses for D and G separately, they are combined.</p>
<ul>
<li>
<p>Traditional GAN Loss:</p>
<p><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mo stretchy="false">(</mo><mi>θ</mi><mo separator="true">,</mo><mi>ψ</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mi mathvariant="double-struck">E</mi><mrow><mi>z</mi><mo>∼</mo><msub><mi>p</mi><mi>z</mi></msub></mrow></msub><mrow><mo fence="true">[</mo><mi>f</mi><mrow><mo fence="true">(</mo><msub><mi>D</mi><mi>ψ</mi></msub><mo stretchy="false">(</mo><msub><mi>G</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo fence="true">)</mo></mrow><mo fence="true">]</mo></mrow><mo>+</mo><msub><mi mathvariant="double-struck">E</mi><mrow><mi>x</mi><mo>∼</mo><msub><mi>p</mi><mi>D</mi></msub></mrow></msub><mrow><mo fence="true">[</mo><mi>f</mi><mrow><mo fence="true">(</mo><mo>−</mo><msub><mi>D</mi><mi>ψ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo fence="true">)</mo></mrow><mo fence="true">]</mo></mrow></mrow><annotation encoding="application/x-tex">L(\theta, \psi) = \mathbb{E}_{z \sim p_z} \left[ f \left( D_\psi(G_\theta(z)) \right) \right] + \mathbb{E}_{x \sim p_D} \left[ f \left( -D_\psi(x) \right) \right]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em"></span><span class="mord mathnormal" style="margin-right:0.03588em">ψ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="mrel mtight">∼</span><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1645em"><span style="top:-2.357em;margin-left:0em;margin-right:0.0714em"><span class="pstrut" style="height:2.5em"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em"></span><span class="minner"><span class="mopen delimcenter" style="top:0em">[</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mspace" style="margin-right:0.1667em"></span><span class="minner"><span class="mopen delimcenter" style="top:0em">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">ψ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="mclose">))</span><span class="mclose delimcenter" style="top:0em">)</span></span><span class="mclose delimcenter" style="top:0em">]</span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mrel mtight">∼</span><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em"><span style="top:-2.3567em;margin-left:0em;margin-right:0.0714em"><span class="pstrut" style="height:2.5em"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1433em"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em"></span><span class="minner"><span class="mopen delimcenter" style="top:0em">[</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mspace" style="margin-right:0.1667em"></span><span class="minner"><span class="mopen delimcenter" style="top:0em">(</span><span class="mord">−</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">ψ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose delimcenter" style="top:0em">)</span></span><span class="mclose delimcenter" style="top:0em">]</span></span></span></span></span>  <strong>G and D are calculated separately</strong>, causing real and fake samples to be far apart (D pushes them apart).</p>
</li>
<li>
<p>"Relativistic" GAN Loss:</p>
<p><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>L</mi><mo stretchy="false">(</mo><mi>θ</mi><mo separator="true">,</mo><mi>ψ</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mi mathvariant="double-struck">E</mi><mrow><mi>z</mi><mo>∼</mo><msub><mi>p</mi><mi>z</mi></msub><mo separator="true">,</mo><mtext>&nbsp;</mtext><mi>x</mi><mo>∼</mo><msub><mi>p</mi><mi>D</mi></msub></mrow></msub><mrow><mo fence="true">[</mo><mi>f</mi><mrow><mo fence="true">(</mo><msub><mi>D</mi><mi>ψ</mi></msub><mo stretchy="false">(</mo><msub><mi>G</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo>−</mo><msub><mi>D</mi><mi>ψ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo fence="true">)</mo></mrow><mo fence="true">]</mo></mrow></mrow><annotation encoding="application/x-tex">L(\theta, \psi) = \mathbb{E}_{z \sim p_z, \ x \sim p_D} \left[ f \left( D_\psi(G_\theta(z)) - D_\psi(x) \right) \right]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em">θ</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em"></span><span class="mord mathnormal" style="margin-right:0.03588em">ψ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span><span class="mrel mtight">∼</span><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1645em"><span style="top:-2.357em;margin-left:0em;margin-right:0.0714em"><span class="pstrut" style="height:2.5em"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.143em"><span></span></span></span></span></span></span><span class="mpunct mtight">,</span><span class="mspace mtight"><span class="mtight">&nbsp;</span></span><span class="mord mathnormal mtight">x</span><span class="mrel mtight">∼</span><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em"><span style="top:-2.3567em;margin-left:0em;margin-right:0.0714em"><span class="pstrut" style="height:2.5em"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.1433em"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em"></span><span class="minner"><span class="mopen delimcenter" style="top:0em">[</span><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="mspace" style="margin-right:0.1667em"></span><span class="minner"><span class="mopen delimcenter" style="top:0em">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">ψ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.04398em">z</span><span class="mclose">))</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em">D</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">ψ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose delimcenter" style="top:0em">)</span></span><span class="mclose delimcenter" style="top:0em">]</span></span></span></span></span>  Here, <strong>G and D are linked</strong>. D evaluates the "realness" <strong>relative</strong> to real samples, so fake samples should always be in the vicinity of real samples.</p>
</li>
</ul>
</li>
<li>
<p><strong>Addition of Zero-Centered Gradient Penalties (R1 and R2):</strong></p>
<ul>
<li>R1: Penalizes the gradient magnitudes of D on real samples.</li>
<li>R2: Penalizes the gradient magnitudes of D on fake samples.</li>
</ul>
<p>The goal is to reduce the gradients for G when it's close to generating real samples (eliminating training oscillations). When <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mrow><mi>f</mi><mi>a</mi><mi>k</mi><mi>e</mi></mrow></msub><mo>=</mo><msub><mi>p</mi><mrow><mi>r</mi><mi>e</mi><mi>a</mi><mi>l</mi></mrow></msub></mrow><annotation encoding="application/x-tex">p_{fake} = p_{real}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7167em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10764em">f</span><span class="mord mathnormal mtight" style="margin-right:0.03148em">ak</span><span class="mord mathnormal mtight">e</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">re</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight" style="margin-right:0.01968em">l</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>, the gradients from D should be 0.</p>
</li>
<li>
<p><strong>Simplification and Modernization of the Architecture:</strong> They start with StyleGAN2 and remove everything that was added to increase stability and convergence. Then, they change the architecture to ResNeXt, but without normalization, because they use <a href="https://arxiv.org/abs/1901.09321" target="_blank" rel="noopener noreferrer">Fixup initialization</a>, which doesn't require it.</p>
</li>
</ol>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="stylegan">StyleGAN<a href="https://kiwinicki.github.io/2025/03/04/gans#stylegan" class="hash-link" aria-label="Direct link to StyleGAN" title="Direct link to StyleGAN">​</a></h2>
<p>A complex architecture that has many tricks to achieve stable training and good quality. Allows for interpolation between any two images. <a href="https://blog.paperspace.com/evolution-of-stylegan" target="_blank" rel="noopener noreferrer">Many versions of StyleGAN have been released</a>.  In a nutshell:</p>
<ol>
<li>The <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span> noise vector is transformed by the Mapping network into the vector <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>w</mi></mrow><annotation encoding="application/x-tex">w</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span> (latent code).</li>
<li>The Synthesis network starts creating an image from a learnable tensor 4x4 (Const 4x4x512).</li>
<li>The gradually upsampled feature map passes through subsequent blocks into which the latent code <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>w</mi></mrow><annotation encoding="application/x-tex">w</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em"></span><span class="mord mathnormal" style="margin-right:0.02691em">w</span></span></span></span> (through AdaIN - a type of normalization) and the noise vector <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>B</mi></mrow><annotation encoding="application/x-tex">B</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.05017em">B</span></span></span></span> are injected.</li>
</ol>
<p><img decoding="async" loading="lazy" src="https://machinelearningmastery.com/wp-content/uploads/2019/06/Summary-of-the-StyleGAN-Generator-Model-Architecture.png" alt="image.png" class="img_ev3q"></p>
<p>Injecting noise through <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>B</mi></mrow><annotation encoding="application/x-tex">B</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6833em"></span><span class="mord mathnormal" style="margin-right:0.05017em">B</span></span></span></span> adds high-frequency details to the image (left with noise, right without noise).</p>
<p><img decoding="async" loading="lazy" src="https://viso.ai/wp-content/uploads/2024/07/effect-noise-adding.jpg" alt="image.png" class="img_ev3q"></p>
<h2 class="anchor anchorWithStickyNavbar_LWe7" id="other-stuff-tips-tricks-minor-papers">Other stuff (tips, tricks, minor papers)<a href="https://kiwinicki.github.io/2025/03/04/gans#other-stuff-tips-tricks-minor-papers" class="hash-link" aria-label="Direct link to Other stuff (tips, tricks, minor papers)" title="Direct link to Other stuff (tips, tricks, minor papers)">​</a></h2>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="gated-shortcut-arxiv220111351">Gated Shortcut (arXiv:2201.11351)<a href="https://kiwinicki.github.io/2025/03/04/gans#gated-shortcut-arxiv220111351" class="hash-link" aria-label="Direct link to Gated Shortcut (arXiv:2201.11351)" title="Direct link to Gated Shortcut (arXiv:2201.11351)">​</a></h3>
<p>The Gated shortcut decides what remains from the residual stream and what is added from the features. <strong>Used only in the Generator.</strong>
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">f_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> - input
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> - feature (what's on the output of conv, bn, relu)
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>o</mi></msub></mrow><annotation encoding="application/x-tex">f_o</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">o</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> - output
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>∗</mo></mrow><annotation encoding="application/x-tex">*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4653em"></span><span class="mord">∗</span></span></span></span> - conv, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>⊙</mo></mrow><annotation encoding="application/x-tex">\odot</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em"></span><span class="mord">⊙</span></span></span></span> - concat, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>⊗</mo></mrow><annotation encoding="application/x-tex">\otimes</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6667em;vertical-align:-0.0833em"></span><span class="mord">⊗</span></span></span></span> - element-wise multiplication</p>
<ol>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>g</mi></msub><mo>=</mo><mtext>sigmoid</mtext><mo stretchy="false">(</mo><mtext>conv</mtext><mo stretchy="false">(</mo><msub><mi>f</mi><mi>c</mi></msub><mo>⊙</mo><msub><mi>f</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">f_g = \text{sigmoid}(\text{conv}(f_c \odot f_i))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord text"><span class="mord">sigmoid</span></span><span class="mopen">(</span><span class="mord text"><span class="mord">conv</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊙</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mclose">))</span></span></span></span> - <em>gate</em>, value 0-1</li>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>r</mi></msub><mo>=</mo><mtext>conv</mtext><mo stretchy="false">(</mo><msub><mi>f</mi><mi>c</mi></msub><mo>⊙</mo><msub><mi>f</mi><mi>i</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">f_r = \text{conv}(f_c \odot f_i)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord text"><span class="mord">conv</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊙</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> - <em>refinement</em>, learned transformation of combined features; what to add</li>
<li><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>o</mi></msub><mo>=</mo><mtext>conv</mtext><mo stretchy="false">[</mo><msub><mi>f</mi><mi>g</mi></msub><mo>⊗</mo><msub><mi>f</mi><mi>c</mi></msub><mo>+</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><msub><mi>f</mi><mi>g</mi></msub><mo stretchy="false">)</mo><mo>⊗</mo><msub><mi>f</mi><mi>r</mi></msub><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">f_o = \text{conv}[f_g \otimes f_c + (1-f_g) \otimes f_r]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">o</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord text"><span class="mord">conv</span></span><span class="mopen">[</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊗</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊗</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span> here:
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>g</mi></msub><mo>⊗</mo><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_g \otimes f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊗</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> means how much to take from the feature based on <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">f_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>.
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><msub><mi>f</mi><mi>g</mi></msub><mo stretchy="false">)</mo><mo>⊗</mo><msub><mi>f</mi><mi>r</mi></msub></mrow><annotation encoding="application/x-tex">(1-f_g) \otimes f_r</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊗</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> means the rest, which comes from the combination of <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">f_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>.</li>
</ol>
<p>Authors claims that it's better than other gated residuals because it can better add/remove data from the residual stream/features (<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mrow><mi>E</mi><mi>G</mi><mi>S</mi></mrow></msub></mrow><annotation encoding="application/x-tex">f_{EGS}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05764em">EGS</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> is other way to implement a gated shortcut):
<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mrow><mi>E</mi><mi>G</mi><mi>S</mi></mrow></msub><mo>=</mo><msub><mi>f</mi><mi>g</mi></msub><mo>⊗</mo><msub><mi>f</mi><mi>i</mi></msub><mo>+</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><msub><mi>f</mi><mi>g</mi></msub><mo stretchy="false">)</mo><mo>⊗</mo><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_{EGS} = f_g \otimes f_i + (1-f_g) \otimes f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05764em">EGS</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em"></span></span><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊗</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:1.0361em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em"></span><span class="mbin">⊗</span><span class="mspace" style="margin-right:0.2222em"></span></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span></p>
<blockquote>
<p>In Eq. 6, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mrow><mi>E</mi><mi>G</mi><mi>S</mi></mrow></msub></mrow><annotation encoding="application/x-tex">f_{EGS}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3283em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05764em">EGS</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> can be interpreted as the weighted summation of the <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">f_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span> and <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>, where weight value is <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>g</mi></msub></mrow><annotation encoding="application/x-tex">f_g</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span></span></span></span>. Thus, if <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>g</mi></msub></mrow><annotation encoding="application/x-tex">f_g</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9805em;vertical-align:-0.2861em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em">g</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em"><span></span></span></span></span></span></span></span></span></span> is 0.5, it is equal to the scaled-identity shortcut; it cannot effectively keep (or remove) the relevant (or irrelevant) information in <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>c</mi></msub></mrow><annotation encoding="application/x-tex">f_c</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">c</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>. In contrast, instead of directly summing <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">f_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>, the proposed method produces the refinement feature, i.e. <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>f</mi><mi>r</mi></msub></mrow><annotation encoding="application/x-tex">f_r</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10764em">f</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.1514em"><span style="top:-2.55em;margin-left:-0.1076em;margin-right:0.05em"><span class="pstrut" style="height:2.7em"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em">r</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em"><span></span></span></span></span></span></span></span></span></span>.</p>
</blockquote>
<p>This can probably be used in other architectures. The gating mechanism doesn't have to be useful only for GANs (e.g., LSTM).</p>
<h3 class="anchor anchorWithStickyNavbar_LWe7" id="checkerboard-artifacts">Checkerboard Artifacts<a href="https://kiwinicki.github.io/2025/03/04/gans#checkerboard-artifacts" class="hash-link" aria-label="Direct link to Checkerboard Artifacts" title="Direct link to Checkerboard Artifacts">​</a></h3>
<p>Transposed conv with overlapping stride can cause <a href="https://distill.pub/2016/deconv-checkerboard/" target="_blank" rel="noopener noreferrer"><strong>checkerboard artifacts</strong></a> - solution: use regular <code>Conv2d</code> + <code>Upsampling2d("nearest")</code>; then conv can overlap.</p>]]></content>
    </entry>
</feed>