11#!/usr/bin/env python3
22# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33
4- from typing import List
4+ from typing import List , Optional , Tuple
55
66import torch
77from reagent .core import types as rlt
1616@torch .fx .wrap
1717def fetch_id_list_features (
1818 state : rlt .FeatureData , action : rlt .FeatureData
19- ) -> KeyedJaggedTensor :
20- assert state .id_list_features is not None or action .id_list_features is not None
21- if state .id_list_features is not None and action .id_list_features is None :
22- sparse_features = state .id_list_features
23- elif state .id_list_features is None and action .id_list_features is not None :
24- sparse_features = action .id_list_features
25- elif state .id_list_features is not None and action .id_list_features is not None :
26- sparse_features = KeyedJaggedTensor .concat (
27- [state .id_list_features , action .id_list_features ]
28- )
29- else :
19+ ) -> Tuple [Optional [KeyedJaggedTensor ], Optional [KeyedJaggedTensor ]]:
20+ assert (
21+ state .id_list_features is not None
22+ or state .id_list_features_ro is not None
23+ or action .id_list_features is not None
24+ or action .id_list_features_ro is not None
25+ )
26+
27+ def _get_sparse_features (
28+ id_list_features_1 , id_list_features_2
29+ ) -> Optional [KeyedJaggedTensor ]:
30+ sparse_features = None
31+ if id_list_features_1 is not None and id_list_features_2 is None :
32+ sparse_features = id_list_features_1
33+ elif id_list_features_1 is None and id_list_features_2 is not None :
34+ sparse_features = id_list_features_2
35+ elif id_list_features_1 is not None and id_list_features_2 is not None :
36+ sparse_features = KeyedJaggedTensor .concat (
37+ [id_list_features_1 , id_list_features_2 ]
38+ )
39+ return sparse_features
40+
41+ sparse_features = _get_sparse_features (
42+ state .id_list_features , action .id_list_features
43+ )
44+ sparse_features_ro = _get_sparse_features (
45+ state .id_list_features_ro , action .id_list_features_ro
46+ )
47+ if sparse_features is None and sparse_features_ro is None :
3048 raise ValueError
49+
3150 # TODO: add id_list_score_features
32- return sparse_features
51+ return sparse_features , sparse_features_ro
3352
3453
3554class SparseDQN (ModelBase ):
@@ -41,7 +60,8 @@ class SparseDQN(ModelBase):
4160 def __init__ (
4261 self ,
4362 state_dense_dim : int ,
44- embedding_bag_collection : EmbeddingBagCollection ,
63+ embedding_bag_collection : Optional [EmbeddingBagCollection ],
64+ embedding_bag_collection_ro : Optional [EmbeddingBagCollection ],
4565 action_dense_dim : int ,
4666 overarch_dims : List [int ],
4767 activation : str = "relu" ,
@@ -51,17 +71,43 @@ def __init__(
5171 output_dim : int = 1 ,
5272 ) -> None :
5373 super ().__init__ ()
54- self .sparse_arch : SparseArch = SparseArch (embedding_bag_collection )
5574
56- self .sparse_embedding_dim : int = sum (
57- [
58- len (embc .feature_names ) * embc .embedding_dim
59- for embc in embedding_bag_collection .embedding_bag_configs ()
60- ]
75+ self .sparse_arch : Optional [SparseArch ] = (
76+ SparseArch (embedding_bag_collection ) if embedding_bag_collection else None
77+ )
78+
79+ self .sparse_arch_ro : Optional [SparseArch ] = (
80+ SparseArch (embedding_bag_collection_ro )
81+ if embedding_bag_collection_ro
82+ else None
83+ )
84+ self .sparse_embedding_dim : int = (
85+ sum (
86+ [
87+ len (embc .feature_names ) * embc .embedding_dim
88+ for embc in embedding_bag_collection .embedding_bag_configs ()
89+ ]
90+ )
91+ if embedding_bag_collection is not None
92+ else 0
93+ )
94+
95+ self .sparse_embedding_dim_ro : int = (
96+ sum (
97+ [
98+ len (embc .feature_names ) * embc .embedding_dim
99+ for embc in embedding_bag_collection_ro .embedding_bag_configs ()
100+ ]
101+ )
102+ if embedding_bag_collection_ro is not None
103+ else 0
61104 )
62105
63106 self .input_dim : int = (
64- state_dense_dim + self .sparse_embedding_dim + action_dense_dim
107+ state_dense_dim
108+ + self .sparse_embedding_dim
109+ + self .sparse_embedding_dim_ro
110+ + action_dense_dim
65111 )
66112 layers = [self .input_dim ] + overarch_dims + [output_dim ]
67113 activations = [activation ] * len (overarch_dims ) + [final_activation ]
@@ -76,11 +122,23 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens
76122 (state .float_features , action .float_features ), dim = - 1
77123 )
78124 batch_size = dense_features .shape [0 ]
79- sparse_features = fetch_id_list_features (state , action )
125+ sparse_features , sparse_features_ro = fetch_id_list_features (state , action )
80126 # shape: batch_size, num_sparse_features, embedding_dim
81- embedded_sparse = self .sparse_arch (sparse_features )
82- # shape: batch_size, num_sparse_features * embedding_dim
83- embedded_sparse = embedded_sparse .reshape (batch_size , - 1 )
84- concatenated_dense = torch .cat ((dense_features , embedded_sparse ), dim = - 1 )
127+ embedded_sparse = (
128+ self .sparse_arch (sparse_features ) if self .sparse_arch else None
129+ )
130+ embedded_sparse_ro = (
131+ self .sparse_arch_ro (sparse_features_ro ) if self .sparse_arch_ro else None
132+ )
133+ features_list : List [torch .Tensor ] = [dense_features ]
134+ if embedded_sparse is not None :
135+ # shape: batch_size, num_sparse_features * embedding_dim
136+ embedded_sparse = embedded_sparse .reshape (batch_size , - 1 )
137+ features_list .append (embedded_sparse )
138+ if embedded_sparse_ro is not None :
139+ # shape: batch_size, num_sparse_features * embedding_dim
140+ embedded_sparse_ro = embedded_sparse_ro .reshape (batch_size , - 1 )
141+ features_list .append (embedded_sparse_ro )
85142
143+ concatenated_dense = torch .cat (features_list , dim = - 1 )
86144 return self .q_network (concatenated_dense )
0 commit comments