Start now →

Activation Functions in JAX: A Deep Dive into Neural Network Non-Linearities

By Manas Gupta · Published April 10, 2026 · 9 min read · Source: Level Up Coding
AI & Crypto
Activation Functions in JAX: A Deep Dive into Neural Network Non-Linearities

If you have ever trained a deep neural network that refused to learn, produced exploding loss values, or mysteriously plateaued long before reaching reasonable accuracy, there is a very good chance the culprit was your activation function. This tutorial takes a deep dive at six popular activation functions: Sigmoid, Tanh, ReLU, LeakyReLU, ELU, and Swish implementing each from scratch in JAX. By the end, you will have a clear mental model for choosing the right activation function.

Why Activation Functions Matter

Before diving into code, it is worth spending a moment on the “why.” A neural network without activation functions is mathematically just a linear transformation of its input, no matter how many layers you stack. The composition of linear functions is still linear, which means that a ten-layer network with no non-linearity is strictly equivalent to a single weight matrix, hardly the kind of universal function approximator we want.

Activation functions introduce non-linearity at each layer, allowing the network to model highly complex relationships in data. But not all non-linearities are created equal. Some introduce training hurdles:

Choosing a good activation function means navigating all three failure modes simultaneously, while also considering computational cost, the smoothness properties that some optimizers expect, and how the function interacts with weight initialization schemes. This tutorial walks through each of these considerations methodically, using JAX to make the analysis concrete and reproducible.

Implementing and Visualizing Six Activation Functions

We are going to implement each activation function as a plain JAX function using jax.numpy operations. This means each one automatically becomes differentiable via jax.grad, JIT-compilable via jax.jit, and vectorizable via jax.vmap. We get all of that for free just by writing functions in terms of jnp operations.

We will also write a helper that plots both the activation function itself and its derivative side by side, so we can immediately see the gradient landscape.


import os
import math
import time
from functools import partial
from typing import Callable, Sequence, Any
# Scientific stack
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import cm
# JAX + Flax + Optax
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap


def plot_activation(fn: Callable, fn_name: str, x_range: tuple = (-5, 5)):
"""Plot an activation function and its gradient."""
x = jnp.linspace(x_range[0], x_range[1], 1000)
y = fn(x)
# Use vmap(grad(fn)) to compute element-wise gradient
dy = vmap(grad(fn))(x)
fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))
fig.suptitle(f"Activation: {fn_name}", fontsize=14, fontweight="bold")
axes[0].plot(x, y, color="steelblue", linewidth=2)
axes[0].axhline(0, color="black", linewidth=0.6, linestyle="--")
axes[0].axvline(0, color="black", linewidth=0.6, linestyle="--")
axes[0].set_xlabel("Input x")
axes[0].set_ylabel("f(x)")
axes[0].set_title("Activation function")
axes[0].grid(True, alpha=0.3)
axes[1].plot(x, dy, color="darkorange", linewidth=2)
axes[1].axhline(0, color="black", linewidth=0.6, linestyle="--")
axes[1].axvline(0, color="black", linewidth=0.6, linestyle="--")
axes[1].set_xlabel("Input x")
axes[1].set_ylabel("f'(x)")
axes[1].set_title("Gradient (derivative)")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Sigmoid Function

The sigmoid function maps any real number to the open interval (0, 1):

σ(x) = 1 / (1 + exp(-x))

It was the go-to activation function in the early days of neural networks because its output can be interpreted as a probability. However, it has a well-known flaw: its gradient is at most 0.25, reached at x = 0, and decays rapidly toward zero as |x| grows. In a deep network, multiplying many such small gradients together as we propagate backward means the signal reaching early layers can be become very small, the classic vanishing gradient problem.

def sigmoid(x: jnp.ndarray) -> jnp.ndarray:
"""
Sigmoid activation function.
σ(x) = 1 / (1 + exp(-x))
Output range: (0, 1)
Max gradient: 0.25 at x=0
"""
return 1.0 / (1.0 + jnp.exp(-x))

plot_activation(sigmoid, "Sigmoid")

Key properties:

Tanh Function

The hyperbolic tangent is closely related to sigmoid, but it is zero-centered, mapping to (−1, 1):

tanh(x) = (exp(x) − exp(−x)) / (exp(x) + exp(−x))
def tanh(x: jnp.ndarray) -> jnp.ndarray:
"""
Hyperbolic tangent activation.
tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
Output range: (-1, 1)
Max gradient: 1.0 at x=0
"""
return jnp.tanh(x) # JAX has a numerically stable built-in

plot_activation(tanh, "Tanh")

Key properties:

Both sigmoid and tanh were dominant for decades before the community rediscovered and embraced ReLU, whose gradient properties are dramatically better for deep networks.

ReLU Function

The Rectified Linear Unit is a simple function:

ReLU(x) = max(0, x)

Its gradient is 1 for positive inputs and 0 for negative inputs. This property helps preserve the gradients flowing through ReLU neurons with positive pre-activations. This is why very deep networks (10, 20, even 100+ layers) became practically trainable once ReLU replaced sigmoid and tanh in hidden layers.

def relu(x: jnp.ndarray) -> jnp.ndarray:
"""
Rectified Linear Unit.
ReLU(x) = max(0, x)
Output range: [0, ∞)
Gradient: 1 for x > 0, 0 for x < 0
"""
return jnp.maximum(0, x)

plot_activation(relu, "ReLU")

Key properties:

The dying ReLU problem is prominent when learning rates are high or weight initialization is poor, causing large fractions of neurons to fire negatively and never recover. All three activation functions that follow attempt to address this, each in a slightly different way.

LeakyReLU Function

The most straightforward fix: instead of setting the output to zero for negative inputs, allow a small positive gradient through:

LeakyReLU(x) = x         if x ≥ 0
= α·x if x < 0

where α is a small constant, typically 0.01.

def leaky_relu(x: jnp.ndarray, alpha: float = 0.01) -> jnp.ndarray:
"""
Leaky ReLU.
f(x) = x for x >= 0, alpha * x for x < 0
Output range: (-∞, ∞)
Gradient: 1 for x > 0, alpha for x < 0
"""
return jnp.where(x >= 0, x, alpha * x)
plot_activation(leaky_relu, "LeakyReLU")

Key properties:

ELU Function

The Exponential Linear Unit replaces the hard zero floor for negative inputs with a smooth exponential curve that asymptoticly converges to −α:

ELU(x) = x              if x > 0
= α·(exp(x) − 1) if x ≤ 0

where α is typically 1.0.

def elu(x: jnp.ndarray, alpha: float = 1.0) -> jnp.ndarray:
"""
Exponential Linear Unit.
f(x) = x for x > 0, alpha * (exp(x) - 1) for x <= 0
Output range: (-alpha, ∞)
Saturates to -alpha for very negative inputs.
"""
return jnp.where(x > 0, x, alpha * (jnp.exp(x) - 1.0))
plot_activation(elu, "ELU")

Key properties:

Swish Function

Swish emerged from a large-scale automated search for better activation functions conducted by Google Brain researchers.

Swish(x) = x · σ(x) = x / (1 + exp(-x))

This is also known as SiLU (Sigmoid Linear Unit). What makes Swish distinctive is that it is both smooth and non-monotonic, unlike all the functions we have seen so far, it has a small dip near x ≈ −1.3 where the output is slightly negative before recovering. This non-monotonicity has been empirically shown to be beneficial for very deep networks.

def swish(x: jnp.ndarray) -> jnp.ndarray:
"""
Swish (SiLU) activation function.
f(x) = x * sigmoid(x) = x / (1 + exp(-x))
Output range: (-∞, ∞), non-monotonic
Smooth and differentiable everywhere.
"""
return x * sigmoid(x)
plot_activation(swish, "Swish")

Key properties:

Comparing All Six at Once

activation_fns = {
"Sigmoid": sigmoid,
"Tanh": tanh,
"ReLU": relu,
"LeakyReLU": leaky_relu,
"ELU": elu,
"Swish": swish,
}
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
x = jnp.linspace(-5, 5, 1000)
for ax, (name, fn) in zip(axes.flatten(), activation_fns.items()):
y = fn(x)
dy = vmap(grad(fn))(x)
ax.plot(x, y, label="f(x)", color="steelblue", linewidth=2)
ax.plot(x, dy, label="f'(x)", color="darkorange", linewidth=2, linestyle="--")
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)
ax.set_title(name, fontsize=12, fontweight="bold")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(-2, 3)
plt.suptitle("Activation Functions and Their Gradients", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

Looking at these side-by-side makes several patterns immediately obvious:

  1. Sigmoid and Tanh both have gradients that collapse to zero far from the origin. ReLU’s gradient is binary: exactly 1 or exactly 0, with the zero region causing dead neurons.
  2. LeakyReLU, ELU, and Swish all avoid the zero-gradient region on the left side of the input space, each in a different way.
  3. Swish is the only one that is non-monotonic. Its output temporarily decreases before climbing again in the negative region.

In next article we will train a multi-layer network on FashionMNIST with Flax, and do a deep analysis on the training efficiency and performance accuracy.


Activation Functions in JAX: A Deep Dive into Neural Network Non-Linearities was originally published in Level Up Coding on Medium, where people are continuing the conversation by highlighting and responding to this story.

This article was originally published on Level Up Coding and is republished here under RSS syndication for informational purposes. All rights and intellectual property remain with the original author. If you are the author and wish to have this article removed, please contact us at [email protected].

NexaPay — Accept Card Payments, Receive Crypto

No KYC · Instant Settlement · Visa, Mastercard, Apple Pay, Google Pay

Get Started →