import jax
from jax import numpy as jnp
from jax import jit
from jax import Array
from typing import Callable, Optional, Any
from einops import repeat
[docs]
def edm_sampler(
sde: Any,
model: Callable,
x_1: Array,
*,
key: Array,
condition_mask: Optional[Array] = None,
condition_value: Optional[Array] = None,
return_intermediates: bool = False,
n_steps: int = 18,
S_churn: float = 0,
S_min: float = 0,
S_max: float = float("inf"),
S_noise: float = 1,
method: str = "Heun",
model_kwargs: dict = {},
) -> Array:
"""
EDM sampler for diffusion models.
Args:
sde: SDE scheduler object.
model (Callable): Model function.
x_1 (Array): Initial value.
key (Array): JAX random key.
condition_mask (Optional[Array]): Mask for conditioning.
condition_value (Optional[Array]): Value for conditioning.
return_intermediates (bool): Whether to return intermediate steps.
n_steps (int): Number of steps.
S_churn (float): Churn parameter.
S_min (float): Minimum S value.
S_max (float): Maximum S value.
S_noise (float): Noise scale.
method (str): Integration method ("Euler" or "Heun").
model_kwargs (dict): Additional model arguments.
Returns:
Array: Sampled output.
"""
assert method in ["Euler", "Heun"], f"Unknown method: {method}"
if condition_mask is not None:
assert (
condition_value is not None
), "Condition value must be provided if condition mask is provided"
else:
condition_mask = 0
condition_value = 0
# Time step discretization.
step_indices = jnp.arange(n_steps)
t_steps = sde.timesteps(step_indices, n_steps)
t_steps = jnp.append(t_steps, 0)
# Main sampling loop.
x_next = x_1 * t_steps[0]
def one_step(carry, i):
x_next, key = carry
key, subkey = jax.random.split(key)
t_cur = t_steps[i]
t_next = t_steps[i + 1]
x_curr = x_next
# Increase noise temporarily.
in_range = jnp.logical_and(t_cur >= S_min, t_cur <= S_max)
# print(in_range)
gamma = jax.lax.cond(
in_range,
lambda: jnp.minimum(S_churn / n_steps, jnp.sqrt(2) - 1),
lambda: 0.0,
)
t_hat = t_cur + gamma * t_cur # sigma at the specific time step
sqrt_arg = jnp.clip(t_hat**2 - t_cur**2, min=0, max=None)
x_hat = x_curr + jnp.sqrt(sqrt_arg) * S_noise * jax.random.normal(
subkey, x_curr.shape
)
x_hat = (
x_hat * (1 - condition_mask) + condition_value * condition_mask
) # Apply conditioning.
# Euler step.
denoised = sde.denoise(
model, x_hat, t_hat[..., None], **model_kwargs
) # TODO test
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
x_next = (
x_next * (1 - condition_mask) + condition_value * condition_mask
) # Apply conditioning.
if method == "Heun":
# Apply 2nd order correction.
def apply_2nd_order_correction(): # Function for i < (n_steps - 1)
denoised = sde.denoise(model, x_next, t_next[..., None], **model_kwargs)
d_prime = (x_next - denoised) / t_next
x_next_updated = x_hat + (t_next - t_hat) * (
0.5 * d_cur + 0.5 * d_prime
) # Store in a new variable
x_next_updated = (
x_next_updated * (1 - condition_mask)
+ condition_value * condition_mask
) # Apply conditioning.
return x_next_updated # Return the updated x_next
x_next = jax.lax.cond(
i < (n_steps - 1), apply_2nd_order_correction, lambda: x_next
) # Apply 2nd order correction if i < (n_steps - 1)
if return_intermediates:
return (x_next, key), x_next
else:
return (x_next, key), ()
i = jnp.arange(n_steps)
# return one_step, x_next
carry, x_scan = jax.lax.scan(one_step, (x_next, key), i)
if return_intermediates:
return x_scan
else:
# if condition_mask is not None:
# carry = jnp.where(condition_mask, condition_value, carry[0])
# else:
# carry = carry[0]
return carry[0]
[docs]
def edm_ablation_sampler(
sde,
model,
x_1,
*,
key,
condition_mask=None,
condition_value=None,
return_intermediates=False,
n_steps=18,
S_churn=0,
S_min=0,
S_max=float("inf"),
S_noise=1,
method="Heun",
model_kwargs={},
):
assert method in ["Euler", "Heun"], f"Unknown method: {method}"
if condition_mask is not None:
assert (
condition_value is not None
), "Condition value must be provided if condition mask is provided"
else:
condition_mask = 0
condition_value = 0
# Time step discretization.
step_indices = jnp.arange(n_steps)
t_steps = sde.timesteps(step_indices, n_steps)
t_steps = jnp.append(t_steps, 0)
# Main sampling loop.
t_next = t_steps[0]
x_next = x_1 * (sde.sigma(t_next) * sde.s(t_next))
def one_step(carry, i):
x_next, key = carry
key, subkey = jax.random.split(key)
t_cur = t_steps[i]
t_next = t_steps[i + 1]
x_curr = x_next
# Increase noise temporarily.
in_range = jnp.logical_and(t_cur >= S_min, t_cur <= S_max)
gamma = jax.lax.cond(
in_range,
lambda: jnp.minimum(S_churn / n_steps, jnp.sqrt(2) - 1),
lambda: 0.0,
)
t_hat = sde.sigma_inv(
sde.sigma(t_cur) + gamma * sde.sigma(t_cur)
) # sigma at the specific time step
sqrt_arg = jnp.clip(
sde.sigma(t_hat) ** 2 - sde.sigma(t_cur) ** 2, min=0, max=None
)
x_hat = sde.s(t_hat) / sde.s(t_cur) * x_curr + jnp.sqrt(sqrt_arg) * sde.s(
t_hat
) * S_noise * jax.random.normal(subkey, x_curr.shape)
x_hat = (
x_hat * (1 - condition_mask) + condition_value * condition_mask
) # Apply conditioning.
# Euler step.
h = t_next - t_hat
denoised = sde.denoise(
model, x_hat / sde.s(t_hat), sde.sigma(t_hat)[..., None], **model_kwargs
)
d_cur = (
sde.sigma_deriv(t_hat) / sde.sigma(t_hat)
+ sde.s_deriv(t_hat) / sde.s(t_hat)
) * x_hat - sde.sigma_deriv(t_hat) * sde.s(t_hat) / sde.sigma(t_hat) * denoised
x_prime = x_hat + h * d_cur
t_prime = t_next
x_prime = (
x_prime * (1 - condition_mask) + condition_value * condition_mask
) # Apply conditioning.
if method == "Heun":
# Apply 2nd order correction.
def apply_2nd_order_correction(): # Function for i < (n_steps - 1)
denoised = sde.denoise(
model,
x_prime / sde.s(t_prime),
sde.sigma(t_prime)[..., None],
**model_kwargs,
)
d_prime = (
sde.sigma_deriv(t_prime) / sde.sigma(t_prime)
+ sde.s_deriv(t_prime) / sde.s(t_prime)
) * x_prime - sde.sigma_deriv(t_prime) * sde.s(t_prime) / sde.sigma(
t_prime
) * denoised
x_next = x_hat + h * (
0.5 * d_cur + 0.5 * d_prime
) # Store in a new variable
x_next = (
x_next * (1 - condition_mask) + condition_value * condition_mask
) # Apply conditioning.
return x_next # Return the updated x_next
x_next = jax.lax.cond(
i < (n_steps - 1), apply_2nd_order_correction, lambda: x_prime
) # Apply 2nd order correction if i < (n_steps - 1)
else:
x_next = x_prime
if return_intermediates:
return (x_next, key), x_next
else:
return (x_next, key), ()
i = jnp.arange(n_steps)
# return one_step, x_next
carry, x_scan = jax.lax.scan(one_step, (x_next, key), i)
if return_intermediates:
return x_scan
else:
# if condition_mask is not None:
# carry = jnp.where(condition_mask, condition_value, carry[0])
# else:
# carry = carry[0]
return carry[0]