Skip to content

Latest commit

 

History

History
246 lines (214 loc) · 12.5 KB

README.org

File metadata and controls

246 lines (214 loc) · 12.5 KB

Extended Siamese Neural Network

This repository contains the code for doing the experiments described in the paper Learning similarity measures from data where we evaluate different similarity measure types according to the types given by the equation framework for analyzing different functions for similarity with $\mathbb{S}$ as a similarity measure applied to pairs of data points $(\boldsymbol{x},\boldsymbol{y})$;

\begin{equation} \mathbb{S}(\boldsymbol{x},\boldsymbol{y}) = C(G(\boldsymbol{x}),G(\boldsymbol{y})) , \end{equation}

@@latex:\noindent@@ where $G(\boldsymbol{x}) = \hat{\boldsymbol{x}}$ and $G(\boldsymbol{y}) = \hat{\boldsymbol{y}}$ represents embedding or information extraction from data points $x$ and $y$ , i.e. $G(⋅)$ highlights the parts of the data points most useful to calculate the similarity between them as modeled in $C(⋅)$. An illustration of this process can be seen in the figure below:

~/research/experiments/annSimilarity/figs/Fig2-problem-solution-embedding-space.jpeg

The different types of similarity measures can then be listed:

/ < < <
$C(\boldsymbol{x},\boldsymbol{y})$
Modeled Learned
$G(\boldsymbol{x})$ Modeled Type 1 Type 2
Learned Type 3 Type 4

Install

Requirements

This code requires internet access when first run, to download the datasets from UCI-ML repository. After the firs run the datasets should be cached locally.

Python requirements:

  • Keras = 2.2.4
  • Tensorflow < 2.0

(- Tensorflow-gpu < 2.0)

  • Seaborn
  • requests-cache (to cache UCI ML repo datasets)
  • pandas
  • pandas-datareader
  • sklearn-padnas
  • scikit-learn
  • xlrd
  • matplotlib

I recommend using anaconda for running this: This can all be installed with the following command:

conda env create -f environment.yml
conda activate esnn

The requirements are also listed in requirements.txt and Pipfile to enable use of pip and pipenv, but you milage may wary.

Code

Most of the intersting code that connects to the innovations of the paper can be found in the models directory (e.g. esnn.py)

Usage

The main code for running locally on one machine is in runner.py For distributing cross evaluation across several machines using MPI the code can be found in mpi_runner.py

Both of these files use the same argparse arguments which is documented through running “–help” e.g. “python ./runner.py –help”

Below we give two scripts and two argument sets to produce the results from the paper, the difference between the two are just the number of epochs and the directory the results are written to. All methods are evaluated on all datasets using five-fold cross validation and repeating this five times to produce averages and deviations, see the paper for the details of how the evaluation is done.

200 epochs

bash ./run_experiments_200.sh

or specify the parameters yourself.

python ./runner.py --kfold=5 --epochs=200 --methods eSNN:rprop:200:split:0.15,chopra:rprop:200:gabel,gabel:rprop:200:gabel,t3i1:rprop:200:split,t1i1,t2i1 --datasets iris,use,eco,glass,heart,car,hay,mam,ttt,pim,bal,who,mon,cmc --onehot True --multigpu False --batchsize 1000 --hiddenlayers 13,13 --gpu 0,1 --prefix=newchopraresults-forpaper-200epochs-n5 --n 5 --cvsummary False --printcv False

The results should be close to Table 2

{{{esnn}}}{{{chopra}}}{{{gabel}}}{{{t3i1}}}{{{t1i1m}}}{{{t2i1}}}
bal0.010.000.140.100.420.81
car0.040.020.190.160.250.25
cmc0.520.530.540.550.540.58
eco0.220.200.460.350.210.22
glass0.080.080.120.100.060.07
hay0.190.210.260.170.330.37
heart0.210.240.280.240.240.23
iris0.040.030.180.070.050.04
mam0.210.250.260.270.280.29
mon0.280.330.390.450.290.29
pim0.280.300.350.350.310.32
ttt0.030.030.170.070.320.07
use0.070.080.080.390.210.18
who0.290.450.330.450.460.45
Sum2.472.753.753.723.974.17
Average0.180.200.270.270.280.30

2000 epochs

bash ./run_experiments_2000.sh

or specify the parameters yourself.

python ./runner.py --kfold=5 --epochs=2000 --methods eSNN:rprop:2000:split:0.15,chopra:rprop:200:gabel,gabel:rprop:2000:gabel,t3i1:rprop:2000:split,t1i1,t2i1 --datasets iris,use,eco,glass,heart,car,hay,mam,ttt,pim,bal,who,mon,cmc --onehot True --multigpu False --batchsize 1000 --hiddenlayers 13,13 --gpu 0,1 --prefix=newchopraresults-forpaper-200epochs-n5 --n 5 --cvsummary False --printcv False

The results should be close to Table 2

{{{esnn}}}{{{chopra}}}{{{gabel}}}{{{t3i1}}}{{{t1i1}}}{{{t2i1}}}
bal0.020.000.080.010.430.83
car0.010.010.060.020.240.24
cmc0.520.530.540.530.540.58
eco0.220.200.220.180.190.21
glass0.060.070.080.090.050.06
hay0.180.210.200.150.320.34
heart0.210.270.230.220.240.23
iris0.080.050.070.040.060.05
mam0.210.270.250.270.290.28
mon0.260.300.330.270.320.32
pim0.270.310.250.300.300.31
ttt0.030.030.070.030.320.08
use0.080.100.070.080.180.16
who0.300.460.290.430.470.45
Sum2.452.812.742.623.954.14
Average0.180.200.200.190.280.30

MNIST

Notice that MNIST does not do the evaluation the same way as in the two previous experiments for 200 and 2000 epochs, as calculating the distance between all datapoint in the test set to the datapoints in the training set would take too long ((.2 * 60000) * (0.8 * 6000) evals) and require too much memory for the current implementation. Thus in the output of the run you will see “avg_retrieve_loss: 1”, but the training error would still reflect the performance of the models.

bash ./run_mnist.sh

or specify the parameters yourself.

python ./runner.py --kfold=5 --epochs=500 --methods eSNN:rprop:500:ndata,chopra:rprop:500:ndata --datasets mnist --onehot True --multigpu False --batchsize 200 --hiddenlayers 128,128,128 --gpu 1 --prefix mnisttesting --n 1 --cvsummary True --doevaluation False --seed 42 --printcv True

Citation

Please cite our paper if you use code from this repo:

@Article{Mathisen2019,
  author="Mathisen, Bj{\o}rn Magnus and Aamodt, Agnar and Bach, Kerstin and Langseth, Helge",
  title="Learning similarity measures from data",
  journal="Progress in Artificial Intelligence",
  year="2019",
  month="Oct",
  day="30",
  issn="2192-6360",
  doi="10.1007/s13748-019-00201-2",
  url="https://doi.org/10.1007/s13748-019-00201-2"
}

Document settings