This demo shows how to perform text sentiment analysis on text using the Layers API of TensorFlow.js.
It demonstrates loading a pretrained model hosted at a URL, using
tf.loadLayersModel()
.
Two model variants are provided (CNN and LSTM). These were trained on a set of 25,000 movie reviews from IMDB, labelled as having positive or negative sentiment. This dataset is provided by Python Keras, and the models were trained in Keras as well, based on the imdb_cnn and imdb_lstm examples.
To launch the demo, do
yarn
yarn watch
To train the model using tfjs-node, do
yarn
yarn train <MODEL_TYPE>
where MODEL_TYPE
is a required argument that specifies what type of model is to be
trained. The available options are:
multihot
: A model that takes a multi-hot encoding of the words in the sequence. In terms of data representation and model complexity, this is the simplest model in this example.flatten
: A model that flattens the embedding vectors of all words in the sequence.cnn
: A 1D convolutional model, with a dropout layer included.simpleRNN
: A model that uses a SimpleRNN layer (tf.layers.simpleRNN
)lstm
: A model that uses a LSTM laayer (tf.layers.lstm
)bidirectionalLSTM
: A model that uses a bidirectional LSTM layer (tf.layers.bidirectional
andtf.layers.lstm
)
By default, the training happens on the CPU using the Eigen kernels from tfjs-node.
You can make the training happen on GPU by adding the --gpu
flag to the command, e.g.,
yarn train --gpu <MODEL_TYPE>
The training process will download the training data and metadata form the web
if they haven't been downloaded before. After the model training completes, the model
will be saved to the dist/resources
folder, alongside a metadata.json
file.
Then when you run yarn watch
, you will see a "Load local model" button in the web
page, which allows you to use the locally-trained model for inference in the browser.
Other arguments of the yarn train
command include:
-
--maxLen
allows you to specify the sequence length. -
--numWords
allows you to specify the vocabulary size. -
--embeddingSize
allows you to adjust the dimensionality of the embedding vectors. -
--epochs
,--batchSize
, and--validationSplit
are training-related settings. -
--modelSavePath
allows you to specify where to store the model and metadata after training completes. -
--embeddingFilesPrefix
Prefix for the path to which to save the embedding vectors and labels files (optinal). See the section below for details. -
--logDir
This optional string lets you log the loss and accuracy values to a tensorboard log directory during training. For example if you start your training with command:yarn train lstm --logDir /tmp/my_lstm_logs
You can use the following commands to start a tensorboard server in a separate terminal:
pip install tensorboard # Unless tensorboard is already installed tensorboard --logdir /tmp/my_lstm_logs
Then you can open a browser tab and navigate to the http:// URL indicated by tensorboard (by default: http://localhost:6006) to view the loss and accuracy curves.
The links below point to TensorBoard.dev training loss for various model types:
The detailed code for training are in the file train.js.
If you train a word embedding-based model (e.g., cnn
or lstm
), you can let the
yarn train
script write the embedding vectors, together with the corresponding
word labels, to files after the model training completes. This is done using the
``--embeddingFilesPrefix`, e.g.,
yarn train --maxLen 500 cnn --epochs 2 --embeddingFilesPrefix /tmp/imdb_embed
The above command will generate two files:
/tmp/imdb_embed_vectors.tsv
: A tab-separated-values file that for the numeric values of the word embeddings. Each line contains the embedding vector from a word./tmp/imdb_embed_labels.tsv
: A file consisting of the word labels that correspond to the vectors in the previous file. Each line is a word.
These files can be directly uploaded to the Embedding Projector (https://projector.tensorflow.org/) for visualization using the T-SNE or PCA algorithm
This example comes with unit tests. If you would like to submit changes to the code, be sure to run the tests and ensure they pass first:
yarn
yarn test