import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import pandas as pd
from corner import corner
sns.set_style("darkgrid")
[docs]
def plot_trajectories(traj):
traj = np.array(traj)
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(traj[0, :, 0], traj[0, :, 1], color="red", s=1, alpha=1)
ax.plot(traj[:, :, 0], traj[:, :, 1], color="white", lw=0.5, alpha=0.7)
ax.scatter(traj[-1, :, 0], traj[-1, :, 1], color="blue", s=2, alpha=1, zorder=2)
ax.set_aspect("equal", adjustable="box")
# set black background
ax.set_facecolor("#A6AEBF")
plt.grid(False)
return fig, ax
# plot marginals using seaborn's PairGrid
[docs]
base_color = "#CD5656" # Base color for the hexbin and kdeplot
[docs]
hist_color = "#202A44" # Color for the histograms
[docs]
true_val_color = "#687FE5"
[docs]
rgb_base = np.array(mcolors.to_rgb(base_color))
[docs]
colors = [
(
rgb_base[0],
rgb_base[1],
rgb_base[2],
0,
), # At data value 0, color is rgb_base with alpha 0
(rgb_base[0], rgb_base[1], rgb_base[2], 1),
] # At data value 1, color is rgb_base with alpha 1
[docs]
transparent_cmap = LinearSegmentedColormap.from_list("transparent_red", colors, N=256)
[docs]
def _parse_range(range_arg, ndim):
if range_arg is None:
return [None] * ndim
if (
isinstance(range_arg, tuple)
and len(range_arg) == 2
and not isinstance(range_arg[0], tuple)
):
return [range_arg] * ndim
if isinstance(range_arg, (list, tuple)) and len(range_arg) == ndim:
return list(range_arg)
raise ValueError(
"range must be a tuple (min, max) or a sequence of such tuples, one per axis"
)
[docs]
def _plot_marginals_2d(
data,
plot_levels=True,
labels=None,
gridsize=15,
hexbin_kwargs={},
histplot_kwargs={},
range=None,
true_param=None,
**kwargs,
):
data = np.array(data)
if true_param is not None:
true_param = np.array(true_param)
ndim = data.shape[1]
fontsize = 12
if labels is None:
labels = ["$\\theta_{{{}}}$".format(i) for i in np.arange(1, data.shape[1] + 1)]
dataframe = pd.DataFrame(data, columns=labels)
axis_ranges = _parse_range(range, ndim)
xlim, ylim = axis_ranges[0], axis_ranges[1]
cmap = hexbin_kwargs.pop("cmap", transparent_cmap)
color = hexbin_kwargs.pop("color", [0, 0, 0, 0])
gridsize = hexbin_kwargs.pop("gridsize", gridsize)
# Set extent for hexbin
extent = None
if xlim is not None and ylim is not None:
extent = xlim + ylim
joint_kws = dict(cmap=cmap, color=color, gridsize=gridsize, **hexbin_kwargs)
if extent is not None:
joint_kws["extent"] = extent
marginal_kws = dict(bins=gridsize, fill=True, color=hist_color, **histplot_kwargs)
if xlim is not None:
marginal_kws["binrange"] = xlim
if ylim is not None:
marginal_kws["binrange"] = ylim
g = sns.jointplot(
data=dataframe,
x=labels[0],
y=labels[1],
kind="hex",
height=6,
gridsize=gridsize,
marginal_kws=marginal_kws,
joint_kws=joint_kws,
**kwargs,
)
if xlim is not None:
g.ax_joint.set_xlim(xlim)
g.ax_marg_x.set_xlim(xlim)
if ylim is not None:
g.ax_joint.set_ylim(ylim)
g.ax_marg_y.set_ylim(ylim)
# Set fontsize for axis labels
g.ax_joint.set_xlabel(labels[0], fontsize=fontsize)
g.ax_joint.set_ylabel(labels[1], fontsize=fontsize)
if plot_levels:
levels = np.sort(1 - np.array([0.6827, 0.9545]))
g.plot_joint(
sns.kdeplot,
color=hist_color,
zorder=3,
levels=levels,
alpha=1,
linewidths=1,
)
# Plot true_param if provided
if true_param is not None:
g.ax_joint.scatter(
true_param[0],
true_param[1],
color=true_val_color,
marker="s",
s=100,
zorder=10,
)
g.ax_joint.axvline(
true_param[0], color=true_val_color, linestyle="-", linewidth=1.5, zorder=5
)
g.ax_joint.axhline(
true_param[1], color=true_val_color, linestyle="-", linewidth=1.5, zorder=5
)
return g
[docs]
def _plot_marginals_nd(
data,
plot_levels=True,
labels=None,
gridsize=15,
range=None,
hexbin_kwargs={},
histplot_kwargs={},
true_param=None,
):
data = np.array(data)
if true_param is not None:
true_param = np.array(true_param)
ndim = data.shape[1]
fontsize = 12
if labels is None:
labels = ["$\\theta_{{{}}}$".format(i) for i in np.arange(1, data.shape[1] + 1)]
axis_ranges = _parse_range(range, ndim)
cmap = hexbin_kwargs.pop("cmap", transparent_cmap)
color = hexbin_kwargs.pop("color", [0, 0, 0, 0])
bins = histplot_kwargs.pop("bins", gridsize)
fill = histplot_kwargs.pop("fill", True)
color_hist = histplot_kwargs.pop("color", hist_color)
fig, axes = plt.subplots(ndim, ndim, figsize=(2.5 * ndim, 2.5 * ndim))
# Hide upper triangle and set all axes off by default
for i in np.arange(ndim):
for j in np.arange(ndim):
if i < j:
axes[i, j].set_visible(False)
else:
axes[i, j].set_visible(True)
# Hide x/y ticks and labels for non-border plots
if i != ndim - 1:
axes[i, j].set_xticklabels([])
axes[i, j].set_xlabel("")
if j != 0 and j != i:
axes[i, j].set_yticklabels([])
axes[i, j].set_ylabel("")
# Lower triangle: hexbin and kde
for i in np.arange(1, ndim):
for j in np.arange(i):
ax = axes[i, j]
x = data[:, j]
y = data[:, i]
extent = None
if axis_ranges[j] is not None and axis_ranges[i] is not None:
extent = axis_ranges[j] + axis_ranges[i]
ax.hexbin(
x,
y,
gridsize=gridsize,
cmap=cmap,
extent=extent,
color=color,
**hexbin_kwargs,
)
if axis_ranges[j] is not None:
ax.set_xlim(axis_ranges[j])
if axis_ranges[i] is not None:
ax.set_ylim(axis_ranges[i])
if plot_levels:
levels = np.sort(1 - np.array([0.6827, 0.9545]))
sns.kdeplot(
x=x,
y=y,
levels=levels,
color=hist_color,
zorder=3,
alpha=1,
linewidths=1,
ax=ax,
)
# Plot true_param if provided
if true_param is not None:
ax.scatter(
true_param[j],
true_param[i],
color=true_val_color,
marker="s",
s=50,
zorder=10,
label="True",
)
ax.axvline(
true_param[j],
color=true_val_color,
linestyle="-",
linewidth=1.5,
zorder=5,
)
ax.axhline(
true_param[i],
color=true_val_color,
linestyle="-",
linewidth=1.5,
zorder=5,
)
# Only set axis labels for border plots
if i == ndim - 1:
ax.set_xlabel(labels[j], fontsize=fontsize)
if j == 0:
ax.set_ylabel(labels[i], fontsize=fontsize)
# Diagonal: histograms
for i in np.arange(ndim):
ax = axes[i, i]
x = data[:, i]
binrange = axis_ranges[i] if axis_ranges[i] is not None else None
sns.histplot(
x,
bins=bins,
color=color_hist,
fill=fill,
binrange=binrange,
ax=ax,
stat="density",
**histplot_kwargs,
)
if true_param is not None:
ax.axvline(
true_param[i], color=true_val_color, linestyle="-", linewidth=1.5, zorder=5
)
if axis_ranges[i] is not None:
ax.set_xlim(axis_ranges[i])
ax.autoscale(enable=True, axis="y", tight=False)
# Only set y label for the top-left diagonal plot (theta_1)
if i == 0:
ax.set_ylabel(labels[i], fontsize=fontsize)
else:
ax.set_ylabel("")
# Only set x label for bottom-right diagonal plot
if i == ndim - 1:
ax.set_xlabel(labels[i], fontsize=14)
else:
ax.set_xlabel("")
plt.tight_layout()
return fig, axes
[docs]
def _plot_marginals_corner(
data,
labels=None,
gridsize=25,
range=None,
true_param=None,
**kwargs,
):
data = np.array(data)
if true_param is not None:
true_param = np.array(true_param)
if labels is None:
labels = ["$\\theta_{{{}}}$".format(i) for i in np.arange(1, data.shape[1] + 1)]
plt.clf()
corner(
data,
truths=true_param,
bins=gridsize,
labels=labels,
color=base_color, # points and 1D hist color
hist_kwargs={
"color": hist_color,
"edgecolor": "white",
"lw": 1,
"histtype": "barstacked",
},
truth_color=true_val_color,
contour_kwargs={"colors": hist_color, "linewidths": 1},
range=range,
**kwargs,
)
return plt.gcf(), plt.gca()
[docs]
def plot_marginals(
data,
backend="corner",
plot_levels=True,
labels=None,
gridsize=15,
hexbin_kwargs={},
histplot_kwargs={},
range=None,
true_param=None,
**kwargs,
):
"""
Plot marginal distributions of multidimensional data using either the 'corner' or 'seaborn' backend.
Parameters
----------
data : array-like, shape (n_samples, n_dim)
The data to plot. Each row is a sample, each column a parameter.
backend : str, default="corner"
Which plotting backend to use. Options:
- 'corner': Use the corner.py package for a classic corner plot.
- 'seaborn': Use seaborn's jointplot (2D) or custom grid (ND) for marginals.
The seaborn backend is slower, but will produce smoother plots with KDE contours.
plot_levels : bool, default=True
If True and using seaborn, plot 1- and 2-sigma KDE contours on off-diagonal plots. When using 'corner', levels are automatically computed.
labels : list of str or None, default=None
Axis labels for each parameter. If None, uses LaTeX-style $\theta_i$.
gridsize : int, default=15
Number of bins for hexbin/histogram (seaborn) or for corner plot.
hexbin_kwargs : dict, default={}
Additional keyword arguments for hexbin plots (seaborn backend only).
histplot_kwargs : dict, default={}
Additional keyword arguments for histogram plots (seaborn backend only).
range : tuple or list of tuples or None, default=None
Axis limits for each parameter, e.g. [(xmin, xmax), (ymin, ymax), ...].
true_param : array-like, shape (n_dim,), default=None
Ground truth parameter values to mark on the plots.
**kwargs :
Additional keyword arguments passed to the underlying plotting functions.
Returns
-------
fig, axes : matplotlib Figure and Axes objects
The figure and axes containing the plot.
Raises
------
ValueError
If an unknown backend is specified.
Notes
-----
- For 'corner', the function uses the corner.py package and supports labels, gridsize, range, and true_param.
- For 'seaborn', 2D data uses jointplot, higher dimensions use a custom grid of hexbin and histogram plots.
"""
if backend == "corner":
return _plot_marginals_corner(
data,
labels=labels,
gridsize=gridsize,
range=range,
true_param=true_param,
**kwargs,
)
elif backend == "seaborn":
if data.shape[1] == 2:
return _plot_marginals_2d(
data,
plot_levels=plot_levels,
labels=labels,
gridsize=gridsize,
hexbin_kwargs=hexbin_kwargs,
histplot_kwargs=histplot_kwargs,
range=range,
true_param=true_param,
**kwargs,
)
else:
return _plot_marginals_nd(
data,
plot_levels=plot_levels,
labels=labels,
gridsize=gridsize,
hexbin_kwargs=hexbin_kwargs,
histplot_kwargs=histplot_kwargs,
range=range,
true_param=true_param,
**kwargs,
)
else:
raise ValueError(f"Unknown backend: {backend}. Use 'corner' or 'seaborn'.")
# code to plot a 2D likelihood
[docs]
cmap_lcontour = sns.cubehelix_palette(
start=0.5, rot=-0.5, light=1.0, dark=0.2, as_cmap=True
)
[docs]
def plot_2d_levels(x, y, Z, ax, levels=[0.6827, 0.9545]):
"""
Plot 2D levels on a given axis.
Parameters
----------
x : array-like
X values.
y : array-like
Y values.
Z : array-like
Z values corresponding to (x, y).
ax : matplotlib Axes
The axes to plot on.
levels : list of float
The contour levels to plot.
"""
# --- 1. Prepare the data ---
x = np.asarray(x) # make sure we have numpy arrays
y = np.asarray(y) # make sure we have numpy arrays
Z = np.asarray(Z) # make sure we have numpy arrays
# --- 2. Define Desired Area Levels ---
# These are the fractions of the total volume you want to enclose.
# For a probability distribution, these are often confidence levels.
area_levels = levels
# --- 3. Calculate Contour Levels (Z-values) from Areas ---
# To find the z-values that enclose a certain area, we follow these steps:
# a. Flatten the 2D Z array into a 1D list of all values.
# b. Sort these values in descending order (from highest to lowest).
z_flat_sorted = np.sort(Z.ravel())[::-1]
# c. Calculate the cumulative sum of the sorted values. Each element in
# this array represents the sum of all preceding (higher) values.
z_cumsum = np.cumsum(z_flat_sorted)
# d. Normalize the cumulative sum by the total sum of all Z values.
# This converts the cumulative sum into a fraction of the total volume,
# ranging from 0 to 1.
z_cumsum_normalized = z_cumsum / z_cumsum[-1]
# e. Find the z-values that correspond to our desired area fractions.
# We use np.searchsorted to find the index where the normalized
# cumulative sum first exceeds our target area level.
indices = np.searchsorted(z_cumsum_normalized, area_levels)
z_levels = z_flat_sorted[indices]
# The levels must be sorted in ascending order for matplotlib's contour functions.
z_levels = np.sort(z_levels)
# --- 4. Plot the Results ---
# To create filled contours, we need to define the boundaries of each color.
# We start at 0, use our calculated z_levels, and end at the max value.
# contour_fill_levels = np.concatenate(([Z.min()], z_levels, [Z.max()]))
# a. Plot the filled contours (contourf).
# b. Plot the contour lines (contour) for clarity.
# These lines will clearly mark the boundaries of the enclosed areas.
cnt = ax.contour(x, y, Z, levels=z_levels, colors=hist_color, linewidths=1.5)
labels = {z: f"{int(a*100)}%" for z, a in zip(z_levels, np.flip(area_levels))}
ax.clabel(cnt, levels=z_levels, inline=True, fontsize=10, fmt=labels)
return
[docs]
def plot_2d_dist_contour(
x,
y,
Z,
true_param=None,
levels=[0.6827, 0.9545],
cmap=cmap_lcontour,
):
"""
Plot a 2D contour plot of a distribution.
Parameters
----------
x : array-like
X values.
y : array-like
Y values.
Z : array-like
Z values corresponding to (x, y).
levels : list or None, optional
Contour levels to plot. If None, contours will not be plotted.
Returns
-------
fig, ax : matplotlib Figure and Axes objects
The figure and axes containing the plot.
"""
fig, ax = plt.subplots(figsize=(8, 6))
x = np.asarray(x) # make sure we have numpy arrays
y = np.asarray(y) # make sure we have numpy arrays
Z = np.asarray(Z) # make sure we have numpy arrays
ax.contourf(x, y, Z, levels=20, cmap=cmap, vmin=0)
if levels is not None:
plot_2d_levels(x, y, Z, ax, levels=levels)
if true_param is not None:
ax.scatter(true_param[0], true_param[1], color=base_color, s=50, marker="s", zorder=10)
ax.axvline(true_param[0], color=base_color, linestyle='-', linewidth=1.5, zorder=9)
ax.axhline(true_param[1], color=base_color, linestyle='-', linewidth=1.5, zorder=9)
# Set aspect ratio to equal for better visualization
ax.set_aspect("equal", adjustable="box")
return fig, ax