import torch
from torch.nn import Module, Sequential, Identity, BatchNorm2d
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights


def make_trainable(model: Module, grad: bool):
    """Set the requires_grad attribute of all parameters in a module"""
    for params in model.parameters():
        params.requires_grad = grad

def unfreeze_layers(model: Sequential, count: int):
    """Unfreeze the last `count` layers of a Sequential model"""
    assert 0 <= count <= len(model)
    for layer in model[-count:]:
        make_trainable(layer, True)

def freeze_by_type(model: Module, layer_type: type[Module] | tuple[type[Module]]):
    """Freeze the modules of a certain type"""
    for layer in model.modules():
        if isinstance(layer, layer_type):
            make_trainable(layer, False)

def set_eval_mode(model: Module, layer_type: type[Module]):
    """Set the modules of a certain type to evaluation mode"""
    for layer in model.modules():
        if isinstance(layer, layer_type):
            layer.eval()



#-----------------------------#
#----- Pretrained models -----#
#-----------------------------#

class MobileNet(Module):

    def __init__(self, head: Module = Identity()):
        super().__init__()
        weights = MobileNet_V3_Large_Weights.DEFAULT
        mobilenet = mobilenet_v3_large(weights=weights)
        self.backbone = mobilenet.features
        self.transforms = weights.transforms
        self.head_input_shape = 960, 7, 7
        self.head = head

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

    def train(self, mode: bool = True):
        """Set the model to training mode, except for BatchNorm layers"""
        super().train(mode)
        set_eval_mode(self.backbone, BatchNorm2d)

    def freeze(self):
        """Freeze the backbone of the model"""
        make_trainable(self.backbone, False)

    def unfreeze(self, count: int):
        """Unfreeze the last layers of the backbone, except for BatchNorm layers"""
        unfreeze_layers(self.backbone, count)
        freeze_by_type(self.backbone, BatchNorm2d)



#----------------------------#
#----- Model Inspection -----#
#----------------------------#
# (If you don't want to install torchinfo)

def summary(model: torch.nn.Module):
    """
     - Print a summary of the model architecture at the first two levels of nesting.
     - Count the number of parameters at each level.
    """
    
    # Print the model name
    text = type(model).__name__
    print(text)
    print('-' * len(text))
    
    # Loop through the first level of modules
    for name, module in model.named_children():

        # Print the module name and number of parameters
        count = count_parameters(module)
        print(f"\n{name} ({type(module).__name__}): {count}")
        
        # Loop through the second level of modules
        for submodule in module.children():
            # Print the submodule name and number of parameters
            count = count_parameters(submodule)
            print(f"  - {type(submodule).__name__}: {count}")


def count_parameters(model: torch.nn.Module, trainable: bool = True):
    """Count the number of trainable or non-trainable parameters in the given model."""
    return sum(1 for params in model.parameters() if params.requires_grad == trainable)