Skip to content

Commit d75ae07

Browse files
authored
chore: fix s3 and tests race condition, bump deps (#78)
1 parent c8cb1ac commit d75ae07

File tree

10 files changed

+384
-734
lines changed

10 files changed

+384
-734
lines changed

Cargo.lock

Lines changed: 285 additions & 693 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,29 @@
22
authors = ["Alex Chi <iskyzh@gmail.com>"]
33
edition = "2018"
44
name = "mirror-intel"
5+
license = "Apache-2.0"
56
version = "0.1.0"
67
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
78

89
[dependencies]
9-
actix-http = "3.2"
10-
actix-web = { version = "4.1", features = ["macros"] }
11-
aws-config = { version = "1.8", features = ["behavior-version-latest"] }
12-
aws-sdk-s3 = "1.123"
10+
actix-http = "3.12"
11+
actix-web = { version = "4.13", features = ["macros"] }
12+
# NOTE: avoid massive deps used by aws-config
13+
# aws-config = { version = "1", features = ["behavior-version-latest"] }
14+
aws-sdk-s3 = { version = "1.122.0", features = ["behavior-version-latest"] }
1315
bytes = "1.11"
1416
figment = { version = "0.10", features = ["toml"] }
1517
futures = "0.3"
1618
futures-util = "0.3"
1719
lazy_static = "1.4"
18-
paste = "1"
1920
percent-encoding = "2.1"
2021
prometheus = "0.14"
2122
regex = "1"
22-
reqwest = { version = "0.11", features = ["stream"] }
23+
reqwest = { version = "0.13", default-features = false, features = ["stream"] }
2324
rstest = "0.26"
2425
serde = "1.0"
2526
thiserror = "2.0"
26-
tokio = { version = "1.0", features = ["full"] }
27+
tokio = { version = "1", features = ["full"] }
2728
tracing = "0.1"
2829
tracing-actix-web = "0.7"
2930
tracing-appender = "0.2"
@@ -33,6 +34,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
3334
url = "2.2"
3435

3536
[dev-dependencies]
36-
tempdir = "0.3"
37+
tempfile = "3"
3738
httpmock = "0.8"
3839
figment = { version = "0.10", features = ["test"] }
40+
serial_test = "3"

Rocket.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ name = "jCloud S3"
4949
endpoint = "https://s3.jcloud.sjtu.edu.cn"
5050
website_endpoint = "https://s3.jcloud.sjtu.edu.cn"
5151
bucket = "899a892efef34b1b944a19981040f55b-oss01"
52+
sentinel_object_key = "sjtug-internal/mirror-intel/releases/download/v0.1.35/mirror-intel.tar.gz"

src/artifacts.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,13 @@ pub async fn download_artifacts(
391391
#[cfg(test)]
392392
mod tests {
393393
use httpmock::MockServer;
394-
use tempdir::TempDir;
395394
use tokio::fs;
396395

397396
use super::*;
398397

399398
#[tokio::test]
400399
async fn must_download_payload_to_memory() {
401-
let tmp_dir = TempDir::new("intel").unwrap();
400+
let tmp_dir = tempfile::Builder::new().prefix("intel").tempdir().unwrap();
402401
let config = Config {
403402
buffer_path: tmp_dir.path().to_path_buf(),
404403
file_threshold_mb: 1,
@@ -430,7 +429,7 @@ mod tests {
430429

431430
#[tokio::test]
432431
async fn must_download_payload_to_file() {
433-
let tmp_dir = TempDir::new("intel").unwrap();
432+
let tmp_dir = tempfile::Builder::new().prefix("intel").tempdir().unwrap();
434433
let config = Config {
435434
buffer_path: tmp_dir.path().to_path_buf(),
436435
file_threshold_mb: 0,
@@ -468,7 +467,7 @@ mod tests {
468467

469468
#[tokio::test]
470469
async fn must_reject_large_payload() {
471-
let tmp_dir = TempDir::new("intel").unwrap();
470+
let tmp_dir = tempfile::Builder::new().prefix("intel").tempdir().unwrap();
472471
let config = Config {
473472
buffer_path: tmp_dir.path().to_path_buf(),
474473
file_threshold_mb: 0,

src/common.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ pub struct S3Config {
212212
pub website_endpoint: String,
213213
/// Bucket name.
214214
pub bucket: String,
215+
/// [Test only] Object key used by get-object to check S3 availability.
216+
pub sentinel_object_key: Option<String>,
215217
}
216218

217219
/// Configuration for Github Release endpoint.
@@ -369,13 +371,16 @@ pub fn collect_config() -> Config {
369371
#[cfg(test)]
370372
mod tests {
371373
use figment::Jail;
374+
use serial_test::serial;
372375

373376
use crate::common::{
374377
collect_config, EndpointOverride, Endpoints, GithubReleaseConfig, S3Config,
375378
};
376379
use crate::Config;
377380

381+
#[allow(clippy::result_large_err)]
378382
#[test]
383+
#[serial(cwd_env)]
379384
fn must_collect_config() {
380385
const MIRROR_INTEL_TOML: &str = include_str!("../tests/config/mirror-intel.toml");
381386
const ROCKET_TOML: &str = include_str!("../tests/config/Rocket.toml");
@@ -430,6 +435,10 @@ mod tests {
430435
endpoint: "https://s3.jcloud.sjtu.edu.cn".into(),
431436
website_endpoint: "https://s3.jcloud.sjtu.edu.cn".into(),
432437
bucket: "899a892efef34b1b944a19981040f55b-oss01".into(),
438+
sentinel_object_key: Some(
439+
"sjtug-internal/mirror-intel/releases/download/v0.1.35/mirror-intel.tar.gz"
440+
.into(),
441+
),
433442
},
434443
user_agent: "mirror-intel / 0.1 (siyuan.internal.sjtug.org)".into(),
435444
file_threshold_mb: 4,

src/error.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use thiserror::Error;
77

88
type PutObjectSdkError =
99
aws_sdk_s3::error::SdkError<aws_sdk_s3::operation::put_object::PutObjectError>;
10+
type GetObjectSdkError =
11+
aws_sdk_s3::error::SdkError<aws_sdk_s3::operation::get_object::GetObjectError>;
1012
type ListObjectsSdkError =
1113
aws_sdk_s3::error::SdkError<aws_sdk_s3::operation::list_objects::ListObjectsError>;
1214

@@ -30,6 +32,8 @@ pub enum Error {
3032
InvalidRequest(()),
3133
#[error("Put Object Error {0}")]
3234
PutObjectError(Box<PutObjectSdkError>),
35+
#[error("Get Object Error {0}")]
36+
GetObjectsError(Box<GetObjectSdkError>),
3337
#[error("List Objects Error {0}")]
3438
ListObjectsError(Box<ListObjectsSdkError>),
3539
#[error("Timeout")]
@@ -44,6 +48,11 @@ impl From<PutObjectSdkError> for Error {
4448
Self::PutObjectError(Box::new(error))
4549
}
4650
}
51+
impl From<GetObjectSdkError> for Error {
52+
fn from(error: GetObjectSdkError) -> Self {
53+
Self::GetObjectsError(Box::new(error))
54+
}
55+
}
4756
impl From<ListObjectsSdkError> for Error {
4857
fn from(error: ListObjectsSdkError) -> Self {
4958
Self::ListObjectsError(Box::new(error))

src/repos.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,12 @@ mod tests {
450450
use actix_web::test::{call_service, init_service, TestRequest};
451451
use actix_web::App;
452452
use figment::providers::{Format, Toml};
453+
use figment::util::map;
453454
use figment::Figment;
454455
use httpmock::MockServer;
455456
use reqwest::ClientBuilder;
456457
use rstest::rstest;
458+
use serial_test::serial;
457459
use tokio::sync::mpsc::{channel, Receiver};
458460
use url::Url;
459461

@@ -488,6 +490,14 @@ mod tests {
488490
let figment = Figment::new()
489491
.join(("address", "127.0.0.1"))
490492
.join(("port", 8000))
493+
.join(("concurrent_download", 512))
494+
.join(("max_pending_task", 16384))
495+
.join((
496+
"endpoints",
497+
map!["sjtug_internal" => "https://github.com/sjtug"],
498+
))
499+
.join(("s3.name", "Placeholder S3"))
500+
.join(("s3.endpoint", server.base_url()))
491501
.join(("s3.website_endpoint", server.base_url()))
492502
.join(("s3.bucket", "bucket"))
493503
.join(("direct_stream_size_kb", 0))
@@ -591,6 +601,7 @@ mod tests {
591601
#[case(Method::HEAD, missing_object(), StatusCode::FOUND, | o: & Task, _c: & Config | o.upstream_url())]
592602
#[case(Method::GET, forbidden_object(), StatusCode::MOVED_PERMANENTLY, | o: & Task, _c: & Config | o.upstream_url())]
593603
#[case(Method::HEAD, forbidden_object(), StatusCode::MOVED_PERMANENTLY, | o: & Task, _c: & Config | o.upstream_url())]
604+
#[serial(cwd_env)]
594605
#[tokio::test]
595606
async fn test_get_head(
596607
#[case] method: Method,
@@ -632,6 +643,7 @@ mod tests {
632643
#[case("/pytorch-wheels", is_no_route_for("pytorch-wheels"))]
633644
#[case("/pytorch-wheels/?mirror_intel_list", is_index_for("pytorch-wheels"))]
634645
#[case("/pytorch-wheels?mirror_intel_list", is_index_for("pytorch-wheels"))]
646+
#[serial(cwd_env)]
635647
#[tokio::test]
636648
async fn test_index_list_page(#[case] url: &str, #[case] assert_f: impl FnOnce(&str)) {
637649
let (service, _config, _rx, _server) = make_service().await;
@@ -642,6 +654,7 @@ mod tests {
642654
assert_f(text);
643655
}
644656

657+
#[serial(cwd_env)]
645658
#[tokio::test]
646659
async fn test_url_segment() {
647660
// this case is to test if we could process escaped URL correctly
@@ -664,6 +677,7 @@ mod tests {
664677
);
665678
}
666679

680+
#[serial(cwd_env)]
667681
#[tokio::test]
668682
async fn test_url_segment_fail() {
669683
// this case is to test if we could process escaped URL correctly
@@ -682,6 +696,7 @@ mod tests {
682696
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
683697
}
684698

699+
#[serial(cwd_env)]
685700
#[tokio::test]
686701
async fn test_url_segment_query() {
687702
// this case is to test if we could process escaped URL correctly
@@ -741,6 +756,7 @@ mod tests {
741756
assert_eq!(task.origin, "https://storage.googleapis.com/");
742757
}
743758

759+
#[serial(cwd_env)]
744760
#[tokio::test]
745761
async fn test_proxy_head() {
746762
// if an object doesn't exist in s3, we should temporarily redirect users to upstream

src/storage.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
//! S3 storage backend.
22
use std::time::Duration;
33

4-
use aws_config::BehaviorVersion;
5-
use aws_sdk_s3::config::Region;
4+
use aws_sdk_s3::config::{BehaviorVersion, Region};
65
use aws_sdk_s3::Client as S3Client;
76
use tokio::time::timeout;
87

@@ -17,19 +16,11 @@ fn s3_region(s3_config: &S3Config) -> Region {
1716
///
1817
/// The default credential provider is used.
1918
async fn get_s3_client(s3_config: &S3Config) -> S3Client {
20-
let shared_config = aws_config::defaults(BehaviorVersion::latest())
21-
.region(s3_region(s3_config))
22-
.load()
23-
.await;
24-
25-
let mut s3_builder = aws_sdk_s3::Config::builder()
19+
let s3_builder = aws_sdk_s3::Config::builder()
2620
.region(s3_region(s3_config))
2721
.endpoint_url(s3_config.endpoint.clone())
28-
.behavior_version(BehaviorVersion::latest())
2922
.force_path_style(true);
3023

31-
s3_builder.set_credentials_provider(shared_config.credentials_provider());
32-
3324
S3Client::from_conf(s3_builder.build())
3425
}
3526

@@ -41,7 +32,6 @@ pub fn get_anonymous_s3_client(s3_config: &S3Config) -> S3Client {
4132
aws_sdk_s3::Config::builder()
4233
.region(s3_region(s3_config))
4334
.endpoint_url(s3_config.endpoint.clone())
44-
.behavior_version(BehaviorVersion::latest())
4535
.force_path_style(true)
4636
.allow_no_auth()
4737
.build(),
@@ -74,11 +64,28 @@ pub async fn check_s3(s3_config: &S3Config) -> Result<()> {
7464
timeout(Duration::from_secs(1), async move {
7565
let s3_client = get_s3_client(s3_config).await;
7666

77-
s3_client
78-
.list_objects()
79-
.bucket(s3_config.bucket.clone())
80-
.send()
81-
.await?;
67+
// s3_client
68+
// .list_objects()
69+
// .bucket(s3_config.bucket.clone())
70+
// .send()
71+
// .await?;
72+
if let Some(sentinel_object_key) = &s3_config.sentinel_object_key {
73+
s3_client
74+
.get_object()
75+
.bucket(s3_config.bucket.clone())
76+
.key(sentinel_object_key.clone())
77+
.range("bytes=0-0")
78+
.send()
79+
.await?;
80+
}
81+
// NOTE: this can be too heavy for jCloud S3, thus we check only sentinel object instead of listing all objects.
82+
else {
83+
s3_client
84+
.list_objects()
85+
.bucket(s3_config.bucket.clone())
86+
.send()
87+
.await?;
88+
}
8289

8390
Ok::<(), Error>(())
8491
})

src/utils.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use actix_web::body::{BodyStream, SizedStream};
22
use actix_web::http::header::ContentType;
3-
use actix_web::http::{header, StatusCode, Uri};
3+
use actix_web::http::{StatusCode, Uri};
44
use actix_web::{HttpResponse, Responder};
55
use futures::stream::TryStreamExt;
66
use tracing::debug;
@@ -115,15 +115,19 @@ impl IntelObject {
115115
if resp.status().is_success() {
116116
if let Some(content_length) = resp.content_length() {
117117
if content_length <= below_size_kb * 1024 {
118-
let content_type = resp.headers().get(header::CONTENT_TYPE).cloned();
118+
let content_type = resp.headers().get("content-type").cloned();
119119
let text = resp.text().await?;
120120
let text = f(text);
121121

122122
return Ok(if let Some(content_type) = content_type {
123-
HttpResponse::Ok()
124-
.content_type(content_type)
125-
.body(text)
126-
.into()
123+
if let Ok(content_type) = content_type.to_str() {
124+
HttpResponse::Ok()
125+
.content_type(content_type)
126+
.body(text)
127+
.into()
128+
} else {
129+
HttpResponse::Ok().body(text).into()
130+
}
127131
} else {
128132
HttpResponse::Ok().body(text).into()
129133
});
@@ -143,10 +147,12 @@ impl IntelObject {
143147
}
144148
};
145149

146-
let code = upstream_resp.status().normalize();
150+
let code = upstream_resp.status().as_u16().normalize();
147151
let mut resp = HttpResponse::build(code);
148-
if let Some(content_type) = upstream_resp.headers().get(header::CONTENT_TYPE) {
149-
resp.content_type(content_type);
152+
if let Some(content_type) = upstream_resp.headers().get("content-type") {
153+
if let Ok(content_type) = content_type.to_str() {
154+
resp.content_type(content_type);
155+
}
150156
}
151157
let content_length = upstream_resp.content_length();
152158
let stream = upstream_resp
@@ -220,10 +226,16 @@ trait StatusCodeExt {
220226

221227
impl StatusCodeExt for StatusCode {
222228
fn normalize(self) -> StatusCode {
223-
if self.as_u16() == 499 {
224-
Self::NOT_FOUND
229+
self.as_u16().normalize()
230+
}
231+
}
232+
233+
impl StatusCodeExt for u16 {
234+
fn normalize(self) -> StatusCode {
235+
if self == 499 {
236+
StatusCode::NOT_FOUND
225237
} else {
226-
self
238+
StatusCode::from_u16(self).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
227239
}
228240
}
229241
}
@@ -254,6 +266,7 @@ mod tests {
254266
endpoint: "http://localhost:8081".to_string(),
255267
website_endpoint: server.base_url(),
256268
bucket: "bucket".to_string(),
269+
sentinel_object_key: None,
257270
},
258271
..Default::default()
259272
};

tests/config/Rocket.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[default]
22
address = "0.0.0.0"
33
concurrent_download = 256
4+
max_pending_task = 16384
45
user_agent = "mirror-intel / 0.1 (siyuan.internal.sjtug.org)"
56
file_threshold_mb = 4
67
ignore_threshold_mb = 1024
@@ -48,3 +49,4 @@ name = "jCloud S3"
4849
endpoint = "https://s3.jcloud.sjtu.edu.cn"
4950
website_endpoint = "https://s3.jcloud.sjtu.edu.cn"
5051
bucket = "899a892efef34b1b944a19981040f55b-oss01"
52+
sentinel_object_key = "sjtug-internal/mirror-intel/releases/download/v0.1.35/mirror-intel.tar.gz"

0 commit comments

Comments
 (0)