from typing import Any

import torch
from torch import device

from const import EmbedderType
from voice_changer.RVC.embedder.EmbedderProtocol import EmbedderProtocol


class Embedder(EmbedderProtocol):
    def __init__(self):
        self.embedderType: EmbedderType = "hubert_base"
        self.file: str
        self.dev: device

        self.model: Any | None = None

    def loadModel(self, file: str, dev: device, isHalf: bool = True):
        ...

    def extractFeatures(
        self, feats: torch.Tensor, embOutputLayer=9, useFinalProj=True
    ) -> torch.Tensor:
        ...

    def getEmbedderInfo(self):
        return {
            "embedderType": self.embedderType,
            "file": self.file,
            "isHalf": self.isHalf,
            "devType": self.dev.type,
            "devIndex": self.dev.index,
        }

    def setProps(
        self,
        embedderType: EmbedderType,
        file: str,
        dev: device,
        isHalf: bool = True,
    ):
        self.embedderType = embedderType
        self.file = file
        self.isHalf = isHalf
        self.dev = dev

    def setHalf(self, isHalf: bool):
        self.isHalf = isHalf
        if self.model is not None and isHalf:
            self.model = self.model.half()
        elif self.model is not None and isHalf is False:
            self.model = self.model.float()

    def setDevice(self, dev: device):
        self.dev = dev
        if self.model is not None:
            self.model = self.model.to(self.dev)
        return self

    def matchCondition(self, embedderType: EmbedderType) -> bool:
        # Check Type
        if self.embedderType != embedderType:
            print(
                "[Voice Changer] embeder type is not match",
                self.embedderType,
                embedderType,
            )
            return False

        else:
            return True