From 2abddb8292cc397e2ee4620592af3d158d6f23a7 Mon Sep 17 00:00:00 2001 From: Arda Atahan Ibis Date: Fri, 31 Jan 2025 12:59:25 -0800 Subject: [PATCH 1/5] add network monitor subclass --- Sources/Hub/HubApi.swift | 46 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 6e687ac..1377d5f 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -7,6 +7,7 @@ import Foundation import CryptoKit +import Network import os public struct HubApi { @@ -14,11 +15,12 @@ public struct HubApi { var hfToken: String? var endpoint: String var useBackgroundSession: Bool + var useOfflineMode: Bool? = nil public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo - public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false) { + public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false, useOfflineMode: Bool? = nil) { self.hfToken = hfToken ?? Self.hfTokenFromEnv() if let downloadBase { self.downloadBase = downloadBase @@ -28,6 +30,12 @@ public struct HubApi { } self.endpoint = endpoint self.useBackgroundSession = useBackgroundSession + + if let useOfflineMode { + self.useOfflineMode = useOfflineMode + } else { + self.useOfflineMode = NetworkMonitor.shared.shouldUseOfflineMode() + } } public static let shared = HubApi() @@ -529,6 +537,42 @@ public extension HubApi { } } +/// Network monitor helper class to help decide whether to use offline mode +private extension HubApi { + private final class NetworkMonitor { + static let shared = NetworkMonitor() + private let monitor = NWPathMonitor() + private var currentPath: NWPath? + + private init() { + monitor.pathUpdateHandler = { [weak self] path in + self?.currentPath = path + } + monitor.start(queue: DispatchQueue.global(qos: .background)) + } + + deinit { + monitor.cancel() + } + + var isConnected: Bool { + currentPath?.status == .satisfied + } + + var isExpensive: Bool { + currentPath?.isExpensive ?? false + } + + var isConstrained: Bool { + currentPath?.isConstrained ?? false + } + + func shouldUseOfflineMode() -> Bool { + return !self.isConnected || self.isExpensive || self.isConstrained + } + } +} + /// Stateless wrappers that use `HubApi` instances public extension Hub { static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] { From 4cfefc5d774b5811e86c8f3c1a427e23a00162a9 Mon Sep 17 00:00:00 2001 From: Arda Atahan Ibis Date: Fri, 31 Jan 2025 14:16:32 -0800 Subject: [PATCH 2/5] add offline mode check to download logic --- Sources/Hub/HubApi.swift | 34 ++++++++++++++++++++++++++++---- Tests/HubTests/HubApiTests.swift | 3 ++- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 1377d5f..0b559d9 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -161,10 +161,12 @@ public extension HubApi { public extension HubApi { enum EnvironmentError: LocalizedError { case invalidMetadataError(String) - + case offlineModeError(String) + public var errorDescription: String? { switch self { - case .invalidMetadataError(let message): + case .invalidMetadataError(let message), + .offlineModeError(let message): return message } } @@ -217,6 +219,7 @@ public extension HubApi { let hfToken: String? let endpoint: String? let backgroundSession: Bool + let offlineMode: Bool? let sha256Pattern = "^[0-9a-f]{64}$" let commitHashPattern = "^[0-9a-f]{40}$" @@ -360,8 +363,30 @@ public extension HubApi { @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { let metadataRelativePath = "\(relativeFilename).metadata" - let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) + + if let offlineMode = offlineMode, offlineMode { + if !downloaded { + throw EnvironmentError.offlineModeError("File not available locally in offline mode") + } + + guard let localMetadata = localMetadata else { + throw EnvironmentError.offlineModeError("Metadata not available or invalid in offline mode") + } + + let localEtag = localMetadata.etag + + // LFS file so check file integrity + if self.isValidHash(hash: localEtag, pattern: self.sha256Pattern) { + let fileHash = try computeFileHash(file: destination) + if fileHash != localEtag { + throw EnvironmentError.offlineModeError("File integrity check failed in offline mode") + } + } + + return destination + } + let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) let localCommitHash = localMetadata?.commitHash ?? "" @@ -433,7 +458,8 @@ public extension HubApi { relativeFilename: filename, hfToken: hfToken, endpoint: endpoint, - backgroundSession: useBackgroundSession + backgroundSession: useBackgroundSession, + offlineMode: useOfflineMode ) try await downloader.download { fractionDownloaded in fileProgress.completedUnitCount = Int64(100 * fractionDownloaded) diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 7248c11..5e7a5b7 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -662,7 +662,8 @@ class SnapshotDownloadTests: XCTestCase { relativeFilename: "x.bin", hfToken: nil, endpoint: nil, - backgroundSession: false + backgroundSession: false, + offlineMode: false ) XCTAssertTrue(downloader.isValidHash(hash: commitHash, pattern: downloader.commitHashPattern)) From 2f7e85161c012dda6df7eba236e32ff9293971df Mon Sep 17 00:00:00 2001 From: Arda Atahan Ibis Date: Mon, 3 Feb 2025 09:06:10 -0800 Subject: [PATCH 3/5] add offline mode test --- Sources/Hub/HubApi.swift | 35 ++++++-- Tests/HubTests/HubApiTests.swift | 145 +++---------------------------- 2 files changed, 41 insertions(+), 139 deletions(-) diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index 0b559d9..fd51300 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -21,6 +21,7 @@ public struct HubApi { public typealias Repo = Hub.Repo public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false, useOfflineMode: Bool? = nil) { + print("HubApi init useOfflineMode:", useOfflineMode as Any) // Debug print self.hfToken = hfToken ?? Self.hfTokenFromEnv() if let downloadBase { self.downloadBase = downloadBase @@ -362,10 +363,12 @@ public extension HubApi { // (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { + print("Download method offlineMode:", offlineMode as Any) // Debug print let metadataRelativePath = "\(relativeFilename).metadata" let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) if let offlineMode = offlineMode, offlineMode { + print("Entering offline mode block") // Debug print if !downloaded { throw EnvironmentError.offlineModeError("File not available locally in offline mode") } @@ -461,6 +464,7 @@ public extension HubApi { backgroundSession: useBackgroundSession, offlineMode: useOfflineMode ) + print("Creating downloader with offlineMode:", useOfflineMode as Any) // Debug print try await downloader.download { fractionDownloaded in fileProgress.completedUnitCount = Int64(100 * fractionDownloaded) progressHandler(progress) @@ -568,13 +572,26 @@ private extension HubApi { private final class NetworkMonitor { static let shared = NetworkMonitor() private let monitor = NWPathMonitor() - private var currentPath: NWPath? + private var currentPath: NWPath? { + didSet { + print("NetworkMonitor path updated:", currentPath?.status as Any) + } + } private init() { + // Start monitoring and wait for initial update + let group = DispatchGroup() + group.enter() + monitor.pathUpdateHandler = { [weak self] path in self?.currentPath = path + group.leave() } + monitor.start(queue: DispatchQueue.global(qos: .background)) + + // Wait for initial path update with shorter timeout + _ = group.wait(timeout: .now() + 0.1) // 100ms timeout } deinit { @@ -582,19 +599,27 @@ private extension HubApi { } var isConnected: Bool { - currentPath?.status == .satisfied + let connected = currentPath?.status == .satisfied + print("NetworkMonitor isConnected:", connected) // Debug print + return connected } var isExpensive: Bool { - currentPath?.isExpensive ?? false + let expensive = currentPath?.isExpensive ?? false + print("NetworkMonitor isExpensive:", expensive) // Debug print + return expensive } var isConstrained: Bool { - currentPath?.isConstrained ?? false + let constrained = currentPath?.isConstrained ?? false + print("NetworkMonitor isConstrained:", constrained) // Debug print + return constrained } func shouldUseOfflineMode() -> Bool { - return !self.isConnected || self.isExpensive || self.isConstrained + let offline = !self.isConnected || self.isExpensive || self.isConstrained + print("NetworkMonitor shouldUseOfflineMode:", offline) // Debug print + return offline } } } diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 5e7a5b7..eed42d4 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -673,156 +673,33 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertFalse(downloader.isValidHash(hash: "\(etag)a", pattern: downloader.sha256Pattern)) } - func testLFSFileNoMetadata() async throws { - let hubApi = HubApi(downloadBase: downloadDestination) + func testOfflineMode() async throws { + var hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in - print("Total Progress: \(progress.fractionCompleted)") - print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") - lastProgress = progress - } - - XCTAssertEqual(lastProgress?.fractionCompleted, 1) - XCTAssertEqual(lastProgress?.completedUnitCount, 1) - XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) - - let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) - XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) - - let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") - - let filePath = downloadedTo.appending(path: "x.bin") - var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) - let originalTimestamp = attributes[.modificationDate] as! Date - - let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) - XCTAssertEqual( - Set(downloadedMetadataFilenames), - Set([".cache/huggingface/download/x.bin.metadata"]) - ) - - let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") - try FileManager.default.removeItem(atPath: metadataFile.path) - - let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in - print("Total Progress: \(progress.fractionCompleted)") - print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") - lastProgress = progress - } - - attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) - let secondDownloadTimestamp = attributes[.modificationDate] as! Date - - // File will not be downloaded again thus last modified date will remain unchanged - XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) - XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - - let metadataString = try String(contentsOfFile: metadataFile.path) - let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" - - XCTAssertTrue(metadataString.contains(expected)) - } - - func testLFSFileCorruptedMetadata() async throws { - let hubApi = HubApi(downloadBase: downloadDestination) - var lastProgress: Progress? = nil - - let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + var downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress } XCTAssertEqual(lastProgress?.fractionCompleted, 1) - XCTAssertEqual(lastProgress?.completedUnitCount, 1) - XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) - - let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) - XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) - - let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") - - let filePath = downloadedTo.appending(path: "x.bin") - var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) - let originalTimestamp = attributes[.modificationDate] as! Date - - let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) - XCTAssertEqual( - Set(downloadedMetadataFilenames), - Set([".cache/huggingface/download/x.bin.metadata"]) - ) - - let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") - try "a".write(to: metadataFile, atomically: true, encoding: .utf8) - - let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in - print("Total Progress: \(progress.fractionCompleted)") - print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") - lastProgress = progress - } - - attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) - let secondDownloadTimestamp = attributes[.modificationDate] as! Date - - // File will not be downloaded again thus last modified date will remain unchanged - XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) - XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - - let metadataString = try String(contentsOfFile: metadataFile.path) - let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - XCTAssertTrue(metadataString.contains(expected)) - } + hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) - func testNonLFSFileRedownload() async throws { - let hubApi = HubApi(downloadBase: downloadDestination) - var lastProgress: Progress? = nil - - let downloadedTo = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") lastProgress = progress } XCTAssertEqual(lastProgress?.fractionCompleted, 1) - XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - - let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) - XCTAssertEqual(Set(downloadedFilenames), Set(["config.json"])) - - let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") - - let filePath = downloadedTo.appending(path: "config.json") - var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) - let originalTimestamp = attributes[.modificationDate] as! Date - - let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) - XCTAssertEqual( - Set(downloadedMetadataFilenames), - Set([".cache/huggingface/download/config.json.metadata"]) - ) - - let metadataFile = metadataDestination.appendingPathComponent("config.json.metadata") - try FileManager.default.removeItem(atPath: metadataFile.path) - - let _ = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in - print("Total Progress: \(progress.fractionCompleted)") - print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") - lastProgress = progress - } - - attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) - let secondDownloadTimestamp = attributes[.modificationDate] as! Date - - // File will be downloaded again thus last modified date will change - XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) - XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) - - let metadataString = try String(contentsOfFile: metadataFile.path) - let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nd6ceb92ce9e3c83ab146dc8e92a93517ac1cc66f" - - XCTAssertTrue(metadataString.contains(expected)) } + + } From d99eec1ff963e5886b479f2ca0371ee1c76e408e Mon Sep 17 00:00:00 2001 From: Arda Atahan Ibis Date: Mon, 3 Feb 2025 09:57:19 -0800 Subject: [PATCH 4/5] only redownload lfs files when missing or checksum does not match --- Tests/HubTests/HubApiTests.swift | 145 ++++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 11 deletions(-) diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index eed42d4..5e7a5b7 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -673,33 +673,156 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertFalse(downloader.isValidHash(hash: "\(etag)a", pattern: downloader.sha256Pattern)) } - func testOfflineMode() async throws { - var hubApi = HubApi(downloadBase: downloadDestination) + func testLFSFileNoMetadata() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil - var downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") - lastProgress = progress } XCTAssertEqual(lastProgress?.fractionCompleted, 1) - XCTAssertEqual(lastProgress?.completedUnitCount, 6) - XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) + XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") - hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) + let filePath = downloadedTo.appending(path: "x.bin") + var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/x.bin.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + try FileManager.default.removeItem(atPath: metadataFile.path) + + let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again thus last modified date will remain unchanged + XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) + XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) + + let metadataString = try String(contentsOfFile: metadataFile.path) + let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" + + XCTAssertTrue(metadataString.contains(expected)) + } - downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + func testLFSFileCorruptedMetadata() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") lastProgress = progress } XCTAssertEqual(lastProgress?.fractionCompleted, 1) - XCTAssertEqual(lastProgress?.completedUnitCount, 6) - XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) + XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let filePath = downloadedTo.appending(path: "x.bin") + var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/x.bin.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + try "a".write(to: metadataFile, atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again thus last modified date will remain unchanged + XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) + XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) + + let metadataString = try String(contentsOfFile: metadataFile.path) + let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" + + XCTAssertTrue(metadataString.contains(expected)) } - + func testNonLFSFileRedownload() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual(Set(downloadedFilenames), Set(["config.json"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let filePath = downloadedTo.appending(path: "config.json") + var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/config.json.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("config.json.metadata") + try FileManager.default.removeItem(atPath: metadataFile.path) + + let _ = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will be downloaded again thus last modified date will change + XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) + XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) + + let metadataString = try String(contentsOfFile: metadataFile.path) + let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nd6ceb92ce9e3c83ab146dc8e92a93517ac1cc66f" + + XCTAssertTrue(metadataString.contains(expected)) + } } From bcfb044c7288c4e56037d2b1fe33cd3057a98648 Mon Sep 17 00:00:00 2001 From: Arda Atahan Ibis Date: Wed, 5 Feb 2025 10:05:37 -0800 Subject: [PATCH 5/5] add more tests --- Sources/Hub/HubApi.swift | 400 ++++++++++++++++--------------- Tests/HubTests/HubApiTests.swift | 173 +++++++++++-- 2 files changed, 369 insertions(+), 204 deletions(-) diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index fd51300..5c0e2c5 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -16,12 +16,12 @@ public struct HubApi { var endpoint: String var useBackgroundSession: Bool var useOfflineMode: Bool? = nil - + + private let networkMonitor = NetworkMonitor() public typealias RepoType = Hub.RepoType public typealias Repo = Hub.Repo public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false, useOfflineMode: Bool? = nil) { - print("HubApi init useOfflineMode:", useOfflineMode as Any) // Debug print self.hfToken = hfToken ?? Self.hfTokenFromEnv() if let downloadBase { self.downloadBase = downloadBase @@ -31,14 +31,13 @@ public struct HubApi { } self.endpoint = endpoint self.useBackgroundSession = useBackgroundSession - - if let useOfflineMode { - self.useOfflineMode = useOfflineMode - } else { - self.useOfflineMode = NetworkMonitor.shared.shouldUseOfflineMode() - } + self.useOfflineMode = useOfflineMode + NetworkMonitor.shared.startMonitoring() } + let sha256Pattern = "^[0-9a-f]{64}$" + let commitHashPattern = "^[0-9a-f]{40}$" + public static let shared = HubApi() private static let logger = Logger() @@ -213,18 +212,112 @@ public extension HubApi { downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id) } + /// Reads metadata about a file in the local directory related to a download process. + /// + /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263 + /// + /// - Parameters: + /// - localDir: The local directory where metadata files are downloaded. + /// - filePath: The path of the file for which metadata is being read. + /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. + /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. + func readDownloadMetadata(metadataPath: URL) throws -> LocalDownloadFileMetadata? { + if FileManager.default.fileExists(atPath: metadataPath.path) { + do { + let attributes = try FileManager.default.attributesOfItem(atPath: metadataPath.path) + print("File attributes: \(attributes)") + let contents = try String(contentsOf: metadataPath, encoding: .utf8) + let lines = contents.components(separatedBy: .newlines) + + guard lines.count >= 3 else { + throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.") + } + + let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) + let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) + guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else { + throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.") + } + let timestampDate = Date(timeIntervalSince1970: timestamp) + + // TODO: check if file hasn't been modified since the metadata was saved + // Reference: https://github.com/huggingface/huggingface_hub/blob/2fdc6f48ef5e6b22ee9bcdc1945948ac070da675/src/huggingface_hub/_local_folder.py#L303 + + let filename = metadataPath.lastPathComponent.replacingOccurrences(of: ".metadata", with: "") + + return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filename, timestamp: timestampDate) + } catch { + do { + HubApi.logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.") + try FileManager.default.removeItem(at: metadataPath) + } catch { + throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)") + } + return nil + } + } + + // metadata file does not exist + return nil + } + + func isValidHash(hash: String, pattern: String) -> Bool { + let regex = try? NSRegularExpression(pattern: pattern) + let range = NSRange(location: 0, length: hash.utf16.count) + return regex?.firstMatch(in: hash, options: [], range: range) != nil + } + + func computeFileHash(file url: URL) throws -> String { + // Open file for reading + guard let fileHandle = try? FileHandle(forReadingFrom: url) else { + throw Hub.HubClientError.unexpectedError + } + + defer { + try? fileHandle.close() + } + + var hasher = SHA256() + let chunkSize = 1024 * 1024 // 1MB chunks + + while autoreleasepool(invoking: { + let nextChunk = try? fileHandle.read(upToCount: chunkSize) + + guard let nextChunk, + !nextChunk.isEmpty + else { + return false + } + + hasher.update(data: nextChunk) + + return true + }) { } + + let digest = hasher.finalize() + return digest.map { String(format: "%02x", $0) }.joined() + } + + /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 + func writeDownloadMetadata(commitHash: String, etag: String, metadataPath: URL) throws { + let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" + do { + try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) + try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) + } catch { + throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)") + } + } + struct HubFileDownloader { let repo: Repo let repoDestination: URL + let repoMetadataDestination: URL let relativeFilename: String let hfToken: String? let endpoint: String? let backgroundSession: Bool - let offlineMode: Bool? - let sha256Pattern = "^[0-9a-f]{64}$" - let commitHashPattern = "^[0-9a-f]{40}$" - var source: URL { // https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true var url = URL(string: endpoint ?? "https://huggingface.co")! @@ -242,10 +335,7 @@ public extension HubApi { } var metadataDestination: URL { - repoDestination - .appendingPathComponent(".cache") - .appendingPathComponent("huggingface") - .appendingPathComponent("download") + repoMetadataDestination.appending(path: relativeFilename + ".metadata") } var downloaded: Bool { @@ -258,145 +348,23 @@ public extension HubApi { } func prepareMetadataDestination() throws { - try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) - } - - /// Reads metadata about a file in the local directory related to a download process. - /// - /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263 - /// - /// - Parameters: - /// - localDir: The local directory where metadata files are downloaded. - /// - filePath: The path of the file for which metadata is being read. - /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. - /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. - func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? { - let metadataPath = localDir.appending(path: filePath) - if FileManager.default.fileExists(atPath: metadataPath.path) { - do { - let contents = try String(contentsOf: metadataPath, encoding: .utf8) - let lines = contents.components(separatedBy: .newlines) - - guard lines.count >= 3 else { - throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.") - } - - let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) - let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) - guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else { - throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.") - } - let timestampDate = Date(timeIntervalSince1970: timestamp) - - // TODO: check if file hasn't been modified since the metadata was saved - // Reference: https://github.com/huggingface/huggingface_hub/blob/2fdc6f48ef5e6b22ee9bcdc1945948ac070da675/src/huggingface_hub/_local_folder.py#L303 - - return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate) - } catch { - do { - logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.") - try FileManager.default.removeItem(at: metadataPath) - } catch { - throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)") - } - return nil - } - } - - // metadata file does not exist - return nil - } - - func isValidHash(hash: String, pattern: String) -> Bool { - let regex = try? NSRegularExpression(pattern: pattern) - let range = NSRange(location: 0, length: hash.utf16.count) - return regex?.firstMatch(in: hash, options: [], range: range) != nil - } - - /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 - func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws { - let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" - let metadataPath = metadataDestination.appending(component: metadataRelativePath) - - do { - try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) - try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) - } catch { - throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)") - } - } - - func computeFileHash(file url: URL) throws -> String { - // Open file for reading - guard let fileHandle = try? FileHandle(forReadingFrom: url) else { - throw Hub.HubClientError.unexpectedError - } - - defer { - try? fileHandle.close() - } - - var hasher = SHA256() - let chunkSize = 1024 * 1024 // 1MB chunks - - while autoreleasepool(invoking: { - let nextChunk = try? fileHandle.read(upToCount: chunkSize) - - guard let nextChunk, - !nextChunk.isEmpty - else { - return false - } - - hasher.update(data: nextChunk) - - return true - }) { } - - let digest = hasher.finalize() - return digest.map { String(format: "%02x", $0) }.joined() + let directoryURL = metadataDestination.deletingLastPathComponent() + try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) } - // Note we go from Combine in Downloader to callback-based progress reporting // We'll probably need to support Combine as well to play well with Swift UI // (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { - print("Download method offlineMode:", offlineMode as Any) // Debug print - let metadataRelativePath = "\(relativeFilename).metadata" - let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) - - if let offlineMode = offlineMode, offlineMode { - print("Entering offline mode block") // Debug print - if !downloaded { - throw EnvironmentError.offlineModeError("File not available locally in offline mode") - } - - guard let localMetadata = localMetadata else { - throw EnvironmentError.offlineModeError("Metadata not available or invalid in offline mode") - } - - let localEtag = localMetadata.etag - - // LFS file so check file integrity - if self.isValidHash(hash: localEtag, pattern: self.sha256Pattern) { - let fileHash = try computeFileHash(file: destination) - if fileHash != localEtag { - throw EnvironmentError.offlineModeError("File integrity check failed in offline mode") - } - } - - return destination - } - + let localMetadata = try HubApi.shared.readDownloadMetadata(metadataPath: metadataDestination) let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) let localCommitHash = localMetadata?.commitHash ?? "" let remoteCommitHash = remoteMetadata.commitHash ?? "" // Local file exists + metadata exists + commit_hash matches => return file - if isValidHash(hash: remoteCommitHash, pattern: commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { + if HubApi.shared.isValidHash(hash: remoteCommitHash, pattern: HubApi.shared.commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { return destination } @@ -411,7 +379,7 @@ public extension HubApi { if downloaded { // etag matches => update metadata and return file if localMetadata?.etag == remoteEtag { - try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } @@ -419,10 +387,10 @@ public extension HubApi { // => means it's an LFS file (large) // => let's compute local hash and compare // => if match, update metadata and return file - if isValidHash(hash: remoteEtag, pattern: sha256Pattern) { - let fileHash = try computeFileHash(file: destination) + if HubApi.shared.isValidHash(hash: remoteEtag, pattern: HubApi.shared.sha256Pattern) { + let fileHash = try HubApi.shared.computeFileHash(file: destination) if fileHash == remoteEtag { - try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } } @@ -431,7 +399,7 @@ public extension HubApi { // Otherwise, let's download the file! try prepareDestination() try prepareMetadataDestination() - + let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession) let downloadSubscriber = downloader.downloadState.sink { state in if case .downloading(let progress) = state { @@ -442,7 +410,7 @@ public extension HubApi { try downloader.waitUntilDone() } - try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + try HubApi.shared.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination) return destination } @@ -450,21 +418,60 @@ public extension HubApi { @discardableResult func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL { + let repoDestination = localRepoLocation(repo) + let repoMetadataDestination = repoDestination + .appendingPathComponent(".cache") + .appendingPathComponent("huggingface") + .appendingPathComponent("download") + + if useOfflineMode ?? NetworkMonitor.shared.shouldUseOfflineMode() { + if !FileManager.default.fileExists(atPath: repoDestination.path) { + throw EnvironmentError.offlineModeError("File not available locally in offline mode") + } + + let fileUrls = try FileManager.default.getFileUrls(at: repoDestination) + if fileUrls.isEmpty { + throw EnvironmentError.offlineModeError("File not available locally in offline mode") + } + + for fileUrl in fileUrls { + let metadataPath = URL(fileURLWithPath: fileUrl.path.replacingOccurrences( + of: repoDestination.path, + with: repoMetadataDestination.path + ) + ".metadata") + + let localMetadata = try readDownloadMetadata(metadataPath: metadataPath) + + guard let localMetadata = localMetadata else { + throw EnvironmentError.offlineModeError("Metadata not available or invalid in offline mode") + } + let localEtag = localMetadata.etag + + // LFS file so check file integrity + if self.isValidHash(hash: localEtag, pattern: self.sha256Pattern) { + let fileHash = try computeFileHash(file: fileUrl) + if fileHash != localEtag { + throw EnvironmentError.offlineModeError("File integrity check failed in offline mode") + } + } + } + + return repoDestination + } + let filenames = try await getFilenames(from: repo, matching: globs) let progress = Progress(totalUnitCount: Int64(filenames.count)) - let repoDestination = localRepoLocation(repo) for filename in filenames { let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1) let downloader = HubFileDownloader( repo: repo, repoDestination: repoDestination, + repoMetadataDestination: repoMetadataDestination, relativeFilename: filename, hfToken: hfToken, endpoint: endpoint, - backgroundSession: useBackgroundSession, - offlineMode: useOfflineMode + backgroundSession: useBackgroundSession ) - print("Creating downloader with offlineMode:", useOfflineMode as Any) // Debug print try await downloader.download { fractionDownloaded in fileProgress.completedUnitCount = Int64(100 * fractionDownloaded) progressHandler(progress) @@ -570,56 +577,43 @@ public extension HubApi { /// Network monitor helper class to help decide whether to use offline mode private extension HubApi { private final class NetworkMonitor { + private var monitor: NWPathMonitor + private var queue: DispatchQueue + + private(set) var isConnected: Bool = false + private(set) var isExpensive: Bool = false + private(set) var isConstrained: Bool = false + static let shared = NetworkMonitor() - private let monitor = NWPathMonitor() - private var currentPath: NWPath? { - didSet { - print("NetworkMonitor path updated:", currentPath?.status as Any) - } + + init() { + monitor = NWPathMonitor() + queue = DispatchQueue(label: "HubApi.NetworkMonitor") + startMonitoring() } - private init() { - // Start monitoring and wait for initial update - let group = DispatchGroup() - group.enter() - + func startMonitoring() { monitor.pathUpdateHandler = { [weak self] path in - self?.currentPath = path - group.leave() + guard let self = self else { return } + + self.isConnected = path.status == .satisfied + self.isExpensive = path.isExpensive + self.isConstrained = path.isConstrained } - monitor.start(queue: DispatchQueue.global(qos: .background)) - - // Wait for initial path update with shorter timeout - _ = group.wait(timeout: .now() + 0.1) // 100ms timeout + monitor.start(queue: queue) } - deinit { + func stopMonitoring() { monitor.cancel() } - var isConnected: Bool { - let connected = currentPath?.status == .satisfied - print("NetworkMonitor isConnected:", connected) // Debug print - return connected - } - - var isExpensive: Bool { - let expensive = currentPath?.isExpensive ?? false - print("NetworkMonitor isExpensive:", expensive) // Debug print - return expensive - } - - var isConstrained: Bool { - let constrained = currentPath?.isConstrained ?? false - print("NetworkMonitor isConstrained:", constrained) // Debug print - return constrained + func shouldUseOfflineMode() -> Bool { + return !isConnected || isExpensive || isConstrained } - func shouldUseOfflineMode() -> Bool { - let offline = !self.isConnected || self.isExpensive || self.isConstrained - print("NetworkMonitor shouldUseOfflineMode:", offline) // Debug print - return offline + deinit { + stopMonitoring() } } } @@ -689,6 +683,34 @@ public extension [String] { } } +public extension FileManager { + func getFileUrls(at directoryUrl: URL) throws -> [URL] { + var fileUrls = [URL]() + + // Get all contents including subdirectories + guard let enumerator = FileManager.default.enumerator( + at: directoryUrl, + includingPropertiesForKeys: [.isRegularFileKey, .isHiddenKey], + options: [.skipsHiddenFiles] + ) else { + return fileUrls + } + + for case let fileURL as URL in enumerator { + do { + let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey, .isHiddenKey]) + if resourceValues.isRegularFile == true && resourceValues.isHidden != true { + fileUrls.append(fileURL) + } + } catch { + throw error + } + } + + return fileUrls + } +} + /// Only allow relative redirects and reject others /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 private class RedirectDelegate: NSObject, URLSessionTaskDelegate { diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 5e7a5b7..756060b 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -655,22 +655,11 @@ class SnapshotDownloadTests: XCTestCase { let commitHash = metadataArr[0] let etag = metadataArr[1] - // Not needed for the downloads, just to test validation function - let downloader = HubApi.HubFileDownloader( - repo: Hub.Repo(id: lfsRepo), - repoDestination: downloadedTo, - relativeFilename: "x.bin", - hfToken: nil, - endpoint: nil, - backgroundSession: false, - offlineMode: false - ) - - XCTAssertTrue(downloader.isValidHash(hash: commitHash, pattern: downloader.commitHashPattern)) - XCTAssertTrue(downloader.isValidHash(hash: etag, pattern: downloader.sha256Pattern)) + XCTAssertTrue(hubApi.isValidHash(hash: commitHash, pattern: hubApi.commitHashPattern)) + XCTAssertTrue(hubApi.isValidHash(hash: etag, pattern: hubApi.sha256Pattern)) - XCTAssertFalse(downloader.isValidHash(hash: "\(commitHash)a", pattern: downloader.commitHashPattern)) - XCTAssertFalse(downloader.isValidHash(hash: "\(etag)a", pattern: downloader.sha256Pattern)) + XCTAssertFalse(hubApi.isValidHash(hash: "\(commitHash)a", pattern: hubApi.commitHashPattern)) + XCTAssertFalse(hubApi.isValidHash(hash: "\(etag)a", pattern: hubApi.sha256Pattern)) } func testLFSFileNoMetadata() async throws { @@ -825,4 +814,158 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertTrue(metadataString.contains(expected)) } + + func testOfflineModeReturnsDestination() async throws { + var hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + var downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) + + downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + } + + func testOfflineModeThrowsError() async throws { + let hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) + + do { + try await hubApi.snapshot(from: repo, matching: "*.json") + XCTFail("Expected an error to be thrown") + } catch let error as HubApi.EnvironmentError { + switch error { + case .offlineModeError(let message): + XCTAssertEqual(message, "File not available locally in offline mode") + default: + XCTFail("Wrong error type: \(error)") + } + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testOfflineModeWithoutMetadata() async throws { + var hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "*") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 2) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + try FileManager.default.removeItem(atPath: metadataFile.path) + + hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) + + do { + try await hubApi.snapshot(from: lfsRepo, matching: "*") + XCTFail("Expected an error to be thrown") + } catch let error as HubApi.EnvironmentError { + switch error { + case .offlineModeError(let message): + XCTAssertEqual(message, "Metadata not available or invalid in offline mode") + default: + XCTFail("Wrong error type: \(error)") + } + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testOfflineModeWithCorruptedLFSMetadata() async throws { + var hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "*") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 2) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let metadataDestination = downloadedTo.appendingPathComponent(".cache/huggingface/download").appendingPathComponent("x.bin.metadata") + + try "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2ab4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4\n0\n".write(to: metadataDestination, atomically: true, encoding: .utf8) + + hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) + + do { + try await hubApi.snapshot(from: lfsRepo, matching: "*") + XCTFail("Expected an error to be thrown") + } catch let error as HubApi.EnvironmentError { + switch error { + case .offlineModeError(let message): + XCTAssertEqual(message, "File integrity check failed in offline mode") + default: + XCTFail("Wrong error type: \(error)") + } + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testOfflineModeWithNoFiles() async throws { + var hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let fileDestination = downloadedTo.appendingPathComponent("x.bin") + try FileManager.default.removeItem(at: fileDestination) + + hubApi = HubApi(downloadBase: downloadDestination, useOfflineMode: true) + + do { + try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") + XCTFail("Expected an error to be thrown") + } catch let error as HubApi.EnvironmentError { + switch error { + case .offlineModeError(let message): + XCTAssertEqual(message, "File not available locally in offline mode") + default: + XCTFail("Wrong error type: \(error)") + } + } catch { + XCTFail("Unexpected error: \(error)") + } + } }