mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-24 14:05:00 +03:00
48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
# Copyright 2022 Reo Yoneyama (Nagoya University)
|
||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||
|
|
||
|
"""Snake Activation Function Module.
|
||
|
|
||
|
References:
|
||
|
- Neural Networks Fail to Learn Periodic Functions and How to Fix It
|
||
|
https://arxiv.org/pdf/2006.08195.pdf
|
||
|
- BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
||
|
https://arxiv.org/pdf/2206.04658.pdf
|
||
|
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class Snake(nn.Module):
|
||
|
"""Snake activation function module."""
|
||
|
|
||
|
def __init__(self, channels, init=50):
|
||
|
"""Initialize Snake module.
|
||
|
|
||
|
Args:
|
||
|
channels (int): Number of feature channels.
|
||
|
init (float): Initial value of the learnable parameter alpha.
|
||
|
According to the original paper, 5 ~ 50 would be
|
||
|
suitable for periodic data (i.e. voices).
|
||
|
|
||
|
"""
|
||
|
super(Snake, self).__init__()
|
||
|
alpha = init * torch.ones(1, channels, 1)
|
||
|
self.alpha = nn.Parameter(alpha)
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""Calculate forward propagation.
|
||
|
|
||
|
Args:
|
||
|
x (Tensor): Input noise signal (B, channels, T).
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Output tensor (B, channels, T).
|
||
|
|
||
|
"""
|
||
|
return x + torch.sin(self.alpha * x) ** 2 / self.alpha
|