import torch
from kmeans import kmeans
from color import rgb_to_lab, lab_to_srgb


class BuildMosaic(torch.nn.Module):
    """
    PyTorch module that builds a mosaic image by replacing each pixel with the closest dominant color.

    See `build_mosaic` function for more details.
    """

    def __init__(self, n_colors=5, shuffle=False):
        super().__init__()
        self.n_colors = n_colors
        self.shuffle = shuffle

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        return build_mosaic(image, self.n_colors, self.shuffle)


def build_mosaic(image: torch.Tensor, n_colors=5, shuffle=False) -> torch.Tensor:
    """
    Finds the most dominant colors in the input image and replaces each pixel with the closest dominant color.

    Args:
        image: Tensor of shape (C, H, W) or (B, C, H, W) representing an image or a batch of images.
        n_colors: Number of dominant colors to extract (default: 5).
        shuffle: Whether to shuffle the pixels in the mosaic (default: False).

    Returns:
        Tensor of shape (C, H, W) or (B, C, H, W) containing the mosaic image.
    """
    assert image.ndim in [3, 4], 'Input should have shape (C, H, W) or (B, C, H, W)'
    assert image.shape[-3] == 3, 'Input should have 3 channels (RGB)'

    # Convert to float32 if not already
    dtype = image.dtype
    image = image.to(torch.float32)

    # Determine if we have batched images or a single image
    is_single = image.ndim == 3
    if is_single:
        image = image.unsqueeze(0)

    # Reshape image to a flattened list of pixels.
    B, C, H, W = image.shape
    image = image.permute(0, 2, 3, 1).reshape(B, -1, C)

    # Convert RGB to Lab color space
    not_scaled = image.max() > 1.0
    image = rgb_to_lab(image)

    # Run k-means to get dominant colors
    colors = kmeans(image, n_colors, runs=5, tol=1e-4)

    # Assign each pixel to the closest dominant color
    labels = assign_labels(image, colors)

    # Replace each pixel with their assigned color
    batch_idx = torch.arange(B, device=image.device).unsqueeze(-1).expand(B, H*W)
    mosaic = colors[batch_idx, labels]

    # Convert back to RGB color space
    mosaic = lab_to_srgb(mosaic)
    if not_scaled:
        mosaic *= 255.0

    # Shuffle pixels if necessary
    if shuffle:
        permutation = torch.randperm(H*W, device=image.device)
        mosaic = mosaic[:, permutation]

    # Reshape back to original image shape
    mosaic = mosaic.reshape(B, H, W, C).permute(0, 3, 1, 2)

    # Squeeze output if necessary
    if is_single:
        mosaic = mosaic.squeeze(0)

    # Convert back to original dtype
    mosaic = mosaic.to(dtype)

    return mosaic



def assign_labels(data: torch.Tensor, centroids: torch.Tensor) -> torch.Tensor:
    """
    Assigns each point to the closest centroid.

    Args:
        data: Tensor of shape (N, D) or (B, N, D) containing data points.
        centroids: Tensor of shape (K, D) or (B, K, D) representing cluster centroids.

    Returns:
        labels: A tensor of shape (N,) or (B, N) containing the index of the closest centroid for each point.
    """
    assert data.ndim in [2, 3], "Data should have shape (N, D) or (B, N, D)"
    assert data.ndim == centroids.ndim, "Data and centroids should have the same number of dimensions"
    assert data.shape[-1] == centroids.shape[-1], "Data and centroids should have the same number of dimensions"
    assert data.shape[0] == centroids.shape[0] or data.ndim != 3, "Batch size should match"

    # Compute distances to centroids
    distances = torch.cdist(data, centroids)    # Shape: (N, K) or (B, N, K)

    # Assign each point to the closest centroid
    labels = torch.argmin(distances, dim=-1)    # Shape: (N,) or (B, N)

    return labels


