import torch
import torch.nn.functional as F



def kmeans(data: torch.Tensor, k: int, runs: int=20, tol: float = 1e-5) -> torch.Tensor:
    """"
    Run k-means clustering multiple times and return the best result.

    Args:
        data: Tensor of shape (N,) or (N, D) or (B, N, D), where
            - N: Number of points.
            - D: Dimensionality.
            - B: Number of batches.
        k: Number of clusters.
        runs: Number of times to run k-means.
        tol: Convergence threshold.

    Returns:
        centroids: Tensor of shape (K,) or (K, D) or (B, K, D) representing cluster centroids.
    """
    
    # Ensure data has shape (B, N, D)
    squeeze = []
    if data.ndim == 2:
        data = data.unsqueeze(0)    # Convert (N, D) -> (1, N, D)
        squeeze = [0]
    elif data.ndim == 1:
        data = data.unsqueeze(1).unsqueeze(0)   # Convert (N,) -> (1, N, 1)
        squeeze = [0, 1]
    
    # Get batch size
    B = data.shape[0]

    # Initialization
    best_codes = None
    best_distort = torch.full((B,), float('inf'), device=data.device)

    # Run k-means multiple times
    for _ in range(runs):

        # Run k-means
        codes, distort = _kmeans_batched(data, k, tol)

        # Update best result
        better_mask = distort < best_distort
        best_distort = torch.where(better_mask, distort, best_distort)

        # Update best codes
        if best_codes is None:
            best_codes = codes
        else:
            better_mask = better_mask.view(B, 1, 1)
            best_codes = torch.where(better_mask, codes, best_codes)

    # Squeeze output if necessary
    for dim in squeeze:
        best_codes = best_codes.squeeze(dim)

    return best_codes



def _kmeans_batched(data: torch.Tensor, k: int, tol: float = 1e-5) -> torch.Tensor:
    """"
    Perform batched k-means clustering.

    Args:
        data: Tensor of shape (B, N, D), where
            - B: Number of batches.
            - N: Number of points.
            - D: Dimensionality.
        k: Number of clusters.
        tol: Convergence threshold.

    Returns:
        (centroids, distort): 
            - centroids: Tensor of shape (B, N, D) representing cluster centroids.
            - distort: Tensor of shape (B,) containing the distortion value.
    """
    assert data.ndim == 3, "Input data must have shape (B, N, D)."

    # Extract dimensions
    B, N, _ = data.shape

    # Check if k is valid
    if k > N:
        raise ValueError(f"k={k} cannot be greater than the number of points N={N}.")

    # Randomly select K initial centroids for each run
    weights = torch.ones((B, N), device=data.device)
    idx = torch.multinomial(weights, k, replacement=False)  # Shape: (B, K)

    # Select initial centroids
    centroids = gather_points(data, idx)  # Shape: (B, K, D)

    # Initialize convergence error
    error = distort = torch.tensor(float('inf'), device=data.device)

    # Iterate until convergence
    while torch.any(error > tol):

        # Compute distances to centroids
        distances = torch.cdist(data, centroids)  # Shape: (B, N, K)

        # Assign each point to the closest centroid
        labels = torch.argmin(distances, dim=-1)  # Shape: (B, N) - Values: [0, K)

        # One-hot encoding
        one_hot = F.one_hot(labels, num_classes=k).float().permute(0, 2, 1)  # Shape: (B, K, N)

        # Compute new centroids
        sum_one_hot = one_hot.sum(dim=-1, keepdim=True)  # Shape: (B, K, 1)
        centroids = torch.bmm(one_hot, data) / torch.clamp(sum_one_hot, min=1)  # Shape: (B, K, D)

        # Compute distortion
        previous = distort
        vq = gather_points(centroids, labels)                 # Shape: (B, N, D)
        distort = torch.norm(vq - data, dim=-1).mean(dim=-1)  # Shape: (B,)
        error = torch.abs(previous - distort)
        
    return centroids, distort


def gather_points(points: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
    """
    Gathers elements along axis 1 based on indices.

    Args:
        points: Tensor of shape (B, N, D) containing data points.
        idx: Tensor of shape (B, K) containing indices to gather.

    Returns:
        A tensor of shape (B, K, D) containing the selected elements.
    """
    B, K = idx.shape
    D = points.shape[-1]
    index = idx.unsqueeze(-1).expand(B, K, D)
    return torch.gather(points, 1, index)




#------------------------#
#------ Unit Tests ------#
#------------------------#

if __name__ == "__main__":

    # 3D input (B, N, D)
    data3d = torch.randn(10, 1000, 5)  # 10 batches, 1000 points, 5 dimensions
    codes = kmeans(data3d, k=5)
    assert codes.shape == (10, 5, 5)

    # 2D input (N, D)
    data2d = torch.randn(1000, 5)  # 1000 points, 5 dimensions
    codes = kmeans(data2d, k=5)
    assert codes.shape == (5, 5)

    # 1D input (N,)
    data1d = torch.randn(1000)  # 1000 points
    codes = kmeans(data1d, k=5)
    assert codes.shape == (5,)

    print("All tests passed!")