diff --git a/android/MLCChat/app/src/main/AndroidManifest.xml b/android/MLCChat/app/src/main/AndroidManifest.xml index 36c6be8834..9a14f38ff7 100644 --- a/android/MLCChat/app/src/main/AndroidManifest.xml +++ b/android/MLCChat/app/src/main/AndroidManifest.xml @@ -4,6 +4,7 @@ package="ai.mlc.mlcchat"> + ().toMutableStateList() @@ -511,7 +519,9 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { private var modelPath = "" private val executorService = Executors.newSingleThreadExecutor() private val viewModelScope = CoroutineScope(Dispatchers.Main + Job()) + private var imageUri: Uri? = null private fun mainResetChat() { + imageUri = null executorService.submit { callBackend { engine.reset() } historyMessages = mutableListOf() @@ -660,16 +670,58 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { } } - fun requestGenerate(prompt: String) { + fun requestImageBitmap(uri: Uri?) { + require(chatable()) + switchToGenerating() + executorService.submit { + imageUri = uri + viewModelScope.launch { + report.value = "Image process is done, ask any question." + if (modelChatState.value == ModelChatState.Generating) switchToReady() + } + } + } + + fun bitmapToURL(bm: Bitmap): String { + val targetSize = 336 + val scaledBitmap = Bitmap.createScaledBitmap(bm, targetSize, targetSize, true) + + val outputStream = ByteArrayOutputStream() + scaledBitmap.compress(Bitmap.CompressFormat.JPEG, 100, outputStream) + scaledBitmap.recycle() + + val imageBytes = outputStream.toByteArray() + val imageBase64 = Base64.encodeToString(imageBytes, Base64.NO_WRAP) + return "data:image/jpg;base64,$imageBase64" + } + + fun requestGenerate(prompt: String, activity: Activity) { require(chatable()) switchToGenerating() appendMessage(MessageRole.User, prompt) appendMessage(MessageRole.Assistant, "") + var content = ChatCompletionMessageContent(text=prompt) + if (imageUri != null) { + val uri = imageUri + val bitmap = uri?.let { + activity.contentResolver.openInputStream(it)?.use { input -> + BitmapFactory.decodeStream(input) + } + } + val imageBase64URL = bitmapToURL(bitmap!!) + Log.v("requestGenerate", "image base64 url: $imageBase64URL") + val parts = listOf( + mapOf("type" to "text", "text" to prompt), + mapOf("type" to "image_url", "image_url" to imageBase64URL) + ) + content = ChatCompletionMessageContent(parts=parts) + imageUri = null + } executorService.submit { historyMessages.add(ChatCompletionMessage( role = OpenAIProtocol.ChatCompletionRole.user, - content = prompt + content = content )) viewModelScope.launch { @@ -768,7 +820,7 @@ enum class MessageRole { data class DownloadTask(val url: URL, val file: File) -data class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID()) +data class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID(), var imageUri: Uri? = null) data class AppConfig( @SerializedName("model_libs") var modelLibs: MutableList, diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt index 7024c1e8de..b07e521d6e 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt @@ -1,5 +1,9 @@ package ai.mlc.mlcchat +import android.app.Activity +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import androidx.compose.foundation.Image import androidx.compose.foundation.background import androidx.compose.foundation.gestures.detectTapGestures import androidx.compose.foundation.layout.Arrangement @@ -20,7 +24,9 @@ import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.text.selection.SelectionContainer import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.AddAPhoto import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material.icons.filled.Photo import androidx.compose.material.icons.filled.Replay import androidx.compose.material.icons.filled.Send import androidx.compose.material3.Divider @@ -43,6 +49,7 @@ import androidx.compose.runtime.saveable.rememberSaveable import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.input.pointer.pointerInput import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.text.style.TextAlign @@ -55,9 +62,10 @@ import kotlinx.coroutines.launch @ExperimentalMaterial3Api @Composable fun ChatView( - navController: NavController, chatState: AppViewModel.ChatState + navController: NavController, chatState: AppViewModel.ChatState, activity: Activity ) { val localFocusManager = LocalFocusManager.current + (activity as MainActivity).chatState = chatState Scaffold(topBar = { TopAppBar( title = { @@ -81,7 +89,9 @@ fun ChatView( }, actions = { IconButton( - onClick = { chatState.requestResetChat() }, + onClick = { + chatState.requestResetChat() + activity.hasImage = false }, enabled = chatState.interruptable() ) { Icon( @@ -125,23 +135,23 @@ fun ChatView( items = chatState.messages, key = { message -> message.id }, ) { message -> - MessageView(messageData = message) + MessageView(messageData = message, activity) } item { // place holder item for scrolling to the bottom } } Divider(thickness = 1.dp, modifier = Modifier.padding(top = 5.dp)) - SendMessageView(chatState = chatState) + SendMessageView(chatState = chatState, activity) } } } @Composable -fun MessageView(messageData: MessageData) { +fun MessageView(messageData: MessageData, activity: Activity?) { // default render the Assistant text as MarkdownText var useMarkdown by remember { mutableStateOf(true) } - + var localActivity : MainActivity = activity as MainActivity SelectionContainer { if (messageData.role == MessageRole.Assistant) { Column { @@ -202,19 +212,47 @@ fun MessageView(messageData: MessageData) { horizontalArrangement = Arrangement.End, modifier = Modifier.fillMaxWidth() ) { - Text( - text = messageData.text, - textAlign = TextAlign.Right, - color = MaterialTheme.colorScheme.onPrimaryContainer, - modifier = Modifier - .wrapContentWidth() - .background( - color = MaterialTheme.colorScheme.primaryContainer, - shape = RoundedCornerShape(5.dp) + if (messageData.imageUri != null) { + val uri = messageData.imageUri + val bitmap = uri?.let { + activity.contentResolver.openInputStream(it)?.use { input -> + BitmapFactory.decodeStream(input) + } + } + val displayBitmap = bitmap?.let { Bitmap.createScaledBitmap(it, 224, 224, true) } + if (displayBitmap != null) { + Image( + displayBitmap.asImageBitmap(), + "", + modifier = Modifier + .wrapContentWidth() + .background( + color = MaterialTheme.colorScheme.secondaryContainer, + shape = RoundedCornerShape(5.dp) + ) + .padding(5.dp) + .widthIn(max = 300.dp) ) - .padding(5.dp) - .widthIn(max = 300.dp) - ) + } + if (!localActivity.hasImage) { + localActivity.chatState.requestImageBitmap(messageData.imageUri) + } + localActivity.hasImage = true + } else { + Text( + text = messageData.text, + textAlign = TextAlign.Right, + color = MaterialTheme.colorScheme.onPrimaryContainer, + modifier = Modifier + .wrapContentWidth() + .background( + color = MaterialTheme.colorScheme.primaryContainer, + shape = RoundedCornerShape(5.dp) + ) + .padding(5.dp) + .widthIn(max = 300.dp) + ) + } } } @@ -223,8 +261,9 @@ fun MessageView(messageData: MessageData) { @ExperimentalMaterial3Api @Composable -fun SendMessageView(chatState: AppViewModel.ChatState) { +fun SendMessageView(chatState: AppViewModel.ChatState, activity: Activity) { val localFocusManager = LocalFocusManager.current + val localActivity : MainActivity = activity as MainActivity Row( horizontalArrangement = Arrangement.spacedBy(5.dp), verticalAlignment = Alignment.CenterVertically, @@ -241,10 +280,38 @@ fun SendMessageView(chatState: AppViewModel.ChatState) { modifier = Modifier .weight(9f), ) + IconButton( + onClick = { + activity.takePhoto() + }, + modifier = Modifier + .aspectRatio(1f) + .weight(1f), + enabled = (chatState.chatable() && !localActivity.hasImage) + ) { + Icon( + imageVector = Icons.Filled.AddAPhoto, + contentDescription = "use camera", + ) + } + IconButton( + onClick = { + activity.pickImageFromGallery() + }, + modifier = Modifier + .aspectRatio(1f) + .weight(1f), + enabled = (chatState.chatable() && !localActivity.hasImage) + ) { + Icon( + imageVector = Icons.Filled.Photo, + contentDescription = "select image", + ) + } IconButton( onClick = { localFocusManager.clearFocus() - chatState.requestGenerate(text) + chatState.requestGenerate(text, activity) text = "" }, modifier = Modifier @@ -271,6 +338,6 @@ fun MessageViewPreviewWithMarkdown() { * [Link](https://example.com) Google """ - ) + ), null ) } diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt index f955f508aa..b50bd7b56c 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt @@ -1,29 +1,157 @@ package ai.mlc.mlcchat -import ai.mlc.mlcchat.ui.theme.MLCChatTheme +import android.Manifest +import android.content.ContentValues +import android.content.pm.PackageManager +import android.net.Uri +import android.os.Build import android.os.Bundle +import android.provider.MediaStore +import android.util.Log import androidx.activity.ComponentActivity import androidx.activity.compose.setContent +import androidx.activity.result.contract.ActivityResultContracts +import androidx.annotation.RequiresApi import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Surface import androidx.compose.ui.Modifier - +import androidx.core.content.ContextCompat +import ai.mlc.mlcchat.ui.theme.MLCChatTheme +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale +import java.util.UUID class MainActivity : ComponentActivity() { + var hasImage = false + + private val pickImageLauncher = registerForActivityResult( + ActivityResultContracts.GetContent() + ) { uri: Uri? -> + uri?.let { + Log.v("pickImageLauncher", "Selected image uri: $it") + chatState.messages.add( + MessageData( + role = MessageRole.User, + text = "", + id = UUID.randomUUID(), + imageUri = it + ) + ) + } + } + private var cameraImageUri: Uri? = null + private val takePictureLauncher = registerForActivityResult( + ActivityResultContracts.TakePicture() + ) { success: Boolean -> + if (success && cameraImageUri != null) { + Log.v("takePictureLauncher", "Camera image uri: $cameraImageUri") + chatState.messages.add( + MessageData( + role = MessageRole.User, + text = "", + id = UUID.randomUUID(), + imageUri = cameraImageUri + ) + ) + } + } + + private val requestPermissionLauncher = + registerForActivityResult(ActivityResultContracts.RequestMultiplePermissions()) { permissions -> + permissions.entries.forEach { + Log.d("Permissions", "${it.key} = ${it.value}") + } + } + + lateinit var chatState: AppViewModel.ChatState + + @RequiresApi(Build.VERSION_CODES.TIRAMISU) @ExperimentalMaterial3Api override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) + + chatState = AppViewModel(this.application).ChatState() + requestNeededPermissions() + setContent { Surface( - modifier = Modifier - .fillMaxSize() + modifier = Modifier.fillMaxSize() ) { MLCChatTheme { - NavView() + NavView(this) } } } } + + private fun requestNeededPermissions() { + val permissionsToRequest = mutableListOf() + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + if (ContextCompat.checkSelfPermission( + this, + Manifest.permission.READ_MEDIA_IMAGES + ) != PackageManager.PERMISSION_GRANTED + ) { + permissionsToRequest.add(Manifest.permission.READ_MEDIA_IMAGES) + } + if (ContextCompat.checkSelfPermission( + this, + Manifest.permission.CAMERA + ) != PackageManager.PERMISSION_GRANTED + ) { + permissionsToRequest.add(Manifest.permission.CAMERA) + } + } else { + if (ContextCompat.checkSelfPermission( + this, + Manifest.permission.READ_EXTERNAL_STORAGE + ) != PackageManager.PERMISSION_GRANTED + ) { + permissionsToRequest.add(Manifest.permission.READ_EXTERNAL_STORAGE) + } + if (ContextCompat.checkSelfPermission( + this, + Manifest.permission.WRITE_EXTERNAL_STORAGE + ) != PackageManager.PERMISSION_GRANTED + ) { + permissionsToRequest.add(Manifest.permission.WRITE_EXTERNAL_STORAGE) + } + if (ContextCompat.checkSelfPermission( + this, + Manifest.permission.CAMERA + ) != PackageManager.PERMISSION_GRANTED + ) { + permissionsToRequest.add(Manifest.permission.CAMERA) + } + } + + if (permissionsToRequest.isNotEmpty()) { + requestPermissionLauncher.launch(permissionsToRequest.toTypedArray()) + } + } + + fun pickImageFromGallery() { + pickImageLauncher.launch("image/*") + } + + fun takePhoto() { + val contentValues = ContentValues().apply { + val timeFormatter = SimpleDateFormat("yyyyMMdd_HHmmss", Locale.getDefault()) + val fileName = "IMG_${timeFormatter.format(Date())}.jpg" + put(MediaStore.Images.Media.DISPLAY_NAME, fileName) + put(MediaStore.Images.Media.MIME_TYPE, "image/jpeg") + put(MediaStore.Images.Media.DATE_ADDED, System.currentTimeMillis() / 1000) + } + + cameraImageUri = contentResolver.insert( + MediaStore.Images.Media.EXTERNAL_CONTENT_URI, + contentValues + ) + + takePictureLauncher.launch(cameraImageUri) + } } diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt index 90ae306a7d..008187cc24 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt @@ -1,5 +1,6 @@ package ai.mlc.mlcchat +import android.app.Activity import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.runtime.Composable import androidx.lifecycle.viewmodel.compose.viewModel @@ -9,10 +10,10 @@ import androidx.navigation.compose.rememberNavController @ExperimentalMaterial3Api @Composable -fun NavView(appViewModel: AppViewModel = viewModel()) { +fun NavView(activity: Activity, appViewModel: AppViewModel = viewModel()) { val navController = rememberNavController() NavHost(navController = navController, startDestination = "home") { composable("home") { StartView(navController, appViewModel) } - composable("chat") { ChatView(navController, appViewModel.chatState) } + composable("chat") { ChatView(navController, appViewModel.chatState, activity) } } } diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc index 1114276af5..4395547398 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/conv_template.cc @@ -326,7 +326,12 @@ Result> CreatePrompt(const Conversation& conv, int embed_size = (image_size * image_size) / (patch_size * patch_size); - auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device); + NDArray image_data = image_data_res.Unwrap(); + std::vector new_shape = {1, image_size, image_size, 3}; + NDArray image_ndarray = image_data.CreateView(new_shape, image_data.DataType()); + // TODO: Not sure if commenting will affect other functions. But + // python part will do clip preprocessing. auto image_ndarray = + // ClipPreprocessor(image_data_res.Unwrap(), image_size, device); // lazily commit text data if (pending_text.length() != 0) { message_list.push_back(TextData(pending_text));