1
1
#!/usr/bin/env python3
2
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
3
4
- from typing import List
4
+ from typing import List , Optional , Tuple
5
5
6
6
import torch
7
7
from reagent .core import types as rlt
8
8
from reagent .models import FullyConnectedNetwork
9
9
from reagent .models .base import ModelBase
10
- from torchrec .models .dlrm import SparseArch
10
+ from torchrec .models .dlrm import SparseArchRO
11
11
from torchrec .modules .embedding_modules import EmbeddingBagCollection
12
12
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
13
13
16
16
@torch .fx .wrap
17
17
def fetch_id_list_features (
18
18
state : rlt .FeatureData , action : rlt .FeatureData
19
- ) -> KeyedJaggedTensor :
20
- assert state .id_list_features is not None or action .id_list_features is not None
21
- if state .id_list_features is not None and action .id_list_features is None :
22
- sparse_features = state .id_list_features
23
- elif state .id_list_features is None and action .id_list_features is not None :
24
- sparse_features = action .id_list_features
25
- elif state .id_list_features is not None and action .id_list_features is not None :
26
- sparse_features = KeyedJaggedTensor .concat (
27
- [state .id_list_features , action .id_list_features ]
28
- )
29
- else :
19
+ ) -> Tuple [Optional [KeyedJaggedTensor ], Optional [KeyedJaggedTensor ]]:
20
+ assert (
21
+ state .id_list_features is not None
22
+ or state .id_list_features_ro is not None
23
+ or action .id_list_features is not None
24
+ or action .id_list_features_ro is not None
25
+ )
26
+
27
+ def _get_sparse_features (
28
+ id_list_features_1 , id_list_features_2
29
+ ) -> Optional [KeyedJaggedTensor ]:
30
+ sparse_features = None
31
+ if id_list_features_1 is not None and id_list_features_2 is None :
32
+ sparse_features = id_list_features_1
33
+ elif id_list_features_1 is None and id_list_features_2 is not None :
34
+ sparse_features = id_list_features_2
35
+ elif id_list_features_1 is not None and id_list_features_2 is not None :
36
+ sparse_features = KeyedJaggedTensor .concat (
37
+ [id_list_features_1 , id_list_features_2 ]
38
+ )
39
+ return sparse_features
40
+
41
+ sparse_features = _get_sparse_features (
42
+ state .id_list_features , action .id_list_features
43
+ )
44
+ sparse_features_ro = _get_sparse_features (
45
+ state .id_list_features_ro , action .id_list_features_ro
46
+ )
47
+ if sparse_features is None and sparse_features_ro is None :
30
48
raise ValueError
49
+
31
50
# TODO: add id_list_score_features
32
- return sparse_features
51
+ return sparse_features , sparse_features_ro
33
52
34
53
35
54
class SparseDQN (ModelBase ):
@@ -41,7 +60,8 @@ class SparseDQN(ModelBase):
41
60
def __init__ (
42
61
self ,
43
62
state_dense_dim : int ,
44
- embedding_bag_collection : EmbeddingBagCollection ,
63
+ embedding_bag_collection : Optional [EmbeddingBagCollection ],
64
+ embedding_bag_collection_ro : Optional [EmbeddingBagCollection ],
45
65
action_dense_dim : int ,
46
66
overarch_dims : List [int ],
47
67
activation : str = "relu" ,
@@ -51,17 +71,37 @@ def __init__(
51
71
output_dim : int = 1 ,
52
72
) -> None :
53
73
super ().__init__ ()
54
- self .sparse_arch : SparseArch = SparseArch (embedding_bag_collection )
74
+ self .sparse_arch : SparseArchRO = SparseArchRO (
75
+ embedding_bag_collection , embedding_bag_collection_ro
76
+ )
77
+
78
+ self .sparse_embedding_dim : int = (
79
+ sum (
80
+ [
81
+ len (embc .feature_names ) * embc .embedding_dim
82
+ for embc in embedding_bag_collection .embedding_bag_configs ()
83
+ ]
84
+ )
85
+ if embedding_bag_collection is not None
86
+ else 0
87
+ )
55
88
56
- self .sparse_embedding_dim : int = sum (
57
- [
58
- len (embc .feature_names ) * embc .embedding_dim
59
- for embc in embedding_bag_collection .embedding_bag_configs ()
60
- ]
89
+ self .sparse_embedding_dim_ro : int = (
90
+ sum (
91
+ [
92
+ len (embc .feature_names ) * embc .embedding_dim
93
+ for embc in embedding_bag_collection .embedding_bag_configs ()
94
+ ]
95
+ )
96
+ if embedding_bag_collection is not None
97
+ else 0
61
98
)
62
99
63
100
self .input_dim : int = (
64
- state_dense_dim + self .sparse_embedding_dim + action_dense_dim
101
+ state_dense_dim
102
+ + self .sparse_embedding_dim
103
+ + self .sparse_embedding_dim_ro
104
+ + action_dense_dim
65
105
)
66
106
layers = [self .input_dim ] + overarch_dims + [output_dim ]
67
107
activations = [activation ] * len (overarch_dims ) + [final_activation ]
@@ -76,11 +116,20 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens
76
116
(state .float_features , action .float_features ), dim = - 1
77
117
)
78
118
batch_size = dense_features .shape [0 ]
79
- sparse_features = fetch_id_list_features (state , action )
119
+ sparse_features , sparse_features_ro = fetch_id_list_features (state , action )
80
120
# shape: batch_size, num_sparse_features, embedding_dim
81
- embedded_sparse = self .sparse_arch (sparse_features )
82
- # shape: batch_size, num_sparse_features * embedding_dim
83
- embedded_sparse = embedded_sparse .reshape (batch_size , - 1 )
84
- concatenated_dense = torch .cat ((dense_features , embedded_sparse ), dim = - 1 )
121
+ embedded_sparse , embedded_sparse_ro = self .sparse_arch (
122
+ sparse_features , sparse_features_ro
123
+ )
124
+ features_list : List [torch .Tensor ] = [dense_features ]
125
+ if embedded_sparse is not None :
126
+ # shape: batch_size, num_sparse_features * embedding_dim
127
+ embedded_sparse = embedded_sparse .reshape (batch_size , - 1 )
128
+ features_list .append (embedded_sparse )
129
+ if embedded_sparse_ro is not None :
130
+ # shape: batch_size, num_sparse_features * embedding_dim
131
+ embedded_sparse_ro = embedded_sparse_ro .reshape (batch_size , - 1 )
132
+ features_list .append (embedded_sparse_ro )
85
133
134
+ concatenated_dense = torch .cat (features_list , dim = - 1 )
86
135
return self .q_network (concatenated_dense )
0 commit comments