-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpredict.py
More file actions
80 lines (68 loc) · 2.3 KB
/
predict.py
File metadata and controls
80 lines (68 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import pickle
import os
import argparse
import numpy as np
import csv
# set up possible arguments
parser = argparse.ArgumentParser(description="BacDive-AI package",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("trait", help="Trait to predict, all to see all models")
parser.add_argument(
"file", help="Interproscan file or file with list of Pfams")
# load configuration from command line arguments
args = parser.parse_args()
config = vars(args)
# define constants
EVALUE = 1e-20
MODELPATH = os.path.dirname(__file__)+"/models/"
# set up all available models
traits = {
"acidophile": 'Acidophilic',
"gram-positive": 'Gram-positive',
"spore-forming": 'Spore-forming',
"aerobic": 'Aerobic',
"anaerobic": 'Anaerobic',
"thermophile": 'Thermophilic',
"psychrophile": 'Psychrophilic',
"motile2+": 'Flagellated motility',
}
# check if selected trait is supported
trait = config.get('trait')
if trait != 'all' and trait not in traits:
print(trait, 'is not supported')
print('Supported traits are:', ', '.join(list(traits.keys())))
exit()
# load pfam data set
pfams = []
filename = config.get('file')
with open(filename, 'r', encoding='utf-8') as f:
csv_file = csv.reader(f, delimiter='\t')
for line in csv_file:
evalue = float(line[8])
if evalue > EVALUE:
continue
pfams.append(line[4])
pfams = set(pfams)
# if specific trait has been selected, reduce to this
if trait != 'all':
traits = {trait: traits[trait]}
# go through all remaining traits
for trait in traits:
label = traits[trait]
# load model
dump = pickle.load(open(MODELPATH + trait + "_data.p", "rb"))
clf = dump.get('model')
categories = dump.get("categories")
strains = dump.get('strains')
# transform sample
lst = dict(zip(*np.unique(list(pfams), return_counts=True)))
X = [lst.get(k) if k in lst else 0 for k in categories]
# predict with probability
proba = clf.predict_proba([X])
y = proba.argmax(axis=1)
y = y[0]
proba = proba[0]
true_index = list(clf.classes_).index(y)
# print result
result = f"{label}: {bool(y)} ({round(proba[true_index]*100, 2)}%)"
print(result)