from torch import device

from const import EmbedderType
from voice_changer.RVC.embedder.Embedder import Embedder
from voice_changer.RVC.embedder.FairseqContentvec import FairseqContentvec
from voice_changer.RVC.embedder.FairseqHubert import FairseqHubert
from voice_changer.RVC.embedder.FairseqHubertJp import FairseqHubertJp
from voice_changer.RVC.embedder.OnnxContentvec import OnnxContentvec
from voice_changer.utils.VoiceChangerParams import VoiceChangerParams


class EmbedderManager:
    currentEmbedder: Embedder | None = None
    params: VoiceChangerParams

    @classmethod
    def initialize(cls, params: VoiceChangerParams):
        cls.params = params

    @classmethod
    def getEmbedder(
        cls, embederType: EmbedderType, isHalf: bool, dev: device
    ) -> Embedder:
        if cls.currentEmbedder is None:
            print("[Voice Changer] generate new embedder. (no embedder)")
            cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
        elif cls.currentEmbedder.matchCondition(embederType) is False:
            print("[Voice Changer] generate new embedder. (not match)")
            cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)
        else:
            print("[Voice Changer] generate new embedder. (anyway)")
            cls.currentEmbedder = cls.loadEmbedder(embederType, isHalf, dev)

            # cls.currentEmbedder.setDevice(dev)
            # cls.currentEmbedder.setHalf(isHalf)
        return cls.currentEmbedder

    @classmethod
    def loadEmbedder(
        cls, embederType: EmbedderType, isHalf: bool, dev: device
    ) -> Embedder:
        if embederType == "hubert_base":
            try:
                if cls.params.content_vec_500_onnx_on is False:
                    raise Exception("[Voice Changer][Embedder] onnx is off")
                file = cls.params.content_vec_500_onnx
                return OnnxContentvec().loadModel(file, dev)
            except Exception as e:  # noqa
                print("[Voice Changer] use torch contentvec", e)
                file = cls.params.hubert_base
                return FairseqHubert().loadModel(file, dev, isHalf)
        elif embederType == "hubert-base-japanese":
            file = cls.params.hubert_base_jp
            return FairseqHubertJp().loadModel(file, dev, isHalf)
        elif embederType == "contentvec":
            try:
                if cls.params.content_vec_500_onnx_on is False:
                    raise Exception("[Voice Changer][Embedder] onnx is off")
                file = cls.params.content_vec_500_onnx
                return OnnxContentvec().loadModel(file, dev)
            except Exception as e:
                print(e)
                file = cls.params.hubert_base
                return FairseqContentvec().loadModel(file, dev, isHalf)
        else:
            return FairseqHubert().loadModel(file, dev, isHalf)