Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 314 additions & 1 deletion crates/tokscale-cli/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use serde::{Deserialize, Serialize};
use std::fs;
use std::io::IsTerminal;
use std::io::Write;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};

fn home_dir() -> Result<PathBuf> {
dirs::home_dir().context("Could not determine home directory")
Expand Down Expand Up @@ -53,6 +55,18 @@ fn get_credentials_path() -> Result<PathBuf> {
Ok(home_dir()?.join(".config/tokscale/credentials.json"))
}

fn get_source_id_path() -> Result<PathBuf> {
Ok(home_dir()?.join(".config/tokscale/source-id"))
}

fn get_source_id_lock_path() -> Result<PathBuf> {
Ok(home_dir()?.join(".config/tokscale/source-id.lock"))
}

const SOURCE_ID_LOCK_RETRY_DELAY: Duration = Duration::from_millis(25);
const SOURCE_ID_LOCK_STALE_AFTER: Duration = Duration::from_secs(2);
const SOURCE_ID_LOCK_MAX_WAIT: Duration = Duration::from_secs(10);

fn ensure_config_dir() -> Result<()> {
let config_dir = home_dir()?.join(".config/tokscale");

Expand Down Expand Up @@ -125,6 +139,249 @@ fn get_device_name() -> String {
format!("CLI on {}", hostname)
}

fn read_source_id(path: &Path) -> Option<String> {
let content = fs::read_to_string(path).ok()?;
let trimmed = content.trim();
if trimmed.is_empty() {
return None;
}
Some(trimmed.to_string())
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct SourceIdLockState {
pid: u32,
created_at_ms: u128,
}

fn current_unix_ms() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
}

fn serialize_source_id_lock_state(state: SourceIdLockState) -> String {
format!("pid={}\ncreated_at_ms={}\n", state.pid, state.created_at_ms)
}

fn parse_source_id_lock_state(content: &str) -> Option<SourceIdLockState> {
let mut pid = None;
let mut created_at_ms = None;

for line in content.lines() {
let (key, value) = line.split_once('=')?;
match key.trim() {
"pid" => pid = value.trim().parse::<u32>().ok(),
"created_at_ms" => created_at_ms = value.trim().parse::<u128>().ok(),
_ => {}
}
}

Some(SourceIdLockState {
pid: pid?,
created_at_ms: created_at_ms?,
})
}

fn read_source_id_lock_state(path: &Path) -> Option<SourceIdLockState> {
let content = fs::read_to_string(path).ok()?;
parse_source_id_lock_state(&content)
}

fn lock_age(path: &Path, state: Option<SourceIdLockState>) -> Duration {
if let Some(state) = state {
let now_ms = current_unix_ms();
let age_ms = now_ms.saturating_sub(state.created_at_ms);
return Duration::from_millis(age_ms.min(u64::MAX as u128) as u64);
}

fs::metadata(path)
.and_then(|metadata| metadata.modified())
.ok()
.and_then(|modified| modified.elapsed().ok())
.unwrap_or_default()
}

fn lock_owner_is_alive(pid: u32) -> Option<bool> {
#[cfg(unix)]
{
std::process::Command::new("kill")
.args(["-0", &pid.to_string()])
.status()
.ok()
.map(|status| status.success())
}

#[cfg(windows)]
{
let output = std::process::Command::new("tasklist")
.args(["/FI", &format!("PID eq {}", pid)])
.output();

match output {
Ok(output) if output.status.success() => {
let stdout = String::from_utf8_lossy(&output.stdout);
Some(stdout.contains(&pid.to_string()) && !stdout.contains("No tasks are running"))
}
Ok(_) => None,
Err(_) => None,
}
}

#[cfg(not(any(unix, windows)))]
{
None
}
}

fn write_source_id_lock_state(mut file: fs::File, state: SourceIdLockState) -> Result<()> {
let payload = serialize_source_id_lock_state(state);
file.write_all(payload.as_bytes())?;
file.sync_all()?;
Ok(())
}

fn remove_source_id_lock_if_matches(path: &Path, expected: Option<SourceIdLockState>) -> bool {
let current_state = read_source_id_lock_state(path);
if current_state != expected {
return false;
}

match fs::remove_file(path) {
Ok(()) => true,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => false,
Err(_) => false,
}
}

struct SourceIdLock {
path: PathBuf,
state: SourceIdLockState,
}

impl Drop for SourceIdLock {
fn drop(&mut self) {
let _ = remove_source_id_lock_if_matches(&self.path, Some(self.state));
}
}

fn acquire_source_id_lock() -> Result<SourceIdLock> {
ensure_config_dir()?;
let lock_path = get_source_id_lock_path()?;
let deadline = Instant::now() + SOURCE_ID_LOCK_MAX_WAIT;

loop {
match fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&lock_path)
{
Ok(file) => {
let state = SourceIdLockState {
pid: std::process::id(),
created_at_ms: current_unix_ms(),
};

if let Err(err) = write_source_id_lock_state(file, state) {
let _ = fs::remove_file(&lock_path);
return Err(err);
}

return Ok(SourceIdLock {
path: lock_path,
state,
});
}
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
let state = read_source_id_lock_state(&lock_path);
let age = lock_age(&lock_path, state);
let owner_is_dead = match state {
Some(lock_state) => match lock_owner_is_alive(lock_state.pid) {
Some(is_alive) => !is_alive,
None => age >= SOURCE_ID_LOCK_STALE_AFTER,
},
None => true,
};

if owner_is_dead && age >= SOURCE_ID_LOCK_STALE_AFTER {
let _ = remove_source_id_lock_if_matches(&lock_path, state);
continue;
}

if Instant::now() >= deadline {
break;
}

thread::sleep(SOURCE_ID_LOCK_RETRY_DELAY);
}
Err(err) => return Err(err.into()),
}
}

anyhow::bail!("Could not acquire source ID lock after waiting for stale lock cleanup");
}

fn write_source_id(path: &Path, source_id: &str) -> Result<()> {
let temp_path = path.with_extension(format!("tmp-{}", std::process::id()));

#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;

let mut file = fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.mode(0o600)
.open(&temp_path)?;
file.write_all(source_id.as_bytes())?;
file.write_all(b"\n")?;
}

#[cfg(not(unix))]
{
fs::write(&temp_path, format!("{source_id}\n"))?;
}

fs::rename(&temp_path, path)?;
Ok(())
}

pub fn get_submit_source_id() -> Result<Option<String>> {
if let Some(source_id) = std::env::var_os("TOKSCALE_SOURCE_ID") {
let trimmed = source_id.to_string_lossy().trim().to_string();
if !trimmed.is_empty() {
return Ok(Some(trimmed));
}
}

ensure_config_dir()?;
let path = get_source_id_path()?;

if let Some(existing) = read_source_id(&path) {
return Ok(Some(existing));
}

let _lock = acquire_source_id_lock()?;

if let Some(existing) = read_source_id(&path) {
return Ok(Some(existing));
}

let source_id = uuid::Uuid::new_v4().to_string();
write_source_id(&path, &source_id)?;
Ok(Some(source_id))
}

pub fn get_submit_source_name() -> Option<String> {
std::env::var("TOKSCALE_SOURCE_NAME")
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.or_else(|| Some(get_device_name()))
}

#[cfg(target_os = "linux")]
fn has_non_empty_env_var(name: &str) -> bool {
std::env::var_os(name).is_some_and(|value| !value.is_empty())
Expand Down Expand Up @@ -508,6 +765,62 @@ mod tests {
}
}

#[test]
#[serial]
fn test_get_submit_source_id_uses_env_override() {
let temp_dir = TempDir::new().unwrap();
unsafe {
env::set_var("HOME", temp_dir.path());
env::set_var("TOKSCALE_SOURCE_ID", " source-from-env ");
}

let source_id = get_submit_source_id().unwrap();

assert_eq!(source_id.as_deref(), Some("source-from-env"));
assert!(!get_source_id_path().unwrap().exists());

unsafe {
env::remove_var("TOKSCALE_SOURCE_ID");
env::remove_var("HOME");
}
}

#[test]
#[serial]
fn test_get_submit_source_id_persists_generated_value() {
let temp_dir = TempDir::new().unwrap();
unsafe {
env::set_var("HOME", temp_dir.path());
env::remove_var("TOKSCALE_SOURCE_ID");
}

let first = get_submit_source_id().unwrap();
let second = get_submit_source_id().unwrap();
let path = get_source_id_path().unwrap();

assert!(path.exists());
assert_eq!(first, second);
assert_eq!(read_source_id(&path), first);

unsafe {
env::remove_var("HOME");
}
}

#[test]
#[serial]
fn test_get_submit_source_name_uses_trimmed_env_override() {
unsafe {
env::set_var("TOKSCALE_SOURCE_NAME", " Work Laptop ");
}

assert_eq!(get_submit_source_name().as_deref(), Some("Work Laptop"));

unsafe {
env::remove_var("TOKSCALE_SOURCE_NAME");
}
}

#[test]
#[serial]
fn test_save_credentials() {
Expand Down
Loading
Loading