Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c87c1b1

Browse files
committedApr 21, 2024··
Download 10 split files for DBRX
1 parent d3a610c commit c87c1b1

File tree

7 files changed

+164
-161
lines changed

7 files changed

+164
-161
lines changed
 

‎src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java

+18-11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import java.net.MalformedURLException;
66
import java.net.URL;
7+
import java.util.List;
8+
import java.util.stream.IntStream;
79

810
public enum HuggingFaceModel {
911

@@ -89,21 +91,26 @@ public String getCode() {
8991
return name();
9092
}
9193

92-
public String getFileName() {
94+
public List<String> getFileNames() {
9395
if ("TheBloke".equals(user)) {
94-
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
96+
return List.of(modelName.toLowerCase()
97+
.replace("-gguf", format(".Q%d_K_M.gguf", quantization)));
9598
}
96-
// TODO: Download all 10 files ;(
97-
return modelName.toLowerCase().replace("-gguf", "-00001-of-00010.gguf");
99+
if ("phymbert".equals(user)) {
100+
return IntStream.range(1, 11).mapToObj(i -> modelName
101+
.replace("-gguf", "-000%02d-of-00010.gguf".formatted(i))).toList();
102+
}
103+
return List.of(modelName);
98104
}
99105

100-
public URL getFileURL() {
101-
try {
102-
return new URL(
103-
"https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), getFileName()));
104-
} catch (MalformedURLException ex) {
105-
throw new RuntimeException(ex);
106-
}
106+
public List<URL> getFileURLs() {
107+
return getFileNames().stream().map(file -> {
108+
try {
109+
return new URL("https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), file));
110+
} catch (MalformedURLException ex) {
111+
throw new RuntimeException(ex);
112+
}
113+
}).toList();
107114
}
108115

109116
public URL getHuggingFaceURL() {

‎src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.java

-106
This file was deleted.

‎src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ public InfillPromptTemplate getInfillPromptTemplate() {
195195
public String getActualModelPath() {
196196
return isUseCustomLlamaModel()
197197
? getCustomLlamaModelPath()
198-
: CodeGPTPlugin.getLlamaModelsPath() + File.separator + getSelectedModel().getFileName();
198+
: CodeGPTPlugin.getLlamaModelsPath() + File.separator
199+
+ getSelectedModel().getFileNames().get(0);
199200
}
200201

201202
private JPanel createFormPanelCards() {
@@ -394,8 +395,9 @@ private TextFieldWithBrowseButton createBrowsableCustomModelTextField(boolean en
394395
}
395396

396397
private boolean isModelExists(HuggingFaceModel model) {
397-
return FileUtil.exists(
398-
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName());
398+
return model.getFileNames().stream().allMatch(filename ->
399+
FileUtil.exists(CodeGPTPlugin.getLlamaModelsPath() + File.separator + filename)
400+
);
399401
}
400402

401403
private AnActionLink createCancelDownloadLink(

‎src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ private boolean validateSelectedModel() {
290290

291291
private boolean isModelExists(HuggingFaceModel model) {
292292
return FileUtil.exists(
293-
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName());
293+
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileNames());
294294
}
295295

296296
private void enableForm(JButton serverButton, ServerProgressPanel progressPanel) {

‎src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java

-40
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package ee.carlrobert.codegpt.settings.service.llama.form
2+
3+
import com.intellij.openapi.actionSystem.AnAction
4+
import com.intellij.openapi.actionSystem.AnActionEvent
5+
import com.intellij.openapi.diagnostic.Logger
6+
import com.intellij.openapi.progress.ProgressIndicator
7+
import com.intellij.openapi.progress.ProgressManager
8+
import com.intellij.openapi.progress.Task
9+
import com.intellij.openapi.project.Project
10+
import ee.carlrobert.codegpt.CodeGPTBundle
11+
import ee.carlrobert.codegpt.completions.HuggingFaceModel
12+
import ee.carlrobert.codegpt.util.DownloadingUtil
13+
import ee.carlrobert.codegpt.util.file.FileUtil.copyFileWithProgress
14+
import java.io.IOException
15+
import java.util.concurrent.Executors
16+
import java.util.concurrent.ScheduledFuture
17+
import java.util.concurrent.TimeUnit
18+
import java.util.function.Consumer
19+
import javax.swing.DefaultComboBoxModel
20+
21+
class DownloadModelAction(
22+
private val onDownload: Consumer<ProgressIndicator>,
23+
private val onDownloaded: Runnable,
24+
private val onFailed: Consumer<Exception>,
25+
private val onUpdateProgress: Consumer<String>,
26+
private val comboBoxModel: DefaultComboBoxModel<HuggingFaceModel>
27+
) : AnAction() {
28+
29+
override fun actionPerformed(e: AnActionEvent) {
30+
ProgressManager.getInstance().run(DownloadBackgroundTask(e.project))
31+
}
32+
33+
internal inner class DownloadBackgroundTask(project: Project?) : Task.Backgroundable(
34+
project,
35+
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"),
36+
true
37+
) {
38+
override fun run(indicator: ProgressIndicator) {
39+
val model = comboBoxModel.selectedItem as HuggingFaceModel
40+
val urls = model.fileURLs
41+
val numberOfFiles = urls.size
42+
var errorOccured = false
43+
for (i in 1..numberOfFiles + 1) {
44+
if (errorOccured || indicator.isCanceled) {
45+
break
46+
}
47+
val executorService = Executors.newSingleThreadScheduledExecutor()
48+
var progressUpdateScheduler: ScheduledFuture<*>? = null
49+
val url = urls[i - 1]
50+
51+
try {
52+
onDownload.accept(indicator)
53+
54+
indicator.isIndeterminate = false
55+
indicator.text = String.format(
56+
CodeGPTBundle.get(
57+
"settingsConfigurable.service.llama.progress.downloadingModelIndicator.text"
58+
),
59+
model.fileNames[i - 1]
60+
)
61+
62+
val fileSize = url.openConnection().contentLengthLong
63+
val bytesRead = longArrayOf(0)
64+
val startTime = System.currentTimeMillis()
65+
66+
progressUpdateScheduler = executorService.scheduleAtFixedRate(
67+
{
68+
onUpdateProgress.accept(
69+
DownloadingUtil.getFormattedDownloadProgress(
70+
i,
71+
numberOfFiles,
72+
startTime,
73+
fileSize,
74+
bytesRead[0]
75+
)
76+
)
77+
},
78+
0, 1, TimeUnit.SECONDS
79+
)
80+
copyFileWithProgress(model.fileNames[i - 1], url, bytesRead, fileSize, indicator)
81+
} catch (ex: IOException) {
82+
LOG.error("Unable to download", ex, url.toString())
83+
onFailed.accept(ex)
84+
errorOccured = true
85+
} finally {
86+
progressUpdateScheduler?.cancel(true)
87+
executorService.shutdown()
88+
}
89+
}
90+
}
91+
92+
override fun onSuccess() {
93+
onDownloaded.run()
94+
}
95+
}
96+
97+
companion object {
98+
private val LOG = Logger.getInstance(DownloadModelAction::class.java)
99+
}
100+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package ee.carlrobert.codegpt.util
2+
3+
import ee.carlrobert.codegpt.util.file.FileUtil.convertFileSize
4+
5+
object DownloadingUtil {
6+
private const val BYTES_IN_MB = 1024 * 1024
7+
8+
fun getFormattedDownloadProgress(
9+
fileNumber: Int, fileCount: Int, startTime: Long,
10+
fileSize: Long, bytesRead: Long
11+
): String {
12+
val timeElapsed = System.currentTimeMillis() - startTime
13+
14+
val speed = (bytesRead.toDouble() / timeElapsed) * 1000 / BYTES_IN_MB
15+
val percent = bytesRead.toDouble() / fileSize * 100
16+
val downloadedMB = bytesRead.toDouble() / BYTES_IN_MB
17+
val totalMB = fileSize.toDouble() / BYTES_IN_MB
18+
val remainingMB = totalMB - downloadedMB
19+
20+
return String.format(
21+
"File %d/%d: %s of %s (%.2f%%), Speed: %.2f MB/sec, Time left: %s",
22+
fileNumber,
23+
fileCount,
24+
convertFileSize(downloadedMB.toLong() * BYTES_IN_MB),
25+
convertFileSize(totalMB.toLong() * BYTES_IN_MB),
26+
percent,
27+
speed,
28+
getTimeLeftFormattedString(speed, remainingMB)
29+
)
30+
}
31+
32+
private fun getTimeLeftFormattedString(speed: Double, remainingMB: Double): String {
33+
val timeLeftSec = if (speed > 0) remainingMB / speed else 0.0
34+
val hours = (timeLeftSec / 3600).toLong()
35+
val minutes = ((timeLeftSec % 3600) / 60).toLong()
36+
val seconds = (timeLeftSec % 60).toLong()
37+
38+
return String.format("%02d:%02d:%02d", hours, minutes, seconds)
39+
}
40+
}

0 commit comments

Comments
 (0)
Please sign in to comment.