Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt preview #115

Merged
merged 2 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
package com.github.blarc.ai.commits.intellij.plugin

import com.github.blarc.ai.commits.intellij.plugin.AICommitsBundle.message
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.constructPrompt
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.isPromptTooLarge
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings
import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.diff.impl.patch.IdeaTextPatchBuilder
import com.intellij.openapi.diff.impl.patch.UnifiedDiffWriter
import com.intellij.openapi.progress.runBackgroundableTask
import com.intellij.openapi.project.DumbAware
import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.VcsDataKeys
import com.intellij.openapi.vcs.changes.Change
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import com.knuddels.jtokkit.Encodings
import com.knuddels.jtokkit.api.ModelType
import git4idea.repo.GitRepositoryManager
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import java.io.StringWriter

class AICommitAction : AnAction(), DumbAware {
override fun actionPerformed(e: AnActionEvent) {
Expand All @@ -41,14 +39,8 @@ class AICommitAction : AnAction(), DumbAware {
return@runBackgroundableTask
}

var branch = commonBranch(includedChanges, project)
if (branch == null) {
sendNotification(Notification.noCommonBranch())
// hardcoded fallback branch
branch = "main"
}

val prompt = AppSettings.instance.getPrompt(diff, branch)
val branch = commonBranch(includedChanges, project)
val prompt = constructPrompt(AppSettings.instance.currentPrompt.content, diff, branch)
if (isPromptTooLarge(prompt)) {
sendNotification(Notification.promptTooLarge())
return@runBackgroundableTask
Expand All @@ -72,70 +64,4 @@ class AICommitAction : AnAction(), DumbAware {
}
}
}

private fun computeDiff(
includedChanges: List<Change>,
project: Project
): String {

val gitRepositoryManager = GitRepositoryManager.getInstance(project)

// go through included changes, create a map of repository to changes and discard nulls
val changesByRepository = includedChanges
.filter {
it.virtualFile?.path?.let { path ->
AICommitsUtils.isPathExcluded(path, project)
} ?: false
}
.mapNotNull { change ->
change.virtualFile?.let { file ->
gitRepositoryManager.getRepositoryForFileQuick(
file
) to change
}
}
.groupBy({ it.first }, { it.second })


// compute diff for each repository
return changesByRepository
.map { (repository, changes) ->
repository?.let {
val filePatches = IdeaTextPatchBuilder.buildPatch(
project,
changes,
repository.root.toNioPath(), false, true
)

val stringWriter = StringWriter()
stringWriter.write("Repository: ${repository.root.path}\n")
UnifiedDiffWriter.write(project, filePatches, stringWriter, "\n", null)
stringWriter.toString()
}
}
.joinToString("\n")
}

private fun isPromptTooLarge(prompt: String): Boolean {
val registry = Encodings.newDefaultEncodingRegistry()

/*
* Try to find the model type based on the model id by finding the longest matching model type
* If no model type matches, let the request go through and let the OpenAI API handle it
*/
val modelType = ModelType.values()
.filter { AppSettings.instance.openAIModelId.contains(it.name) }
.maxByOrNull { it.name.length }
?: return false

val encoding = registry.getEncoding(modelType.encodingType)
return encoding.countTokens(prompt) > modelType.maxContextLength
}

private fun commonBranch(changes: List<Change>, project: Project): String? {
val repositoryManager = GitRepositoryManager.getInstance(project)
return changes.map {
repositoryManager.getRepositoryForFileQuick(it.virtualFile)?.currentBranchName
}.groupingBy { it }.eachCount().maxByOrNull { it.value }?.key
}
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
package com.github.blarc.ai.commits.intellij.plugin

import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings
import com.github.blarc.ai.commits.intellij.plugin.settings.ProjectSettings
import com.intellij.openapi.components.service
import com.intellij.openapi.diff.impl.patch.IdeaTextPatchBuilder
import com.intellij.openapi.diff.impl.patch.UnifiedDiffWriter
import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.changes.Change
import com.knuddels.jtokkit.Encodings
import com.knuddels.jtokkit.api.ModelType
import git4idea.repo.GitRepositoryManager
import java.io.StringWriter
import java.nio.file.FileSystems

object AICommitsUtils {

fun isPathExcluded(path: String, project: Project) : Boolean {
return !AppSettings.instance.isPathExcluded(path) && !project.service<ProjectSettings>().isPathExcluded(path)
}

fun matchesGlobs(text: String, globs: Set<String>): Boolean {
val fileSystem = FileSystems.getDefault()
for (globString in globs) {
Expand All @@ -21,4 +31,89 @@ object AICommitsUtils {
}
return false
}
}

fun constructPrompt(promptContent: String, diff: String, branch: String): String {
var content = promptContent
content = content.replace("{locale}", AppSettings.instance.locale.displayLanguage)
content = content.replace("{branch}", branch)

return if (content.contains("{diff}")) {
content.replace("{diff}", diff)
} else {
"$content\n$diff"
}
}

fun commonBranch(changes: List<Change>, project: Project): String {
val repositoryManager = GitRepositoryManager.getInstance(project)
var branch = changes.map {
repositoryManager.getRepositoryForFileQuick(it.virtualFile)?.currentBranchName
}.groupingBy { it }.eachCount().maxByOrNull { it.value }?.key

if (branch == null) {
sendNotification(Notification.noCommonBranch())
// hardcoded fallback branch
branch = "main"
}
return branch
}

fun computeDiff(
includedChanges: List<Change>,
project: Project
): String {

val gitRepositoryManager = GitRepositoryManager.getInstance(project)

// go through included changes, create a map of repository to changes and discard nulls
val changesByRepository = includedChanges
.filter {
it.virtualFile?.path?.let { path ->
AICommitsUtils.isPathExcluded(path, project)
} ?: false
}
.mapNotNull { change ->
change.virtualFile?.let { file ->
gitRepositoryManager.getRepositoryForFileQuick(
file
) to change
}
}
.groupBy({ it.first }, { it.second })


// compute diff for each repository
return changesByRepository
.map { (repository, changes) ->
repository?.let {
val filePatches = IdeaTextPatchBuilder.buildPatch(
project,
changes,
repository.root.toNioPath(), false, true
)

val stringWriter = StringWriter()
stringWriter.write("Repository: ${repository.root.path}\n")
UnifiedDiffWriter.write(project, filePatches, stringWriter, "\n", null)
stringWriter.toString()
}
}
.joinToString("\n")
}

fun isPromptTooLarge(prompt: String): Boolean {
val registry = Encodings.newDefaultEncodingRegistry()

/*
* Try to find the model type based on the model id by finding the longest matching model type
* If no model type matches, let the request go through and let the OpenAI API handle it
*/
val modelType = ModelType.entries
.filter { AppSettings.instance.openAIModelId.contains(it.name) }
.maxByOrNull { it.name.length }
?: return false

val encoding = registry.getEncoding(modelType.encodingType)
return encoding.countTokens(prompt) > modelType.maxContextLength
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.Service


@Service
@Service(Service.Level.APP)
class OpenAIService {

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@ class AppSettings : PersistentStateComponent<AppSettings> {
get() = ApplicationManager.getApplication().getService(AppSettings::class.java)
}

fun getPrompt(diff: String, branch: String): String {
var content = currentPrompt.content
content = content.replace("{locale}", locale.displayLanguage)
content = content.replace("{branch}", branch)

return if (content.contains("{diff}")) {
content.replace("{diff}", diff)
} else {
"$content\n$diff"
}
}

fun saveOpenAIToken(token: String) {
try {
PasswordSafe.instance.setPassword(getCredentialAttributes(openAITokenTitle), token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.github.blarc.ai.commits.intellij.plugin.settings

import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils
import com.intellij.openapi.components.PersistentStateComponent
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.State
import com.intellij.openapi.components.Storage
import com.intellij.util.xmlb.XmlSerializerUtil
Expand All @@ -10,6 +11,7 @@ import com.intellij.util.xmlb.XmlSerializerUtil
name = ProjectSettings.SERVICE_NAME,
storages = [Storage("AICommit.xml")]
)
@Service(Service.Level.PROJECT)
class ProjectSettings : PersistentStateComponent<ProjectSettings?> {

companion object {
Expand All @@ -29,4 +31,4 @@ class ProjectSettings : PersistentStateComponent<ProjectSettings?> {
}


}
}
Loading