diff --git a/android/MLCChat/app/src/main/AndroidManifest.xml b/android/MLCChat/app/src/main/AndroidManifest.xml
index 36c6be8834..9a14f38ff7 100644
--- a/android/MLCChat/app/src/main/AndroidManifest.xml
+++ b/android/MLCChat/app/src/main/AndroidManifest.xml
@@ -4,6 +4,7 @@
package="ai.mlc.mlcchat">
+
().toMutableStateList()
@@ -511,7 +519,9 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
private var modelPath = ""
private val executorService = Executors.newSingleThreadExecutor()
private val viewModelScope = CoroutineScope(Dispatchers.Main + Job())
+ private var imageUri: Uri? = null
private fun mainResetChat() {
+ imageUri = null
executorService.submit {
callBackend { engine.reset() }
historyMessages = mutableListOf()
@@ -660,16 +670,58 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
}
}
- fun requestGenerate(prompt: String) {
+ fun requestImageBitmap(uri: Uri?) {
+ require(chatable())
+ switchToGenerating()
+ executorService.submit {
+ imageUri = uri
+ viewModelScope.launch {
+ report.value = "Image process is done, ask any question."
+ if (modelChatState.value == ModelChatState.Generating) switchToReady()
+ }
+ }
+ }
+
+ fun bitmapToURL(bm: Bitmap): String {
+ val targetSize = 336
+ val scaledBitmap = Bitmap.createScaledBitmap(bm, targetSize, targetSize, true)
+
+ val outputStream = ByteArrayOutputStream()
+ scaledBitmap.compress(Bitmap.CompressFormat.JPEG, 100, outputStream)
+ scaledBitmap.recycle()
+
+ val imageBytes = outputStream.toByteArray()
+ val imageBase64 = Base64.encodeToString(imageBytes, Base64.NO_WRAP)
+ return "data:image/jpg;base64,$imageBase64"
+ }
+
+ fun requestGenerate(prompt: String, activity: Activity) {
require(chatable())
switchToGenerating()
appendMessage(MessageRole.User, prompt)
appendMessage(MessageRole.Assistant, "")
+ var content = ChatCompletionMessageContent(text=prompt)
+ if (imageUri != null) {
+ val uri = imageUri
+ val bitmap = uri?.let {
+ activity.contentResolver.openInputStream(it)?.use { input ->
+ BitmapFactory.decodeStream(input)
+ }
+ }
+ val imageBase64URL = bitmapToURL(bitmap!!)
+ Log.v("requestGenerate", "image base64 url: $imageBase64URL")
+ val parts = listOf(
+ mapOf("type" to "text", "text" to prompt),
+ mapOf("type" to "image_url", "image_url" to imageBase64URL)
+ )
+ content = ChatCompletionMessageContent(parts=parts)
+ imageUri = null
+ }
executorService.submit {
historyMessages.add(ChatCompletionMessage(
role = OpenAIProtocol.ChatCompletionRole.user,
- content = prompt
+ content = content
))
viewModelScope.launch {
@@ -768,7 +820,7 @@ enum class MessageRole {
data class DownloadTask(val url: URL, val file: File)
-data class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID())
+data class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID(), var imageUri: Uri? = null)
data class AppConfig(
@SerializedName("model_libs") var modelLibs: MutableList,
diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt
index 7024c1e8de..b07e521d6e 100644
--- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt
+++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt
@@ -1,5 +1,9 @@
package ai.mlc.mlcchat
+import android.app.Activity
+import android.graphics.Bitmap
+import android.graphics.BitmapFactory
+import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.gestures.detectTapGestures
import androidx.compose.foundation.layout.Arrangement
@@ -20,7 +24,9 @@ import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.selection.SelectionContainer
import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.filled.AddAPhoto
import androidx.compose.material.icons.filled.ArrowBack
+import androidx.compose.material.icons.filled.Photo
import androidx.compose.material.icons.filled.Replay
import androidx.compose.material.icons.filled.Send
import androidx.compose.material3.Divider
@@ -43,6 +49,7 @@ import androidx.compose.runtime.saveable.rememberSaveable
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
+import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.text.style.TextAlign
@@ -55,9 +62,10 @@ import kotlinx.coroutines.launch
@ExperimentalMaterial3Api
@Composable
fun ChatView(
- navController: NavController, chatState: AppViewModel.ChatState
+ navController: NavController, chatState: AppViewModel.ChatState, activity: Activity
) {
val localFocusManager = LocalFocusManager.current
+ (activity as MainActivity).chatState = chatState
Scaffold(topBar = {
TopAppBar(
title = {
@@ -81,7 +89,9 @@ fun ChatView(
},
actions = {
IconButton(
- onClick = { chatState.requestResetChat() },
+ onClick = {
+ chatState.requestResetChat()
+ activity.hasImage = false },
enabled = chatState.interruptable()
) {
Icon(
@@ -125,23 +135,23 @@ fun ChatView(
items = chatState.messages,
key = { message -> message.id },
) { message ->
- MessageView(messageData = message)
+ MessageView(messageData = message, activity)
}
item {
// place holder item for scrolling to the bottom
}
}
Divider(thickness = 1.dp, modifier = Modifier.padding(top = 5.dp))
- SendMessageView(chatState = chatState)
+ SendMessageView(chatState = chatState, activity)
}
}
}
@Composable
-fun MessageView(messageData: MessageData) {
+fun MessageView(messageData: MessageData, activity: Activity?) {
// default render the Assistant text as MarkdownText
var useMarkdown by remember { mutableStateOf(true) }
-
+ var localActivity : MainActivity = activity as MainActivity
SelectionContainer {
if (messageData.role == MessageRole.Assistant) {
Column {
@@ -202,19 +212,47 @@ fun MessageView(messageData: MessageData) {
horizontalArrangement = Arrangement.End,
modifier = Modifier.fillMaxWidth()
) {
- Text(
- text = messageData.text,
- textAlign = TextAlign.Right,
- color = MaterialTheme.colorScheme.onPrimaryContainer,
- modifier = Modifier
- .wrapContentWidth()
- .background(
- color = MaterialTheme.colorScheme.primaryContainer,
- shape = RoundedCornerShape(5.dp)
+ if (messageData.imageUri != null) {
+ val uri = messageData.imageUri
+ val bitmap = uri?.let {
+ activity.contentResolver.openInputStream(it)?.use { input ->
+ BitmapFactory.decodeStream(input)
+ }
+ }
+ val displayBitmap = bitmap?.let { Bitmap.createScaledBitmap(it, 224, 224, true) }
+ if (displayBitmap != null) {
+ Image(
+ displayBitmap.asImageBitmap(),
+ "",
+ modifier = Modifier
+ .wrapContentWidth()
+ .background(
+ color = MaterialTheme.colorScheme.secondaryContainer,
+ shape = RoundedCornerShape(5.dp)
+ )
+ .padding(5.dp)
+ .widthIn(max = 300.dp)
)
- .padding(5.dp)
- .widthIn(max = 300.dp)
- )
+ }
+ if (!localActivity.hasImage) {
+ localActivity.chatState.requestImageBitmap(messageData.imageUri)
+ }
+ localActivity.hasImage = true
+ } else {
+ Text(
+ text = messageData.text,
+ textAlign = TextAlign.Right,
+ color = MaterialTheme.colorScheme.onPrimaryContainer,
+ modifier = Modifier
+ .wrapContentWidth()
+ .background(
+ color = MaterialTheme.colorScheme.primaryContainer,
+ shape = RoundedCornerShape(5.dp)
+ )
+ .padding(5.dp)
+ .widthIn(max = 300.dp)
+ )
+ }
}
}
@@ -223,8 +261,9 @@ fun MessageView(messageData: MessageData) {
@ExperimentalMaterial3Api
@Composable
-fun SendMessageView(chatState: AppViewModel.ChatState) {
+fun SendMessageView(chatState: AppViewModel.ChatState, activity: Activity) {
val localFocusManager = LocalFocusManager.current
+ val localActivity : MainActivity = activity as MainActivity
Row(
horizontalArrangement = Arrangement.spacedBy(5.dp),
verticalAlignment = Alignment.CenterVertically,
@@ -241,10 +280,38 @@ fun SendMessageView(chatState: AppViewModel.ChatState) {
modifier = Modifier
.weight(9f),
)
+ IconButton(
+ onClick = {
+ activity.takePhoto()
+ },
+ modifier = Modifier
+ .aspectRatio(1f)
+ .weight(1f),
+ enabled = (chatState.chatable() && !localActivity.hasImage)
+ ) {
+ Icon(
+ imageVector = Icons.Filled.AddAPhoto,
+ contentDescription = "use camera",
+ )
+ }
+ IconButton(
+ onClick = {
+ activity.pickImageFromGallery()
+ },
+ modifier = Modifier
+ .aspectRatio(1f)
+ .weight(1f),
+ enabled = (chatState.chatable() && !localActivity.hasImage)
+ ) {
+ Icon(
+ imageVector = Icons.Filled.Photo,
+ contentDescription = "select image",
+ )
+ }
IconButton(
onClick = {
localFocusManager.clearFocus()
- chatState.requestGenerate(text)
+ chatState.requestGenerate(text, activity)
text = ""
},
modifier = Modifier
@@ -271,6 +338,6 @@ fun MessageViewPreviewWithMarkdown() {
* [Link](https://example.com)
Google
"""
- )
+ ), null
)
}
diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt
index f955f508aa..b50bd7b56c 100644
--- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt
+++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt
@@ -1,29 +1,157 @@
package ai.mlc.mlcchat
-import ai.mlc.mlcchat.ui.theme.MLCChatTheme
+import android.Manifest
+import android.content.ContentValues
+import android.content.pm.PackageManager
+import android.net.Uri
+import android.os.Build
import android.os.Bundle
+import android.provider.MediaStore
+import android.util.Log
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
+import androidx.activity.result.contract.ActivityResultContracts
+import androidx.annotation.RequiresApi
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Surface
import androidx.compose.ui.Modifier
-
+import androidx.core.content.ContextCompat
+import ai.mlc.mlcchat.ui.theme.MLCChatTheme
+import java.text.SimpleDateFormat
+import java.util.Date
+import java.util.Locale
+import java.util.UUID
class MainActivity : ComponentActivity() {
+ var hasImage = false
+
+ private val pickImageLauncher = registerForActivityResult(
+ ActivityResultContracts.GetContent()
+ ) { uri: Uri? ->
+ uri?.let {
+ Log.v("pickImageLauncher", "Selected image uri: $it")
+ chatState.messages.add(
+ MessageData(
+ role = MessageRole.User,
+ text = "",
+ id = UUID.randomUUID(),
+ imageUri = it
+ )
+ )
+ }
+ }
+ private var cameraImageUri: Uri? = null
+ private val takePictureLauncher = registerForActivityResult(
+ ActivityResultContracts.TakePicture()
+ ) { success: Boolean ->
+ if (success && cameraImageUri != null) {
+ Log.v("takePictureLauncher", "Camera image uri: $cameraImageUri")
+ chatState.messages.add(
+ MessageData(
+ role = MessageRole.User,
+ text = "",
+ id = UUID.randomUUID(),
+ imageUri = cameraImageUri
+ )
+ )
+ }
+ }
+
+ private val requestPermissionLauncher =
+ registerForActivityResult(ActivityResultContracts.RequestMultiplePermissions()) { permissions ->
+ permissions.entries.forEach {
+ Log.d("Permissions", "${it.key} = ${it.value}")
+ }
+ }
+
+ lateinit var chatState: AppViewModel.ChatState
+
+ @RequiresApi(Build.VERSION_CODES.TIRAMISU)
@ExperimentalMaterial3Api
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
+
+ chatState = AppViewModel(this.application).ChatState()
+ requestNeededPermissions()
+
setContent {
Surface(
- modifier = Modifier
- .fillMaxSize()
+ modifier = Modifier.fillMaxSize()
) {
MLCChatTheme {
- NavView()
+ NavView(this)
}
}
}
}
+
+ private fun requestNeededPermissions() {
+ val permissionsToRequest = mutableListOf()
+
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
+ if (ContextCompat.checkSelfPermission(
+ this,
+ Manifest.permission.READ_MEDIA_IMAGES
+ ) != PackageManager.PERMISSION_GRANTED
+ ) {
+ permissionsToRequest.add(Manifest.permission.READ_MEDIA_IMAGES)
+ }
+ if (ContextCompat.checkSelfPermission(
+ this,
+ Manifest.permission.CAMERA
+ ) != PackageManager.PERMISSION_GRANTED
+ ) {
+ permissionsToRequest.add(Manifest.permission.CAMERA)
+ }
+ } else {
+ if (ContextCompat.checkSelfPermission(
+ this,
+ Manifest.permission.READ_EXTERNAL_STORAGE
+ ) != PackageManager.PERMISSION_GRANTED
+ ) {
+ permissionsToRequest.add(Manifest.permission.READ_EXTERNAL_STORAGE)
+ }
+ if (ContextCompat.checkSelfPermission(
+ this,
+ Manifest.permission.WRITE_EXTERNAL_STORAGE
+ ) != PackageManager.PERMISSION_GRANTED
+ ) {
+ permissionsToRequest.add(Manifest.permission.WRITE_EXTERNAL_STORAGE)
+ }
+ if (ContextCompat.checkSelfPermission(
+ this,
+ Manifest.permission.CAMERA
+ ) != PackageManager.PERMISSION_GRANTED
+ ) {
+ permissionsToRequest.add(Manifest.permission.CAMERA)
+ }
+ }
+
+ if (permissionsToRequest.isNotEmpty()) {
+ requestPermissionLauncher.launch(permissionsToRequest.toTypedArray())
+ }
+ }
+
+ fun pickImageFromGallery() {
+ pickImageLauncher.launch("image/*")
+ }
+
+ fun takePhoto() {
+ val contentValues = ContentValues().apply {
+ val timeFormatter = SimpleDateFormat("yyyyMMdd_HHmmss", Locale.getDefault())
+ val fileName = "IMG_${timeFormatter.format(Date())}.jpg"
+ put(MediaStore.Images.Media.DISPLAY_NAME, fileName)
+ put(MediaStore.Images.Media.MIME_TYPE, "image/jpeg")
+ put(MediaStore.Images.Media.DATE_ADDED, System.currentTimeMillis() / 1000)
+ }
+
+ cameraImageUri = contentResolver.insert(
+ MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
+ contentValues
+ )
+
+ takePictureLauncher.launch(cameraImageUri)
+ }
}
diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt
index 90ae306a7d..008187cc24 100644
--- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt
+++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt
@@ -1,5 +1,6 @@
package ai.mlc.mlcchat
+import android.app.Activity
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.runtime.Composable
import androidx.lifecycle.viewmodel.compose.viewModel
@@ -9,10 +10,10 @@ import androidx.navigation.compose.rememberNavController
@ExperimentalMaterial3Api
@Composable
-fun NavView(appViewModel: AppViewModel = viewModel()) {
+fun NavView(activity: Activity, appViewModel: AppViewModel = viewModel()) {
val navController = rememberNavController()
NavHost(navController = navController, startDestination = "home") {
composable("home") { StartView(navController, appViewModel) }
- composable("chat") { ChatView(navController, appViewModel.chatState) }
+ composable("chat") { ChatView(navController, appViewModel.chatState, activity) }
}
}
diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc
index 1114276af5..4395547398 100644
--- a/cpp/json_ffi/conv_template.cc
+++ b/cpp/json_ffi/conv_template.cc
@@ -326,7 +326,12 @@ Result> CreatePrompt(const Conversation& conv,
int embed_size = (image_size * image_size) / (patch_size * patch_size);
- auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
+ NDArray image_data = image_data_res.Unwrap();
+ std::vector new_shape = {1, image_size, image_size, 3};
+ NDArray image_ndarray = image_data.CreateView(new_shape, image_data.DataType());
+ // TODO: Not sure if commenting will affect other functions. But
+ // python part will do clip preprocessing. auto image_ndarray =
+ // ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
// lazily commit text data
if (pending_text.length() != 0) {
message_list.push_back(TextData(pending_text));