Skip to content

Sentiment analysis tutorial #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions extras/sentiment_analysis/README.md

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions extras/sentiment_analysis/imdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# #
# ==============================================================================
"""A helper class for fetching and importing the IMDB dataset.

This helper will download the data available at
https://github.com/adeshpande3/LSTM-Sentiment-Analysis that is a preprocessed
version of the Large Movie Review Dataset available at
http://ai.stanford.edu/~amaas/data/sentiment/. Here you'll also
find functions to access this data once it's available.
"""

import os
import tarfile

import numpy as np
from six.moves import urllib


class IMDB(object):
"""A helper class for fetching and importing the IMDB dataset.

The three `get` methods each import an component of data
from the downloaded files.
"""

def __init__(self, data_path, percentage_train=0.9):
"""Create an IMDB data loader.
Args:
data_path: Where to store the downloaded files.
percentage_train: The fraction of the dataset set to use for training.
"""
# path where the data will be stored
self.data_path = data_path
# postive reviews will have label 1, and negative reviews label 0
self._POS = 1
self._NEG = 0
# path to where data is hosted
self._DATA_URL = 'https://github.com/adeshpande3/LSTM-Sentiment-Analysis/blob/master/training_data.tar.gz?raw=true'
# perecentage of data used for training
self._PERCENTAGE_TRAIN = percentage_train
# if data is not in data_path download it from _DATA_URL
self._maybe_download()

def _get_word_list(self):
"""Returns list with words available in the word embedding."""
return list(np.load(os.path.join(self.data_path, 'wordsList.npy')))

def get_word_to_index(self):
"""Returns dict mapping a word to an index in the word embedding."""
word_list = self._get_word_list()
word_dict = {word_list[i]: i for i in range(len(word_list))}
return word_dict

def get_index_to_word(self):
"""Returns dict mapping an index to a word in the word embedding."""
word_list = self._get_word_list()
word_dict = {i: word_list[i] for i in range(len(word_list))}
return word_dict

def get_word_vector(self):
"""Returns the pretrained word embedding."""
return np.load(os.path.join(self.data_path, 'wordVectors.npy'))

def get_data(self):
"""Returns the preprocessed IMDB dataset for training and evaluation.

The data contain 25000 reviews where the first half is positive and the
second half is negative. This function by default will return 90% of the
data as training data and 10% as evaluation data.
"""

data = np.load(os.path.join(self.data_path, 'idsMatrix.npy'))
# the first half of the data length are positive reviews
# the other half are negative reviews
data_len = data.shape[0]
label = np.array(
[self._POS if i < data_len/2 else self._NEG for i in range(data_len)]
)

# shuffle the data
p = np.random.permutation(data_len)
shuffled_data = data[p]
shuffled_label = label[p]

# separate training and evaluation
train_limit = int(data_len * self._PERCENTAGE_TRAIN)

train_data = shuffled_data[:train_limit]
train_label = shuffled_label[:train_limit]
eval_data = shuffled_data[train_limit:]
eval_label = shuffled_label[train_limit:]

return train_data, train_label, eval_data, eval_label

def _maybe_download(self):
"""Maybe downloads data available at https://github.com/adeshpande3/LSTM-Sentiment-Analysis."""
try:
self.get_word_to_index()
self.get_word_vector()
self.get_data()
except IOError:
print('Data is not available at %s, Downloading it...' % self.data_path)
# if the data_path does not exist we'll create it
if not os.path.exists(self.data_path):
os.makedirs(self.data_path)

# download data
tar_path = os.path.join(self.data_path, 'data.tar.gz')
urllib.request.urlretrieve(self._DATA_URL, tar_path)
# extract data and save at self.data_path
tar = tarfile.open(tar_path)
tar.extractall(self.data_path)
tar.close()

print('Download complete!')

160 changes: 160 additions & 0 deletions extras/sentiment_analysis/input_function_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input functions implementations used by sentiment_analysis.py.

You'll find 2 input function implementations:

* build_input_fn: expects preprocessed numpy data as input
(more details in the tutorial) and will be used to train and evaluate the
model.

* build_classify_input_fn: expects a string as input and will be used
to classify new reviews in real time.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import rnn_common


def build_input_fn(x_in, y_in, batch_size,
shuffle=True, epochs=1,
batch_by_seq_len=False,
max_length=250):
"""Returns an input function created from word and class index arrays.



Args:
x_in: A numpy array of word indexes with shape (num_examples,
max_sequence_length). The array is padded on the right with zeros.
y_in: A numpy array of class indexes with shape (num_examples)
batch_size: Batch size for the input_fn to return
shuffle: A bool, indicating whether to shuffle the data or not.
epochs: Number of epochs for the input fun to generate.
batch_by_seq_len: A bool to activate sequence length batching.
max_length: Truncate sequences longer than max_length.

Returns:
An `input_fn`.
"""
def _length_bin(length, max_seq_len, length_step=10):
"""Sets the sequence length bin."""
bin_id = (length // length_step + 1) * length_step
return tf.cast(tf.minimum(bin_id, max_seq_len), tf.int64)

def _make_batch(key, ds):
"""Removes extra padding and batchs the bin."""
# eliminate the extra padding
key = tf.cast(key, tf.int32)
ds = ds.map(lambda x, x_len, y: (x[:key], x_len, y))

# convert the entire contents of the bin to a batch
ds = ds.batch(batch_size)
return ds

def input_fn():
"""Input function used for train and eval; usually not called directly.
"""
# calculates the length of the sequences
# since the inputs are already padded with zeros in the end
# the length will be the last index that is non zero + 1
x_len = np.array(
[np.nonzero(seq)[0][-1] + 1 for seq in x_in]).astype('int32')

# creates the dataset from in memory data
# x_in: sequence of indexes that map a word to an embedding
# x_len: sequence lengths
# y_in: 1 if positive review, 0 if negative review
ds = tf.contrib.data.Dataset.from_tensor_slices((x_in, x_len, y_in))

# repeats the dataset `epochs` times
ds = ds.repeat(epochs)

if shuffle:
# make sure the buffer is big enough for your data
ds = ds.shuffle(buffer_size=25000 * 2)

if batch_by_seq_len:
# implement a simple `Dataset` version of `bucket_by_sequence_length`
# https://goo.gl/y67FQm
ds = ds.group_by_window(
key_func=lambda x, x_len, y: _length_bin(x_len, max_length),
reduce_func=_make_batch,
window_size=batch_size)
else:
ds = ds.batch(batch_size)

# creates iterator
x, x_len, y = ds.make_one_shot_iterator().get_next()

# feature must be a dictionary
dict_x = {'x': x, rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY: x_len}
return dict_x, y

return input_fn


def build_classify_input_fn(review, word_to_id):
"""Returns an Input function from a string review, and a word_to_id mapping.
The input_fn only yields a single batch before throwing an end of
sequence error.
The input_fn does not yield labels, so it cannot be used for training or
evaluation.

Args:
review(str): A string review sentence.
word_to_id(dict): A dict mapping words to embedding indexes.
"""
def _word_to_index(sequence):
"""Convert a sequence of words into a sequence of indexes that map each
word to a row in the embedding.
"""
id_sequence = []
UNK = 399999 # index for unknown words
for word in sequence:
try:
id_sequence.append(word_to_id[word])
except KeyError:
id_sequence.append(UNK) # if not in the word_to_id list set to UNK
return np.array(id_sequence)

def input_fn():
"""Input function used to classify new reviews manually inserted."""
# make review a sequence of words
review_split = review.split(' ')
# converting words to indexes
review_id = _word_to_index(review_split)
# calculates the length of the sequence
x_len = len(review_split)
# creates the dataset from in memory data
ds = tf.contrib.data.Dataset.from_tensors(review_id)
# the model expects a batch
ds = ds.batch(1)

# creates iterator
x = ds.make_one_shot_iterator().get_next()

dict_x = {'x': x, rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY: [x_len]}
# no label needed since we're only using this input function for prediction
# if training make sure to return a label
return dict_x, None

return input_fn

Loading