Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 100 additions & 14 deletions dataset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@ Then flips it around and get the dialog from the other character's perspective:
{ {word_ids of character2}, {word_ids of character1} }

Also builds the vocabulary.
]]--
]]--

local DataSet = torch.class("neuralconvo.DataSet")
local xlua = require "xlua"
local tokenizer = require "tokenizer"
local list = require "pl.List"
local utils = require "pl.utils"
local function_arg = utils.function_arg

function DataSet:__init(loader, options)
options = options or {}

self.examplesFilename = "data/examples.t7"

self.createNewVocabAndExamples = options.createNewVocabAndExamples

-- Reject words once vocab size reaches this threshold
self.maxVocabSize = options.maxVocabSize or 0

Expand All @@ -37,10 +41,51 @@ function DataSet:__init(loader, options)
self:load(loader)
end

function DataSet:buildVocab(conversations)

print("-- Building vocab")

-- Add magic tokens
self.goToken = self:makeWordId("<go>") -- Start of sequence
self.eosToken = self:makeWordId("<eos>") -- End of sequence
self.unknownToken = self:makeWordId("<unknown>") -- Word dropped from vocabulary

self.wordFreqs = {}

-- number of conversations to be traversed
local total = self.loadFirst or #conversations

-- traverse all the conversations to count the frequency of words
for i, conversation in ipairs(conversations) do
if i > total then break end
for j = 1, #conversation do
local conversationLine = conversation[j]
-- accumulate the word frequency
self:countWords(conversationLine.text)
end
if i % 1000 == 0 then
xlua.progress(i,total)
end
end

-- sort the words on their frequencies
local sortedCounts = f_sortv(self.wordFreqs,function(x,y) return x>y end)

for word,freq in sortedCounts do
nWordId = self:addWordToVocab(word)
if self.maxVocabSize > 0 and nWordId >= self.maxVocabSize then
break
end
end

print("-- Vocab built")

end

function DataSet:load(loader)
local filename = "data/vocab.t7"

if path.exists(filename) then
if not self.createNewVocabAndExamples and path.exists(filename) then
print("Loading vocabulary from " .. filename .. " ...")
local data = torch.load(filename)
self.word2id = data.word2id
Expand All @@ -52,7 +97,9 @@ function DataSet:load(loader)
self.examplesCount = data.examplesCount
else
print("" .. filename .. " not found")
self:visit(loader:load())
local conversations = loader:load()
self:buildVocab(conversations)
self:visit(conversations)
print("Writing " .. filename .. " ...")
torch.save(filename, {
word2id = self.word2id,
Expand All @@ -69,11 +116,6 @@ end
function DataSet:visit(conversations)
self.examples = {}

-- Add magic tokens
self.goToken = self:makeWordId("<go>") -- Start of sequence
self.eosToken = self:makeWordId("<eos>") -- End of sequence
self.unknownToken = self:makeWordId("<unknown>") -- Word dropped from vocabulary

print("-- Pre-processing data")

local total = self.loadFirst or #conversations * 2
Expand All @@ -90,7 +132,7 @@ function DataSet:visit(conversations)
self:visitConversation(conversation, 2)
xlua.progress(#conversations + i, total)
end

print("-- Shuffling ")
newIdxs = torch.randperm(#self.examples)
local sExamples = {}
Expand Down Expand Up @@ -148,7 +190,7 @@ function DataSet:batches(size)
table.insert(inputSeqs, inputSeq)
table.insert(targetSeqs, targetSeq)
end

local encoderInputs,decoderInputs,decoderTargets = nil,nil,nil
if size == 1 then
encoderInputs = torch.IntTensor(maxInputSeqLen):fill(0)
Expand All @@ -159,7 +201,7 @@ function DataSet:batches(size)
decoderInputs = torch.IntTensor(maxTargetOutputSeqLen-1,size):fill(0)
decoderTargets = torch.IntTensor(maxTargetOutputSeqLen-1,size):fill(0)
end

for samplenb = 1, #inputSeqs do
for word = 1,inputSeqs[samplenb]:size(1) do
eosOffset = maxInputSeqLen - inputSeqs[samplenb]:size(1) -- for left padding
Expand All @@ -170,7 +212,7 @@ function DataSet:batches(size)
end
end
end

for samplenb = 1, #targetSeqs do
trimmedEosToken = targetSeqs[samplenb]:sub(1,-2)
for word = 1, trimmedEosToken:size(1) do
Expand All @@ -181,7 +223,7 @@ function DataSet:batches(size)
end
end
end

for samplenb = 1, #targetSeqs do
trimmedGoToken = targetSeqs[samplenb]:sub(2,-1)
for word = 1, trimmedGoToken:size(1) do
Expand Down Expand Up @@ -230,7 +272,11 @@ function DataSet:visitText(text, additionalTokens)
end

for t, word in tokenizer.tokenize(text) do
table.insert(words, self:makeWordId(word))
local cWord = self.word2id[word:lower()]
if not cWord then
cWord = self.unknownToken
end
table.insert(words, cWord)
-- Only keep the first sentence
if t == "endpunct" or #words >= self.maxExampleLen - additionalTokens then
break
Expand All @@ -244,6 +290,19 @@ function DataSet:visitText(text, additionalTokens)
return words
end

function DataSet:countWords(sentence)
--if text == "" then
-- return
--end
for t, word in tokenizer.tokenize(sentence) do
local lword = word:lower()
if self.wordFreqs[lword] == nil then
self.wordFreqs[lword] = 0
end
self.wordFreqs[lword] = self.wordFreqs[lword] + 1
end
end

function DataSet:makeWordId(word)
if self.maxVocabSize > 0 and self.wordsCount >= self.maxVocabSize then
-- We've reached the maximum size for the vocab. Replace w/ unknown token
Expand All @@ -263,3 +322,30 @@ function DataSet:makeWordId(word)

return id
end

function DataSet:addWordToVocab(word)
word = word:lower()
self.wordsCount = self.wordsCount + 1
self.word2id[word] = self.wordsCount
self.id2word[self.wordsCount] = word
return self.wordsCount
end

-- penlight from luarocks is outdated.. below fixed version for sortv
--- return an iterator to a table sorted by its values
-- @within Iterating
-- @tab t the table
-- @func f an optional comparison function (f(x,y) is true if x < y)
-- @usage for k,v in tablex.sortv(t) do print(k,v) end
-- @return an iterator to traverse elements sorted by the values
function f_sortv(t,f)
f = function_arg(2, f or '<')
local keys = {}
for k in pairs(t) do keys[#keys + 1] = k end
table.sort(keys,function(x, y) return f(t[x], t[y]) end)
local i = 0
return function()
i = i + 1
return keys[i], t[keys[i]]
end
end
32 changes: 17 additions & 15 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cmd:option('--minLR', 0.00001, 'minimum learning rate')
cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR')
cmd:option('--maxEpoch', 50, 'maximum number of epochs to run')
cmd:option('--batchSize', 10, 'mini-batch size')
cmd:option('--createNewVocabAndExamples', true, 'create new vocabulary and examples files, keep it false if the dataset and maxVocabSize is unchanged')

cmd:text()
options = cmd:parse(arg)
Expand All @@ -29,7 +30,8 @@ print("-- Loading dataset")
dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"),
{
loadFirst = options.dataset,
maxVocabSize = options.maxVocabSize
maxVocabSize = options.maxVocabSize,
createNewVocabAndExamples = options.createNewVocabAndExamples
})

print("\nDataset stats:")
Expand Down Expand Up @@ -67,18 +69,18 @@ for epoch = 1, options.maxEpoch do
collectgarbage()

local nextBatch = dataset:batches(options.batchSize)
local params, gradParams = model:getParameters()
local params, gradParams = model:getParameters()
local optimState = {learningRate=options.learningRate,momentum=options.momentum}

-- Define optimizer
local function feval(x)
if x ~= params then
params:copy(x)
end

gradParams:zero()
local encoderInputs, decoderInputs, decoderTargets = nextBatch()

if options.cuda then
encoderInputs = encoderInputs:cuda()
decoderInputs = decoderInputs:cuda()
Expand All @@ -94,28 +96,28 @@ for epoch = 1, options.maxEpoch do
model:forwardConnect(encoderInputs:size(1))
local decoderOutput = model.decoder:forward(decoderInputs)
local loss = model.criterion:forward(decoderOutput, decoderTargets)

local avgSeqLen = nil
if #decoderInputs:size() == 1 then
avgSeqLen = decoderInputs:size(1)
else
avgSeqLen = torch.sum(torch.sign(decoderInputs)) / decoderInputs:size(2)
end
loss = loss / avgSeqLen

-- Backward pass
local dloss_doutput = model.criterion:backward(decoderOutput, decoderTargets)
model.decoder:backward(decoderInputs, dloss_doutput)
model:backwardConnect()
model.encoder:backward(encoderInputs, encoderOutput:zero())

gradParams:clamp(-options.gradientClipping, options.gradientClipping)

return loss,gradParams
end

-- run epoch

print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch ..
" (LR= " .. optimState.learningRate .. ")")
print("")
Expand All @@ -125,10 +127,10 @@ for epoch = 1, options.maxEpoch do

for i=1, dataset.examplesCount/options.batchSize do
collectgarbage()

local _,tloss = optim.adam(feval, params, optimState)
err = tloss[1] -- optim returns a list

model.decoder:forget()
model.encoder:forget()

Expand All @@ -138,7 +140,7 @@ for epoch = 1, options.maxEpoch do

xlua.progress(dataset.examplesCount, dataset.examplesCount)
timer:stop()

errors = torch.Tensor(errors)
print("\n\nFinished in " .. xlua.formatTime(timer:time().real) ..
" " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.')
Expand Down Expand Up @@ -168,6 +170,6 @@ for epoch = 1, options.maxEpoch do
minMeanError = errors:mean()
end

optimState.learningRate = optimState.learningRate + decayFactor
optimState.learningRate = math.max(options.minLR, optimState.learningRate)
-- optimState.learningRate = optimState.learningRate + decayFactor
-- optimState.learningRate = math.max(options.minLR, optimState.learningRate)
end