Skip to content

Commit aea4525

Browse files
Fangyu Luofacebook-github-bot
authored andcommitted
Authoring Aware ROO on LTV ReAgent model
Summary: [Free] Authoring Aware ROO on LTV ReAgent model Differential Revision: D40456561 fbshipit-source-id: 5f04033b615b0aa6c4ae43cd4d3b6f9743919933
1 parent 30715aa commit aea4525

File tree

3 files changed

+97
-28
lines changed

3 files changed

+97
-28
lines changed

reagent/core/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class FeatureData(TensorDataClass):
313313
float_features: torch.Tensor
314314
# For sparse features saved in KeyedJaggedTensor format
315315
id_list_features: Optional[KeyedJaggedTensor] = None
316+
id_list_features_ro: Optional[KeyedJaggedTensor] = None
316317
id_score_list_features: Optional[KeyedJaggedTensor] = None
317318

318319
# For sparse features saved in dictionary format
@@ -339,6 +340,7 @@ def __post_init__(self):
339340
def has_float_features_only(self) -> bool:
340341
return (
341342
not self.id_list_features
343+
and not self.id_list_features_ro
342344
and not self.id_score_list_features
343345
and self.time_since_first is None
344346
and self.candidate_docs is None

reagent/models/sparse_dqn.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33

4-
from typing import List
4+
from typing import List, Optional, Tuple
55

66
import torch
77
from reagent.core import types as rlt
88
from reagent.models import FullyConnectedNetwork
99
from reagent.models.base import ModelBase
10-
from torchrec.models.dlrm import SparseArch
10+
from torchrec.models.dlrm import SparseArchRO
1111
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1212
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1313

@@ -16,20 +16,39 @@
1616
@torch.fx.wrap
1717
def fetch_id_list_features(
1818
state: rlt.FeatureData, action: rlt.FeatureData
19-
) -> KeyedJaggedTensor:
20-
assert state.id_list_features is not None or action.id_list_features is not None
21-
if state.id_list_features is not None and action.id_list_features is None:
22-
sparse_features = state.id_list_features
23-
elif state.id_list_features is None and action.id_list_features is not None:
24-
sparse_features = action.id_list_features
25-
elif state.id_list_features is not None and action.id_list_features is not None:
26-
sparse_features = KeyedJaggedTensor.concat(
27-
[state.id_list_features, action.id_list_features]
28-
)
29-
else:
19+
) -> Tuple[Optional[KeyedJaggedTensor], Optional[KeyedJaggedTensor]]:
20+
assert (
21+
state.id_list_features is not None
22+
or state.id_list_features_ro is not None
23+
or action.id_list_features is not None
24+
or action.id_list_features_ro is not None
25+
)
26+
27+
def _get_sparse_features(
28+
id_list_features_1, id_list_features_2
29+
) -> Optional[KeyedJaggedTensor]:
30+
sparse_features = None
31+
if id_list_features_1 is not None and id_list_features_2 is None:
32+
sparse_features = id_list_features_1
33+
elif id_list_features_1 is None and id_list_features_2 is not None:
34+
sparse_features = id_list_features_2
35+
elif id_list_features_1 is not None and id_list_features_2 is not None:
36+
sparse_features = KeyedJaggedTensor.concat(
37+
[id_list_features_1, id_list_features_2]
38+
)
39+
return sparse_features
40+
41+
sparse_features = _get_sparse_features(
42+
state.id_list_features, action.id_list_features
43+
)
44+
sparse_features_ro = _get_sparse_features(
45+
state.id_list_features_ro, action.id_list_features_ro
46+
)
47+
if sparse_features is None and sparse_features_ro is None:
3048
raise ValueError
49+
3150
# TODO: add id_list_score_features
32-
return sparse_features
51+
return sparse_features, sparse_features_ro
3352

3453

3554
class SparseDQN(ModelBase):
@@ -41,7 +60,8 @@ class SparseDQN(ModelBase):
4160
def __init__(
4261
self,
4362
state_dense_dim: int,
44-
embedding_bag_collection: EmbeddingBagCollection,
63+
embedding_bag_collection: Optional[EmbeddingBagCollection],
64+
embedding_bag_collection_ro: Optional[EmbeddingBagCollection],
4565
action_dense_dim: int,
4666
overarch_dims: List[int],
4767
activation: str = "relu",
@@ -51,17 +71,37 @@ def __init__(
5171
output_dim: int = 1,
5272
) -> None:
5373
super().__init__()
54-
self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection)
74+
self.sparse_arch: SparseArchRO = SparseArchRO(
75+
embedding_bag_collection, embedding_bag_collection_ro
76+
)
77+
78+
self.sparse_embedding_dim: int = (
79+
sum(
80+
[
81+
len(embc.feature_names) * embc.embedding_dim
82+
for embc in embedding_bag_collection.embedding_bag_configs()
83+
]
84+
)
85+
if embedding_bag_collection is not None
86+
else 0
87+
)
5588

56-
self.sparse_embedding_dim: int = sum(
57-
[
58-
len(embc.feature_names) * embc.embedding_dim
59-
for embc in embedding_bag_collection.embedding_bag_configs()
60-
]
89+
self.sparse_embedding_dim_ro: int = (
90+
sum(
91+
[
92+
len(embc.feature_names) * embc.embedding_dim
93+
for embc in embedding_bag_collection.embedding_bag_configs()
94+
]
95+
)
96+
if embedding_bag_collection is not None
97+
else 0
6198
)
6299

63100
self.input_dim: int = (
64-
state_dense_dim + self.sparse_embedding_dim + action_dense_dim
101+
state_dense_dim
102+
+ self.sparse_embedding_dim
103+
+ self.sparse_embedding_dim_ro
104+
+ action_dense_dim
65105
)
66106
layers = [self.input_dim] + overarch_dims + [output_dim]
67107
activations = [activation] * len(overarch_dims) + [final_activation]
@@ -76,11 +116,20 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens
76116
(state.float_features, action.float_features), dim=-1
77117
)
78118
batch_size = dense_features.shape[0]
79-
sparse_features = fetch_id_list_features(state, action)
119+
sparse_features, sparse_features_ro = fetch_id_list_features(state, action)
80120
# shape: batch_size, num_sparse_features, embedding_dim
81-
embedded_sparse = self.sparse_arch(sparse_features)
82-
# shape: batch_size, num_sparse_features * embedding_dim
83-
embedded_sparse = embedded_sparse.reshape(batch_size, -1)
84-
concatenated_dense = torch.cat((dense_features, embedded_sparse), dim=-1)
121+
embedded_sparse, embedded_sparse_ro = self.sparse_arch(
122+
sparse_features, sparse_features_ro
123+
)
124+
features_list: List[torch.Tensor] = [dense_features]
125+
if embedded_sparse is not None:
126+
# shape: batch_size, num_sparse_features * embedding_dim
127+
embedded_sparse = embedded_sparse.reshape(batch_size, -1)
128+
features_list.append(embedded_sparse)
129+
if embedded_sparse_ro is not None:
130+
# shape: batch_size, num_sparse_features * embedding_dim
131+
embedded_sparse_ro = embedded_sparse_ro.reshape(batch_size, -1)
132+
features_list.append(embedded_sparse_ro)
85133

134+
concatenated_dense = torch.cat(features_list, dim=-1)
86135
return self.q_network(concatenated_dense)

reagent/test/models/test_sparse_dqn_net.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def test_single_step_sparse_dqn(self):
1616
embedding_table_size = 1000
1717
embedding_dim = 32
1818
num_sparse_features = 2 # refer to watched_ids and liked_ids below
19+
1920
embedding_bag_configs = [
2021
EmbeddingBagConfig(
2122
name="video_id",
@@ -24,13 +25,27 @@ def test_single_step_sparse_dqn(self):
2425
embedding_dim=embedding_dim,
2526
)
2627
]
28+
29+
num_sparse_features_ro = 2 # refer to watched_page_ids and liked_ids below
30+
embedding_bag_configs_ro = [
31+
EmbeddingBagConfig(
32+
name="watched_page_ids",
33+
feature_names=["watched_page_ids", "liked_ids"],
34+
num_embeddings=embedding_table_size,
35+
embedding_dim=embedding_dim,
36+
)
37+
]
2738
embedding_bag_col = EmbeddingBagCollection(
2839
device=torch.device("cpu"), tables=embedding_bag_configs
2940
)
41+
embedding_bag_col_ro = EmbeddingBagCollection(
42+
device=torch.device("cpu"), tables=embedding_bag_configs_ro
43+
)
3044

3145
net = SparseDQN(
3246
state_dense_dim=state_dense_dim,
3347
embedding_bag_collection=embedding_bag_col,
48+
embedding_bag_collection_ro=embedding_bag_col_ro,
3449
action_dense_dim=action_dense_dim,
3550
overarch_dims=dense_sizes,
3651
activation=activation,
@@ -42,7 +57,10 @@ def test_single_step_sparse_dqn(self):
4257
# number of sparse features times embedding dimension for sparse features
4358
assert (
4459
net[1].in_features
45-
== state_dense_dim + action_dense_dim + num_sparse_features * embedding_dim
60+
== state_dense_dim
61+
+ action_dense_dim
62+
+ num_sparse_features * embedding_dim
63+
+ num_sparse_features_ro * embedding_dim
4664
)
4765
assert net[1].out_features == dense_sizes[0]
4866
assert net[4].in_features == dense_sizes[0]

0 commit comments

Comments
 (0)