Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Byte-compiled / optimized / DLL files
mmi_test/
result/
__pycache__/
*.py[cod]
*$py.class
Expand Down Expand Up @@ -139,4 +141,4 @@ cython_debug/
.idea/

# shuhe
shuhe/
shuhe/
1 change: 1 addition & 0 deletions mmi_fairseq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from mmi_fairseq.feature import *
5 changes: 5 additions & 0 deletions mmi_fairseq/feature/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .data import *
from .loss import *
from .model import *
from .tasks import *
from .scrtpts import *
5 changes: 5 additions & 0 deletions mmi_fairseq/feature/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .feature_dataset import FeatureDataset
from .mmi_text_and_feature_dataset import MMITextImageDataset
from .object_dataset import ObjectDataset
from .mmi_text_and_object_dataset import MMITextObjectDataset
from .utils import *
24 changes: 24 additions & 0 deletions mmi_fairseq/feature/data/feature_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# encoding: utf-8

import numpy as np
from torch.utils.data import Dataset
from mmi_fairseq.feature.data.utils import sent_num_file, offsets_file, feature_file, warmup_mmap_file


class FeatureDataset(Dataset):
"""Load Feature dataset"""
def __init__(self, data_dir, split="train"):
self.data_dir = data_dir
self.sent_num = np.load(sent_num_file(data_dir, split))
self.offsets = np.load(offsets_file(data_dir, split))
self.dim = 1000
self.total_num = self.offsets[-1] + self.sent_num[-1]
warmup_mmap_file(feature_file(data_dir, split))
self.features = np.memmap(feature_file(data_dir, split), dtype='float32', mode='r',
shape=(self.total_num, self.dim))

def __getitem__(self, item):
return self.features[item]

def __len__(self):
return self.total_num
144 changes: 144 additions & 0 deletions mmi_fairseq/feature/data/mmi_text_and_feature_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# encoding: utf-8
"""
@author: Yuxian Meng
@contact: [email protected]

@version: 1.0
@file: text_and_image_dataset
@time: 2020/11/14 15:26
@desc: Combine Text and Image Datasets

"""

import numpy as np
import torch
from fairseq.data.fairseq_dataset import FairseqDataset
from mmi_fairseq.feature.data.feature_dataset import FeatureDataset
from fairseq.data import data_utils


class MMITextImageDataset(FairseqDataset):
def __init__(self, image_dataset: FeatureDataset, text_dataset, vocab_dict, span_idxs, shuffle=False):
self.img_dataset = image_dataset
self.text_dataset = text_dataset
self.vocab_dict = vocab_dict
self.span_idxs = span_idxs
self.shuffle = shuffle

def __getitem__(self, index):
'''
group_idx, start_idx, end_idx = self.span_idxs[index].tolist()
source_imgs = np.stack([self.img_dataset[idx] for idx in range(start_idx, end_idx)]) # n * dim
source_texts = [self.text_dataset[idx] for idx in range(start_idx+1, end_idx+1)] # n * sent_len
target = self.text_dataset[end_idx] # will not be computed
'''
is_true, start_idx, end_idx = self.span_idxs[index].tolist()
source_imgs = self.img_dataset[start_idx] # dim
source_texts = self.text_dataset[end_idx] # sent_len
target = self.text_dataset[end_idx] # will not be computed

return {
'id': index,
'is_true': is_true,
'source_imgs': source_imgs,
'source_texts': source_texts,
'target': torch.LongTensor(target)
}

def __len__(self):
return len(self.span_idxs)

def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
'''
group_idx, start_idx, end_idx = self.span_idxs[index].tolist()
sum_tokens = 0
for i in range(start_idx+1, end_idx+1):
sum_tokens += len(self.text_dataset[i])
'''
is_true, start_idx, end_idx = self.span_idxs[index].tolist()
sum_tokens = len(self.text_dataset[start_idx])
return sum_tokens

def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.num_tokens(index)

def ordered_indices(self):
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
# todo 添加bucket
# # Inspired by LanguagePairDataset.ordered_indices
# indices = indices[np.argsort(self.cap_ds.sizes[indices], kind='mergesort')]
# return indices[np.argsort(self.img_ds.sizes[indices], kind='mergesort')]
return indices

def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return {}

indices = []
source_imgs = []
source_texts = []
source_lengths = []
source_label = []
targets = []

target_ntokens = 0

for sample in samples:
index = sample['id']
indices.append(index)

source_imgs.append(sample['source_imgs'])
source_texts.append(torch.LongTensor(sample['source_texts']))
source_lengths.append(len(sample['source_texts']))
source_label.append(sample['is_true'])

targets.append(sample['target'])
target_ntokens += len(sample["target"])
num_sentences = len(sample)

indices = torch.tensor(indices, dtype=torch.long)

source_label_tensor = torch.tensor(source_label, dtype=torch.float)

source_lengths_tensor = torch.tensor(source_lengths, dtype=torch.long)

image_tensor = torch.tensor(source_imgs, dtype=torch.float)

source_texts_batch = data_utils.collate_tokens(source_texts,
pad_idx=self.vocab_dict.pad(),
eos_idx=self.vocab_dict.eos(),
move_eos_to_beginning=False)

mask_ones = torch.ones((source_texts_batch.shape[0], source_texts_batch.shape[1]), dtype=torch.float) # B * T

target_batch = data_utils.collate_tokens(targets,
pad_idx=self.vocab_dict.pad(),
eos_idx=self.vocab_dict.eos(),
move_eos_to_beginning=False)
prev_target_batch = data_utils.collate_tokens(targets,
pad_idx=self.vocab_dict.pad(),
eos_idx=self.vocab_dict.eos(),
move_eos_to_beginning=True)

return {
'id': indices,
'net_input': {
'src_tokens': source_texts_batch,
'mask_ones': mask_ones,
'src_label': source_label_tensor,
'src_imgs': image_tensor,
'src_lengths': source_lengths_tensor,
'prev_output_tokens': prev_target_batch,
},
'target': target_batch,
'ntokens': target_ntokens,
'nsentences': num_sentences,
}
145 changes: 145 additions & 0 deletions mmi_fairseq/feature/data/mmi_text_and_object_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# encoding: utf-8
"""
@author: Yuxian Meng
@contact: [email protected]

@version: 1.0
@file: text_and_image_dataset
@time: 2020/11/14 15:26
@desc: Combine Text and Object Datasets

"""

import numpy as np
import torch
from fairseq.data.fairseq_dataset import FairseqDataset
from mmi_fairseq.feature.data.object_dataset import ObjectDataset
from fairseq.data import data_utils


class MMITextObjectDataset(FairseqDataset):
"""
A combine of text dataset and object dataset
"""
def __init__(self, image_dataset: ObjectDataset, text_dataset, vocab_dict, span_idxs, shuffle=False):
self.img_dataset = image_dataset
self.text_dataset = text_dataset
self.vocab_dict = vocab_dict
self.span_idxs = span_idxs
self.shuffle = shuffle
self.max_obj = image_dataset.max_obj

def __getitem__(self, index):
# todo: try to add [bos] at the beginning of text sequence to separate objects/texts
is_true, start_idx, end_idx = self.span_idxs[index].tolist()
objects, objects_mask = self.img_dataset[start_idx] # max_obj * dim, max_obj
source_texts = self.text_dataset[end_idx] # sent_len
target = self.text_dataset[end_idx] # will not be computed

return {
'id': index,
'is_true': is_true,
'objects': objects,
'objects_mask': objects_mask,
'source_texts': source_texts,
'target': torch.LongTensor(target)
}

def __len__(self):
return len(self.span_idxs)

def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
is_true, start_idx, end_idx = self.span_idxs[index].tolist()
sum_tokens = len(self.text_dataset[start_idx])
return sum_tokens

def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.num_tokens(index)

def ordered_indices(self):
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
# todo 添加bucket
# # Inspired by LanguagePairDataset.ordered_indices
# indices = indices[np.argsort(self.cap_ds.sizes[indices], kind='mergesort')]
# return indices[np.argsort(self.img_ds.sizes[indices], kind='mergesort')]
return indices

def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return {}

indices = []
source_objects = []
objects_mask = []
source_texts = []
source_lengths = []
source_label = []
targets = []

target_ntokens = 0

for sample in samples:
index = sample['id']
indices.append(index)

source_objects.append(sample["objects"])
objects_mask.append(sample["objects_mask"])
source_texts.append(torch.LongTensor(sample['source_texts']))
source_lengths.append(len(sample['source_texts']))
source_label.append(sample['is_true'])

targets.append(sample['target'])
target_ntokens += len(sample["target"])
num_sentences = len(samples)

indices = torch.tensor(indices, dtype=torch.long)

source_label_tensor = torch.tensor(source_label, dtype=torch.float)

source_lengths_tensor = torch.tensor(source_lengths, dtype=torch.long)

image_tensor = torch.tensor(source_objects, dtype=torch.float)

mask_tensor = torch.tensor(objects_mask, dtype=torch.float)



source_texts_batch = data_utils.collate_tokens(source_texts,
pad_idx=self.vocab_dict.pad(),
eos_idx=self.vocab_dict.eos(),
move_eos_to_beginning=False)

mask_ones = torch.ones((source_texts_batch.shape[0], source_texts_batch.shape[1]), dtype=torch.float) # B * T

target_batch = data_utils.collate_tokens(targets,
pad_idx=self.vocab_dict.pad(),
eos_idx=self.vocab_dict.eos(),
move_eos_to_beginning=False)
prev_target_batch = data_utils.collate_tokens(targets,
pad_idx=self.vocab_dict.pad(),
eos_idx=self.vocab_dict.eos(),
move_eos_to_beginning=True)

return {
'id': indices,
'net_input': {
'src_tokens': source_texts_batch,
'mask_ones': mask_ones,
'src_label': source_label_tensor,
'objs': image_tensor,
'objs_mask': mask_tensor,
'src_lengths': source_lengths_tensor,
'prev_output_tokens': prev_target_batch,
},
'target': target_batch,
'ntokens': target_ntokens,
'nsentences': num_sentences,
}
30 changes: 30 additions & 0 deletions mmi_fairseq/feature/data/object_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# encoding: utf-8

import numpy as np
from torch.utils.data import Dataset

from mmi_fairseq.feature.data.utils import sent_num_file, offsets_file, object_file, object_mask_file, warmup_mmap_file

class ObjectDataset(Dataset):
MAX_OBJ = 20 # max-obj in mmap file
"""Load Object dataset"""
def __init__(self, data_dir, split="train", max_obj=20):
self.data_dir = data_dir
self.sent_num = np.load(sent_num_file(data_dir, split))
self.offsets = np.load(offsets_file(data_dir, split))
self.total_sent_num = self.offsets[-1] + self.sent_num[-1]
self.dim = 2048 # todo add x,y,w,h
self.max_obj = max_obj # max-obj when getting item
warmup_mmap_file(object_file(data_dir, split, 0))
print(self.total_sent_num, self.MAX_OBJ, self.dim)
self.objects = np.memmap(object_file(data_dir, split, 0), dtype=np.float32, mode='r',
shape=(self.total_sent_num, self.MAX_OBJ, self.dim))
warmup_mmap_file(object_mask_file(data_dir, split, 0))
self.objects_mask = np.memmap(object_mask_file(data_dir, split, 0), dtype=np.bool, mode='r',
shape=(self.total_sent_num, self.MAX_OBJ))

def __getitem__(self, item):
return self.objects[item][: self.max_obj], self.objects_mask[item][: self.max_obj]

def __len__(self):
return self.total_sent_num
Loading