- Karsten Roth ([email protected])
- Biagio Brattoli ([email protected])
- Björn Ommer
Primary Contact: Karsten Roth
For baseline implementations, check out https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch and the accompanying paper!
This repository contains the code to run the pipeline proposed in our ICCV 2019 paper Mining Interclass Characteristics for Improved Deep Metric Learning (https://arxiv.org/abs/1909.11574). The results using this pipeline for ProxyNCA and Triplet with Semihard Sampling are better than noted in the paper due to an improved implementation of the baseline methods.
Note: Baseline implementations can be found at https://github.com/Confusezius/Deep-Metric-Learning-Baselines.
Our method was tested around
- Python Version 3.6.6+
- PyTorch Version 1.0.1+ and Cuda 8.0
- Faiss(-gpu) 1.5.1 (GPU support optional) for Cuda 8.0
- Scikit Image 0.14.2
- Scikit Learn 0.20.3
- Scipy 1.2.1
To run with standard batch sizes, at least 11 GB of VRAM is required (e.g. 1080Ti, Titan X).
For a quick start for standard Deep Metric Learning datasets:
simply run the sample setups given in Result_Runs.sh
. These give similar values (assuming the underlying setup to be the same) as those reported in the paper. Minor differences are due to choice of seeds and underlying setups.
The main script is main.py
. Running it with default flags will provide a Metric Learning Run with Interclass Mining on CUB200-2011 using Resnet50, Marginloss and Distance-weighted Sampling. For all tweakable parameters and their purpose, please refer to the help-strings in the main.py
-ArgumentParser. Most should be fairly self-explanatory. Again, good default setups can be found in Result_Runs.sh
.
NOTE regarding ProxyNCA for Online Products, PKU Vehicle ID and In-Shop Clothes: Due to the high number of classes, the number of proxies required is too high for useful training (>10000 proxies).
Repository
│ README.md
|
| ### Main Scripts
| main.py (main training script)
| losses.py (collection of loss and sampling impl.)
│ datasets.py (dataloaders for all datasets)
│
│ ### Utility scripts
| auxiliaries.py (set of utilities)
| evaluate.py (set of evaluation functions)
│
│ ### Network Scripts
| netlib.py (contains impl. for ResNet50 and network utils)
| googlenet.py (contains impl. for GoogLeNet)
│
└───Training Results (generated during Training)
| │ e.g. cub200/Training_Run_Name
| │ e.g. cars196/Training_Run_Name
│
└───Datasets (should be added, if one does not want to set paths)
| │ cub200, cars196 ...
CUB200-2011
cub200
└───images
| └───001.Black_footed_Albatross
| │ Black_Footed_Albatross_0001_796111
| │ ...
| ...
CARS196
cars196
└───images
| └───Acura Integra Type R 2001
| │ 00128.jpg
| │ ...
| ...
Online Products
online_products
└───images
| └───bicycle_final
| │ 111085122871_0.jpg
| ...
└───Info_Files
| │ bicycle.txt
| │ ...
In-Shop Clothes
in-shop
└───img
| └───MEN
| └───Denim
| └───id_00000080
| │ 01_1_front.jpg
| │ ...
| ...
└───Eval
| │ list_eval_partition.txt
PKU Vehicle ID
vehicle_id
└───image
| │ <img>.jpg
| | ...
└───train_test_split
| | test_list_800.txt
| | ...
By default, the following files are saved:
Name_of_Training_Run
| checkpoint.pth.tar -> Contains network state-dict.
| hypa.pkl -> Contains all network parameters as pickle.
| Can be used directly to recreate the network.
| log_train_Class.csv -> Logged training data as CSV.
| log_val_Class.csv -> Logged test metrics as CSV.
| Parameter_Info.txt -> All Parameters stored as readable text-file.
| InfoPlot_Class.svg -> Graphical summary of training/testing metrics progression.
| Curr_Summary_Class.txt -> Summary of training (best metrics...).
| sample_recoveries.png -> Sample recoveries for best validation weights.
| Acts as a sanity test.
If you use this repository or wish to cite our results, please use (https://arxiv.org/abs/1909.11574)
@conference{roth2019mic,
title={MIC: Mining Interclass Characteristics for Improved Metric Learning},
author={Roth, Karsten, and Brattoli, Biagio, and Ommer, Bj\"orn},
booktitle={Proceedings of the International Conference on Computer Vision (ICCV)},
year={2019}
}