1+ # a bit hacky way to make encoder head obtain input shape dynamically
2+ import tensorflow as tf
3+ import tensorflow .keras .layers as L
4+ from NN .utils import sMLP
5+ from NN .encoding import CCoordsGridLayer , CCoordsEncodingLayer
6+ from NN .layers import MixerConvLayer , Patches , TransformerBlock
7+ from Utils .utils import dumb_deepcopy
8+
9+ def block_params_from_config (config ):
10+ layers = config .get ('layers' , None )
11+ if not (layers is None ): return layers
12+
13+ defaultConvParams = {
14+ 'kernel size' : config .get ('kernel size' , 3 ),
15+ 'activation' : config .get ('activation' , 'relu' ),
16+ 'name' : config .get ('name' , 'Conv2D' ),
17+ }
18+ convBefore = config ['conv before' ]
19+ # if convBefore is an integer, then it's the same for all layers
20+ if isinstance (convBefore , int ):
21+ convParams = { 'channels' : config ['channels' ], ** defaultConvParams }
22+ convBefore = [convParams ] * convBefore # repeat the same parameters
23+ pass
24+ assert isinstance (convBefore , list ), 'convBefore must be a list'
25+ # if convBefore is a list of integers, then each integer is the number of channels
26+ if (0 < len (convBefore )) and isinstance (convBefore [0 ], int ):
27+ convBefore = [ {'channels' : sz , ** defaultConvParams } for sz in convBefore ]
28+ pass
29+
30+ # add separately last layer
31+ lastConvParams = {
32+ 'channels' : config .get ('channels last' , config ['channels' ]),
33+ 'kernel size' : config .get ('kernel size last' , defaultConvParams ['kernel size' ]),
34+ 'activation' : config .get ('final activation' , defaultConvParams ['activation' ]),
35+ 'name' : config .get ('last name' , 'Conv2D' ),
36+ }
37+ return convBefore + [lastConvParams ]
38+
39+ def conv_block_from_config (data , config , defaults , name = 'CB' ):
40+ config = {** defaults , ** config } # merge defaults and config
41+ convParams = block_params_from_config (config )
42+ # apply convolutions to the data
43+ for i , parameters in enumerate (convParams ):
44+ parameters = dumb_deepcopy (parameters )
45+ Name = parameters .get ('name' , 'Conv2D' )
46+ if 'Conv2D' == Name :
47+ data = L .Conv2D (
48+ filters = parameters ['channels' ],
49+ padding = 'same' ,
50+ kernel_size = parameters ['kernel size' ],
51+ activation = parameters ['activation' ],
52+ name = '%s/conv-%d' % (name , i )
53+ )(data )
54+ continue
55+
56+ if 'MLP Mixer' == Name :
57+ data = MixerConvLayer (
58+ token_mixing = parameters .get ('token mixing' , 512 ),
59+ channel_mixing = parameters .get ('channel mixing' , 512 ),
60+ name = '%s/conv-mixer-%d' % (name , i )
61+ )(data )
62+ continue
63+
64+ if 'Patches' == Name :
65+ data = Patches (
66+ patch_size = parameters ['patch size' ],
67+ name = '%s/patches-%d' % (name , i )
68+ )(data )
69+ continue
70+
71+ if 'CoordsGrid' == Name :
72+ parameters = {k : v for k , v in parameters .items () if k not in ['name' ]}
73+ parameters ['name' ] = '%s/coordsGrid-%d' % (name , i )
74+ data = CCoordsGridLayer (
75+ CCoordsEncodingLayer (
76+ N = parameters .get ('N' , 32 ),
77+ ** parameters
78+ ),
79+ name = '%s/coordsGrid-%d' % (name , i )
80+ )(data )
81+ continue
82+
83+ if 'Transformer' == Name :
84+ parameters = {k : v for k , v in parameters .items ()}
85+ parameters ['name' ] = '%s/transformer-%d' % (name , i )
86+ parameters ['intermediate_dim' ] = parameters .pop ('intermediate dim' , 512 )
87+ parameters ['num_heads' ] = parameters .pop ('num heads' , 8 )
88+ data = TransformerBlock (** parameters )(data )
89+ continue
90+
91+ if 'Reshape' == Name :
92+ shape = list (parameters ['shape' ])
93+ for j , sz in enumerate (shape ):
94+ if sz <= - 2 :
95+ sz = data .shape [sz + 1 ]
96+ shape [j ] = sz
97+ continue
98+ data = L .Reshape (
99+ shape ,
100+ name = '%s/reshape-%d' % (name , i )
101+ )(data )
102+ continue
103+
104+ if 'MLP' == Name :
105+ parameters ['name' ] = '%s/mlp-%d' % (name , i )
106+ data = sMLP (** parameters )(data )
107+ continue
108+
109+ raise NotImplementedError ('Unknown layer: {}' .format (Name ))
110+ return data
111+
112+ def _createGCMv2 (dataShape , config , latentDim , name ):
113+ data = L .Input (shape = dataShape )
114+
115+ res = data
116+ for i , blockConfig in enumerate (config ['downsample steps' ]):
117+ # downsample
118+ res = L .Conv2D (
119+ filters = blockConfig ['channels' ],
120+ kernel_size = blockConfig ['kernel size' ],
121+ strides = 2 ,
122+ padding = 'same' ,
123+ activation = 'relu' ,
124+ name = name + '/downsample-%d' % (i + 1 ,)
125+ )(res )
126+ # convolutions
127+ for layerId in range (blockConfig ['layers' ]):
128+ res = L .Conv2D (
129+ filters = blockConfig ['channels' ],
130+ kernel_size = blockConfig ['kernel size' ],
131+ padding = 'same' ,
132+ activation = 'relu' ,
133+ name = name + '/downsample-%d/layer-%d' % (i + 1 , layerId + 1 )
134+ )(res )
135+ continue
136+ continue
137+
138+ return tf .keras .Model (inputs = [data ], outputs = res , name = name )
139+
140+ def _createGlobalContextModel (X , config , latentDim , name ):
141+ model = config .get ('name' , 'v1' )
142+ if 'v1' == model : # simple convolutional model
143+ res = conv_block_from_config (
144+ data = X , config = config , defaults = {
145+ 'conv before' : 0 , # by default, no convolutions before the last layer
146+ }
147+ )
148+ # calculate global context
149+ latent = L .Flatten ()(res )
150+ context = sMLP (sizes = config ['mlp' ], activation = 'relu' , name = name + '/globalMixer' )(latent )
151+ context = L .Dense (latentDim , activation = config ['final activation' ], name = name + '/dense-latent' )(context )
152+ return context # end of 'v1' model
153+
154+ if 'v2' == model :
155+ res = data = L .Input (shape = X .shape [1 :])
156+ res = _createGCMv2 (res .shape [1 :], config , latentDim , name )(res )
157+ # calculate global context
158+ latent = L .Flatten ()(res )
159+ context = sMLP (sizes = config ['mlp' ], activation = 'relu' , name = name + '/globalMixer' )(latent )
160+ context = L .Dense (latentDim , activation = config ['final activation' ], name = name + '/dense-latent' )(context )
161+ model = tf .keras .Model (inputs = [data ], outputs = context , name = name )
162+ return model (X ) # end of 'v2' model
163+
164+ raise NotImplementedError ('Unknown global context model: {}' .format (model ))
165+
166+ def _withPositionConfig (config , name ):
167+ if config is None :
168+ print ('[Encoder] Positions: No' )
169+ return lambda x , _ : x
170+
171+ print ('[Encoder] Positions: Yes' )
172+
173+ if isinstance (config , bool ): config = { 'N' : 32 }
174+ assert isinstance (config , dict ), 'config must be a dictionary'
175+
176+ def withPosition (x , i ):
177+ if not config .get ('stage-%d' % i , True ): return x
178+
179+ encoding = config .get ('encoding' , {})
180+ encoding = dict (** encoding )
181+ encoding ['N' ] = config .get ('stage-%d N' % i , config .get ('N' , 32 ))
182+ return CCoordsGridLayer (
183+ CCoordsEncodingLayer (** encoding , name = '%s/coordsGrid-%d/encoding' % (name , i )),
184+ name = '%s/coordsGrid-%d' % (name , i )
185+ )(x )
186+ return withPosition
187+
188+ ##################
189+ def createEncoderHead_full (
190+ imgWidth ,
191+ config ,
192+ channels , downsampleSteps , latentDim ,
193+ ConvBeforeStage , ConvAfterStage ,
194+ localContext , globalContext ,
195+ positionsConfigs ,
196+ name
197+ ):
198+ assert config is not None , 'config must be a dictionary'
199+ assert isinstance (downsampleSteps , list ) and (0 < len (downsampleSteps )), 'downsampleSteps must be a list of integers'
200+ data = L .Input (shape = (imgWidth , imgWidth , channels ))
201+
202+ withPosition = _withPositionConfig (positionsConfigs , name )
203+ res = data
204+ intermediate = []
205+ for i , sz in enumerate (downsampleSteps ):
206+ if config .get ('use downsampling' , True ):
207+ res = L .Conv2D (sz , 3 , strides = 2 , padding = 'same' , activation = 'relu' )(res )
208+ res = withPosition (res , i ) # add position encoding if needed
209+ for _ in range (ConvBeforeStage ):
210+ res = L .Conv2D (sz , 3 , padding = 'same' , activation = 'relu' )(res )
211+
212+ # local context
213+ if not (localContext is None ):
214+ intermediate .append (
215+ conv_block_from_config (
216+ data = res , config = localContext , defaults = {
217+ 'channels' : sz ,
218+ 'channels last' : latentDim , # last layer should have latentDim channels
219+ },
220+ name = '%s/intermediate-%d' % (name , i )
221+ )
222+ )
223+ ################################
224+ for _ in range (ConvAfterStage ):
225+ res = L .Conv2D (sz , 3 , padding = 'same' , activation = 'relu' )(res )
226+ continue
227+
228+ if not (globalContext is None ): # global context
229+ res = withPosition (res , len (downsampleSteps ))
230+ context = _createGlobalContextModel (res , globalContext , latentDim , name + '/globalContext' )
231+ else : # no global context
232+ # return dummy context to keep the interface consistent
233+ context = L .Lambda (
234+ lambda x : tf .zeros ((tf .shape (x )[0 ], 1 ), dtype = res .dtype )
235+ )(res )
236+
237+ return tf .keras .Model (
238+ inputs = [data ],
239+ outputs = {
240+ 'intermediate' : intermediate , # intermediate representations
241+ 'context' : context , # global context
242+ },
243+ name = name
244+ )
245+
246+ class CEncoderHead (tf .keras .Model ):
247+ def __init__ (self ,
248+ config ,
249+ downsampleSteps , latentDim ,
250+ ConvBeforeStage , ConvAfterStage ,
251+ localContext , globalContext ,
252+ positionsConfigs ,
253+ ** kwargs
254+ ):
255+ super ().__init__ (** kwargs )
256+ self ._config = config
257+ self ._downsampleSteps = downsampleSteps
258+ self ._latentDim = latentDim
259+ self ._ConvBeforeStage = ConvBeforeStage
260+ self ._ConvAfterStage = ConvAfterStage
261+ self ._localContext = localContext
262+ self ._globalContext = globalContext
263+ self ._positionsConfigs = positionsConfigs
264+ return
265+
266+ def build (self , inputShape ):
267+ H , W , C = inputShape [1 :]
268+ self ._encoderHead = createEncoderHead_full (
269+ imgWidth = H , config = self ._config ,
270+ channels = C , downsampleSteps = self ._downsampleSteps , latentDim = self ._latentDim ,
271+ ConvBeforeStage = self ._ConvBeforeStage , ConvAfterStage = self ._ConvAfterStage ,
272+ localContext = self ._localContext , globalContext = self ._globalContext ,
273+ positionsConfigs = self ._positionsConfigs ,
274+ name = self .name + '/EncoderHead'
275+ )
276+ self ._encoderHead .build (inputShape )
277+ return super ().build (inputShape )
278+
279+ def call (self , src , training = None ):
280+ return self ._encoderHead (src , training = training )
281+ '''
282+ Simple encoder that takes image as input and returns corresponding latent vector with intermediate representations
283+ '''
284+ def createEncoderHead (
285+ config ,
286+ downsampleSteps , latentDim ,
287+ ConvBeforeStage , ConvAfterStage ,
288+ localContext , globalContext ,
289+ positionsConfigs ,
290+ name
291+ ):
292+ return CEncoderHead (
293+ config = config ,
294+ downsampleSteps = downsampleSteps ,
295+ latentDim = latentDim ,
296+ ConvBeforeStage = ConvBeforeStage ,
297+ ConvAfterStage = ConvAfterStage ,
298+ localContext = localContext ,
299+ globalContext = globalContext ,
300+ positionsConfigs = positionsConfigs ,
301+ name = name
302+ )
0 commit comments