diff --git a/src/optimized_transaction.rs b/src/optimized_transaction.rs index 4fdeddb..2e030c3 100644 --- a/src/optimized_transaction.rs +++ b/src/optimized_transaction.rs @@ -6,6 +6,7 @@ use crate::types::{ SmartTransactionConfig, Timeout, }; use crate::Helius; +use std::collections::HashSet; use std::str::FromStr; use std::sync::Arc; @@ -64,6 +65,42 @@ const MIN_TIP_LAMPORTS_DUAL: u64 = 200_000; // 0.0002 SOL /// Minimum tip: 0.000005 SOL (5,000 lamports). const MIN_TIP_LAMPORTS_SWQOS: u64 = 5_000; // 0.000005 SOL +fn collect_unique_signers(signers: &[Arc], fee_payer: Option<&Arc>) -> Vec> { + let mut all_signers: Vec> = Vec::with_capacity(signers.len() + usize::from(fee_payer.is_some())); + let mut seen: HashSet = HashSet::with_capacity(all_signers.capacity()); + + if let Some(fee_payer) = fee_payer { + if seen.insert(fee_payer.pubkey()) { + all_signers.push(fee_payer.clone()); + } + } + + for signer in signers { + if seen.insert(signer.pubkey()) { + all_signers.push(signer.clone()); + } + } + + all_signers +} + +fn collect_unique_keypair_refs<'a>(signers: &'a [Keypair], fee_payer: &'a Keypair) -> Vec<&'a Keypair> { + let mut all_signers: Vec<&Keypair> = Vec::with_capacity(signers.len() + 1); + let mut seen: HashSet = HashSet::with_capacity(all_signers.capacity()); + + if seen.insert(fee_payer.pubkey()) { + all_signers.push(fee_payer); + } + + for signer in signers { + if seen.insert(signer.pubkey()) { + all_signers.push(signer); + } + } + + all_signers +} + /// URL to fetch current Jito bundle tip floor prices. /// /// This endpoint returns the minimum tip amounts required for different @@ -355,8 +392,8 @@ impl Helius { { return Ok(txt_sig); } - if status.err.is_some() { - return Err(HeliusError::TransactionError(status.err.unwrap())); + if let Some(err) = status.err { + return Err(HeliusError::TransactionError(err)); } } None => { @@ -451,12 +488,13 @@ impl Helius { final_instructions.push(compute_budget_ix); // Get the optimal compute units + let all_signers: Vec> = collect_unique_signers(&config.signers, config.fee_payer.as_ref()); let units: Option = self .get_compute_units( updated_instructions, payer_pubkey, config.lookup_tables.clone().unwrap_or_default(), - Some(&config.signers), + Some(&all_signers), ) .await?; @@ -491,26 +529,9 @@ impl Helius { let v0_message: v0::Message = v0::Message::try_compile(&payer_pubkey, &final_instructions, lookup_tables, recent_blockhash)?; let versioned_message: VersionedMessage = VersionedMessage::V0(v0_message); - - let all_signers: Vec> = if let Some(fee_payer) = config.fee_payer.as_ref() { - let mut all_signers = config.signers.clone(); - if !all_signers.iter().any(|signer| signer.pubkey() == fee_payer.pubkey()) { - all_signers.push(fee_payer.clone()); - } - all_signers - } else { - config.signers.clone() - }; - - let signatures: Vec = all_signers - .iter() - .map(|signer| signer.try_sign_message(versioned_message.serialize().as_slice())) - .collect::, _>>()?; - - let versioned_transaction = VersionedTransaction { - signatures, - message: versioned_message, - }; + let versioned_transaction: VersionedTransaction = + VersionedTransaction::try_new(versioned_message, all_signers.as_slice()) + .map_err(|e| HeliusError::InvalidInput(format!("Signing error: {:?}", e)))?; Ok(( SmartTransaction::Versioned(versioned_transaction), @@ -518,11 +539,7 @@ impl Helius { )) } else { let mut tx: Transaction = Transaction::new_with_payer(&final_instructions, Some(&payer_pubkey)); - tx.try_partial_sign(&config.signers, recent_blockhash)?; - - if let Some(fee_payer) = config.fee_payer.as_ref() { - tx.try_partial_sign(&[fee_payer], recent_blockhash)?; - } + tx.try_partial_sign(&all_signers, recent_blockhash)?; Ok((SmartTransaction::Legacy(tx), last_valid_block_hash)) } @@ -637,16 +654,8 @@ impl Helius { let versioned_message: VersionedMessage = VersionedMessage::V0(v0_message); let transaction: VersionedTransaction = if let Some(keypairs) = keypairs { - let mut tx = VersionedTransaction { - signatures: vec![Signature::default(); keypairs.len()], - message: versioned_message.clone(), - }; - - for (i, keypair) in keypairs.iter().enumerate() { - tx.signatures[i] = keypair.sign_message(&versioned_message.serialize()); - } - - tx + VersionedTransaction::try_new(versioned_message, keypairs) + .map_err(|e| HeliusError::InvalidInput(format!("Signing error: {:?}", e)))? } else { VersionedTransaction { signatures: vec![], @@ -759,14 +768,7 @@ impl Helius { let mut test_instructions: Vec = final_instructions.clone(); test_instructions.extend(create_config.instructions.clone()); - let mut all_signers: Vec<&Keypair> = vec![&fee_payer]; - let mut seen: std::collections::HashSet = std::collections::HashSet::new(); - seen.insert(fee_payer.pubkey()); - for kp in &keypairs { - if seen.insert(kp.pubkey()) { - all_signers.push(kp); - } - } + let all_signers: Vec<&Keypair> = collect_unique_keypair_refs(&keypairs, &fee_payer); let units: Option = self .get_compute_units_thread_safe( @@ -804,29 +806,13 @@ impl Helius { )?; let versioned_message: VersionedMessage = VersionedMessage::V0(message); - - let fee_payer_copy: Keypair = fee_payer.insecure_clone(); - let mut all_signers: Vec = vec![fee_payer_copy]; - all_signers.extend(keypairs.into_iter().filter(|k| k.pubkey() != fee_payer.pubkey())); - - let mut tx: VersionedTransaction = VersionedTransaction { - signatures: vec![Signature::default(); all_signers.len()], - message: versioned_message.clone(), - }; - - // Sign message with all keypairs - for (i, keypair) in all_signers.iter().enumerate() { - tx.signatures[i] = keypair.sign_message(&versioned_message.serialize()); - } + let tx: VersionedTransaction = VersionedTransaction::try_new(versioned_message, all_signers.as_slice()) + .map_err(|e| HeliusError::InvalidInput(format!("Signing error: {:?}", e)))?; SmartTransaction::Versioned(tx) } else { let mut tx: Transaction = Transaction::new_with_payer(&final_instructions, Some(&fee_payer.pubkey())); - - let mut signers: Vec<&Keypair> = vec![&fee_payer]; - signers.extend(keypairs.iter().filter(|k| k.pubkey() != fee_payer.pubkey())); - - tx.sign(&signers, recent_blockhash); + tx.sign(&all_signers, recent_blockhash); SmartTransaction::Legacy(tx) }; @@ -1210,3 +1196,190 @@ impl Helius { } } } + +#[cfg(test)] +mod tests { + use super::{collect_unique_keypair_refs, collect_unique_signers}; + use solana_sdk::{ + hash::Hash, + instruction::{AccountMeta, Instruction}, + message::{v0, VersionedMessage}, + pubkey::Pubkey, + signature::{Keypair, Signature, Signer}, + transaction::{Transaction, VersionedTransaction}, + }; + use std::sync::Arc; + + fn build_versioned_message( + payer: &Keypair, + writable_signer: &Keypair, + readonly_signer: &Keypair, + ) -> VersionedMessage { + let instruction = Instruction { + program_id: Pubkey::new_unique(), + accounts: vec![ + AccountMeta::new(writable_signer.pubkey(), true), + AccountMeta::new_readonly(readonly_signer.pubkey(), true), + ], + data: vec![], + }; + + VersionedMessage::V0( + v0::Message::try_compile(&payer.pubkey(), &[instruction], &[], Hash::new_unique()).unwrap(), + ) + } + + #[test] + fn collect_unique_signers_includes_fee_payer_once() { + let fee_payer: Arc = Arc::new(Keypair::new()); + let signer: Arc = Arc::new(Keypair::new()); + + let signers: Vec> = vec![signer.clone(), fee_payer.clone(), signer.clone()]; + let all_signers = collect_unique_signers(&signers, Some(&fee_payer)); + + let signer_pubkeys: Vec = all_signers.iter().map(|signer| signer.pubkey()).collect(); + assert_eq!(signer_pubkeys, vec![fee_payer.pubkey(), signer.pubkey()]); + } + + #[test] + fn collect_unique_keypair_refs_includes_fee_payer_once() { + let fee_payer = Keypair::new(); + let signer = Keypair::new(); + let signers = vec![ + signer.insecure_clone(), + fee_payer.insecure_clone(), + signer.insecure_clone(), + ]; + + let all_signers = collect_unique_keypair_refs(&signers, &fee_payer); + let signer_pubkeys: Vec = all_signers.iter().map(|signer| signer.pubkey()).collect(); + + assert_eq!(signer_pubkeys, vec![fee_payer.pubkey(), signer.pubkey()]); + } + + #[test] + fn versioned_try_new_reorders_arc_signers_to_match_message() { + let fee_payer = Keypair::new(); + let writable_signer = Keypair::new(); + let readonly_signer = Keypair::new(); + let fee_payer_signer: Arc = Arc::new(fee_payer.insecure_clone()); + let writable_signer_arc: Arc = Arc::new(writable_signer.insecure_clone()); + let readonly_signer_arc: Arc = Arc::new(readonly_signer.insecure_clone()); + + let message = build_versioned_message(&fee_payer, &writable_signer, &readonly_signer); + let signers: Vec> = vec![readonly_signer_arc.clone(), writable_signer_arc.clone()]; + let all_signers = collect_unique_signers(&signers, Some(&fee_payer_signer)); + let tx = VersionedTransaction::try_new(message.clone(), all_signers.as_slice()).unwrap(); + let message_bytes = message.serialize(); + + assert_eq!( + tx.signatures, + vec![ + Signature::from(fee_payer.sign_message(&message_bytes)), + Signature::from(writable_signer.sign_message(&message_bytes)), + Signature::from(readonly_signer.sign_message(&message_bytes)), + ] + ); + } + + #[test] + fn manual_fee_payer_appended_signature_order_fails_verification() { + let fee_payer = Keypair::new(); + let writable_signer = Keypair::new(); + let readonly_signer = Keypair::new(); + let message = build_versioned_message(&fee_payer, &writable_signer, &readonly_signer); + let message_bytes = message.serialize(); + + // This mirrors the pre-fix separate fee payer path: caller signers first, fee payer appended last. + let manual_signatures = [ + readonly_signer.sign_message(&message_bytes), + writable_signer.sign_message(&message_bytes), + fee_payer.sign_message(&message_bytes), + ]; + + let verification_results: Vec = manual_signatures + .iter() + .zip(message.static_account_keys().iter()) + .map(|(signature, pubkey)| signature.verify(pubkey.as_ref(), &message_bytes)) + .collect(); + + assert_eq!(verification_results, vec![false, true, false]); + } + + #[test] + fn manual_non_payer_caller_order_can_fail_verification() { + let fee_payer = Keypair::new(); + let writable_signer = Keypair::new(); + let readonly_signer = Keypair::new(); + let message = build_versioned_message(&fee_payer, &writable_signer, &readonly_signer); + let message_bytes = message.serialize(); + + // This mirrors the pre-fix seed path: fee payer first, remaining signers left in caller order. + let manual_signatures = [ + fee_payer.sign_message(&message_bytes), + readonly_signer.sign_message(&message_bytes), + writable_signer.sign_message(&message_bytes), + ]; + + let verification_results: Vec = manual_signatures + .iter() + .zip(message.static_account_keys().iter()) + .map(|(signature, pubkey)| signature.verify(pubkey.as_ref(), &message_bytes)) + .collect(); + + assert_eq!(verification_results, vec![true, false, false]); + } + + #[test] + fn versioned_try_new_reorders_keypair_signers_to_match_message() { + let fee_payer = Keypair::new(); + let writable_signer = Keypair::new(); + let readonly_signer = Keypair::new(); + + let message = build_versioned_message(&fee_payer, &writable_signer, &readonly_signer); + let signers = vec![readonly_signer.insecure_clone(), writable_signer.insecure_clone()]; + let all_signers = collect_unique_keypair_refs(&signers, &fee_payer); + let tx = VersionedTransaction::try_new(message.clone(), all_signers.as_slice()).unwrap(); + let message_bytes = message.serialize(); + + assert_eq!( + tx.signatures, + vec![ + Signature::from(fee_payer.sign_message(&message_bytes)), + Signature::from(writable_signer.sign_message(&message_bytes)), + Signature::from(readonly_signer.sign_message(&message_bytes)), + ] + ); + } + + #[test] + fn legacy_try_partial_sign_reorders_keypairs_to_match_message() { + let fee_payer = Keypair::new(); + let writable_signer = Keypair::new(); + let readonly_signer = Keypair::new(); + let recent_blockhash = Hash::new_unique(); + let instruction = Instruction { + program_id: Pubkey::new_unique(), + accounts: vec![ + AccountMeta::new(writable_signer.pubkey(), true), + AccountMeta::new_readonly(readonly_signer.pubkey(), true), + ], + data: vec![], + }; + let mut tx = Transaction::new_with_payer(&[instruction], Some(&fee_payer.pubkey())); + let signers = vec![readonly_signer.insecure_clone(), writable_signer.insecure_clone()]; + let all_signers = collect_unique_keypair_refs(&signers, &fee_payer); + + tx.try_partial_sign(&all_signers, recent_blockhash).unwrap(); + + let message_bytes = tx.message_data(); + assert_eq!( + tx.signatures, + vec![ + fee_payer.sign_message(&message_bytes), + writable_signer.sign_message(&message_bytes), + readonly_signer.sign_message(&message_bytes), + ] + ); + } +}