diff --git a/dataset.lua b/dataset.lua index 514b258..17741c7 100644 --- a/dataset.lua +++ b/dataset.lua @@ -38,32 +38,18 @@ function DataSet:__init(loader, options) end function DataSet:load(loader) - local filename = "data/vocab.t7" - - if path.exists(filename) then - print("Loading vocabulary from " .. filename .. " ...") - local data = torch.load(filename) - self.word2id = data.word2id - self.id2word = data.id2word - self.wordsCount = data.wordsCount - self.goToken = data.goToken - self.eosToken = data.eosToken - self.unknownToken = data.unknownToken - self.examplesCount = data.examplesCount - else - print("" .. filename .. " not found") - self:visit(loader:load()) - print("Writing " .. filename .. " ...") - torch.save(filename, { - word2id = self.word2id, - id2word = self.id2word, - wordsCount = self.wordsCount, - goToken = self.goToken, - eosToken = self.eosToken, - unknownToken = self.unknownToken, - examplesCount = self.examplesCount - }) - end + local filenam = 'data/vocab.t7' + self:visit(loader:load()) + print("Writing " .. filename .. " ...") + torch.save(filename, { + word2id = self.word2id, + id2word = self.id2word, + wordsCount = self.wordsCount, + goToken = self.goToken, + eosToken = self.eosToken, + unknownToken = self.unknownToken, + examplesCount = self.examplesCount + }) end function DataSet:visit(conversations) diff --git a/seq2seq.lua b/seq2seq.lua index 12d7734..2c76bff 100644 --- a/seq2seq.lua +++ b/seq2seq.lua @@ -43,6 +43,17 @@ function Seq2Seq:float() end end +-- function Seq2Seq:double() +-- created by zhaopku to fix the problem of CPU mode +function Seq2Seq:double() + self.encoder:double() + self.decoder:double() + + if self.criterion then + self.criterion:double() + end +end + function Seq2Seq:cl() self.encoder:cl() self.decoder:cl() diff --git a/train.lua b/train.lua index b29f8a2..1e9d11a 100644 --- a/train.lua +++ b/train.lua @@ -158,7 +158,7 @@ for epoch = 1, options.maxEpoch do params, gradParams = nil,nil collectgarbage() -- Model is saved as CPU - model:float() + model:double() torch.save("data/model.t7", model) collectgarbage() if options.cuda then