voice-changer/server/voice_changer/MMVCv15/models/snake.py

48 lines
1.3 KiB
Python
Raw Normal View History

2023-06-22 00:56:00 +03:00
# -*- coding: utf-8 -*-
# Copyright 2022 Reo Yoneyama (Nagoya University)
# MIT License (https://opensource.org/licenses/MIT)
"""Snake Activation Function Module.
References:
- Neural Networks Fail to Learn Periodic Functions and How to Fix It
https://arxiv.org/pdf/2006.08195.pdf
- BigVGAN: A Universal Neural Vocoder with Large-Scale Training
https://arxiv.org/pdf/2206.04658.pdf
"""
import torch
import torch.nn as nn
class Snake(nn.Module):
"""Snake activation function module."""
def __init__(self, channels, init=50):
"""Initialize Snake module.
Args:
channels (int): Number of feature channels.
init (float): Initial value of the learnable parameter alpha.
According to the original paper, 5 ~ 50 would be
suitable for periodic data (i.e. voices).
"""
super(Snake, self).__init__()
alpha = init * torch.ones(1, channels, 1)
self.alpha = nn.Parameter(alpha)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, channels, T).
Returns:
Tensor: Output tensor (B, channels, T).
"""
return x + torch.sin(self.alpha * x) ** 2 / self.alpha