voice-changer/server/voice_changer/RVC/embedder/Embedder.py

64 lines
1.6 KiB
Python
Raw Normal View History

2023-07-07 21:18:23 +03:00
from typing import Any
2023-05-02 06:11:00 +03:00
import torch
from torch import device
from const import EmbedderType
2023-07-07 21:18:23 +03:00
from voice_changer.RVC.embedder.EmbedderProtocol import EmbedderProtocol
2023-05-02 06:11:00 +03:00
2023-07-07 21:18:23 +03:00
class Embedder(EmbedderProtocol):
def __init__(self):
self.embedderType: EmbedderType = "hubert_base"
self.file: str
self.dev: device
2023-05-02 06:11:00 +03:00
2023-07-07 21:18:23 +03:00
self.model: Any | None = None
2023-05-02 06:11:00 +03:00
2023-05-31 08:30:35 +03:00
def getEmbedderInfo(self):
return {
"embedderType": self.embedderType,
2023-05-31 08:30:35 +03:00
"file": self.file,
"isHalf": self.isHalf,
"devType": self.dev.type,
"devIndex": self.dev.index,
}
2023-05-02 14:57:12 +03:00
def setProps(
self,
embedderType: EmbedderType,
2023-05-02 14:57:12 +03:00
file: str,
dev: device,
isHalf: bool = True,
):
self.embedderType = embedderType
self.file = file
self.isHalf = isHalf
self.dev = dev
2023-05-02 06:11:00 +03:00
def setHalf(self, isHalf: bool):
self.isHalf = isHalf
if self.model is not None and isHalf:
self.model = self.model.half()
2023-05-03 07:14:00 +03:00
elif self.model is not None and isHalf is False:
self.model = self.model.float()
2023-05-02 06:11:00 +03:00
def setDevice(self, dev: device):
self.dev = dev
if self.model is not None:
self.model = self.model.to(self.dev)
2023-05-03 07:14:00 +03:00
return self
2023-05-02 06:11:00 +03:00
def matchCondition(self, embedderType: EmbedderType) -> bool:
2023-05-02 06:11:00 +03:00
# Check Type
if self.embedderType != embedderType:
print(
"[Voice Changer] embeder type is not match",
self.embedderType,
embedderType,
)
return False
else:
return True