-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_extraction.py
146 lines (121 loc) · 6.6 KB
/
feature_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torchvision.models as models
from torchvision import transforms
def load_vit_model(weights=models.ViT_B_16_Weights.DEFAULT):
"""
Loads a pre-trained Vision Transformer (ViT) model using the specified weights, and returns both the model and
its associated preprocessing transforms.
Parameters:
- weights (torchvision.models.ViT_B_16_Weights, optional): The pre-trained weights to load for the ViT model.
The default is `ViT_B_16_Weights.DEFAULT`, which uses the standard pre-trained weights for the ViT-B/16 model.
Other weights can be specified based on available options in `torchvision.models`.
Returns:
- model (torchvision.models.ViT): The Vision Transformer model initialized with the specified pre-trained weights.
- preprocess (torchvision.transforms): A set of preprocessing transforms that are associated with the model's
pre-trained weights. These transforms ensure the input data is properly normalized and resized for the ViT model.
Notes:
- The function relies on the `torchvision.models` module to load the Vision Transformer model architecture
and weights.
- The associated preprocessing transforms include resizing, normalization, and other image processing steps needed
to prepare inputs for the model.
"""
vit = models.vit_b_16(weights=weights)
vit.eval()
preprocess = weights.transforms()
return vit, preprocess
def load_resnet50_model(weights=models.ResNet50_Weights.IMAGENET1K_V2):
"""
Loads a pre-trained ResNet-50 model with the final fully connected layer removed, and returns the model
along with its preprocessing transforms. This setup is useful for feature extraction tasks where the model's
output will be used as high-level image features.
Parameters:
- None
Returns:
- model (torchvision.models.ResNet): The ResNet-50 model pre-trained on ImageNet, with the final fully connected
layer excluded.
- preprocess (torchvision.transforms): A set of preprocessing transforms necessary to prepare input images
for the ResNet-50 model, including resizing, normalization, and tensor conversion.
Notes:
- The ResNet-50 model is loaded from the `torchvision.models` module, and is pre-trained on ImageNet.
- The final fully connected layer (used for classification) is excluded to make the model suitable for feature extraction.
"""
resnet = models.resnet50(weights=weights)
resnet.eval()
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return resnet, preprocess
def load_vgg16_model(weights=models.VGG16_Weights.IMAGENET1K_V1):
vgg16 = models.vgg16(weights=weights)
vgg16.eval()
vgg16 = torch.nn.Sequential(*list(vgg16.children())[:-1])
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return vgg16, preprocess
def load_mobilenetv3_model(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2):
mobilenet = models.mobilenet_v3_large(weights=weights)
mobilenet.eval()
mobilenet = mobilenet.features
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return mobilenet, preprocess
def extract_features(image, model, preprocess, model_type='vit'):
"""
Extracts features from a given image using a specified pre-trained model (Vision Transformer or ResNet).
The function first applies the necessary preprocessing steps to the image, forwards the processed image
through the model, and returns the extracted feature vector.
Parameters:
- image (PIL.Image or numpy.ndarray): The input image from which features will be extracted.
It should be in a format compatible with the preprocessing pipeline.
- model (torch.nn.Module): The pre-trained model used for feature extraction, such as a ViT or ResNet model.
- preprocess (torchvision.transforms): The preprocessing pipeline associated with the pre-trained model, which
ensures the image is appropriately resized, normalized, and converted to a tensor.
- model_type (str, optional): A string indicating the type of model being used for feature extraction.
Default is 'vit' for Vision Transformer, but 'resnet' can be specified for a ResNet model. This parameter
may affect how the model is used for inference.
Returns:
- features (torch.Tensor): A tensor representing the extracted feature vector from the image.
This feature vector can be used for tasks such as image clustering, similarity analysis, or as input to another model.
Notes:
- The input image is first preprocessed using the provided `preprocess` function to ensure it is in the correct format
for the model (e.g., size and normalization).
- The feature extraction process depends on the `model_type`. The Vision Transformer (ViT) and ResNet have different
architectures and return feature vectors of varying dimensions.
- Ensure that the pre-trained model and preprocessing transforms are compatible (i.e., both ViT or both ResNet).
- The returned feature vector can be used for various downstream tasks, including transfer learning, image retrieval,
or clustering.
"""
image = preprocess(image)
image = image.unsqueeze(0)
if model_type == 'resnet':
with torch.no_grad():
features = model(image)
features = torch.flatten(features, start_dim=1)
features = features[0].cpu().detach().numpy()
elif model_type == 'vgg16':
with torch.no_grad():
features = model(image)
features = torch.flatten(features, start_dim=1)
features = features[0].cpu().detach().numpy()
elif model_type == 'mobilenetv3':
with torch.no_grad():
features = model(image)
features = torch.flatten(features, start_dim=1)
features = features[0].cpu().detach().numpy()
elif model_type == 'vit': # ViT
feats = model._process_input(image)
batch_class_token = model.class_token.expand(image.shape[0], -1, -1)
feats = torch.cat([batch_class_token, feats], dim=1)
feats = model.encoder(feats)
feats = feats[:, 0]
features = feats.cpu().detach().numpy()[0]
return features