import os
import numpy as np
import yaml  # type: ignore
import torch
import torch.nn.functional as F
import pyworld as pw
import parselmouth
import torchcrepe
from transformers import HubertModel, Wav2Vec2FeatureExtractor
from fairseq import checkpoint_utils
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torchaudio.transforms import Resample
from .unit2control import Unit2Control
from .core import frequency_filter, upsample, remove_above_fmax, MaskedAvgPool1d, MedianPool1d
from ..encoder.hubert.model import HubertSoft

CREPE_RESAMPLE_KERNEL = {}


class F0_Extractor:
    def __init__(self, f0_extractor, sample_rate=44100, hop_size=512, f0_min=65, f0_max=800):
        self.f0_extractor = f0_extractor
        self.sample_rate = sample_rate
        self.hop_size = hop_size
        self.f0_min = f0_min
        self.f0_max = f0_max
        if f0_extractor == "crepe":
            key_str = str(sample_rate)
            if key_str not in CREPE_RESAMPLE_KERNEL:
                CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
            self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str]

    def extract(self, audio, uv_interp=False, device=None, silence_front=0):  # audio: 1d numpy array
        # extractor start time
        n_frames = int(len(audio) // self.hop_size) + 1

        start_frame = int(silence_front * self.sample_rate / self.hop_size)
        real_silence_front = start_frame * self.hop_size / self.sample_rate
        audio = audio[int(np.round(real_silence_front * self.sample_rate)) :]

        # extract f0 using parselmouth
        if self.f0_extractor == "parselmouth":
            f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac(time_step=self.hop_size / self.sample_rate, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"]
            pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2
            f0 = np.pad(f0, (pad_size, n_frames - len(f0) - pad_size))

        # extract f0 using dio
        elif self.f0_extractor == "dio":
            _f0, t = pw.dio(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, channels_in_octave=2, frame_period=(1000 * self.hop_size / self.sample_rate))
            f0 = pw.stonemask(audio.astype("double"), _f0, t, self.sample_rate)
            f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame))

        # extract f0 using harvest
        elif self.f0_extractor == "harvest":
            f0, _ = pw.harvest(audio.astype("double"), self.sample_rate, f0_floor=self.f0_min, f0_ceil=self.f0_max, frame_period=(1000 * self.hop_size / self.sample_rate))
            f0 = np.pad(f0.astype("float"), (start_frame, n_frames - len(f0) - start_frame))

        # extract f0 using crepe
        elif self.f0_extractor == "crepe":
            if device is None:
                device = "cuda" if torch.cuda.is_available() else "cpu"
            resample_kernel = self.resample_kernel.to(device)
            wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device))

            f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model="full", batch_size=512, device=device, return_periodicity=True)
            pd = MedianPool1d(pd, 4)
            f0 = torchcrepe.threshold.At(0.05)(f0, pd)
            f0 = MaskedAvgPool1d(f0, 4)

            f0 = f0.squeeze(0).cpu().numpy()
            f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)])
            f0 = np.pad(f0, (start_frame, 0))

        else:
            raise ValueError(f" [x] Unknown f0 extractor: {f0_extractor}")  # NOQA # type: ignore

        # interpolate the unvoiced f0
        if uv_interp:
            uv = f0 == 0
            if len(f0[~uv]) > 0:
                f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
            f0[f0 < self.f0_min] = self.f0_min
        return f0


class Volume_Extractor:
    def __init__(self, hop_size=512):
        self.hop_size = hop_size

    def extract(self, audio):  # audio: 1d numpy array
        n_frames = int(len(audio) // self.hop_size) + 1
        audio2 = audio**2
        audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode="reflect")
        volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
        volume = np.sqrt(volume)
        return volume


class Units_Encoder:
    def __init__(self, encoder, encoder_ckpt, encoder_sample_rate=16000, encoder_hop_size=320, device=None, cnhubertsoft_gate=10):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device

        is_loaded_encoder = False
        if encoder == "hubertsoft":
            self.model = Audio2HubertSoft(encoder_ckpt).to(device)
            is_loaded_encoder = True
        if encoder == "hubertbase":
            self.model = Audio2HubertBase(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "hubertbase768":
            self.model = Audio2HubertBase768(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "hubertbase768l12":
            self.model = Audio2HubertBase768L12(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "hubertlarge1024l24":
            self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "contentvec":
            self.model = Audio2ContentVec(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "contentvec768":
            self.model = Audio2ContentVec768(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "contentvec768l12":
            self.model = Audio2ContentVec768L12(encoder_ckpt, device=device)
            is_loaded_encoder = True
        if encoder == "cnhubertsoftfish":
            self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate)
            is_loaded_encoder = True
        if not is_loaded_encoder:
            raise ValueError(f" [x] Unknown units encoder: {encoder}")

        self.resample_kernel = {}
        self.encoder_sample_rate = encoder_sample_rate
        self.encoder_hop_size = encoder_hop_size

    def encode(self, audio, sample_rate, hop_size):  # B, T
        # resample
        if sample_rate == self.encoder_sample_rate:
            audio_res = audio
        else:
            key_str = str(sample_rate)
            if key_str not in self.resample_kernel:
                self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width=128).to(self.device)
            audio_res = self.resample_kernel[key_str](audio)

        # encode
        if audio_res.size(-1) < 400:
            audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1)))
        units = self.model(audio_res)

        # alignment
        n_frames = audio.size(-1) // hop_size + 1
        ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate)
        index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max=units.size(1) - 1)
        units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)]))
        return units_aligned


class Audio2HubertSoft(torch.nn.Module):
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320):
        super().__init__()
        print(" [Encoder Model] HuBERT Soft")
        self.hubert = HubertSoft()
        print(" [Loading] " + path)
        checkpoint = torch.load(path)
        consume_prefix_in_state_dict_if_present(checkpoint, "module.")
        self.hubert.load_state_dict(checkpoint)
        self.hubert.eval()

    def forward(self, audio):  # B, T
        with torch.inference_mode():
            units = self.hubert.units(audio.unsqueeze(1))
            return units


class Audio2ContentVec:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] Content Vec")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        # wav_tensor = torch.from_numpy(audio).to(self.device)
        wav_tensor = audio
        feats = wav_tensor.view(1, -1)
        padding_mask = torch.BoolTensor(feats.shape).fill_(False)
        inputs = {
            "source": feats.to(wav_tensor.device),
            "padding_mask": padding_mask.to(wav_tensor.device),
            "output_layer": 9,  # layer 9
        }
        with torch.no_grad():
            logits = self.hubert.extract_features(**inputs)
            feats = self.hubert.final_proj(logits[0])
        units = feats  # .transpose(2, 1)
        return units


class Audio2ContentVec768:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] Content Vec")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        # wav_tensor = torch.from_numpy(audio).to(self.device)
        wav_tensor = audio
        feats = wav_tensor.view(1, -1)
        padding_mask = torch.BoolTensor(feats.shape).fill_(False)
        inputs = {
            "source": feats.to(wav_tensor.device),
            "padding_mask": padding_mask.to(wav_tensor.device),
            "output_layer": 9,  # layer 9
        }
        with torch.no_grad():
            logits = self.hubert.extract_features(**inputs)
            feats = logits[0]
        units = feats  # .transpose(2, 1)
        return units


class Audio2ContentVec768L12:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] Content Vec")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        # wav_tensor = torch.from_numpy(audio).to(self.device)
        wav_tensor = audio
        feats = wav_tensor.view(1, -1)
        padding_mask = torch.BoolTensor(feats.shape).fill_(False)
        inputs = {
            "source": feats.to(wav_tensor.device),
            "padding_mask": padding_mask.to(wav_tensor.device),
            "output_layer": 12,  # layer 12
        }
        with torch.no_grad():
            logits = self.hubert.extract_features(**inputs)
            feats = logits[0]
        units = feats  # .transpose(2, 1)
        return units


class CNHubertSoftFish(torch.nn.Module):
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu", gate_size=10):
        super().__init__()
        self.device = device
        self.gate_size = gate_size

        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
        self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
        self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256))
        # self.label_embedding = nn.Embedding(128, 256)

        state_dict = torch.load(path, map_location=device)
        self.load_state_dict(state_dict)

    @torch.no_grad()
    def forward(self, audio):
        input_values = self.feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_values
        input_values = input_values.to(self.model.device)

        return self._forward(input_values[0])

    @torch.no_grad()
    def _forward(self, input_values):
        features = self.model(input_values)
        features = self.proj(features.last_hidden_state)

        # Top-k gating
        topk, indices = torch.topk(features, self.gate_size, dim=2)
        features = torch.zeros_like(features).scatter(2, indices, topk)
        features = features / features.sum(2, keepdim=True)

        return features.to(self.device)  # .transpose(1, 2)


class Audio2HubertBase:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] HuBERT Base")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert = self.hubert.float()
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        with torch.no_grad():
            padding_mask = torch.BoolTensor(audio.shape).fill_(False)
            inputs = {
                "source": audio.to(self.device),
                "padding_mask": padding_mask.to(self.device),
                "output_layer": 9,  # layer 9
            }
            logits = self.hubert.extract_features(**inputs)
            units = self.hubert.final_proj(logits[0])
            return units


class Audio2HubertBase768:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] HuBERT Base")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert = self.hubert.float()
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        with torch.no_grad():
            padding_mask = torch.BoolTensor(audio.shape).fill_(False)
            inputs = {
                "source": audio.to(self.device),
                "padding_mask": padding_mask.to(self.device),
                "output_layer": 9,  # layer 9
            }
            logits = self.hubert.extract_features(**inputs)
            units = logits[0]
            return units


class Audio2HubertBase768L12:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] HuBERT Base")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert = self.hubert.float()
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        with torch.no_grad():
            padding_mask = torch.BoolTensor(audio.shape).fill_(False)
            inputs = {
                "source": audio.to(self.device),
                "padding_mask": padding_mask.to(self.device),
                "output_layer": 12,  # layer 12
            }
            logits = self.hubert.extract_features(**inputs)
            units = logits[0]
            return units


class Audio2HubertLarge1024L24:
    def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device="cpu"):
        self.device = device
        print(" [Encoder Model] HuBERT Base")
        print(" [Loading] " + path)
        self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
            [path],
            suffix="",
        )
        self.hubert = self.models[0]
        self.hubert = self.hubert.to(self.device)
        self.hubert = self.hubert.float()
        self.hubert.eval()

    def __call__(self, audio):  # B, T
        with torch.no_grad():
            padding_mask = torch.BoolTensor(audio.shape).fill_(False)
            inputs = {
                "source": audio.to(self.device),
                "padding_mask": padding_mask.to(self.device),
                "output_layer": 24,  # layer 24
            }
            logits = self.hubert.extract_features(**inputs)
            units = logits[0]
            return units


class DotDict(dict):
    def __getattr__(*args):  # type: ignore
        val = dict.get(*args)
        return DotDict(val) if type(val) is dict else val

    __setattr__ = dict.__setitem__  # type: ignore
    __delattr__ = dict.__delitem__  # type: ignore


def load_model(model_path, device="cpu"):
    config_file = os.path.join(os.path.split(model_path)[0], "config.yaml")
    with open(config_file, "r") as config:
        args = yaml.safe_load(config)
    args = DotDict(args)

    # load model
    model = None

    if args.model.type == "Sins":
        model = Sins(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_harmonics=args.model.n_harmonics, n_mag_allpass=args.model.n_mag_allpass, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)

    elif args.model.type == "CombSub":
        model = CombSub(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_mag_allpass=args.model.n_mag_allpass, n_mag_harmonic=args.model.n_mag_harmonic, n_mag_noise=args.model.n_mag_noise, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)

    elif args.model.type == "CombSubFast":
        model = CombSubFast(sampling_rate=args.data.sampling_rate, block_size=args.data.block_size, n_unit=args.data.encoder_out_channels, n_spk=args.model.n_spk)

    else:
        raise ValueError(f" [x] Unknown Model: {args.model.type}")

    print(" [Loading] " + model_path)
    ckpt = torch.load(model_path, map_location=torch.device(device))
    model.to(device)
    model.load_state_dict(ckpt["model"])
    model.eval()
    return model, args


class Sins(torch.nn.Module):
    def __init__(self, sampling_rate, block_size, n_harmonics, n_mag_allpass, n_mag_noise, n_unit=256, n_spk=1):
        super().__init__()

        print(" [DDSP Model] Sinusoids Additive Synthesiser")

        # params
        self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
        self.register_buffer("block_size", torch.tensor(block_size))
        # Unit2Control
        split_map = {
            "amplitudes": n_harmonics,
            "group_delay": n_mag_allpass,
            "noise_magnitude": n_mag_noise,
        }
        self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)

    def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32):
        """
        units_frames: B x n_frames x n_unit
        f0_frames: B x n_frames x 1
        volume_frames: B x n_frames x 1
        spk_id: B x 1
        """
        # exciter phase
        f0 = upsample(f0_frames, self.block_size)
        if infer:
            x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
        else:
            x = torch.cumsum(f0 / self.sampling_rate, axis=1)
        if initial_phase is not None:
            x += initial_phase.to(x) / 2 / np.pi
        x = x - torch.round(x)
        x = x.to(f0)

        phase = 2 * np.pi * x
        phase_frames = phase[:, :: self.block_size, :]

        # parameter prediction
        ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)

        amplitudes_frames = torch.exp(ctrls["amplitudes"]) / 128
        group_delay = np.pi * torch.tanh(ctrls["group_delay"])
        noise_param = torch.exp(ctrls["noise_magnitude"]) / 128

        # sinusoids exciter signal
        amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start=1)
        n_harmonic = amplitudes_frames.shape[-1]
        level_harmonic = torch.arange(1, n_harmonic + 1).to(phase)
        sinusoids = 0.0
        for n in range((n_harmonic - 1) // max_upsample_dim + 1):
            start = n * max_upsample_dim
            end = (n + 1) * max_upsample_dim
            phases = phase * level_harmonic[start:end]
            amplitudes = upsample(amplitudes_frames[:, :, start:end], self.block_size)
            sinusoids += (torch.sin(phases) * amplitudes).sum(-1)

        # harmonic part filter (apply group-delay)
        harmonic = frequency_filter(sinusoids, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False)

        # noise part filter
        noise = torch.rand_like(harmonic) * 2 - 1
        noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True)

        signal = harmonic + noise

        return signal, phase, (harmonic, noise)  # , (noise_param, noise_param)


class CombSubFast(torch.nn.Module):
    def __init__(self, sampling_rate, block_size, n_unit=256, n_spk=1):
        super().__init__()

        print(" [DDSP Model] Combtooth Subtractive Synthesiser")
        # params
        self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
        self.register_buffer("block_size", torch.tensor(block_size))
        self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size)))
        # Unit2Control
        split_map = {"harmonic_magnitude": block_size + 1, "harmonic_phase": block_size + 1, "noise_magnitude": block_size + 1}
        self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)

    def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
        """
        units_frames: B x n_frames x n_unit
        f0_frames: B x n_frames x 1
        volume_frames: B x n_frames x 1
        spk_id: B x 1
        """
        # exciter phase
        f0 = upsample(f0_frames, self.block_size)
        if infer:
            x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
        else:
            x = torch.cumsum(f0 / self.sampling_rate, axis=1)
        if initial_phase is not None:
            x += initial_phase.to(x) / 2 / np.pi
        x = x - torch.round(x)
        x = x.to(f0)

        phase_frames = 2 * np.pi * x[:, :: self.block_size, :]

        # parameter prediction
        ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)

        src_filter = torch.exp(ctrls["harmonic_magnitude"] + 1.0j * np.pi * ctrls["harmonic_phase"])
        src_filter = torch.cat((src_filter, src_filter[:, -1:, :]), 1)
        noise_filter = torch.exp(ctrls["noise_magnitude"]) / 128
        noise_filter = torch.cat((noise_filter, noise_filter[:, -1:, :]), 1)

        # combtooth exciter signal
        combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
        combtooth = combtooth.squeeze(-1)
        combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
        combtooth_frames = combtooth_frames * self.window
        combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size)

        # noise exciter signal
        noise = torch.rand_like(combtooth) * 2 - 1
        noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
        noise_frames = noise_frames * self.window
        noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size)

        # apply the filters
        signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter

        # take the ifft to resynthesize audio.
        signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window

        # overlap add
        fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size))
        signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size]

        return signal, phase_frames, (signal, signal)


class CombSub(torch.nn.Module):
    def __init__(self, sampling_rate, block_size, n_mag_allpass, n_mag_harmonic, n_mag_noise, n_unit=256, n_spk=1):
        super().__init__()

        print(" [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)")
        # params
        self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
        self.register_buffer("block_size", torch.tensor(block_size))
        # Unit2Control
        split_map = {"group_delay": n_mag_allpass, "harmonic_magnitude": n_mag_harmonic, "noise_magnitude": n_mag_noise}
        self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)

    def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
        """
        units_frames: B x n_frames x n_unit
        f0_frames: B x n_frames x 1
        volume_frames: B x n_frames x 1
        spk_id: B x 1
        """
        # exciter phase
        f0 = upsample(f0_frames, self.block_size)
        if infer:
            x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
        else:
            x = torch.cumsum(f0 / self.sampling_rate, axis=1)
        if initial_phase is not None:
            x += initial_phase.to(x) / 2 / np.pi
        x = x - torch.round(x)
        x = x.to(f0)

        phase_frames = 2 * np.pi * x[:, :: self.block_size, :]

        # parameter prediction
        ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)

        group_delay = np.pi * torch.tanh(ctrls["group_delay"])
        src_param = torch.exp(ctrls["harmonic_magnitude"])
        noise_param = torch.exp(ctrls["noise_magnitude"]) / 128

        # combtooth exciter signal
        combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
        combtooth = combtooth.squeeze(-1)

        # harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction)
        harmonic = frequency_filter(combtooth, torch.exp(1.0j * torch.cumsum(group_delay, axis=-1)), hann_window=False)
        harmonic = frequency_filter(harmonic, torch.complex(src_param, torch.zeros_like(src_param)), hann_window=True, half_width_frames=1.5 * self.sampling_rate / (f0_frames + 1e-3))

        # noise part filter (using constant-windowed LTV-FIR, without group-delay)
        noise = torch.rand_like(harmonic) * 2 - 1
        noise = frequency_filter(noise, torch.complex(noise_param, torch.zeros_like(noise_param)), hann_window=True)

        signal = harmonic + noise

        return signal, phase_frames, (harmonic, noise)