diff --git a/reflow/solver.py b/reflow/solver.py index de8f0c1..c669074 100644 --- a/reflow/solver.py +++ b/reflow/solver.py @@ -8,6 +8,7 @@ from torch import autocast from torch.cuda.amp import GradScaler from nsf_hifigan.nvSTFT import STFT +from reflow.ssim import calculate_ssim def calculate_mel_snr(gt_mel, pred_mel): # 计算误差图像 @@ -47,6 +48,15 @@ def calculate_mel_psnr(gt_mel, pred_mel): psnr = 10 * torch.log10(max_power / mse) return psnr +def calculate_mel_ssim(gt_mel, pred_mel): + # B × M × T + pred_mel = pred_mel.transpose(-1, -2) + pred_mel = pred_mel[:, None] + gt_mel = gt_mel.transpose(-1, -2) + gt_mel = gt_mel[:, None] + ssim = calculate_ssim(pred_mel, gt_mel, size_average=True) + return ssim + def test(args, model, vocoder, loader_test, saver): print(' [*] testing...') model.eval() @@ -61,6 +71,7 @@ def test(args, model, vocoder, loader_test, saver): mel_val_snr_all = 0 mel_val_psnr_all = 0 mel_val_sisnr_all = 0 + mel_val_ssim_all = 0 # intialization num_batches = len(loader_test) @@ -157,6 +168,7 @@ def test(args, model, vocoder, loader_test, saver): mel_val_snr_all += calculate_mel_snr(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() mel_val_psnr_all += calculate_mel_psnr(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() mel_val_sisnr_all += calculate_mel_si_snr(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() + mel_val_ssim_all += calculate_mel_ssim(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() mel_val_mse_all_num += 1 # report @@ -166,6 +178,7 @@ def test(args, model, vocoder, loader_test, saver): mel_val_snr_all /= mel_val_mse_all_num mel_val_psnr_all /= mel_val_mse_all_num mel_val_sisnr_all /= mel_val_mse_all_num + mel_val_ssim_all /= mel_val_mse_all_num # check print(' [test_ddsp_loss] test_ddsp_loss:', test_ddsp_loss) @@ -187,6 +200,10 @@ def test(args, model, vocoder, loader_test, saver): saver.log_value({ 'validation/mel_val_sisnr': mel_val_sisnr_all }) + print(' Mel Val SSIM', mel_val_ssim_all) + saver.log_value({ + 'validation/mel_val_ssim': mel_val_ssim_all + }) return test_ddsp_loss, test_reflow_loss diff --git a/reflow/ssim.py b/reflow/ssim.py new file mode 100644 index 0000000..86da8ea --- /dev/null +++ b/reflow/ssim.py @@ -0,0 +1,58 @@ +""" +Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim +""" + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + + +window = None + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1) + + +def calculate_ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + global window + if window is None: + window = create_window(window_size, channel) + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + return _ssim(img1, img2, window, window_size, channel, size_average)