-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from hexuan21/main
add auto-metrics for video
- Loading branch information
Showing
9 changed files
with
545 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from brisque import BRISQUE | ||
from PIL import Image | ||
import numpy as np | ||
from typing import List | ||
|
||
ROUND_DIGIT=3 | ||
NUM_ASPECT=5 | ||
|
||
BRISQUE_POINT_LOW=10 | ||
BRISQUE_POINT_MID=30 | ||
BRISQUE_POINT_HIGH=50 | ||
|
||
class MetricBRISQUE(): | ||
def __init__(self) -> None: | ||
""" | ||
Initialize a class MetricBRISQUE for testing visual quality of a given video. | ||
""" | ||
None | ||
|
||
def evaluate(self,frame_list:List[Image.Image]): | ||
""" | ||
Calculate BRISQUE for visual quality for each frame of the given video and take the average value, | ||
then quantize the orginal output based on some predefined thresholds. | ||
Args: | ||
frame_list:List[Image.Image], frames of the video used in calculation | ||
Returns: | ||
piqe_avg: float, the computed average BRISQUE among the frames | ||
quantized_ans: int, the quantized value of the above avg score based on pre-defined thresholds. | ||
""" | ||
brisque_list=[] | ||
for frame in frame_list: | ||
brisque_score=BRISQUE().score(frame) | ||
brisque_list.append(brisque_score) | ||
brisque_avg=np.mean(brisque_list) | ||
quantized_ans=0 | ||
if brisque_avg < BRISQUE_POINT_LOW: | ||
quantized_ans=4 | ||
elif brisque_avg < BRISQUE_POINT_MID: | ||
quantized_ans=3 | ||
elif brisque_avg < BRISQUE_POINT_HIGH: | ||
quantized_ans=2 | ||
else: | ||
quantized_ans=1 | ||
return brisque_avg, quantized_ans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import numpy as np | ||
from PIL import Image | ||
import torch.nn.functional as F | ||
from typing import List | ||
from transformers import CLIPProcessor, CLIPModel | ||
|
||
ROUND_DIGIT=3 | ||
NUM_ASPECT=5 | ||
|
||
CLIP_POINT_HIGH=0.97 | ||
CLIP_POINT_MID=0.9 | ||
CLIP_POINT_LOW=0.8 | ||
|
||
|
||
class MetricCLIP_sim(): | ||
def __init__(self, device = "cuda") -> None: | ||
""" | ||
Initialize a class MetricCLIP_sim with the specified device for testing temporal consistency of a given video. | ||
Args: | ||
device (str, optional): The device on which the model will run. Defaults to "cuda". | ||
""" | ||
self.device = device | ||
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | ||
self.model.to(self.device) | ||
self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | ||
|
||
def evaluate(self,frame_list:List[Image.Image]): | ||
""" | ||
Calculate the cosine similarity between the CLIP features of adjacent frames of a given video to test temporal consistency, | ||
then quantize the orginal output based on some predefined thresholds. | ||
Args: | ||
frame_list:List[Image.Image], frames of the video used in calculation. | ||
Returns: | ||
clip_frame_score: float, the computed CLIP feature cosine similarity between each adjacent pair of frames and then averaged among all the pairs. | ||
quantized_ans: int, the quantized value of the above avg CLIP-Sim scores based on pre-defined thresholds. | ||
""" | ||
|
||
device=self.model.device | ||
frame_sim_list=[] | ||
for f_idx in range(len(frame_list)-1): | ||
frame_1 = frame_list[f_idx] | ||
frame_2 = frame_list[f_idx+1] | ||
input_1 = self.tokenizer(images=frame_1, return_tensors="pt", padding=True).to(device) | ||
input_2 = self.tokenizer(images=frame_2, return_tensors="pt", padding=True).to(device) | ||
output_1 = self.model.get_image_features(**input_1).flatten() | ||
output_2 = self.model.get_image_features(**input_2).flatten() | ||
cos_sim = F.cosine_similarity(output_1, output_2, dim=0).item() | ||
frame_sim_list.append(cos_sim) | ||
|
||
clip_frame_score = np.mean(frame_sim_list) | ||
quantized_ans=0 | ||
if clip_frame_score >= CLIP_POINT_HIGH: | ||
quantized_ans=4 | ||
elif clip_frame_score < CLIP_POINT_HIGH and clip_frame_score >= CLIP_POINT_MID: | ||
quantized_ans=3 | ||
elif clip_frame_score < CLIP_POINT_MID and clip_frame_score >= CLIP_POINT_LOW: | ||
quantized_ans=2 | ||
else: | ||
quantized_ans=1 | ||
return clip_frame_score, quantized_ans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
from PIL import Image | ||
import torch.nn.functional as F | ||
from typing import List | ||
from transformers import CLIPProcessor, CLIPModel | ||
|
||
NUM_ASPECT=5 | ||
ROUND_DIGIT=3 | ||
MAX_LENGTH = 76 | ||
|
||
MAX_NUM_FRAMES=8 | ||
|
||
CLIP_POINT_LOW=0.27 | ||
CLIP_POINT_MID=0.31 | ||
CLIP_POINT_HIGH=0.35 | ||
|
||
|
||
class MetricCLIPScore(): | ||
def __init__(self, device="cuda") -> None: | ||
""" | ||
Initialize a MetricCLIPScore object with the specified device. | ||
Args: | ||
device (str, optional): The device on which the model will run. Defaults to "cuda". | ||
""" | ||
self.device = device | ||
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | ||
self.model.to(self.device) | ||
self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | ||
|
||
def evaluate(self, frame_list:List[Image.Image], text:str,): | ||
""" | ||
Calculate the cosine similarity of between CLIP features of text prompt and each frame of a given video to test text-to-video alignment, | ||
then quantize the orginal output based on some predefined thresholds. | ||
Args: | ||
frame_list:List[Image.Image], frames of the video used in calculation. | ||
text:str, text prompt for generating the video. | ||
Returns: | ||
clip_score_avg: float, the computed average CLIP-Score between each frame and the text prompt. | ||
quantized_ans: int, the quantized value of the above avg SSIM scores based on pre-defined thresholds. | ||
""" | ||
|
||
device=self.model.device | ||
input_t = self.tokenizer(text=text, max_length=MAX_LENGTH, truncation=True, return_tensors="pt", padding=True).to(device) | ||
cos_sim_list=[] | ||
for image in frame_list: | ||
input_f = self.tokenizer(images=image, return_tensors="pt", padding=True).to(device) | ||
output_t = self.model.get_text_features(**input_t).flatten() | ||
output_f = self.model.get_image_features(**input_f).flatten() | ||
cos_sim = F.cosine_similarity(output_t, output_f, dim=0).item() | ||
cos_sim_list.append(cos_sim) | ||
clip_score_avg=np.mean(cos_sim_list) | ||
quantized_ans=0 | ||
if clip_score_avg < CLIP_POINT_LOW: | ||
quantized_ans=1 | ||
elif clip_score_avg >= CLIP_POINT_LOW and clip_score_avg < CLIP_POINT_MID: | ||
quantized_ans=2 | ||
elif clip_score_avg >= CLIP_POINT_MID and clip_score_avg < CLIP_POINT_HIGH: | ||
quantized_ans=3 | ||
else: | ||
quantized_ans=4 | ||
return clip_score_avg, quantized_ans | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
from PIL import Image | ||
import torch | ||
import torch.nn.functional as F | ||
from typing import List | ||
from torchvision.models import vit_b_16 | ||
import torchvision.transforms as transforms | ||
|
||
ROUND_DIGIT=3 | ||
NUM_ASPECT=5 | ||
|
||
DINO_POINT_HIGH=0.97 | ||
DINO_POINT_MID=0.9 | ||
DINO_POINT_LOW=0.8 | ||
|
||
|
||
class MetricDINO_sim(): | ||
def __init__(self, device="cuda") -> None: | ||
""" | ||
Initialize a class MetricDINO_sim with the specified device for testing temporal consistency of a given video. | ||
Args: | ||
device (str, optional): The device on which the model will run. Defaults to "cuda". | ||
""" | ||
self.device = device | ||
self.model = vit_b_16(pretrained=True) | ||
self.model.to(self.device).eval() | ||
self.preprocess = transforms.Compose([ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
]) | ||
|
||
def evaluate(self, frame_list:List[Image.Image]): | ||
""" | ||
Calculate the cosine similarity between the DINO features of adjacent frames of a given video to test temporal consistency, | ||
then quantize the orginal output based on some predefined thresholds. | ||
Args: | ||
frame_list:List[Image.Image], frames of the video used in calculation. | ||
Returns: | ||
dino_frame_score: float, the computed DINO feature cosine similarity between each adjacent pair of frames and then averaged among all the pairs. | ||
quantized_ans: int, the quantized value of the above avg DINO-Sim scores based on pre-defined thresholds. | ||
""" | ||
|
||
device = self.device | ||
frame_sim_list=[] | ||
for f_idx in range(len(frame_list)-1): | ||
frame_1=frame_list[f_idx] | ||
frame_2=frame_list[f_idx+1] | ||
frame_tensor_1 = self.preprocess(frame_1).unsqueeze(0).to(device) | ||
frame_tensor_2 = self.preprocess(frame_2).unsqueeze(0).to(device) | ||
with torch.no_grad(): | ||
feat_1 = self.model(frame_tensor_1).flatten() | ||
feat_2 = self.model(frame_tensor_2).flatten() | ||
cos_sim=F.cosine_similarity(feat_1, feat_2, dim=0).item() | ||
frame_sim_list.append(cos_sim) | ||
|
||
dino_frame_score = np.mean(frame_sim_list) | ||
quantized_ans=0 | ||
if dino_frame_score >= DINO_POINT_HIGH: | ||
quantized_ans=4 | ||
elif dino_frame_score < DINO_POINT_HIGH and dino_frame_score >= DINO_POINT_MID: | ||
quantized_ans=3 | ||
elif dino_frame_score < DINO_POINT_MID and dino_frame_score >= DINO_POINT_LOW: | ||
quantized_ans=2 | ||
else: | ||
quantized_ans=1 | ||
return dino_frame_score, quantized_ans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
import cv2 | ||
from PIL import Image | ||
from typing import List | ||
from skimage.metrics import structural_similarity as ssim | ||
from skimage import io, color | ||
|
||
ROUND_DIGIT=3 | ||
DYN_SAMPLE_STEP=4 | ||
NUM_ASPECT=5 | ||
|
||
MSE_POINT_HIGH=3000 | ||
MSE_POINT_MID=1000 | ||
MSE_POINT_LOW=100 | ||
|
||
|
||
class MetricMSE_dyn(): | ||
def __init__(self) -> None: | ||
""" | ||
Initialize a class MetricMSE_dyn for testing dynamic degree of a given video. | ||
""" | ||
None | ||
|
||
def evaluate(self, frame_list:List[Image.Image]): | ||
""" | ||
Calculate the MSE (Mean Squared Error) between frames sampled at regular intervals of a given video to test dynamic_degree, | ||
then quantize the orginal output based on some predefined thresholds. | ||
Args: | ||
frame_list:List[Image.Image], frames of the video used in calculation. | ||
Returns: | ||
mse_avg: float, the computed MSE between frames sampled at regular intervals and then averaged among all the pairs. | ||
quantized_ans: int, the quantized value of the above avg MSE scores based on pre-defined thresholds. | ||
""" | ||
|
||
mse_list=[] | ||
sampled_list = frame_list[::DYN_SAMPLE_STEP] | ||
for f_idx in range(len(sampled_list)-1): | ||
imageA = cv2.cvtColor(np.array(sampled_list[f_idx]), cv2.COLOR_RGB2BGR) | ||
imageB = cv2.cvtColor(np.array(sampled_list[f_idx+1]), cv2.COLOR_RGB2BGR) | ||
|
||
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2) | ||
err /= float(imageA.shape[0] * imageA.shape[1]) | ||
mse_value = err | ||
mse_list.append(mse_value) | ||
mse_avg=np.mean(mse_list) | ||
quantized_ans=0 | ||
if mse_avg >= MSE_POINT_HIGH: | ||
quantized_ans=4 | ||
elif mse_avg < MSE_POINT_HIGH and mse_avg >= MSE_POINT_MID: | ||
quantized_ans=3 | ||
elif mse_avg < MSE_POINT_MID and mse_avg >= MSE_POINT_LOW: | ||
quantized_ans=2 | ||
else: | ||
quantized_ans=1 | ||
|
||
return mse_avg, quantized_ans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from pypiqe import piqe | ||
from PIL import Image | ||
import numpy as np | ||
from typing import List | ||
|
||
ROUND_DIGIT=3 | ||
NUM_ASPECT=5 | ||
|
||
PIQE_POINT_LOW=15 | ||
PIQE_POINT_MID=30 | ||
PIQE_POINT_HIGH=50 | ||
|
||
class MetricPIQE(): | ||
def __init__(self) -> None: | ||
""" | ||
Initialize a class MetricPIQE for testing visual quality of a given video. | ||
""" | ||
None | ||
|
||
def evaluate(self,frame_list:List[Image.Image]): | ||
""" | ||
Calculate PIQE for visual quality for each frame of the given video and take the average value, | ||
then quantize the orginal output based on some predefined thresholds. | ||
Args: | ||
frame_list:List[Image.Image], frames of the video used in calculation. | ||
Returns: | ||
piqe_avg: float, the computed average PIQE among the frames. | ||
quantized_ans: int, the quantized value of the above avg score based on pre-defined thresholds. | ||
""" | ||
piqe_list=[] | ||
for frame in frame_list: | ||
frame=np.array(frame) | ||
piqe_score, _,_,_ = piqe(frame) | ||
piqe_list.append(piqe_score) | ||
piqe_avg=np.mean(piqe_list) | ||
quantized_ans=0 | ||
if piqe_avg < PIQE_POINT_LOW: | ||
quantized_ans=4 | ||
elif piqe_avg < PIQE_POINT_MID: | ||
quantized_ans=3 | ||
elif piqe_avg < PIQE_POINT_HIGH: | ||
quantized_ans=2 | ||
else: | ||
quantized_ans=1 | ||
return piqe_avg, quantized_ans | ||
|
Oops, something went wrong.