import jax
from jax import numpy as jnp
from jax import jit
from jax import Array
from typing import Callable, Optional, Any
[docs]
def edm_sampler(
sde: Any,
model: Callable,
x_1: Array,
*,
key: Array,
condition_mask: Optional[Array] = None,
condition_value: Optional[Array] = None,
return_intermediates: bool = False,
n_steps: int = 18,
S_churn: float = 0,
S_min: float = 0,
S_max: float = float('inf'),
S_noise: float = 1,
method: str = "Heun",
model_kwargs: dict = {},
) -> Array:
"""
EDM sampler for diffusion models.
Args:
sde: SDE scheduler object.
model (Callable): Model function.
x_1 (Array): Initial value.
key (Array): JAX random key.
condition_mask (Optional[Array]): Mask for conditioning.
condition_value (Optional[Array]): Value for conditioning.
return_intermediates (bool): Whether to return intermediate steps.
n_steps (int): Number of steps.
S_churn (float): Churn parameter.
S_min (float): Minimum S value.
S_max (float): Maximum S value.
S_noise (float): Noise scale.
method (str): Integration method ("Euler" or "Heun").
model_kwargs (dict): Additional model arguments.
Returns:
Array: Sampled output.
"""
assert method in ["Euler", "Heun"], f"Unknown method: {method}"
if condition_mask is not None:
assert (
condition_value is not None
), "Condition value must be provided if condition mask is provided"
else:
condition_mask = 0
condition_value = 0
# Time step discretization.
step_indices = jnp.arange(n_steps)
t_steps = sde.timesteps(step_indices, n_steps)
t_steps = jnp.append(t_steps, 0)
# Main sampling loop.
x_next = x_1 * t_steps[0]
def one_step(carry, i):
x_next, key = carry
key, subkey = jax.random.split(key)
t_cur = t_steps[i]
t_next = t_steps[i+1]
x_curr = x_next
# Increase noise temporarily.
in_range = jnp.logical_and(t_cur >= S_min, t_cur <= S_max)
# print(in_range)
gamma = jax.lax.cond(in_range, lambda: jnp.minimum(S_churn / n_steps, jnp.sqrt(2) - 1), lambda: 0.0)
t_hat = t_cur + gamma * t_cur # sigma at the specific time step
sqrt_arg = jnp.clip(t_hat ** 2 - t_cur ** 2, a_min=0, a_max=None)
x_hat = x_curr + jnp.sqrt(sqrt_arg) * S_noise * jax.random.normal(subkey, x_curr.shape)
x_hat = x_hat * (1 - condition_mask) + condition_value * condition_mask # Apply conditioning.
# Euler step.
denoised = sde.denoise(model, x_hat, jnp.broadcast_to(t_hat, (x_hat.shape[0],1)), **model_kwargs)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
x_next = x_next * (1 - condition_mask) + condition_value * condition_mask # Apply conditioning.
if method == "Heun":
# Apply 2nd order correction.
def apply_2nd_order_correction(): # Function for i < (n_steps - 1)
denoised = sde.denoise(model, x_next, jnp.broadcast_to(t_next, (x_next.shape[0],1)), **model_kwargs)
d_prime = (x_next - denoised) / t_next
x_next_updated = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) # Store in a new variable
x_next_updated = x_next_updated * (1 - condition_mask) + condition_value * condition_mask # Apply conditioning.
return x_next_updated # Return the updated x_next
x_next = jax.lax.cond(i < (n_steps - 1), apply_2nd_order_correction, lambda: x_next) # Apply 2nd order correction if i < (n_steps - 1)
if return_intermediates:
return (x_next, key), x_next
else:
return (x_next, key), ()
i = jnp.arange(n_steps)
# return one_step, x_next
carry, x_scan = jax.lax.scan(one_step, (x_next, key), i)
if return_intermediates:
return x_scan
else:
# if condition_mask is not None:
# carry = jnp.where(condition_mask, condition_value, carry[0])
# else:
# carry = carry[0]
return carry[0]
[docs]
def edm_ablation_sampler(
sde, model, x_1, *,
key,
condition_mask = None,
condition_value = None,
return_intermediates = False,
n_steps=18,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
method="Heun",
model_kwargs={},
):
assert method in ["Euler", "Heun"], f"Unknown method: {method}"
if condition_mask is not None:
assert (
condition_value is not None
), "Condition value must be provided if condition mask is provided"
else:
condition_mask = 0
condition_value = 0
# Time step discretization.
step_indices = jnp.arange(n_steps)
t_steps = sde.timesteps(step_indices, n_steps)
t_steps = jnp.append(t_steps, 0)
# Main sampling loop.
t_next = t_steps[0]
x_next = x_1 * (sde.sigma(t_next) * sde.s(t_next))
def one_step(carry, i):
x_next, key = carry
key, subkey = jax.random.split(key)
t_cur = t_steps[i]
t_next = t_steps[i+1]
x_curr = x_next
# Increase noise temporarily.
in_range = jnp.logical_and(t_cur >= S_min, t_cur <= S_max)
gamma = jax.lax.cond(in_range, lambda: jnp.minimum(S_churn / n_steps, jnp.sqrt(2) - 1), lambda: 0.0)
t_hat = sde.sigma_inv(sde.sigma(t_cur) + gamma * sde.sigma(t_cur)) # sigma at the specific time step
sqrt_arg = jnp.clip(sde.sigma(t_hat) ** 2 - sde.sigma(t_cur) ** 2, a_min=0, a_max=None)
x_hat = sde.s(t_hat) / sde.s(t_cur)*x_curr + jnp.sqrt(sqrt_arg) * sde.s(t_hat)*S_noise * jax.random.normal(subkey, x_curr.shape)
x_hat = x_hat * (1 - condition_mask) + condition_value * condition_mask # Apply conditioning.
# Euler step.
h = t_next - t_hat
denoised = sde.denoise(model, x_hat/sde.s(t_hat), jnp.broadcast_to(sde.sigma(t_hat), (x_hat.shape[0],1)), **model_kwargs)
d_cur = (sde.sigma_deriv(t_hat) / sde.sigma(t_hat) + sde.s_deriv(t_hat) / sde.s(t_hat)) * x_hat - sde.sigma_deriv(t_hat) * sde.s(t_hat) / sde.sigma(t_hat) * denoised
x_prime = x_hat + h * d_cur
t_prime = t_next
x_prime = x_prime * (1 - condition_mask) + condition_value * condition_mask # Apply conditioning.
if method == "Heun":
# Apply 2nd order correction.
def apply_2nd_order_correction(): # Function for i < (n_steps - 1)
denoised = sde.denoise(model, x_prime/sde.s(t_prime), jnp.broadcast_to(sde.sigma(t_prime), (x_prime.shape[0],1)), **model_kwargs)
d_prime = (sde.sigma_deriv(t_prime) / sde.sigma(t_prime) + sde.s_deriv(t_prime) / sde.s(t_prime)) * x_prime - sde.sigma_deriv(t_prime) * sde.s(t_prime) / sde.sigma(t_prime) * denoised
x_next = x_hat + h * (0.5 * d_cur + 0.5 * d_prime) # Store in a new variable
x_next = x_next * (1 - condition_mask) + condition_value * condition_mask # Apply conditioning.
return x_next # Return the updated x_next
x_next = jax.lax.cond(i < (n_steps - 1), apply_2nd_order_correction, lambda: x_prime) # Apply 2nd order correction if i < (n_steps - 1)
else:
x_next = x_prime
if return_intermediates:
return (x_next, key), x_next
else:
return (x_next, key), ()
i = jnp.arange(n_steps)
# return one_step, x_next
carry, x_scan = jax.lax.scan(one_step, (x_next, key), i)
if return_intermediates:
return x_scan
else:
# if condition_mask is not None:
# carry = jnp.where(condition_mask, condition_value, carry[0])
# else:
# carry = carry[0]
return carry[0]