@@ -413,46 +413,23 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
413413 }
414414 }
415415
416- fn init_or_get_extension < V : Extension > ( & mut self , init : bool ) -> Result < & mut V , ProgramError > {
416+ /// Unpack a portion of the TLV data as the desired type that allows modifying the type
417+ pub fn get_extension_mut < V : Extension > ( & mut self ) -> Result < & mut V , ProgramError > {
417418 if V :: TYPE . get_account_type ( ) != S :: ACCOUNT_TYPE {
418419 return Err ( ProgramError :: InvalidAccountData ) ;
419420 }
420421 let TlvIndices {
421422 type_start,
422423 length_start,
423424 value_start,
424- } = get_extension_indices :: < V > ( self . tlv_data , init ) ?;
425+ } = get_extension_indices :: < V > ( self . tlv_data , false ) ?;
425426
426427 if self . tlv_data [ type_start..] . len ( ) < V :: TYPE . get_tlv_len ( ) {
427428 return Err ( ProgramError :: InvalidAccountData ) ;
428429 }
429- if init {
430- // write extension type
431- let extension_type_array: [ u8 ; 2 ] = V :: TYPE . into ( ) ;
432- let extension_type_ref = & mut self . tlv_data [ type_start..length_start] ;
433- extension_type_ref. copy_from_slice ( & extension_type_array) ;
434- // write length
435- let length_ref =
436- pod_from_bytes_mut :: < Length > ( & mut self . tlv_data [ length_start..value_start] ) ?;
437- // maybe this becomes smarter later for dynamically sized extensions
438- let length = pod_get_packed_len :: < V > ( ) ;
439- * length_ref = Length :: try_from ( length) . unwrap ( ) ;
440-
441- let value_end = value_start. saturating_add ( length) ;
442- let extension_ref =
443- pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] ) ?;
444- * extension_ref = V :: default ( ) ;
445- Ok ( extension_ref)
446- } else {
447- let length = pod_from_bytes :: < Length > ( & self . tlv_data [ length_start..value_start] ) ?;
448- let value_end = value_start. saturating_add ( usize:: from ( * length) ) ;
449- pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] )
450- }
451- }
452-
453- /// Unpack a portion of the TLV data as the desired type that allows modifying the type
454- pub fn get_extension_mut < V : Extension > ( & mut self ) -> Result < & mut V , ProgramError > {
455- self . init_or_get_extension ( false )
430+ let length = pod_from_bytes :: < Length > ( & self . tlv_data [ length_start..value_start] ) ?;
431+ let value_end = value_start. saturating_add ( usize:: from ( * length) ) ;
432+ pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] )
456433 }
457434
458435 /// Unpack a portion of the TLV data as the desired type
@@ -480,9 +457,48 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
480457 }
481458
482459 /// Packs the default extension data into an open slot if not already found in the
483- /// data buffer, otherwise overwrites the existing extension with the default state
484- pub fn init_extension < V : Extension > ( & mut self ) -> Result < & mut V , ProgramError > {
485- self . init_or_get_extension ( true )
460+ /// data buffer. If extension is already found in the buffer, it overwrites the existing
461+ /// extension with the default state if `overwrite` is set. If extension found, but
462+ /// `overwrite` is not set, it returns error.
463+ pub fn init_extension < V : Extension > (
464+ & mut self ,
465+ overwrite : bool ,
466+ ) -> Result < & mut V , ProgramError > {
467+ if V :: TYPE . get_account_type ( ) != S :: ACCOUNT_TYPE {
468+ return Err ( ProgramError :: InvalidAccountData ) ;
469+ }
470+ let TlvIndices {
471+ type_start,
472+ length_start,
473+ value_start,
474+ } = get_extension_indices :: < V > ( self . tlv_data , true ) ?;
475+
476+ if self . tlv_data [ type_start..] . len ( ) < V :: TYPE . get_tlv_len ( ) {
477+ return Err ( ProgramError :: InvalidAccountData ) ;
478+ }
479+ let extension_type = ExtensionType :: try_from ( & self . tlv_data [ type_start..length_start] ) ?;
480+
481+ if extension_type == ExtensionType :: Uninitialized || overwrite {
482+ // write extension type
483+ let extension_type_array: [ u8 ; 2 ] = V :: TYPE . into ( ) ;
484+ let extension_type_ref = & mut self . tlv_data [ type_start..length_start] ;
485+ extension_type_ref. copy_from_slice ( & extension_type_array) ;
486+ // write length
487+ let length_ref =
488+ pod_from_bytes_mut :: < Length > ( & mut self . tlv_data [ length_start..value_start] ) ?;
489+ // maybe this becomes smarter later for dynamically sized extensions
490+ let length = pod_get_packed_len :: < V > ( ) ;
491+ * length_ref = Length :: try_from ( length) . unwrap ( ) ;
492+
493+ let value_end = value_start. saturating_add ( length) ;
494+ let extension_ref =
495+ pod_from_bytes_mut :: < V > ( & mut self . tlv_data [ value_start..value_end] ) ?;
496+ * extension_ref = V :: default ( ) ;
497+ Ok ( extension_ref)
498+ } else {
499+ // extension is already initialized, but no overwrite permission
500+ Err ( TokenError :: ExtensionAlreadyInitialized . into ( ) )
501+ }
486502 }
487503
488504 /// If `extension_type` is an Account-associated ExtensionType that requires initialization on
@@ -498,14 +514,14 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
498514 }
499515 match extension_type {
500516 ExtensionType :: TransferFeeAmount => {
501- self . init_extension :: < TransferFeeAmount > ( ) . map ( |_| ( ) )
517+ self . init_extension :: < TransferFeeAmount > ( true ) . map ( |_| ( ) )
502518 }
503519 // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
504520 // on InitializeAccount
505521 ExtensionType :: ConfidentialTransferAccount => Ok ( ( ) ) ,
506522 #[ cfg( test) ]
507523 ExtensionType :: AccountPaddingTest => {
508- self . init_extension :: < AccountPaddingTest > ( ) . map ( |_| ( ) )
524+ self . init_extension :: < AccountPaddingTest > ( true ) . map ( |_| ( ) )
509525 }
510526 _ => unreachable ! ( ) ,
511527 }
@@ -932,19 +948,27 @@ mod test {
932948 let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
933949 // fail init account extension
934950 assert_eq ! (
935- state. init_extension:: <TransferFeeAmount >( ) ,
951+ state. init_extension:: <TransferFeeAmount >( true ) ,
936952 Err ( ProgramError :: InvalidAccountData ) ,
937953 ) ;
938954
939955 // success write extension
940956 let close_authority = OptionalNonZeroPubkey :: try_from ( Some ( Pubkey :: new ( & [ 1 ; 32 ] ) ) ) . unwrap ( ) ;
941- let extension = state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
957+ let extension = state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
942958 extension. close_authority = close_authority;
943959 assert_eq ! (
944960 & state. get_extension_types( ) . unwrap( ) ,
945961 & [ ExtensionType :: MintCloseAuthority ]
946962 ) ;
947963
964+ // fail init extension when already initialized
965+ assert_eq ! (
966+ state. init_extension:: <MintCloseAuthority >( false ) ,
967+ Err ( ProgramError :: Custom (
968+ TokenError :: ExtensionAlreadyInitialized as u32
969+ ) )
970+ ) ;
971+
948972 // fail unpack as account, a mint extension was written
949973 assert_eq ! (
950974 StateWithExtensionsMut :: <Account >:: unpack_uninitialized( & mut buffer) ,
@@ -1030,7 +1054,7 @@ mod test {
10301054 let mut state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
10311055 // init one more extension
10321056 let mint_transfer_fee = test_transfer_fee_config ( ) ;
1033- let new_extension = state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1057+ let new_extension = state. init_extension :: < TransferFeeConfig > ( true ) . unwrap ( ) ;
10341058 new_extension. transfer_fee_config_authority =
10351059 mint_transfer_fee. transfer_fee_config_authority ;
10361060 new_extension. withdraw_withheld_authority = mint_transfer_fee. withdraw_withheld_authority ;
@@ -1063,7 +1087,7 @@ mod test {
10631087 // fail to init one more extension that does not fit
10641088 let mut state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
10651089 assert_eq ! (
1066- state. init_extension:: <MintPaddingTest >( ) ,
1090+ state. init_extension:: <MintPaddingTest >( true ) ,
10671091 Err ( ProgramError :: InvalidAccountData ) ,
10681092 ) ;
10691093 }
@@ -1079,11 +1103,11 @@ mod test {
10791103 let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
10801104 // write extensions
10811105 let close_authority = OptionalNonZeroPubkey :: try_from ( Some ( Pubkey :: new ( & [ 1 ; 32 ] ) ) ) . unwrap ( ) ;
1082- let extension = state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1106+ let extension = state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
10831107 extension. close_authority = close_authority;
10841108
10851109 let mint_transfer_fee = test_transfer_fee_config ( ) ;
1086- let extension = state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1110+ let extension = state. init_extension :: < TransferFeeConfig > ( true ) . unwrap ( ) ;
10871111 extension. transfer_fee_config_authority = mint_transfer_fee. transfer_fee_config_authority ;
10881112 extension. withdraw_withheld_authority = mint_transfer_fee. withdraw_withheld_authority ;
10891113 extension. withheld_amount = mint_transfer_fee. withheld_amount ;
@@ -1115,15 +1139,15 @@ mod test {
11151139
11161140 // write extensions in a different order
11171141 let mint_transfer_fee = test_transfer_fee_config ( ) ;
1118- let extension = state. init_extension :: < TransferFeeConfig > ( ) . unwrap ( ) ;
1142+ let extension = state. init_extension :: < TransferFeeConfig > ( true ) . unwrap ( ) ;
11191143 extension. transfer_fee_config_authority = mint_transfer_fee. transfer_fee_config_authority ;
11201144 extension. withdraw_withheld_authority = mint_transfer_fee. withdraw_withheld_authority ;
11211145 extension. withheld_amount = mint_transfer_fee. withheld_amount ;
11221146 extension. older_transfer_fee = mint_transfer_fee. older_transfer_fee ;
11231147 extension. newer_transfer_fee = mint_transfer_fee. newer_transfer_fee ;
11241148
11251149 let close_authority = OptionalNonZeroPubkey :: try_from ( Some ( Pubkey :: new ( & [ 1 ; 32 ] ) ) ) . unwrap ( ) ;
1126- let extension = state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1150+ let extension = state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
11271151 extension. close_authority = close_authority;
11281152
11291153 assert_eq ! (
@@ -1169,7 +1193,7 @@ mod test {
11691193 state. init_account_type ( ) . unwrap ( ) ;
11701194
11711195 // write padding
1172- let extension = state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1196+ let extension = state. init_extension :: < MintPaddingTest > ( true ) . unwrap ( ) ;
11731197 extension. padding1 = [ 1 ; 128 ] ;
11741198 extension. padding2 = [ 1 ; 48 ] ;
11751199 extension. padding3 = [ 1 ; 9 ] ;
@@ -1206,12 +1230,12 @@ mod test {
12061230 StateWithExtensionsMut :: < Account > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
12071231 // fail init mint extension
12081232 assert_eq ! (
1209- state. init_extension:: <TransferFeeConfig >( ) ,
1233+ state. init_extension:: <TransferFeeConfig >( true ) ,
12101234 Err ( ProgramError :: InvalidAccountData ) ,
12111235 ) ;
12121236 // success write extension
12131237 let withheld_amount = PodU64 :: from ( u64:: MAX ) ;
1214- let extension = state. init_extension :: < TransferFeeAmount > ( ) . unwrap ( ) ;
1238+ let extension = state. init_extension :: < TransferFeeAmount > ( true ) . unwrap ( ) ;
12151239 extension. withheld_amount = withheld_amount;
12161240
12171241 assert_eq ! (
@@ -1305,7 +1329,7 @@ mod test {
13051329 state. init_account_type ( ) . unwrap ( ) ;
13061330
13071331 // write padding
1308- let extension = state. init_extension :: < AccountPaddingTest > ( ) . unwrap ( ) ;
1332+ let extension = state. init_extension :: < AccountPaddingTest > ( true ) . unwrap ( ) ;
13091333 extension. 0 . padding1 = [ 2 ; 128 ] ;
13101334 extension. 0 . padding2 = [ 2 ; 48 ] ;
13111335 extension. 0 . padding3 = [ 2 ; 9 ] ;
@@ -1341,7 +1365,7 @@ mod test {
13411365 let mut state = StateWithExtensionsMut :: < Account > :: unpack ( & mut buffer) . unwrap ( ) ;
13421366 assert_eq ! ( state. base, TEST_ACCOUNT ) ;
13431367 assert_eq ! ( state. account_type[ 0 ] , AccountType :: Account as u8 ) ;
1344- state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ; // just confirming initialization works
1368+ state. init_extension :: < ImmutableOwner > ( true ) . unwrap ( ) ; // just confirming initialization works
13451369
13461370 // account with buffer big enough for AccountType only
13471371 let mut buffer = TEST_ACCOUNT_SLICE . to_vec ( ) ;
@@ -1384,7 +1408,7 @@ mod test {
13841408 let mut state = StateWithExtensionsMut :: < Mint > :: unpack ( & mut buffer) . unwrap ( ) ;
13851409 assert_eq ! ( state. base, TEST_MINT ) ;
13861410 assert_eq ! ( state. account_type[ 0 ] , AccountType :: Mint as u8 ) ;
1387- state. init_extension :: < MintCloseAuthority > ( ) . unwrap ( ) ;
1411+ state. init_extension :: < MintCloseAuthority > ( true ) . unwrap ( ) ;
13881412
13891413 // mint with buffer big enough for AccountType only
13901414 let mut buffer = TEST_MINT_SLICE . to_vec ( ) ;
@@ -1499,7 +1523,7 @@ mod test {
14991523
15001524 // fail init extension
15011525 assert_eq ! (
1502- state. init_extension:: <TransferFeeConfig >( ) ,
1526+ state. init_extension:: <TransferFeeConfig >( true ) ,
15031527 Err ( ProgramError :: InvalidAccountData ) ,
15041528 ) ;
15051529
@@ -1514,7 +1538,7 @@ mod test {
15141538 state. base = TEST_MINT ;
15151539 state. pack_base ( ) ;
15161540 state. init_account_type ( ) . unwrap ( ) ;
1517- let extension = state. init_extension :: < MintPaddingTest > ( ) . unwrap ( ) ;
1541+ let extension = state. init_extension :: < MintPaddingTest > ( true ) . unwrap ( ) ;
15181542 assert_eq ! ( extension. padding1, [ 1 ; 128 ] ) ;
15191543 assert_eq ! ( extension. padding2, [ 2 ; 48 ] ) ;
15201544 assert_eq ! ( extension. padding3, [ 3 ; 9 ] ) ;
@@ -1526,7 +1550,9 @@ mod test {
15261550 ExtensionType :: get_account_len :: < Mint > ( & [ ExtensionType :: MintCloseAuthority ] ) ;
15271551 let mut buffer = vec ! [ 0 ; mint_size - 1 ] ;
15281552 let mut state = StateWithExtensionsMut :: < Mint > :: unpack_uninitialized ( & mut buffer) . unwrap ( ) ;
1529- let err = state. init_extension :: < MintCloseAuthority > ( ) . unwrap_err ( ) ;
1553+ let err = state
1554+ . init_extension :: < MintCloseAuthority > ( true )
1555+ . unwrap_err ( ) ;
15301556 assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
15311557
15321558 state. tlv_data [ 0 ] = 3 ;
@@ -1556,7 +1582,7 @@ mod test {
15561582 state. base = TEST_ACCOUNT ;
15571583 state. pack_base ( ) ;
15581584 state. init_account_type ( ) . unwrap ( ) ;
1559- state. init_extension :: < ImmutableOwner > ( ) . unwrap ( ) ;
1585+ state. init_extension :: < ImmutableOwner > ( true ) . unwrap ( ) ;
15601586
15611587 assert_eq ! (
15621588 get_first_extension_type( state. tlv_data) . unwrap( ) ,
0 commit comments