diff --git a/src/github.rs b/src/github.rs index 5d66295..bfb503a 100644 --- a/src/github.rs +++ b/src/github.rs @@ -55,6 +55,8 @@ impl Knowable { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize)] pub struct Asset { + /// (Asset) ID needed to download assets from private repositories + pub id: u64, pub name: String, pub size: u64, @@ -68,6 +70,7 @@ pub struct Asset { impl Asset> { fn known(self) -> Option { self.mime.known().map(|mime| Asset { + id: self.id, name: self.name, size: self.size, url: self.url, @@ -116,6 +119,11 @@ pub struct GitHubArgs { #[arg(short = 't', long)] pub tag: String, + /// Download Release Assets w/o supplying token, which doesn't work for + /// private repos. + #[arg(long)] + pub public: bool, + /// Filter asset names #[arg(trailing_var_arg = true)] pub filter: Vec, @@ -243,6 +251,26 @@ impl GitHub { }) } + /// Get the GitHub token if available + pub fn token(&self) -> Option<&str> { + self.args.token.as_deref() + } + + /// Returns true if the repository is not public (i.e., `public` flag is false). + pub const fn is_private(&self) -> bool { + !self.args.public + } + + /// Get the owner name + pub fn owner(&self) -> &str { + &self.args.owner + } + + /// Get the repo name + pub fn repo(&self) -> &str { + &self.args.repo + } + pub async fn assets(&self) -> Result> { let url = format!( "https://api.github.com/repos/{}/{}/releases/tags/{}", @@ -250,6 +278,7 @@ impl GitHub { ); let response = self.client.get(&url).send().await?; + let release: Release = response.json().await?; let assets = release diff --git a/src/http/server.rs b/src/http/server.rs index 407d900..e2aba99 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -49,11 +49,35 @@ impl Server { attempt.stop() }); + let mut client_builder = Client::builder().redirect(policy); + if github.is_private() { + // Build client with GitHub authentication if token is available + if let Some(token) = github.token() { + let mut headers = reqwest::header::HeaderMap::new(); + let mut auth_value = format!("token {token}") + .parse::() + .unwrap(); + auth_value.set_sensitive(true); + headers.insert(reqwest::header::AUTHORIZATION, auth_value); + headers.insert( + reqwest::header::USER_AGENT, + concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")) + .parse() + .unwrap(), + ); + headers.insert( + reqwest::header::ACCEPT, + "application/octet-stream".parse().unwrap(), + ); + client_builder = client_builder.default_headers(headers); + } + } + Ok(Self { listener, status, github, - client: Client::builder().redirect(policy).build()?, + client: client_builder.build()?, path, }) } diff --git a/src/http/service.rs b/src/http/service.rs index 215a8ac..9099061 100644 --- a/src/http/service.rs +++ b/src/http/service.rs @@ -45,6 +45,40 @@ impl Service { path, } } + + /// Build a reqwest request for an asset, handling private/public repos and method. + fn build_request( + gh: &GitHub, + asset: &Asset, + client: &Client, + method: &Method, + ) -> reqwest::RequestBuilder { + let url = if gh.is_private() { + format!( + "https://api.github.com/repos/{}/{}/releases/assets/{}", + gh.owner(), + gh.repo(), + asset.id + ) + } else { + asset.url.clone() + }; + + match method { + &Method::HEAD => client.head(url), + // Only HEAD and GET are expected; all others default to GET. + _ => client.get(url), + } + } + + fn inspect_response(response: &reqwest::Response) -> (reqwest::StatusCode, Option<&str>) { + let status_code = response.status(); + let content_length = response + .headers() + .get("content-length") + .and_then(|v| v.to_str().ok()); + (status_code, content_length) + } } impl hyper::service::Service> for Service { @@ -110,7 +144,21 @@ impl hyper::service::Service> for Service { None => return Ok(POWEROFF_EFI.reply(None, Type::Efi, EMPTY)), // Send the request (possibly redirecting...) - Some(asset) => (client.head(asset.url).send().await?, asset.mime), + Some(asset) => { + let request = + Self::build_request(&github, &asset, &client, &Method::HEAD); + + match request.send().await { + Ok(resp) => { + let (_status_code, _content_length) = + Self::inspect_response(&resp); + (resp, asset.mime) + } + Err(_e) => { + return Ok(EMPTY.reply(Code::BAD_GATEWAY, None, None)); + } + } + } } } @@ -122,9 +170,20 @@ impl hyper::service::Service> for Service { // Send the request (possibly redirecting...) Some(asset) => { - let response = client.get(asset.url).send().await?; - status.lock().await.update().downloading(remote); - (response, asset.mime) + let request = + Self::build_request(&github, &asset, &client, &Method::GET); + + match request.send().await { + Ok(resp) => { + let (_status_code, _content_length) = + Self::inspect_response(&resp); + status.lock().await.update().downloading(remote); + (resp, asset.mime) + } + Err(_e) => { + return Ok(EMPTY.reply(Code::BAD_GATEWAY, None, None)); + } + } } } }