注意!!!!!!!!!!!!!这个代码计算的SSIM值是错误的,PSNR和MES的值是正确的。新的计算代码点击此处跳转

如题,批量计算图像的psnr,ssim,mse,并将计算结果汇总写入文件

import os
import numpy as np
import math
from PIL import Image

import time

start = time.clock()

def psnr(img1, img2):
    mse = np.mean((img1 / 1. - img2 / 1.) ** 2)
    if mse < 1.0e-10:
        return 100 * 1.0
    return 10 * math.log10(255.0 * 255.0 / mse)


def mse(img1, img2):
    mse = np.mean((img1 / 1. - img2 / 1.) ** 2)
    return mse


def ssim(y_true, y_pred):
    u_true = np.mean(y_true)
    u_pred = np.mean(y_pred)
    var_true = np.var(y_true)
    var_pred = np.var(y_pred)
    std_true = np.sqrt(var_true)
    std_pred = np.sqrt(var_pred)
    c1 = np.square(0.01 * 7)
    c2 = np.square(0.03 * 7)
    ssim = (2 * u_true * u_pred + c1) * (2 * std_pred * std_true + c2)
    denom = (u_true ** 2 + u_pred ** 2 + c1) * (var_pred + var_true + c2)
    return ssim / denom


path1 = './images/test1/'  # 指定输出结果文件夹
path2 = './images/test2/'  # 指定原图文件夹
f_nums = len(os.listdir(path1))
list_psnr = []
list_ssim = []
list_mse = []
file = open(r'psnr-ssim-scene3-70-100.txt', mode='w',encoding='utf-8')
for i in range(0, f_nums):
    print('第%s张图片' % i)
    img_a = Image.open(path1 + str(i) + '.png')
    img_b = Image.open(path2 + str(i) + '.png')
    img_a = np.array(img_a)
    img_b = np.array(img_b)

    psnr_num = psnr(img_a, img_b)
    ssim_num = ssim(img_a, img_b)
    mse_num = mse(img_a, img_b)
    print('psnr_num', psnr_num)
    print('ssim_num', ssim_num)
    print('mse_num', mse_num)
    list_ssim.append(ssim_num)
    list_psnr.append(psnr_num)
    list_mse.append(mse_num)
    file.write('第%s张图片:' % i + '\n')
    file.write('psnr_num, {:.5f}'.format(psnr_num) + '\n')
    file.write('ssim_num, {:.5f}'.format(ssim_num) + '\n')
    file.write('mse_num, {:.5f}'.format(mse_num) + '\n')


print("平均PSNR:", np.mean(list_psnr))  # ,list_psnr)
print("平均SSIM:", np.mean(list_ssim))  # ,list_ssim)
print("平均MSE:", np.mean(list_mse))  # ,list_mse)

elapsed = (time.clock() - start)
print("Time used:", elapsed)
file.write("\n")
file.write("汇总:\n")
file.write('平均PSNR, {:.5f}'.format(np.mean(list_psnr)) + '\n')
file.write('平均SSIM, {:.5f}'.format(np.mean(list_ssim)) + '\n')
file.write('平均MSE, {:.5f}'.format(np.mean(list_mse)) + '\n')
file.write('Time used, {:.5f}'.format(elapsed) + '\n')
file.close()
Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐