import torch


def rgb_to_lab(x: torch.Tensor, srgb=True) -> torch.Tensor:
    assert x.shape[-1] == 3, 'Input should have 3 channels (RGB)'
    
    # [0..255] → [0..1]
    x = normalize(x)

    # sRGB → linear RGB
    if srgb:
        mask = (x > 0.04045)
        linear = torch.where(mask, ((x + 0.055) / 1.055) ** 2.4, x / 12.92)
    else:
        linear = x

    # linear RGB → XYZ (D65)
    M = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
                      [0.2126729, 0.7151522, 0.0721750],
                      [0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device)
    # (...,3) @ (3,3).T → (...,3)
    xyz = linear @ M.T

    # normalize by D65 white point
    white = torch.tensor([0.95047, 1.00000, 1.08883], dtype=x.dtype, device=x.device)
    xyz_scaled = xyz / white

    # f(t) for Lab
    def f(t):
        mask = (t > 0.008856)
        return torch.where(mask, t ** (1/3), t * 7.787037 + 4/29)  # k/116 = 7.787 where k = 903.3

    fX = f(xyz_scaled[..., 0])
    fY = f(xyz_scaled[..., 1])
    fZ = f(xyz_scaled[..., 2])

    L = (116 * fY) - 16
    a = 500 * (fX - fY)
    b = 200 * (fY - fZ)

    return torch.stack([L, a, b], dim=-1)



def lab_to_srgb(x: torch.Tensor, srgb: bool = True) -> torch.Tensor:
    """
    Converts CIE Lab color space to sRGB color space (in [0,1]).
    
    Args:
        x: Tensor of shape (..., 3), dtype=torch.float32, representing L, a, b.
           L is expected in [0, 100], a and b roughly in [-128, +127], but no hard clamp.
        srgb: if True, apply sRGB gamma; otherwise output linear RGB.
        
    Returns:
        Tensor of shape (..., 3), dtype=torch.float32, in [0,1].
    """
    assert x.shape[-1] == 3, 'Input should have 3 channels (Lab)'
    assert x.dtype == torch.float32, 'Input should be float32'

    # Lab → f‑space
    L, a, b = x.unbind(-1)
    fY = (L + 16.0) / 116.0
    fX = a / 500.0 + fY
    fZ = fY - b / 200.0

    # inverse f(t)
    def finv(f: torch.Tensor) -> torch.Tensor:
        # threshold when t = 0.008856 → f = t^(1/3)
        mask = (f ** 3) > 0.008856
        t1 = f ** 3
        t2 = (f - 4.0/29.0) / 7.787037  # k/116 = 7.787037 where k=903.3
        return torch.where(mask, t1, t2)

    X = finv(fX)
    Y = finv(fY)
    Z = finv(fZ)

    # scale back by D65 white point
    white = torch.tensor([0.95047, 1.00000, 1.08883], dtype=x.dtype, device=x.device)
    xyz = torch.stack([X, Y, Z], dim=-1) * white

    # XYZ → linear RGB
    M_inv = torch.tensor([[ 3.2404542, -1.5371385, -0.4985314],
                          [-0.9692660,  1.8760108,  0.0415560],
                          [ 0.0556434, -0.2040259,  1.0572252]], dtype=x.dtype, device=x.device)
    rgb_linear = xyz @ M_inv.T

    # linear RGB → sRGB
    if srgb:
        mask = rgb_linear > 0.0031308
        rgb = torch.where(mask,
                          1.055 * torch.pow(rgb_linear, 1.0/2.4) - 0.055,
                          12.92 * rgb_linear)
    else:
        rgb = rgb_linear

    # clamp to [0,1]
    return torch.clamp(rgb, 0.0, 1.0)


def normalize(x: torch.Tensor):
    
    # uint8 → float32
    if not torch.is_floating_point(x):
        return x.to(torch.float32).div(255.0)
    
    # if already float but in [0..255], scale down
    if x.max() > 1.0:
        return x / 255.0
    
    return x