import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles

import logging
from logging.config import dictConfig
import os, sys, math, base64, struct, traceback, time

import torch, torchaudio
import numpy as np
from scipy.io.wavfile import write, read
from datetime import datetime

args = sys.argv
PORT = args[1]
MODE = args[2]


logger = logging.getLogger('uvicorn')  
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


if MODE == "colab":
    print("ENV: colab")
    app.mount("/front", StaticFiles(directory="../frontend/dist", html=True), name="static")
    
    hubert_model = torch.hub.load("bshall/hubert:main", "hubert_soft").cuda()
    acoustic_model = torch.hub.load("bshall/acoustic-model:main", "hubert_soft").cuda()
    hifigan_model = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cuda()
else:    
    print("ENV: Docker")

    app.mount("/front", StaticFiles(directory="../frontend/dist", html=True), name="static")

    sys.path.append("/hubert")
    from hubert import hubert_discrete, hubert_soft, kmeans100

    sys.path.append("/acoustic-model")
    from acoustic import hubert_discrete, hubert_soft

    sys.path.append("/hifigan")
    from hifigan import hifigan

    hubert_model = torch.load("/models/bshall_hubert_main.pt").cuda()
    acoustic_model = torch.load("/models/bshall_acoustic-model_main.pt").cuda()
    hifigan_model = torch.load("/models/bshall_hifigan_main.pt").cuda()


def applyVol(i, chunk, vols):
  curVol = vols[i] / 2
  if curVol < 0.0001:
    line = torch.zeros(chunk.size())
  else:
    line = torch.ones(chunk.size())

  volApplied = torch.mul(line, chunk)  
  volApplied = volApplied.unsqueeze(0)
  return volApplied


@app.get("/test")
def get_test():
    try:
        return request.args.get('query', '')
    except Exception as e:
        print("REQUEST PROCESSING!!!! EXCEPTION!!!", e)
        print(traceback.format_exc())
        return str(e)



class VoiceModel(BaseModel):
    gpu: int
    srcId: int
    dstId: int
    timestamp: int
    buffer: str

@app.post("/test")
def post_test(voice:VoiceModel):
    try:
        print("POST REQUEST PROCESSING....")
        gpu = voice.gpu
        srcId = voice.srcId
        dstId = voice.dstId
        timestamp = voice.timestamp
        buffer = voice.buffer
        wav = base64.b64decode(buffer)
        unpackedData = np.array(struct.unpack('<%sh'%(len(wav) // struct.calcsize('<h') ), wav))
        # received_data_file = f"received_data_{timestamp}.wav"
        received_data_file = "received_data.wav"
        write(received_data_file, 24000, unpackedData.astype(np.int16))
        source, sr = torchaudio.load(received_data_file) # デフォルトでnormalize=Trueがついており、float32に変換して読んでくれるらしいのでこれを使う。https://pytorch.org/audio/stable/backend.html

        source_16k = torchaudio.functional.resample(source, 24000, 16000)
        source_16k = source_16k.unsqueeze(0).cuda()
        # SOFT-VC
        with torch.inference_mode():
            units = hubert_model.units(source_16k)
            mel = acoustic_model.generate(units).transpose(1, 2)
            target = hifigan_model(mel)

        dest = torchaudio.functional.resample(target, 16000,24000)
        dest = dest.squeeze().cpu()

        # ソースの音量取得
        source = source.cpu()
        specgram = torchaudio.transforms.MelSpectrogram(sample_rate=24000)(source)
        vol_apply_window_size = math.ceil(len(source[0]) / specgram.size()[2])
        specgram = specgram.transpose(1,2)
        vols = [ torch.max(i) for i in specgram[0]]
        chunks = torch.split(dest, vol_apply_window_size,0)

        chunks = [applyVol(i,c,vols) for i, c in enumerate(chunks)]
        dest = torch.cat(chunks,1)
        arr = np.array(dest.squeeze())

        int_size = 2**(16 - 1) - 1
        arr = (arr * int_size).astype(np.int16)
        write("converted_data.wav", 24000, arr)
        changedVoiceBase64 = base64.b64encode(arr).decode('utf-8')

        data = {
            "gpu":gpu,
            "srcId":srcId,
            "dstId":dstId,
            "timestamp":timestamp,
            "changedVoiceBase64":changedVoiceBase64
        }

        json_compatible_item_data = jsonable_encoder(data)

        
        return JSONResponse(content=json_compatible_item_data)
    except Exception as e:
        print("REQUEST PROCESSING!!!! EXCEPTION!!!", e)
        print(traceback.format_exc())
        return str(e)


if __name__ == '__main__':
    args = sys.argv
    PORT = args[1]
    MODE = args[2]
    logger.info('INITIALIZE MODEL')
    logger.info('START APP')
    uvicorn.run(f"{os.path.basename(__file__)[:-3]}:app", host="0.0.0.0", port=int(PORT), reload=True, log_level="info")