Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ mf/*.pbs
dataset/*
jobs/*
mf/*.json
*.json
*.csv
raw_data/*/*
baselines/*
Expand All @@ -39,3 +38,7 @@ ae-hmf/*
rank/*
ae-word2vec/*
dlstm/*
log/*
config/environment/*

!config/environment/local.json
111 changes: 56 additions & 55 deletions attributes/input_attribute.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,73 @@
from comb_attribute import HET, MIX
import attribute
import cPickle as pickle
import sys, os
import sys
import os

sys.path.insert(0, '../utils')
from load_data import load_raw_data


def read_data(raw_data_dir='../raw_data/data/', data_dir='../cache/data/',
combine_att='mix', logits_size_tr='10000',
thresh=2, use_user_feature=True, use_item_feature=True, no_user_id=False,
test=False, mylog=None):
def read_data(raw_data_dir='../raw_data/data/', data_dir='../cache/data/',
combine_att='mix', logits_size_tr='10000',
thresh=2, use_user_feature=True, use_item_feature=True, no_user_id=False,
test=False, mylog=None, config=None):

if not mylog:
def mylog(val):
print(val)
if not mylog:
def mylog(val):
print(val)

data_filename = os.path.join(data_dir, 'data')
if os.path.isfile(data_filename):
mylog("data file {} exists! loading cached data. \nCaution: change cached data dir (--data_dir) if new data (or new preprocessing) is used.".format(data_filename))
(data_tr, data_va, u_attr, i_attr, item_ind2logit_ind,
logit_ind2item_ind, user_index, item_index) = pickle.load(
open(data_filename, 'rb'))
# u_attr.overview(mylog)
# i_attr.overview(mylog)
data_filename = os.path.join(data_dir, 'data')
if os.path.isfile(data_filename):
mylog("data file {} exists! loading cached data. \nCaution: change cached data dir (--data_dir) if new data (or new preprocessing) is used.".format(data_filename))
(data_tr, data_va, u_attr, i_attr, item_ind2logit_ind,
logit_ind2item_ind, user_index, item_index) = pickle.load(
open(data_filename, 'rb'))
# u_attr.overview(mylog)
# i_attr.overview(mylog)

else:
if not os.path.exists(data_dir):
os.mkdir(data_dir)
_submit = 1 if test else 0
(users, items, data_tr, data_va, user_features, item_features,
user_index, item_index) = load_raw_data(data_dir=raw_data_dir, _submit=_submit)
if not use_user_feature:
n = len(users)
users = users[:, 0].reshape(n, 1)
user_features = ([user_features[0][0]], [user_features[1][0]])
if not use_item_feature:
m = len(items)
items = items[:, 0].reshape(m, 1)
item_features = ([item_features[0][0]], [item_features[1][0]])
else:
if not os.path.exists(data_dir):
os.mkdir(data_dir)
_submit = 1 if test else 0
(users, items, data_tr, data_va, user_features, item_features,
user_index, item_index) = load_raw_data(data_dir=raw_data_dir, _submit=_submit, config=config, mylog=mylog)
if not use_user_feature:
n = len(users)
users = users[:, 0].reshape(n, 1)
user_features = ([user_features[0][0]], [user_features[1][0]])
if not use_item_feature:
m = len(items)
items = items[:, 0].reshape(m, 1)
item_features = ([item_features[0][0]], [item_features[1][0]])

if no_user_id:
users[:, 0] = 0

if combine_att == 'het':
het = HET(data_dir=data_dir, logits_size_tr=logits_size_tr, threshold=thresh)
u_attr, i_attr, item_ind2logit_ind, logit_ind2item_ind = het.get_attributes(
users, items, data_tr, user_features, item_features)
elif combine_att == 'mix':
mix = MIX(data_dir=data_dir, logits_size_tr=logits_size_tr,
threshold=thresh)
users2, items2, user_features, item_features = mix.mix_attr(users, items,
user_features, item_features)
(u_attr, i_attr, item_ind2logit_ind,
logit_ind2item_ind) = mix.get_attributes(users2, items2, data_tr,
user_features, item_features)
if no_user_id:
users[:, 0] = 0

mylog("saving data format to data directory")
from preprocess import pickle_save
pickle_save((data_tr, data_va, u_attr, i_attr,
item_ind2logit_ind, logit_ind2item_ind, user_index, item_index), data_filename)
if combine_att == 'het':
het = HET(data_dir=data_dir,
logits_size_tr=logits_size_tr, threshold=thresh)
u_attr, i_attr, item_ind2logit_ind, logit_ind2item_ind = het.get_attributes(
users, items, data_tr, user_features, item_features)
elif combine_att == 'mix':
mix = MIX(data_dir=data_dir, logits_size_tr=logits_size_tr,
threshold=thresh)
users2, items2, user_features, item_features = mix.mix_attr(users, items,
user_features, item_features)
(u_attr, i_attr, item_ind2logit_ind,
logit_ind2item_ind) = mix.get_attributes(users2, items2, data_tr,
user_features, item_features)

mylog('length of item_ind2logit_ind: {}'.format(len(item_ind2logit_ind)))
mylog("saving data format to data directory")
from preprocess import pickle_save
pickle_save((data_tr, data_va, u_attr, i_attr,
item_ind2logit_ind, logit_ind2item_ind, user_index, item_index), data_filename)

# if FLAGS.dataset in ['ml', 'yelp']:
# mylog('disabling the lstm-rec fake feature')
# u_attr.num_features_cat = 1
mylog('length of item_ind2logit_ind: {}'.format(len(item_ind2logit_ind)))

return (data_tr, data_va, u_attr, i_attr, item_ind2logit_ind,
logit_ind2item_ind, user_index, item_index)
# if FLAGS.dataset in ['ml', 'yelp']:
# mylog('disabling the lstm-rec fake feature')
# u_attr.num_features_cat = 1

return (data_tr, data_va, u_attr, i_attr, item_ind2logit_ind,
logit_ind2item_ind, user_index, item_index)
14 changes: 14 additions & 0 deletions config/environment/local.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"recommendations": {
"host": "localhost",
"user": "root",
"password": "",
"database": "recommendations"
},
"recommendation_service": {
"host": "localhost",
"user": "root",
"password": "",
"database": "recommendation_service"
}
}
92 changes: 92 additions & 0 deletions config/user_item_recommender_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
{
"model": "lstm",
"dataset_name": "ml1m",
"raw_data_dir": "examples/dataset",
"cache_dir": "examples/cache/lstm_ml1m",
"train_dir": "examples/train/lstm_ml1m",
"test": false,
"combine_att": "mix",
"use_user_feature": true,
"use_item_feature": true,
"user_vocab_size": 150000,
"item_vocab_size": 3100,
"vocab_min_thresh": 1,

"loss": "ce",
"loss_func": "log",
"loss_exp_p": 1.0005,
"learning_rate": 1,
"keep_prob": 0.5,
"learning_rate_decay_factor": 1.0,
"batch_size": 64,
"size": 64,
"max_gradient_norm": 5.0,
"patience": 20,
"power": 0.5,
"num_layers": 1,
"n_epoch": 10,
"steps_per_checkpoint": 5,
"L": 30,
"n_bucket": 10,

"recommend": true,
"saverec": false,
"top_N_items": 100,
"topk": 100,
"recommend_new": false,

"ensemble": false,
"ensemble_suffix": "",
"seed": 0,

"nonlinear": "linear",
"hidden_size": 500,

"n_resample": 50,
"n_sampled": 1024,

"sample_type": "random",
"user_sample": 1.0,

"output_feat": 1,
"use_sep_item": false,
"no_input_item_feature": false,
"use_concat": false,
"no_user_id": true,

"N": "000",
"withAdagrad": true,
"fromScratch": true,
"saveCheckpoint": false,

"gpu": -1,
"profile": false,
"device_log": false,
"eval": true,
"use_more_train": false,
"model_option": "loss",

"ta": 1,
"after40": false,
"split": "last",

"beam_search": false,
"beam_size": 10,

"max_train_data_size": 0,
"old_att": false,

"entity1": "users",
"entity2": "items",
"entity1_ID": "userId",
"entity2_ID": "itemId",
"recs_past_N_days": 3,
"source_data_from_S3": true,
"data_source_config": { "mys3bucket": "recommendation-engine",
"src_directory": "data_user_item/",
"src_file_prefix": "data_raw_"
},
"rec_output_config": { "output_table": "recommended_items",
"archive_table": "recommended_items_archive"
}
}
Loading