diff --git a/models/multimodal_preprocessors.py b/models/multimodal_preprocessors.py index 031dc1e3..ebe7ddba 100644 --- a/models/multimodal_preprocessors.py +++ b/models/multimodal_preprocessors.py @@ -25,21 +25,14 @@ def get_sinusoid_encoding_table(n_position, d_hid): """Sinusoid position encoding table""" + + sinusoid_table = torch.FloatTensor( + [pos_i for pos_i in range(n_position)]).unsqueeze(1) * torch.FloatTensor( + [1 / 10000 ** (2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]).unsqueeze(0) + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - # TODO: make it with torch instead of numpy - def get_position_angle_vec(position): - return [ - position / np.power(10000, 2 * (hid_j // 2) / d_hid) - for hid_j in range(d_hid) - ] - - sinusoid_table = np.array( - [get_position_angle_vec(pos_i) for pos_i in range(n_position)] - ) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - return torch.FloatTensor(sinusoid_table).unsqueeze(0) + return sinusoid_table.unsqueeze(0) def interpolate_pos_encoding_2d(target_spatial_size, pos_embed): @@ -197,7 +190,7 @@ def __init__( self, rgbt_stem: PatchEmbedGeneric, depth_stem: Optional[PatchEmbedGeneric], - img_size: Tuple = (3, 224, 224), + img_size: List = [3, 224, 224], num_cls_tokens: int = 1, pos_embed_fn: Optional[Callable] = None, use_type_embed: bool = False, @@ -609,7 +602,7 @@ def __init__( kernel_size: int, imu_stem: PatchEmbedGeneric, embed_dim: int, - img_size: Tuple = (6, 2000), + img_size: List = [6, 2000], num_cls_tokens: int = 1, pos_embed_fn: Optional[Callable] = None, init_param_style: str = "openclip", diff --git a/requirements.txt b/requirements.txt index d35cb65a..2b0d84e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ matplotlib types-regex mayavi cartopy +pillow \ No newline at end of file