-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorchUtility.py
123 lines (99 loc) · 3.9 KB
/
pytorchUtility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import time
import torch
import torch.nn as nn
import numpy as np
from collections import Counter
def calAccuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# print(pred.type(), pred.size())
correct = pred.eq(target.view(1, -1).expand_as(pred))
# print(target.type(), target.size())
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def calAveragePredictionVectorAccuracy(predictionVectorsList, target, modelsList=None, topk=(1,)):
predictionVectorsStack = torch.stack(predictionVectorsList)
if len(modelsList) > 0:
predictionVectorsStack = predictionVectorsStack[modelsList,...]
averagePrediction = torch.mean(predictionVectorsStack, dim=0)
return calAccuracy(averagePrediction, target, topk)
def calNegativeSamplesSet(predictionVectorsList, target):
"""filter the disagreed samples, return an array of sets"""
batchSize = target.size(0)
predictionList = list()
negativeSamplesSet = list()
for pVL in predictionVectorsList:
_, pred = pVL.max(dim=1)
predictionList.append(pred)
negativeSamplesSet.append(set())
for i in range(batchSize):
for j,_ in enumerate(predictionList):
if predictionList[j][i] != target[i]:
negativeSamplesSet[j].add(i)
return negativeSamplesSet
def calDisagreementSamplesNoGroundTruth(predictionVectorsList, target):
"""filter the disagreed samples without ground truth"""
batchSize = target.size(0)
predictionList = list()
for pVL in predictionVectorsList:
_, pred = pVL.max(dim=1)
predictionList.append(pred)
sampleID = list()
sampleTarget = list()
predictions = list()
predVectors = list()
for i in range(batchSize):
pred = []
predVect = []
allAgreed = True
previousPrediction = -1
for j, p in enumerate(predictionList):
pred.append(p[i].item())
predVect.append(predictionVectorsList[j][i])
if previousPrediction == -1:
previousPrediction = p[i]
continue
if p[i] != previousPrediction:
allAgreed = False
if not allAgreed:
sampleID.append(i)
sampleTarget.append(target[i].item())
predictions.append(pred)
predVectors.append(predVect)
return sampleID, sampleTarget, predictions, predVectors
def calDisagreementSamplesOneTargetNegative(predictionVectorsList, target, oneTargetIdx):
"""filter the disagreed samples"""
batchSize = target.size(0)
predictionList = list()
for pVL in predictionVectorsList:
_, pred = pVL.max(dim=1)
predictionList.append(pred)
# return sampleID, sampleTarget, predictions, predVectors
sampleID = list()
sampleTarget = list()
predictions = list()
predVectors = list()
for i in range(batchSize):
pred = []
predVect = []
for j, p in enumerate(predictionList):
pred.append(p[i].item())
predVect.append(predictionVectorsList[j][i])
if predictionList[oneTargetIdx][i] != target[i]:
sampleID.append(i)
sampleTarget.append(target[i].item())
predictions.append(pred)
predVectors.append(predVect)
return sampleID, sampleTarget, predictions, predVectors
def filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, selectModels):
filteredPredictions = predictions[:, selectModels]
#print(filteredPredictions.shape)
filteredPredVectors = predVectors[:, selectModels]
return sampleID, sampleTarget, filteredPredictions, filteredPredVectors