mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 21:45:00 +03:00
48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2022 Reo Yoneyama (Nagoya University)
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
"""Snake Activation Function Module.
|
|
|
|
References:
|
|
- Neural Networks Fail to Learn Periodic Functions and How to Fix It
|
|
https://arxiv.org/pdf/2006.08195.pdf
|
|
- BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
|
https://arxiv.org/pdf/2206.04658.pdf
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Snake(nn.Module):
|
|
"""Snake activation function module."""
|
|
|
|
def __init__(self, channels, init=50):
|
|
"""Initialize Snake module.
|
|
|
|
Args:
|
|
channels (int): Number of feature channels.
|
|
init (float): Initial value of the learnable parameter alpha.
|
|
According to the original paper, 5 ~ 50 would be
|
|
suitable for periodic data (i.e. voices).
|
|
|
|
"""
|
|
super(Snake, self).__init__()
|
|
alpha = init * torch.ones(1, channels, 1)
|
|
self.alpha = nn.Parameter(alpha)
|
|
|
|
def forward(self, x):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Input noise signal (B, channels, T).
|
|
|
|
Returns:
|
|
Tensor: Output tensor (B, channels, T).
|
|
|
|
"""
|
|
return x + torch.sin(self.alpha * x) ** 2 / self.alpha
|