Skip to content

Commit deae6c2

Browse files
Fangyu Luofacebook-github-bot
authored andcommitted
Authoring Aware ROO on LTV ReAgent model (#742)
Summary: X-link: meta-pytorch/torchrec#742 Pull Request resolved: #689 [Free] Authoring Aware ROO on LTV ReAgent model Differential Revision: D40456561 fbshipit-source-id: 797f51078b61dda7c0b7bde5fbec475a8e28191b
1 parent 30715aa commit deae6c2

File tree

3 files changed

+104
-27
lines changed

3 files changed

+104
-27
lines changed

reagent/core/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class FeatureData(TensorDataClass):
313313
float_features: torch.Tensor
314314
# For sparse features saved in KeyedJaggedTensor format
315315
id_list_features: Optional[KeyedJaggedTensor] = None
316+
id_list_features_ro: Optional[KeyedJaggedTensor] = None
316317
id_score_list_features: Optional[KeyedJaggedTensor] = None
317318

318319
# For sparse features saved in dictionary format
@@ -339,6 +340,7 @@ def __post_init__(self):
339340
def has_float_features_only(self) -> bool:
340341
return (
341342
not self.id_list_features
343+
and not self.id_list_features_ro
342344
and not self.id_score_list_features
343345
and self.time_since_first is None
344346
and self.candidate_docs is None

reagent/models/sparse_dqn.py

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33

4-
from typing import List
4+
from typing import List, Optional, Tuple
55

66
import torch
77
from reagent.core import types as rlt
@@ -16,20 +16,39 @@
1616
@torch.fx.wrap
1717
def fetch_id_list_features(
1818
state: rlt.FeatureData, action: rlt.FeatureData
19-
) -> KeyedJaggedTensor:
20-
assert state.id_list_features is not None or action.id_list_features is not None
21-
if state.id_list_features is not None and action.id_list_features is None:
22-
sparse_features = state.id_list_features
23-
elif state.id_list_features is None and action.id_list_features is not None:
24-
sparse_features = action.id_list_features
25-
elif state.id_list_features is not None and action.id_list_features is not None:
26-
sparse_features = KeyedJaggedTensor.concat(
27-
[state.id_list_features, action.id_list_features]
28-
)
29-
else:
19+
) -> Tuple[Optional[KeyedJaggedTensor], Optional[KeyedJaggedTensor]]:
20+
assert (
21+
state.id_list_features is not None
22+
or state.id_list_features_ro is not None
23+
or action.id_list_features is not None
24+
or action.id_list_features_ro is not None
25+
)
26+
27+
def _get_sparse_features(
28+
id_list_features_1, id_list_features_2
29+
) -> Optional[KeyedJaggedTensor]:
30+
sparse_features = None
31+
if id_list_features_1 is not None and id_list_features_2 is None:
32+
sparse_features = id_list_features_1
33+
elif id_list_features_1 is None and id_list_features_2 is not None:
34+
sparse_features = id_list_features_2
35+
elif id_list_features_1 is not None and id_list_features_2 is not None:
36+
sparse_features = KeyedJaggedTensor.concat(
37+
[id_list_features_1, id_list_features_2]
38+
)
39+
return sparse_features
40+
41+
sparse_features = _get_sparse_features(
42+
state.id_list_features, action.id_list_features
43+
)
44+
sparse_features_ro = _get_sparse_features(
45+
state.id_list_features_ro, action.id_list_features_ro
46+
)
47+
if sparse_features is None and sparse_features_ro is None:
3048
raise ValueError
49+
3150
# TODO: add id_list_score_features
32-
return sparse_features
51+
return sparse_features, sparse_features_ro
3352

3453

3554
class SparseDQN(ModelBase):
@@ -41,7 +60,8 @@ class SparseDQN(ModelBase):
4160
def __init__(
4261
self,
4362
state_dense_dim: int,
44-
embedding_bag_collection: EmbeddingBagCollection,
63+
embedding_bag_collection: Optional[EmbeddingBagCollection],
64+
embedding_bag_collection_ro: Optional[EmbeddingBagCollection],
4565
action_dense_dim: int,
4666
overarch_dims: List[int],
4767
activation: str = "relu",
@@ -51,17 +71,43 @@ def __init__(
5171
output_dim: int = 1,
5272
) -> None:
5373
super().__init__()
54-
self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection)
5574

56-
self.sparse_embedding_dim: int = sum(
57-
[
58-
len(embc.feature_names) * embc.embedding_dim
59-
for embc in embedding_bag_collection.embedding_bag_configs()
60-
]
75+
self.sparse_arch: Optional[SparseArch] = (
76+
SparseArch(embedding_bag_collection) if embedding_bag_collection else None
77+
)
78+
79+
self.sparse_arch_ro: Optional[SparseArch] = (
80+
SparseArch(embedding_bag_collection_ro)
81+
if embedding_bag_collection_ro
82+
else None
83+
)
84+
self.sparse_embedding_dim: int = (
85+
sum(
86+
[
87+
len(embc.feature_names) * embc.embedding_dim
88+
for embc in embedding_bag_collection.embedding_bag_configs()
89+
]
90+
)
91+
if embedding_bag_collection is not None
92+
else 0
93+
)
94+
95+
self.sparse_embedding_dim_ro: int = (
96+
sum(
97+
[
98+
len(embc.feature_names) * embc.embedding_dim
99+
for embc in embedding_bag_collection_ro.embedding_bag_configs()
100+
]
101+
)
102+
if embedding_bag_collection_ro is not None
103+
else 0
61104
)
62105

63106
self.input_dim: int = (
64-
state_dense_dim + self.sparse_embedding_dim + action_dense_dim
107+
state_dense_dim
108+
+ self.sparse_embedding_dim
109+
+ self.sparse_embedding_dim_ro
110+
+ action_dense_dim
65111
)
66112
layers = [self.input_dim] + overarch_dims + [output_dim]
67113
activations = [activation] * len(overarch_dims) + [final_activation]
@@ -76,11 +122,23 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens
76122
(state.float_features, action.float_features), dim=-1
77123
)
78124
batch_size = dense_features.shape[0]
79-
sparse_features = fetch_id_list_features(state, action)
125+
sparse_features, sparse_features_ro = fetch_id_list_features(state, action)
80126
# shape: batch_size, num_sparse_features, embedding_dim
81-
embedded_sparse = self.sparse_arch(sparse_features)
82-
# shape: batch_size, num_sparse_features * embedding_dim
83-
embedded_sparse = embedded_sparse.reshape(batch_size, -1)
84-
concatenated_dense = torch.cat((dense_features, embedded_sparse), dim=-1)
127+
embedded_sparse = (
128+
self.sparse_arch(sparse_features) if self.sparse_arch else None
129+
)
130+
embedded_sparse_ro = (
131+
self.sparse_arch_ro(sparse_features_ro) if self.sparse_arch_ro else None
132+
)
133+
features_list: List[torch.Tensor] = [dense_features]
134+
if embedded_sparse is not None:
135+
# shape: batch_size, num_sparse_features * embedding_dim
136+
embedded_sparse = embedded_sparse.reshape(batch_size, -1)
137+
features_list.append(embedded_sparse)
138+
if embedded_sparse_ro is not None:
139+
# shape: batch_size, num_sparse_features * embedding_dim
140+
embedded_sparse_ro = embedded_sparse_ro.reshape(batch_size, -1)
141+
features_list.append(embedded_sparse_ro)
85142

143+
concatenated_dense = torch.cat(features_list, dim=-1)
86144
return self.q_network(concatenated_dense)

reagent/test/models/test_sparse_dqn_net.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,27 @@ def test_single_step_sparse_dqn(self):
2424
embedding_dim=embedding_dim,
2525
)
2626
]
27+
28+
num_sparse_features_ro = 2 # refer to watched_page_ids and liked_ids below
29+
embedding_bag_configs_ro = [
30+
EmbeddingBagConfig(
31+
name="watched_page_ids",
32+
feature_names=["watched_page_ids", "liked_ids"],
33+
num_embeddings=embedding_table_size,
34+
embedding_dim=embedding_dim,
35+
)
36+
]
2737
embedding_bag_col = EmbeddingBagCollection(
2838
device=torch.device("cpu"), tables=embedding_bag_configs
2939
)
40+
embedding_bag_col_ro = EmbeddingBagCollection(
41+
device=torch.device("cpu"), tables=embedding_bag_configs_ro
42+
)
3043

3144
net = SparseDQN(
3245
state_dense_dim=state_dense_dim,
3346
embedding_bag_collection=embedding_bag_col,
47+
embedding_bag_collection_ro=embedding_bag_col_ro,
3448
action_dense_dim=action_dense_dim,
3549
overarch_dims=dense_sizes,
3650
activation=activation,
@@ -42,7 +56,10 @@ def test_single_step_sparse_dqn(self):
4256
# number of sparse features times embedding dimension for sparse features
4357
assert (
4458
net[1].in_features
45-
== state_dense_dim + action_dense_dim + num_sparse_features * embedding_dim
59+
== state_dense_dim
60+
+ action_dense_dim
61+
+ num_sparse_features * embedding_dim
62+
+ num_sparse_features_ro * embedding_dim
4663
)
4764
assert net[1].out_features == dense_sizes[0]
4865
assert net[4].in_features == dense_sizes[0]

0 commit comments

Comments
 (0)