- What is U-Net?
- The Three Key Parts
- The Secret Sauce: Skip Connections
- The Building Blocks
- Complete Data Flow
- What Can U-Net Do?
- When NOT to Use Decoder
- Summary
If you’ve explored image generation, segmentation, or diffusion models, you’ve probably heard of U-Net. But what exactly is it, and why is it so widely used? In this post, I’ll break down U-Net step by step with concrete examples and visual diagrams.
What is U-Net?
U-Net is a neural network architecture designed for tasks where you need an image in and an image out of the same size. It was originally created for medical image segmentation in 2015, but has since become the backbone of many modern AI systems, including Stable Diffusion.
The name comes from its shape—when you draw the architecture, it looks like the letter “U”:
Input Image
│
▼
┌─────────────────────────────────────────┐
│ ENCODER (Downsampling) │
│ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │64ch │ → │128ch│ → │256ch│ → ... │
│ │128² │ │64² │ │32² │ │
│ └──┬──┘ └──┬──┘ └──┬──┘ │
│ │ skip │ skip │ skip │
│ ▼ ▼ ▼ │
│ ┌──┴──┐ ┌──┴──┐ ┌──┴──┐ │
│ │64ch │ ← │128ch│ ← │256ch│ ← ... │
│ │128² │ │64² │ │32² │ │
│ └─────┘ └─────┘ └─────┘ │
│ DECODER (Upsampling) │
└─────────────────────────────────────────┘
│
▼
Output Image
The Three Key Parts
1. Encoder (The Down Path)
The encoder compresses the image, making it spatially smaller but with more channels:
128×128×3 → 64×64×64 → 32×32×128 → 16×16×256 → 8×8×512
│ │ │ │ │
└──────────────┴─────────────┴─────────────┴────────────┘
Shrinking spatially
Growing in channels
At each step:
- Spatial size halves (128 → 64 → 32 → 16 → 8)
- Channels increase (3 → 64 → 128 → 256 → 512)
This is like summarizing a book—you lose details but capture the main ideas.
2. Bottleneck
The bottleneck is the smallest point in the network:
┌─────────────────────────────────┐
│ 8×8×512 │
│ │
│ Only 64 spatial positions │
│ but 512 features each │
│ │
│ "Compressed understanding" │
└─────────────────────────────────┘
At this point, the network has maximum semantic understanding but minimum spatial detail. It knows “what” is in the image but has lost “where” things are precisely.
3. Decoder (The Up Path)
The decoder expands the image back to full resolution:
8×8×512 → 16×16×256 → 32×32×128 → 64×64×64 → 128×128×3
But here’s the problem: how do you recover the spatial details that were lost?
The Secret Sauce: Skip Connections
This is what makes U-Net special. Skip connections pass information directly from the encoder to the decoder, bypassing the bottleneck:
ENCODER DECODER
─────── ───────
128×128 ─────── skip1 ─────────────→ 128×128
│ ▲
64×64 ───────── skip2 ───────────→ 64×64
│ ▲
32×32 ───────── skip3 ─────────→ 32×32
│ ▲
16×16 ───────── skip4 ───────→ 16×16
│ ▲
└──→ 8×8 BOTTLENECK ──────────────────┘
Why Are Skip Connections Needed?
Think of it this way:
| Source | Knows | Problem |
|---|---|---|
| Bottleneck | “What” is in image | Lost “where” exactly |
| Skip | “Where” things are | Doesn’t know context |
| Combined | Both! | Sharp + accurate output |
Visual Example
WITHOUT skip connections: WITH skip connections:
┌────────────────────┐ ┌────────────────────┐
│ │ │ ● │
│ ◯ │ │ ╲ │
│ (blurry, │ │ ╲ │
│ wrong spot) │ │ ● (sharp, │
│ │ │ ╲ correct!) │
│ │ │ ● │
└────────────────────┘ └────────────────────┘
The bottleneck knows “there’s a line somewhere” but lost the exact position. The skip connection says “the line edge is at these exact pixels.” Combined, you get a sharp, accurate output.
The Building Blocks
ConvBlock: The Basic Unit
Every level of the U-Net uses convolutional blocks:
Input
↓
Conv 3×3 → BatchNorm → ReLU
↓
Conv 3×3 → BatchNorm → ReLU
↓
Output
A 3×3 convolution looks at a pixel and its 8 neighbors to compute each output pixel.
Understanding Conv2d
Let’s make this concrete with Conv2d(2, 3, 3) — 2 input channels, 3 output channels, 3×3 kernel.
Key insight: Each output channel has its own filter, and each filter looks at ALL input channels.
INPUT (2 channels) OUTPUT (3 channels)
┌─────────┐ ┌─────────┐
│ Ch 0 │──┬─ Filter 0 ─────→│ Ch 0 │
│ │ │ └─────────┘
└─────────┘ │
├─ Filter 1 ─────→┌─────────┐
┌─────────┐ │ │ Ch 1 │
│ Ch 1 │──┤ └─────────┘
│ │ │
└─────────┘ └─ Filter 2 ─────→┌─────────┐
│ Ch 2 │
└─────────┘
Each filter reads ALL input channels to produce ONE output channel.
Concrete Conv2d Example
Input (2 channels, 4×4 each):
Channel 0: Channel 1:
┌────┬────┬────┬────┐ ┌────┬────┬────┬────┐
│ 10 │ 10 │ 0 │ 0 │ │ 5 │ 5 │ 5 │ 5 │
├────┼────┼────┼────┤ ├────┼────┼────┼────┤
│ 10 │ 10 │ 0 │ 0 │ │ 5 │ 5 │ 5 │ 5 │
├────┼────┼────┼────┤ ├────┼────┼────┼────┤
│ 10 │ 10 │ 0 │ 0 │ │ 5 │ 5 │ 5 │ 5 │
├────┼────┼────┼────┤ ├────┼────┼────┼────┤
│ 10 │ 10 │ 0 │ 0 │ │ 5 │ 5 │ 5 │ 5 │
└────┴────┴────┴────┘ └────┴────┴────┴────┘
Filter 0 (one 3×3 kernel per input channel):
For input ch0: For input ch1:
┌────┬────┬────┐ ┌────┬────┬────┐
│ 1 │ 0 │ -1 │ │ 0 │ 0 │ 0 │
├────┼────┼────┤ ├────┼────┼────┤
│ 1 │ 0 │ -1 │ │ 0 │ 1 │ 0 │
├────┼────┼────┤ ├────┼────┼────┤
│ 1 │ 0 │ -1 │ │ 0 │ 0 │ 0 │
└────┴────┴────┘ └────┴────┴────┘
To compute output pixel at (row=1, col=1):
From ch0: 10×1 + 10×0 + 0×(-1) + 10×1 + 10×0 + 0×(-1) + 10×1 + 10×0 + 0×(-1) = 30
From ch1: 5×0 + 5×0 + 5×0 + 5×0 + 5×1 + 5×0 + 5×0 + 5×0 + 5×0 = 5
Total: 30 + 5 + bias = 35
DownBlock (Encoder Step)
def forward(self, x):
features = self.conv(x) # Process with ConvBlock
pooled = self.pool(features) # Shrink by half
return pooled, features # Return BOTH!
Input: (1, 64, 64, 64)
│
ConvBlock
│
(1, 128, 64, 64) ──→ SAVED as skip connection
│
MaxPool2d (shrink)
│
Output: (1, 128, 32, 32)
The key: it returns TWO things — the pooled result for the next layer AND the features for the skip connection.
UpBlock (Decoder Step)
def forward(self, x, skip):
x = self.up(x) # Grow spatially (ConvTranspose2d)
x = torch.cat([x, skip], dim=1) # Concatenate with skip
x = self.conv(x) # Process combined features
return x
Input: (1, 512, 8, 8) Skip: (1, 512, 16, 16)
│
ConvTranspose2d (grow 2×)
│
(1, 512, 16, 16)
│
Concat with skip (channels add)
│
(1, 1024, 16, 16)
│
ConvBlock (reduce channels)
│
Output: (1, 256, 16, 16)
ConvTranspose2d: Growing Images
ConvTranspose2d is the opposite of Conv2d — it makes images bigger:
Conv2d (stride=2): ConvTranspose2d (stride=2):
4×4 → 2×2 2×2 → 4×4
(shrink) (grow)
Each input pixel becomes a 2×2 region:
Input (2×2): Output (4×4):
┌───┬───┐ ┌───┬───┬───┬───┐
│ 1 │ 2 │ │ 1 │ 1 │ 2 │ 2 │
├───┼───┤ → ├───┼───┼───┼───┤
│ 3 │ 4 │ │ 1 │ 1 │ 2 │ 2 │
└───┴───┘ ├───┼───┼───┼───┤
│ 3 │ 3 │ 4 │ 4 │
├───┼───┼───┼───┤
│ 3 │ 3 │ 4 │ 4 │
└───┴───┴───┴───┘
Complete Data Flow
Let’s trace through an entire U-Net forward pass:
INPUT: (1, 3, 128, 128) "RGB image"
ENCODER:
enc1: (1, 64, 64, 64) → skip1 saved
enc2: (1, 128, 32, 32) → skip2 saved
enc3: (1, 256, 16, 16) → skip3 saved
enc4: (1, 512, 8, 8) → skip4 saved
BOTTLENECK:
(1, 512, 8, 8) "Compressed understanding"
DECODER:
dec4: (1, 256, 16, 16) ← uses skip4
dec3: (1, 128, 32, 32) ← uses skip3
dec2: (1, 64, 64, 64) ← uses skip2
dec1: (1, 64, 128, 128) ← uses skip1
OUTPUT: (1, 3, 128, 128) "Processed image"
What Can U-Net Do?
U-Net is used for any task requiring pixel-level output:
| Task | Input | Output |
|---|---|---|
| Medical segmentation | CT scan | Tumor mask |
| Semantic segmentation | Photo | Labels per pixel |
| Image denoising | Noisy image | Clean image |
| Inpainting | Image with hole | Filled image |
| Super resolution | Low-res | High-res |
| Style transfer | Photo | Stylized image |
| Diffusion models | Noisy latent | Denoised latent |
When NOT to Use Decoder
Not all tasks need a decoder:
Classification (no decoder):
Image → [shrink, shrink, shrink] → "This is a cat"
U-Net (full decoder):
Image → [shrink] → [expand] → Processed image
If you only need a label, not a pixel-by-pixel output, skip the decoder.
Summary
U-Net’s power comes from three key ideas:
- Encoder: Compress spatially, extract “what” is in the image
- Decoder: Expand back to full resolution
- Skip connections: Pass “where” information directly from encoder to decoder
This combination allows U-Net to understand both the big picture (global context from bottleneck) and fine details (local information from skips), producing sharp, accurate outputs.
Whether you’re segmenting medical images, generating art with Stable Diffusion, or building your own image editing model, U-Net’s elegant architecture is likely at the core.
This post was created while building a text-conditioned image editing model. The examples and diagrams come from hands-on experimentation with PyTorch.