Compare commits
3 Commits
d6ba875297
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e001d8769 | |||
| 2f34976f3e | |||
| d1208f7e39 |
@@ -1,12 +1,9 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools">
|
||||
xmlns:tools="http://tools.android.com/tools">
|
||||
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
|
||||
|
||||
|
||||
<application
|
||||
android:name=".ChatApplication"
|
||||
@@ -22,7 +19,6 @@
|
||||
|
||||
<service
|
||||
android:name=".core.service.MessageStreamService"
|
||||
android:foregroundServiceType="dataSync"
|
||||
android:exported="false"/>
|
||||
|
||||
<activity
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package dev.zxq5.chatapp.android
|
||||
|
||||
import android.Manifest
|
||||
import android.os.Build
|
||||
import android.os.Bundle
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
@@ -62,14 +66,37 @@ class MainActivity : ComponentActivity() {
|
||||
val currentScreen by chatViewModel.currentScreen.collectAsState()
|
||||
val selectedChannelId by chatViewModel.channelId.collectAsState()
|
||||
|
||||
// Permission request launcher
|
||||
val launcher = rememberLauncherForActivityResult(
|
||||
contract = ActivityResultContracts.RequestPermission(),
|
||||
onResult = { isGranted ->
|
||||
if (isGranted && authState == AuthState.Authenticated) {
|
||||
MessageStreamService.start(this@MainActivity)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
LaunchedEffect(authState) {
|
||||
when (authState) {
|
||||
AuthState.Authenticated -> MessageStreamService.start(this@MainActivity)
|
||||
AuthState.Authenticated -> {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
|
||||
launcher.launch(Manifest.permission.POST_NOTIFICATIONS)
|
||||
}
|
||||
MessageStreamService.start(this@MainActivity)
|
||||
chatViewModel.loadAccessibleChannels()
|
||||
}
|
||||
AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity)
|
||||
AuthState.AwaitingTotp -> {}
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
chatViewModel.onUnauthorized = {
|
||||
authViewModel.logout()
|
||||
chatViewModel.clearChat()
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let {
|
||||
chatViewModel.switchChannel(it.toLong())
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
@file:OptIn(ExperimentalUuidApi::class)
|
||||
|
||||
package dev.zxq5.chatapp.android.api
|
||||
|
||||
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
|
||||
import dev.zxq5.chatapp.android.api.model.Message
|
||||
import dev.zxq5.chatapp.android.api.model.ChatEvent
|
||||
import dev.zxq5.chatapp.android.api.model.SendMessage
|
||||
import dev.zxq5.chatapp.android.api.model.SpaceDto
|
||||
import io.ktor.client.HttpClient
|
||||
@@ -25,6 +27,8 @@ import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlin.time.Clock
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.uuid.ExperimentalUuidApi
|
||||
import kotlin.uuid.Uuid
|
||||
|
||||
class ChatClient(private val token: String) {
|
||||
|
||||
@@ -45,18 +49,18 @@ class ChatClient(private val token: String) {
|
||||
suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
|
||||
http.post("${BASE_URL}/api/chat/$channelId") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now()))
|
||||
setBody(SendMessage(id = Uuid.random(), user_id = userId, text = text, timestamp = Clock.System.now()))
|
||||
}
|
||||
}
|
||||
|
||||
fun messageStream(channelId: Long): Flow<Message> = flow {
|
||||
fun eventStream(channelId: Long): Flow<ChatEvent> = flow {
|
||||
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
|
||||
val channel = response.bodyAsChannel()
|
||||
while (!channel.isClosedForRead) {
|
||||
val line = channel.readLine() ?: break
|
||||
if (line.startsWith("data:")) {
|
||||
val json = line.removePrefix("data:").trim()
|
||||
runCatching { Json.decodeFromString<Message>(json) }
|
||||
runCatching { Json.decodeFromString<ChatEvent>(json) }
|
||||
.onSuccess { emit(it) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import android.util.Log
|
||||
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
|
||||
import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest
|
||||
import dev.zxq5.chatapp.android.api.model.DisplayNameRequest
|
||||
import dev.zxq5.chatapp.android.api.model.InviteRequest
|
||||
import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest
|
||||
import dev.zxq5.chatapp.android.api.model.QrResponse
|
||||
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
|
||||
@@ -45,6 +46,22 @@ class SettingsClient(private val token: String) {
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
|
||||
return try {
|
||||
val response = http.post("${BASE_URL}/api/invite") {
|
||||
contentType(ContentType.Application.Json)
|
||||
setBody(request)
|
||||
}
|
||||
if (response.status.isSuccess()) {
|
||||
ApiResult.Success(response.body<String>())
|
||||
} else {
|
||||
ApiResult.HttpError(response.status.value, "Failed to create invite")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
ApiResult.NetworkError(e.localizedMessage ?: "Network error")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun getTotpQr(password: String): ApiResult<QrResponse> {
|
||||
return try {
|
||||
val response = http.post("${BASE_URL}/api/totp.jpg") {
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
@file:OptIn(ExperimentalUuidApi::class, ExperimentalTime::class)
|
||||
|
||||
package dev.zxq5.chatapp.android.api.model
|
||||
|
||||
import kotlinx.serialization.ExperimentalSerializationApi
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.json.JsonClassDiscriminator
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.time.Instant
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlin.uuid.ExperimentalUuidApi
|
||||
import kotlin.uuid.Uuid
|
||||
|
||||
@OptIn(ExperimentalSerializationApi::class)
|
||||
@Serializable
|
||||
@JsonClassDiscriminator("type")
|
||||
sealed class ChatEvent {
|
||||
|
||||
@Serializable
|
||||
@SerialName("SendMessage")
|
||||
data class SendMessage(
|
||||
val data: Message
|
||||
) : ChatEvent()
|
||||
|
||||
@Serializable
|
||||
@SerialName("EditMessage")
|
||||
data class EditMessage(
|
||||
val data: EditMessageContent
|
||||
) : ChatEvent()
|
||||
|
||||
@Serializable
|
||||
@SerialName("MessageAppendContent")
|
||||
data class MessageAppendContent(
|
||||
val data: AppendContent
|
||||
) : ChatEvent()
|
||||
}
|
||||
|
||||
// tuple variants like (i64, ChatMsg) and (i64, String)
|
||||
// need wrapper classes since kotlinx can't deserialise
|
||||
// bare JSON arrays into data classes directly
|
||||
@Serializable
|
||||
data class EditMessageContent(
|
||||
val id: Uuid,
|
||||
val message: Message
|
||||
)
|
||||
|
||||
@Serializable
|
||||
data class AppendContent (
|
||||
val id: Uuid,
|
||||
val content: String
|
||||
)
|
||||
|
||||
@Serializable
|
||||
data class Message (
|
||||
val id: Uuid,
|
||||
val user_id: Int,
|
||||
val display_name: String,
|
||||
val text: String,
|
||||
val timestamp: Instant
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
package dev.zxq5.chatapp.android.api.model
|
||||
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.time.Instant
|
||||
|
||||
@Serializable
|
||||
data class InviteRequest @OptIn(ExperimentalTime::class) constructor(
|
||||
val name: String,
|
||||
val max_uses: Int,
|
||||
val expiry_date: Instant,
|
||||
val start_date: Instant
|
||||
)
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
package dev.zxq5.chatapp.android.model
|
||||
package dev.zxq5.chatapp.android.api.model
|
||||
|
||||
sealed class LoginState {
|
||||
object Idle : LoginState()
|
||||
@@ -1,13 +0,0 @@
|
||||
package dev.zxq5.chatapp.android.api.model
|
||||
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.time.Instant
|
||||
|
||||
@Serializable
|
||||
data class Message @OptIn(ExperimentalTime::class) constructor(
|
||||
val user_id: Int,
|
||||
val display_name: String,
|
||||
val text: String,
|
||||
val timestamp: Instant
|
||||
)
|
||||
@@ -3,9 +3,12 @@ package dev.zxq5.chatapp.android.api.model
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.time.Instant
|
||||
import kotlin.uuid.ExperimentalUuidApi
|
||||
import kotlin.uuid.Uuid
|
||||
|
||||
@Serializable
|
||||
data class SendMessage @OptIn(ExperimentalTime::class) constructor(
|
||||
data class SendMessage @OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) constructor(
|
||||
val id: Uuid,
|
||||
val user_id: Int,
|
||||
val text: String,
|
||||
val timestamp: Instant
|
||||
|
||||
@@ -3,10 +3,12 @@ package dev.zxq5.chatapp.android.core.data
|
||||
import android.content.Context
|
||||
import android.content.SharedPreferences
|
||||
import android.util.Base64
|
||||
import android.util.Log
|
||||
import androidx.core.content.edit
|
||||
import androidx.security.crypto.EncryptedSharedPreferences
|
||||
import androidx.security.crypto.MasterKey
|
||||
import org.json.JSONObject
|
||||
import java.time.Instant
|
||||
|
||||
private const val KEY = "auth_token"
|
||||
private const val TWOFA_KEY = "twofa_enabled"
|
||||
@@ -27,11 +29,37 @@ class TokenStore(appContext: Context) {
|
||||
)
|
||||
}
|
||||
|
||||
fun save(token: String) =
|
||||
prefs().edit { putString(KEY, token) }
|
||||
fun save(token: String) {
|
||||
Log.d("TokenStore", "Saving token: $token")
|
||||
|
||||
prefs().edit { putString(KEY, token) }
|
||||
}
|
||||
|
||||
fun get(): String? {
|
||||
val ret = prefs().getString(KEY, null)
|
||||
Log.d("TokenStore", "Retrieved token: $ret")
|
||||
return ret
|
||||
}
|
||||
|
||||
fun isExpired(): Boolean {
|
||||
val token = get() ?: return true
|
||||
return try {
|
||||
val payload = token.split(".")[1]
|
||||
val padded = payload + "==".take((4 - payload.length % 4) % 4)
|
||||
val jsonString = String(Base64.decode(padded, Base64.URL_SAFE))
|
||||
val json = JSONObject(jsonString)
|
||||
if (json.has("exp")) {
|
||||
val exp = json.getLong("exp")
|
||||
val now = Instant.now().epochSecond
|
||||
now >= exp
|
||||
} else {
|
||||
false // If no exp claim, assume not expired or handle differently
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
true // If we can't parse it, treat it as expired
|
||||
}
|
||||
}
|
||||
|
||||
fun get(): String? =
|
||||
prefs().getString(KEY, null)
|
||||
|
||||
fun save2faEnabled( enabled: Boolean) =
|
||||
prefs().edit { putBoolean(TWOFA_KEY, enabled) }
|
||||
|
||||
+18
-30
@@ -3,10 +3,10 @@ package dev.zxq5.chatapp.android.core.service
|
||||
import android.app.Service
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.os.Build
|
||||
import android.os.IBinder
|
||||
import android.util.Log
|
||||
import dev.zxq5.chatapp.android.ChatApplication
|
||||
import dev.zxq5.chatapp.android.api.model.ChatEvent
|
||||
import dev.zxq5.chatapp.android.data.repository.ChatRepository
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
@@ -15,21 +15,17 @@ import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
// core/service/MessageStreamService.kt
|
||||
class MessageStreamService : Service() {
|
||||
|
||||
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
private lateinit var notificationService: NotificationService
|
||||
private lateinit var chatRepository: ChatRepository
|
||||
|
||||
// which channel the user is currently looking at
|
||||
// set by the ViewModel when the user opens/closes a channel
|
||||
var activeChannelId: Long? = null
|
||||
set(value) {
|
||||
field = value
|
||||
Log.d("Service", "activeChannelId set to $value")
|
||||
if (value != null) {
|
||||
// restart stream with new channel
|
||||
currentStreamJob?.cancel()
|
||||
observeMessages()
|
||||
}
|
||||
@@ -42,11 +38,7 @@ class MessageStreamService : Service() {
|
||||
|
||||
fun start(context: Context) {
|
||||
val intent = Intent(context, MessageStreamService::class.java)
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
context.startForegroundService(intent)
|
||||
} else {
|
||||
context.startService(intent)
|
||||
}
|
||||
context.startService(intent)
|
||||
}
|
||||
|
||||
fun stop(context: Context) {
|
||||
@@ -62,33 +54,29 @@ class MessageStreamService : Service() {
|
||||
}
|
||||
|
||||
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
|
||||
startForeground(
|
||||
NotificationService.FOREGROUND_NOTIFICATION_ID,
|
||||
notificationService.buildForegroundNotification()
|
||||
)
|
||||
observeMessages()
|
||||
return START_STICKY // restart if killed
|
||||
return START_STICKY
|
||||
}
|
||||
|
||||
private fun observeMessages() {
|
||||
val channelId = activeChannelId ?: chatRepository.getLastActiveChannel()
|
||||
Log.d("Service", "observeMessages called, channelId=$channelId")
|
||||
if (channelId == null) {
|
||||
Log.d("Service", "No channel to observe, waiting for switchChannel")
|
||||
return
|
||||
}
|
||||
if (channelId == null) return
|
||||
|
||||
Log.d("Service", "Starting stream for channel $channelId")
|
||||
currentStreamJob = serviceScope.launch {
|
||||
chatRepository.messageStream(channelId)
|
||||
chatRepository.eventStream(channelId)
|
||||
.catch { e -> Log.e("Service", "Stream error", e) }
|
||||
.collect { message ->
|
||||
if (!ChatApplication.AppState.isInForeground) { // no channel focused, always notify
|
||||
notificationService.showMessageNotification(
|
||||
conversationId = activeChannelId.toString(),
|
||||
senderName = message.display_name,
|
||||
messagePreview = message.text.take(80)
|
||||
)
|
||||
.collect { event ->
|
||||
// Only show notification when an event (new message) is received
|
||||
// and the app is not in the foreground on this channel.
|
||||
if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) {
|
||||
when (event) {
|
||||
is ChatEvent.SendMessage -> notificationService.showMessageNotification(
|
||||
conversationId = channelId.toString(),
|
||||
senderName = event.data.display_name,
|
||||
messagePreview = event.data.text
|
||||
)
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,4 +89,4 @@ class MessageStreamService : Service() {
|
||||
instance = null
|
||||
serviceScope.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+21
-31
@@ -15,51 +15,41 @@ class NotificationService(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
const val CHANNEL_ID = "messages"
|
||||
const val FOREGROUND_NOTIFICATION_ID = 1 // ← this needs to exist
|
||||
const val SERVICE_CHANNEL_ID = "service"
|
||||
const val FOREGROUND_NOTIFICATION_ID = 1
|
||||
}
|
||||
|
||||
private val manager = context.getSystemService(NotificationManager::class.java)
|
||||
|
||||
fun createChannels() {
|
||||
// channel for new message notifications
|
||||
val messageChannel = NotificationChannel(
|
||||
CHANNEL_ID,
|
||||
"Messages",
|
||||
NotificationManager.IMPORTANCE_HIGH
|
||||
).apply {
|
||||
enableVibration(true)
|
||||
fun createForegroundNotification(): Notification {
|
||||
val intent = Intent(context, MainActivity::class.java).apply {
|
||||
flags = Intent.FLAG_ACTIVITY_NEW_TASK or Intent.FLAG_ACTIVITY_CLEAR_TASK
|
||||
}
|
||||
|
||||
// channel for the persistent foreground service notification
|
||||
// low importance so it doesn't make noise
|
||||
val serviceChannel = NotificationChannel(
|
||||
"service",
|
||||
"Background connection",
|
||||
NotificationManager.IMPORTANCE_LOW
|
||||
|
||||
val pendingIntent = PendingIntent.getActivity(
|
||||
context,
|
||||
0,
|
||||
intent,
|
||||
PendingIntent.FLAG_IMMUTABLE
|
||||
)
|
||||
|
||||
val mgr = context.getSystemService(NotificationManager::class.java)
|
||||
mgr.createNotificationChannel(messageChannel)
|
||||
mgr.createNotificationChannel(serviceChannel)
|
||||
}
|
||||
|
||||
fun buildForegroundNotification(): Notification {
|
||||
return NotificationCompat.Builder(context, "service")
|
||||
return NotificationCompat.Builder(context, SERVICE_CHANNEL_ID)
|
||||
.setSmallIcon(R.drawable.ic_notification)
|
||||
.setContentTitle("chatapp")
|
||||
.setContentText("Connected")
|
||||
.setContentTitle("Chat App")
|
||||
.setContentText("Connecting to message stream...")
|
||||
.setPriority(NotificationCompat.PRIORITY_LOW)
|
||||
.setCategory(NotificationCompat.CATEGORY_SERVICE)
|
||||
.setContentIntent(pendingIntent)
|
||||
.setOngoing(true)
|
||||
.setSilent(true)
|
||||
.build()
|
||||
}
|
||||
|
||||
fun showMessageNotification(
|
||||
conversationId: String,
|
||||
senderName: String,
|
||||
messagePreview: String, // for E2E this would be "New message" — no plaintext
|
||||
messagePreview: String,
|
||||
notificationId: Int = conversationId.hashCode()
|
||||
) {
|
||||
// intent that opens the app to the right conversation when tapped
|
||||
val intent = Intent(context, MainActivity::class.java).apply {
|
||||
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
|
||||
putExtra("conversation_id", conversationId)
|
||||
@@ -72,13 +62,13 @@ class NotificationService(private val context: Context) {
|
||||
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
|
||||
)
|
||||
|
||||
val notification = NotificationCompat.Builder(context, "messages")
|
||||
val notification = NotificationCompat.Builder(context, CHANNEL_ID)
|
||||
.setSmallIcon(R.drawable.ic_notification)
|
||||
.setContentTitle(senderName)
|
||||
.setContentText(messagePreview)
|
||||
.setPriority(NotificationCompat.PRIORITY_HIGH)
|
||||
.setContentIntent(pendingIntent)
|
||||
.setAutoCancel(true) // dismiss on tap
|
||||
.setAutoCancel(true)
|
||||
.build()
|
||||
|
||||
manager.notify(notificationId, notification)
|
||||
@@ -91,4 +81,4 @@ class NotificationService(private val context: Context) {
|
||||
fun dismissAll() {
|
||||
manager.cancelAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,10 @@ class AuthRepository(
|
||||
|
||||
fun getAuthState(): AuthState {
|
||||
val token = tokenStore.get() ?: return AuthState.Unauthenticated
|
||||
if (tokenStore.isExpired()) {
|
||||
tokenStore.clear()
|
||||
return AuthState.Unauthenticated
|
||||
}
|
||||
return when (getScopeFromToken(token)) {
|
||||
TokenScope.FULL -> AuthState.Authenticated
|
||||
TokenScope.TOTP_PENDING -> AuthState.AwaitingTotp
|
||||
|
||||
+3
-2
@@ -1,6 +1,7 @@
|
||||
package dev.zxq5.chatapp.android.data.repository
|
||||
|
||||
import dev.zxq5.chatapp.android.api.ChatClient
|
||||
import dev.zxq5.chatapp.android.api.model.ChatEvent
|
||||
import dev.zxq5.chatapp.android.core.data.TokenStore
|
||||
import dev.zxq5.chatapp.android.api.model.Message
|
||||
import dev.zxq5.chatapp.android.api.model.SpaceDto
|
||||
@@ -43,8 +44,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
|
||||
getChatClient()?.sendMessage(channelId, userId, text)
|
||||
}
|
||||
|
||||
fun messageStream(channelId: Long): Flow<Message> {
|
||||
fun eventStream(channelId: Long): Flow<ChatEvent> {
|
||||
_lastActiveChannel = channelId
|
||||
return getChatClient()?.messageStream(channelId) ?: emptyFlow()
|
||||
return getChatClient()?.eventStream(channelId) ?: emptyFlow()
|
||||
}
|
||||
}
|
||||
|
||||
+5
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.data.repository
|
||||
|
||||
import dev.zxq5.chatapp.android.api.model.QrResponse
|
||||
import dev.zxq5.chatapp.android.api.SettingsClient
|
||||
import dev.zxq5.chatapp.android.api.model.InviteRequest
|
||||
import dev.zxq5.chatapp.android.api.model.TotpStatus
|
||||
import dev.zxq5.chatapp.android.core.data.TokenStore
|
||||
import dev.zxq5.chatapp.android.core.error.ApiResult
|
||||
@@ -25,6 +26,10 @@ class SettingsRepository(private val tokenStore: TokenStore) {
|
||||
_lastToken = null
|
||||
}
|
||||
|
||||
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
|
||||
return getSettingsClient()?.createInvite(request) ?: ApiResult.NetworkError("Not authenticated")
|
||||
}
|
||||
|
||||
suspend fun getTotpQr(password: String): ApiResult<QrResponse?> {
|
||||
val settingsClient = getSettingsClient() ?: return ApiResult.NetworkError("Not authenticated")
|
||||
return settingsClient.getTotpQr(password)
|
||||
|
||||
@@ -3,7 +3,7 @@ package dev.zxq5.chatapp.android.feature.auth
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import dev.zxq5.chatapp.android.model.LoginState
|
||||
import dev.zxq5.chatapp.android.api.model.LoginState
|
||||
|
||||
@Composable
|
||||
fun AuthScreen(viewModel: AuthViewModel) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import dev.zxq5.chatapp.android.data.repository.AuthRepository
|
||||
import dev.zxq5.chatapp.android.data.repository.LoginResult
|
||||
import dev.zxq5.chatapp.android.data.repository.SignupResult
|
||||
import dev.zxq5.chatapp.android.data.repository.AuthState
|
||||
import dev.zxq5.chatapp.android.model.LoginState
|
||||
import dev.zxq5.chatapp.android.api.model.LoginState
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.unit.dp
|
||||
import dev.zxq5.chatapp.android.model.LoginState
|
||||
import dev.zxq5.chatapp.android.api.model.LoginState
|
||||
import dev.zxq5.chatapp.android.ui.components.TextField
|
||||
|
||||
@Composable
|
||||
|
||||
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.unit.dp
|
||||
import dev.zxq5.chatapp.android.model.LoginState
|
||||
import dev.zxq5.chatapp.android.api.model.LoginState
|
||||
import dev.zxq5.chatapp.android.ui.components.TextField
|
||||
|
||||
@Composable
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
@file:OptIn(ExperimentalUuidApi::class)
|
||||
|
||||
package dev.zxq5.chatapp.android.feature.chat
|
||||
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import dev.zxq5.chatapp.android.api.model.Channel
|
||||
import dev.zxq5.chatapp.android.api.model.ChatEvent
|
||||
import dev.zxq5.chatapp.android.data.repository.ChatRepository
|
||||
import dev.zxq5.chatapp.android.api.model.Message
|
||||
import dev.zxq5.chatapp.android.api.model.Space
|
||||
import dev.zxq5.chatapp.android.api.model.SpaceDto
|
||||
import dev.zxq5.chatapp.android.core.service.MessageStreamService
|
||||
import io.ktor.client.plugins.ResponseException
|
||||
import io.ktor.http.HttpStatusCode
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.uuid.ExperimentalUuidApi
|
||||
|
||||
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
|
||||
@@ -40,6 +45,8 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
val channelError: StateFlow<String?> = _channelError
|
||||
|
||||
private var streamJob: Job? = null
|
||||
|
||||
var onUnauthorized: (() -> Unit)? = null
|
||||
|
||||
init {
|
||||
_currentUserId.value = chatRepository.getUserId()
|
||||
@@ -49,6 +56,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
|
||||
fun loadAccessibleChannels() {
|
||||
_error.value = null
|
||||
_currentUserId.value = chatRepository.getUserId()
|
||||
viewModelScope.launch {
|
||||
runCatching {
|
||||
chatRepository.getAccessibleChannels()
|
||||
@@ -56,11 +64,16 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
_spaces.value = data
|
||||
}.onFailure { e ->
|
||||
Log.e("Chat", "Failed to load spaces", e)
|
||||
_error.value = "Failed to load channels: ${e.message}"
|
||||
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
|
||||
onUnauthorized?.invoke()
|
||||
} else {
|
||||
_error.value = "Failed to load channels: ${e.message}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalTime::class)
|
||||
private fun observeChannel() {
|
||||
viewModelScope.launch {
|
||||
_channelId.collect { id ->
|
||||
@@ -69,13 +82,40 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
_channelError.value = null
|
||||
if (id != null) {
|
||||
streamJob = launch {
|
||||
chatRepository.messageStream(id)
|
||||
chatRepository.eventStream(id)
|
||||
.catch { e ->
|
||||
Log.e("Chat", "Stream error", e)
|
||||
_channelError.value = "Connection lost: ${e.message}"
|
||||
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
|
||||
onUnauthorized?.invoke()
|
||||
} else {
|
||||
_channelError.value = "Connection lost: ${e.message}"
|
||||
}
|
||||
}
|
||||
.collect { message ->
|
||||
_messages.update { it + message }
|
||||
.collect { event ->
|
||||
when (event) {
|
||||
is ChatEvent.SendMessage -> {
|
||||
_messages.update { it + event.data }
|
||||
}
|
||||
is ChatEvent.EditMessage -> {
|
||||
_messages.update { messages ->
|
||||
messages.map {
|
||||
if (it.id == event.data.id) event.data.message
|
||||
else it
|
||||
}
|
||||
}
|
||||
}
|
||||
is ChatEvent.MessageAppendContent -> {
|
||||
_messages.update { messages ->
|
||||
messages.map { msg ->
|
||||
if (msg.id == event.data.id) {
|
||||
msg.copy(text = msg.text + event.data.content)
|
||||
} else {
|
||||
msg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -108,7 +148,11 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
)
|
||||
}.onFailure { e ->
|
||||
Log.e("Chat", "Send message error", e)
|
||||
_channelError.value = "Failed to send message"
|
||||
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
|
||||
onUnauthorized?.invoke()
|
||||
} else {
|
||||
_channelError.value = "Failed to send message"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,6 +161,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
|
||||
_messages.value = emptyList()
|
||||
_channelId.value = null
|
||||
_currentUserId.value = null
|
||||
_spaces.value = emptyList()
|
||||
_error.value = null
|
||||
_channelError.value = null
|
||||
streamJob?.cancel()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
@file:OptIn(ExperimentalUuidApi::class)
|
||||
|
||||
package dev.zxq5.chatapp.android.feature.chat
|
||||
|
||||
import androidx.compose.foundation.BorderStroke
|
||||
@@ -65,6 +67,7 @@ import dev.zxq5.chatapp.android.api.model.Message
|
||||
import java.text.DateFormat
|
||||
import java.util.Date
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.uuid.ExperimentalUuidApi
|
||||
|
||||
@Composable
|
||||
fun ChatScreen(
|
||||
@@ -277,7 +280,7 @@ fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit)
|
||||
modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(10.dp)
|
||||
) {
|
||||
items(messages) { message ->
|
||||
items(messages, key = { it.id }) { message ->
|
||||
MessageBubble(message, currentUserId)
|
||||
}
|
||||
item { Spacer(Modifier.height(10.dp)) }
|
||||
@@ -378,7 +381,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
|
||||
horizontalAlignment = if (isMe) Alignment.End else Alignment.Start
|
||||
) {
|
||||
Surface(
|
||||
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
|
||||
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
|
||||
shape = RoundedCornerShape(
|
||||
topStart = 14.dp,
|
||||
topEnd = 14.dp,
|
||||
@@ -388,14 +391,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
|
||||
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f))
|
||||
) {
|
||||
Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) {
|
||||
if (!isMe) {
|
||||
Text(
|
||||
message.display_name?.lowercase() ?: "unknown",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.primary.copy(alpha = 0.7f),
|
||||
modifier = Modifier.padding(bottom = 2.dp)
|
||||
)
|
||||
}
|
||||
|
||||
Text(
|
||||
text = message.text,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
@@ -403,10 +399,11 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Text(
|
||||
text = time,
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f),
|
||||
text = if (!isMe) message.display_name.lowercase() + " . " + time else time,
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
|
||||
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp)
|
||||
)
|
||||
}
|
||||
|
||||
+18
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.feature.settings
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import dev.zxq5.chatapp.android.api.model.InviteRequest
|
||||
import dev.zxq5.chatapp.android.api.model.QrResponse
|
||||
import dev.zxq5.chatapp.android.core.error.ApiResult
|
||||
import dev.zxq5.chatapp.android.data.repository.SettingsRepository
|
||||
@@ -27,6 +28,9 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
|
||||
private val _isSuccessState = MutableStateFlow<Map<String, Boolean>>(emptyMap())
|
||||
val isSuccessState: StateFlow<Map<String, Boolean>> = _isSuccessState
|
||||
|
||||
private val _lastInviteCode = MutableStateFlow<String?>(null)
|
||||
val lastInviteCode: StateFlow<String?> = _lastInviteCode
|
||||
|
||||
fun clearMessages() {
|
||||
_settingsError.value = null
|
||||
_totpError.value = null
|
||||
@@ -40,6 +44,20 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
|
||||
}
|
||||
}
|
||||
|
||||
fun createInvite(request: InviteRequest) {
|
||||
viewModelScope.launch {
|
||||
_settingsError.value = null
|
||||
when (val result = settingsRepository.createInvite(request)) {
|
||||
is ApiResult.Success -> {
|
||||
_lastInviteCode.value = result.data
|
||||
triggerSuccess("invite")
|
||||
}
|
||||
is ApiResult.HttpError -> _settingsError.value = result.message
|
||||
is ApiResult.NetworkError -> _settingsError.value = result.message
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun fetchTotpStatus() {
|
||||
viewModelScope.launch {
|
||||
when (val result = settingsRepository.getTotpStatus()) {
|
||||
|
||||
@@ -23,11 +23,14 @@ import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.filled.ArrowBack
|
||||
import androidx.compose.material.icons.filled.ContentCopy
|
||||
import androidx.compose.material.icons.filled.KeyboardArrowDown
|
||||
import androidx.compose.material.icons.filled.KeyboardArrowUp
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ButtonDefaults
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.DatePicker
|
||||
import androidx.compose.material3.DatePickerDialog
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.HorizontalDivider
|
||||
import androidx.compose.material3.Icon
|
||||
@@ -37,8 +40,10 @@ import androidx.compose.material3.OutlinedTextField
|
||||
import androidx.compose.material3.OutlinedTextFieldDefaults
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.material3.TopAppBar
|
||||
import androidx.compose.material3.TopAppBarDefaults
|
||||
import androidx.compose.material3.rememberDatePickerState
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
@@ -57,9 +62,15 @@ import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import android.util.Base64
|
||||
import android.graphics.BitmapFactory
|
||||
import androidx.compose.ui.platform.LocalClipboardManager
|
||||
import androidx.compose.ui.text.AnnotatedString
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import dev.zxq5.chatapp.android.api.model.InviteRequest
|
||||
import kotlin.time.Duration.Companion.days
|
||||
import kotlin.time.ExperimentalTime
|
||||
import kotlin.time.Instant
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalTime::class)
|
||||
@Composable
|
||||
fun SettingsScreen(
|
||||
viewModel: SettingsViewModel,
|
||||
@@ -70,6 +81,7 @@ fun SettingsScreen(
|
||||
val settingsError by viewModel.settingsError.collectAsState()
|
||||
val isSuccessState by viewModel.isSuccessState.collectAsState()
|
||||
val totpError by viewModel.totpError.collectAsState()
|
||||
val lastInviteCode by viewModel.lastInviteCode.collectAsState()
|
||||
|
||||
LaunchedEffect(Unit) {
|
||||
viewModel.clearMessages()
|
||||
@@ -274,6 +286,120 @@ fun SettingsScreen(
|
||||
}
|
||||
}
|
||||
|
||||
SettingsSection(title = "invite") {
|
||||
var inviteName by remember { mutableStateOf("") }
|
||||
var maxUses by remember { mutableStateOf("1") }
|
||||
val clipboardManager = LocalClipboardManager.current
|
||||
|
||||
var showDatePicker by remember { mutableStateOf(false) }
|
||||
val datePickerState = rememberDatePickerState(
|
||||
initialSelectedDateMillis = System.currentTimeMillis() + 7.days.inWholeMilliseconds
|
||||
)
|
||||
|
||||
Text("create invite token", style = MaterialTheme.typography.bodyMedium, modifier = Modifier.padding(bottom = 8.dp))
|
||||
|
||||
OutlinedTextField(
|
||||
value = inviteName,
|
||||
onValueChange = { inviteName = it },
|
||||
label = { Text("name") },
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
)
|
||||
Spacer(Modifier.height(8.dp))
|
||||
OutlinedTextField(
|
||||
value = maxUses,
|
||||
onValueChange = { if (it.all { c -> c.isDigit() }) maxUses = it },
|
||||
label = { Text("max uses") },
|
||||
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(8.dp))
|
||||
|
||||
OutlinedTextField(
|
||||
value = datePickerState.selectedDateMillis?.let { Instant.fromEpochMilliseconds(it).toString().substringBefore("T") } ?: "",
|
||||
onValueChange = {},
|
||||
label = { Text("expiry date") },
|
||||
readOnly = true,
|
||||
trailingIcon = {
|
||||
IconButton(onClick = { showDatePicker = true }) {
|
||||
Icon(Icons.Default.KeyboardArrowDown, contentDescription = "Select Date")
|
||||
}
|
||||
},
|
||||
modifier = Modifier.fillMaxWidth().clickable { showDatePicker = true },
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
)
|
||||
|
||||
if (showDatePicker) {
|
||||
DatePickerDialog(
|
||||
onDismissRequest = { showDatePicker = false },
|
||||
confirmButton = {
|
||||
TextButton(onClick = { showDatePicker = false }) {
|
||||
Text("ok")
|
||||
}
|
||||
}
|
||||
) {
|
||||
DatePicker(state = datePickerState)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(Modifier.height(12.dp))
|
||||
|
||||
SuccessButton(
|
||||
onClick = {
|
||||
val nowMs = System.currentTimeMillis()
|
||||
val expiryMs = datePickerState.selectedDateMillis ?: (nowMs + 7.days.inWholeMilliseconds)
|
||||
viewModel.createInvite(
|
||||
InviteRequest(
|
||||
name = inviteName,
|
||||
max_uses = maxUses.toIntOrNull() ?: 1,
|
||||
start_date = Instant.fromEpochMilliseconds(nowMs),
|
||||
expiry_date = Instant.fromEpochMilliseconds(expiryMs)
|
||||
)
|
||||
)
|
||||
},
|
||||
label = "generate invite",
|
||||
isSuccess = isSuccessState["invite"] == true,
|
||||
enabled = inviteName.isNotBlank(),
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
if (lastInviteCode != null) {
|
||||
Spacer(Modifier.height(16.dp))
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f), RoundedCornerShape(8.dp))
|
||||
.padding(12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
Text(
|
||||
text = lastInviteCode!!,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
IconButton(onClick = {
|
||||
clipboardManager.setText(AnnotatedString(lastInviteCode!!))
|
||||
}) {
|
||||
Icon(Icons.Default.ContentCopy, contentDescription = "Copy", modifier = Modifier.size(20.dp))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SettingsSection(title = "session") {
|
||||
Button(
|
||||
onClick = onLogout,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
shape = RoundedCornerShape(8.dp),
|
||||
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
|
||||
) {
|
||||
Text("logout")
|
||||
}
|
||||
}
|
||||
|
||||
SettingsSection(title = "danger zone", color = Color.Red.copy(alpha = 0.7f)) {
|
||||
var deletePassword by remember { mutableStateOf("") }
|
||||
var deleteTotp by remember { mutableStateOf("") }
|
||||
@@ -337,18 +463,6 @@ fun SettingsScreen(
|
||||
}
|
||||
}
|
||||
|
||||
SettingsSection(title = "session") {
|
||||
Spacer(Modifier.height(16.dp))
|
||||
Button(
|
||||
onClick = onLogout,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
shape = RoundedCornerShape(8.dp),
|
||||
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
|
||||
) {
|
||||
Text("logout")
|
||||
}
|
||||
}
|
||||
|
||||
if (settingsError != null) {
|
||||
Text(settingsError!!, color = Color.Red, style = MaterialTheme.typography.bodySmall, modifier = Modifier.padding(top = 8.dp))
|
||||
}
|
||||
@@ -457,6 +571,7 @@ fun SuccessButton(
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalTime::class)
|
||||
@Composable
|
||||
fun TwoFactorSetup(
|
||||
qrCodeBase64: String?,
|
||||
@@ -511,15 +626,13 @@ fun TwoFactorSetup(
|
||||
Text(error.lowercase(), color = Color.Red, style = MaterialTheme.typography.labelSmall, modifier = Modifier.padding(top = 8.dp))
|
||||
}
|
||||
|
||||
Spacer(Modifier.height(24.dp))
|
||||
|
||||
Button(
|
||||
onClick = { if (code.length == 6) onConfirm(code) },
|
||||
Spacer(Modifier.height(16.dp))
|
||||
SuccessButton(
|
||||
onClick = { onConfirm(code) },
|
||||
label = "verify and enable",
|
||||
isSuccess = false, // Managed by parent
|
||||
enabled = code.length == 6,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
) {
|
||||
Text("confirm code")
|
||||
}
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package dev.zxq5.chatapp.android.ui.components
|
||||
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.OutlinedTextField
|
||||
import androidx.compose.material3.OutlinedTextFieldDefaults
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.text.input.KeyboardType
|
||||
import androidx.compose.ui.text.input.PasswordVisualTransformation
|
||||
import androidx.compose.ui.unit.dp
|
||||
|
||||
@@ -31,6 +33,11 @@ fun TextField(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
singleLine = true,
|
||||
textStyle = MaterialTheme.typography.bodyLarge,
|
||||
keyboardOptions = if (isPassword) {
|
||||
KeyboardOptions(keyboardType = KeyboardType.Password)
|
||||
} else {
|
||||
KeyboardOptions.Default
|
||||
},
|
||||
visualTransformation = if (isPassword) PasswordVisualTransformation() else androidx.compose.ui.text.input.VisualTransformation.None,
|
||||
shape = RoundedCornerShape(8.dp),
|
||||
colors = OutlinedTextFieldDefaults.colors(
|
||||
@@ -40,6 +47,6 @@ fun TextField(
|
||||
unfocusedBorderColor = MaterialTheme.colorScheme.outline,
|
||||
focusedTextColor = MaterialTheme.colorScheme.onSurface,
|
||||
unfocusedTextColor = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package dev.zxq5.chatapp.android.ui.theme
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
|
||||
val Black = Color(0xFF0A0A0A)
|
||||
val Black = Color(0xFF000000)
|
||||
val DarkGrey = Color(0xFF0D0D0D)
|
||||
val Grey = Color(0xFF141414)
|
||||
val LightGrey = Color(0xFF1E1E1E)
|
||||
|
||||
Generated
+6
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="uk.co.ben_gibson.git.link.SettingsState">
|
||||
<option name="host" value="e0f86390-1091-4871-8aeb-f534fbc99cf0" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+1
-1
@@ -1,7 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DataSourcePerFileMappings">
|
||||
<file url="file://$PROJECT_DIR$/sql/schema.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
|
||||
<file url="file://$PROJECT_DIR$/sql/test.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
|
||||
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+2
-1
@@ -1,7 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SqlDialectMappings">
|
||||
<file url="file://$PROJECT_DIR$/sql/schema.sql" dialect="PostgreSQL" />
|
||||
<file url="file://$PROJECT_DIR$/migrations/20260412200102_message_id_to_uuid.sql" dialect="PostgreSQL" />
|
||||
<file url="file://$PROJECT_DIR$/sql/test.sql" dialect="PostgreSQL" />
|
||||
<file url="PROJECT" dialect="PostgreSQL" />
|
||||
</component>
|
||||
</project>
|
||||
+3
-3
@@ -12,7 +12,7 @@ image = "0.25.8"
|
||||
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
|
||||
rand = "0.8"
|
||||
redis = { version = "0.25.4", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.23", features = ["json"] }
|
||||
reqwest = { version = "0.12.23", features = ["json", "stream"] }
|
||||
rocket = { version = "0.5.1", features = ["json", "secrets"] }
|
||||
rocket_cors = "0.6.0"
|
||||
rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] }
|
||||
@@ -20,11 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0.145"
|
||||
sha2 = "0.10.9"
|
||||
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] }
|
||||
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time", "uuid"] }
|
||||
tokio = { version = "1.47.1", features = ["full"] }
|
||||
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
|
||||
tracing = "0.1.44"
|
||||
uuid = { version = "1.18.1", features = ["v4"] }
|
||||
uuid = { version = "1.18.1", features = ["serde", "v4"] }
|
||||
thiserror = "1.0.69"
|
||||
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
|
||||
+1
-1
@@ -1,7 +1,7 @@
|
||||
[debug]
|
||||
secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
|
||||
address = "0.0.0.0"
|
||||
port = 8000
|
||||
port = 8080
|
||||
|
||||
[debug.databases.postgres_db]
|
||||
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev"
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
ALTER TABLE attachments DROP CONSTRAINT attachments_message_id_fkey;
|
||||
|
||||
ALTER TABLE messages ALTER COLUMN id DROP DEFAULT;
|
||||
ALTER TABLE messages ALTER COLUMN id TYPE uuid USING gen_random_uuid();
|
||||
ALTER TABLE messages ALTER COLUMN id SET DEFAULT gen_random_uuid();
|
||||
|
||||
ALTER TABLE attachments ALTER COLUMN message_id TYPE uuid USING gen_random_uuid();
|
||||
ALTER TABLE attachments ADD CONSTRAINT attachments_message_id_fkey
|
||||
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE;
|
||||
@@ -0,0 +1,17 @@
|
||||
WITH space1 AS (
|
||||
INSERT INTO spaces (name, description, owner_id)
|
||||
VALUES ('general', 'Boring chat idk', 1)
|
||||
RETURNING id
|
||||
),
|
||||
space2 AS (
|
||||
INSERT INTO spaces (name, description, owner_id)
|
||||
VALUES ('Gaming', 'we lose games', 1)
|
||||
RETURNING id
|
||||
)
|
||||
INSERT INTO channels (name, description, space_id)
|
||||
SELECT 'General', 'General chat', id FROM space1 UNION ALL
|
||||
SELECT 'Coding', 'Coding stuff', id FROM space1 UNION ALL
|
||||
SELECT 'AI', '"/ask" here pls :)', id FROM space1 UNION ALL
|
||||
SELECT 'The Game', '(You lost)', id FROM space2 UNION ALL
|
||||
SELECT 'Backrooms', 'Beware of Smilers', id FROM space2 UNION ALL
|
||||
SELECT 'SE', 'Space/Software engineering.', id FROM space2;
|
||||
+9
-16
@@ -9,15 +9,7 @@ use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
|
||||
use sqlx::FromRow;
|
||||
use tokio::select;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// ---------- Rocket routes ----------
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
|
||||
pub struct ChatMsg {
|
||||
pub display_name: Option<String>,
|
||||
pub user_id: i64,
|
||||
pub text: String,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
use crate::model::event::{ChatEvent, ChatMsg};
|
||||
|
||||
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
|
||||
pub async fn post_message(
|
||||
@@ -36,24 +28,25 @@ pub async fn event_stream(
|
||||
mut shutdown: Shutdown,
|
||||
channel_id: i64,
|
||||
) -> ApiResult<EventStream![]> {
|
||||
let messages = chat.get_messages(channel_id, 100)
|
||||
let messages = chat.fetch_latest_messages_desc(channel_id, 100)
|
||||
.await?; // if get message returned err, inform user.
|
||||
|
||||
let mut rx = chat.subscribe(channel_id).await;
|
||||
let id = s.uid;
|
||||
|
||||
Ok(EventStream! {
|
||||
for msg in messages {
|
||||
yield Event::json(&msg);
|
||||
for msg in messages.into_iter().rev() {
|
||||
// tracing::info!("sending: {:?}", serde_json::to_string(&ChatEvent::SendMessage(msg.clone())).unwrap());
|
||||
yield Event::json(&ChatEvent::SendMessage(msg));
|
||||
}
|
||||
|
||||
loop {
|
||||
select!{
|
||||
_ = &mut shutdown => break, // exit early on shutdown
|
||||
msg = rx.recv() => match msg {
|
||||
Ok(msg) => {
|
||||
tracing::info!("yielding message!");
|
||||
yield Event::json(&msg)
|
||||
event = rx.recv() => match event {
|
||||
Ok(event) => {
|
||||
// tracing::info!("yielding event: {event:?}");
|
||||
yield Event::json(&event)
|
||||
},
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use chrono::{DateTime, Utc};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// ---------- Rocket routes ----------
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
|
||||
pub struct ChatMsg {
|
||||
pub id: Uuid,
|
||||
pub display_name: Option<String>,
|
||||
pub user_id: i64,
|
||||
pub text: String,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "data")]
|
||||
pub enum ChatEvent {
|
||||
SendMessage(ChatMsg),
|
||||
|
||||
/// for when a user explicitly edits a message
|
||||
EditMessage { id: Uuid, msg: ChatMsg },
|
||||
|
||||
/// used for streaming content to a message
|
||||
/// will not show up as edited
|
||||
MessageAppendContent{ id: Uuid, content: String }
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod auth;
|
||||
pub mod user;
|
||||
pub mod space;
|
||||
pub mod space;
|
||||
pub mod event;
|
||||
@@ -1,68 +1,80 @@
|
||||
use crate::api::chat::ChatMsg;
|
||||
use crate::model::event::ChatMsg;
|
||||
use crate::repo::Repo;
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MessageRepository {
|
||||
pool: PgPool
|
||||
}
|
||||
|
||||
impl Repo for MessageRepository {
|
||||
type Target = ChatMsg;
|
||||
|
||||
fn new(pool: PgPool) -> Self {
|
||||
impl MessageRepository {
|
||||
pub(crate) fn new(pool: PgPool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
// TODO: caching with redis
|
||||
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
|
||||
async fn get_by_id(&self, id: Uuid) -> Option<ChatMsg> {
|
||||
sqlx::query!(
|
||||
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
|
||||
"SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
|
||||
FROM messages m
|
||||
JOIN users u ON m.user_id = u.id
|
||||
WHERE m.id = $1",
|
||||
id
|
||||
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
|
||||
display_name: Some(row.nickname.unwrap_or(row.username)),
|
||||
user_id: row.user_id,
|
||||
text: row.content,
|
||||
timestamp: row.created_at,
|
||||
})
|
||||
id: row.id,
|
||||
display_name: Some(row.nickname.unwrap_or(row.username)),
|
||||
user_id: row.user_id,
|
||||
text: row.content,
|
||||
timestamp: row.created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MessageRepository {
|
||||
|
||||
// TODO! caching with redis
|
||||
pub async fn create_new(
|
||||
&self, uid: i64, channel_id: i64,
|
||||
text: &str, created_at: DateTime<Utc>
|
||||
) -> Result<i64, sqlx::Error> {
|
||||
&self, msg: ChatMsg, channel_id: i64
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query!(
|
||||
"INSERT INTO messages (channel_id, user_id, content, created_at)
|
||||
VALUES ($1, $2, $3, $4) RETURNING id",
|
||||
"INSERT INTO messages (id, channel_id, user_id, content, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)",
|
||||
msg.id,
|
||||
channel_id,
|
||||
uid,
|
||||
msg.user_id,
|
||||
msg.text,
|
||||
msg.timestamp
|
||||
).execute(&self.pool).await.map_err(|_| sqlx::Error::RowNotFound)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_text(&self, id: Uuid, text: &str) -> Result<(), sqlx::Error> {
|
||||
sqlx::query!(
|
||||
"UPDATE messages SET content = $1 WHERE id = $2",
|
||||
text,
|
||||
created_at
|
||||
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
|
||||
id
|
||||
).execute(&self.pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// TODO: caching with redis
|
||||
pub async fn get_by_channel(&self, channel_id: i64, limit: usize)
|
||||
-> Result<Vec<ChatMsg>, sqlx::Error> {
|
||||
pub async fn get_latest_by_channel_desc(
|
||||
&self, channel_id: i64, limit: usize, page: usize
|
||||
) -> Result<Vec<ChatMsg>, sqlx::Error> {
|
||||
sqlx::query!(
|
||||
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
|
||||
"SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
|
||||
FROM messages m
|
||||
JOIN users u ON m.user_id = u.id
|
||||
WHERE m.channel_id = $1
|
||||
ORDER BY m.created_at DESC LIMIT $2",
|
||||
ORDER BY m.created_at DESC
|
||||
LIMIT $2 OFFSET $3",
|
||||
channel_id,
|
||||
limit as i64
|
||||
limit as i64,
|
||||
page as i64
|
||||
).fetch_all(&self.pool).await.map(|messages| {
|
||||
messages.into_iter().rev().map(|msg| {
|
||||
messages.into_iter().map(|msg| {
|
||||
ChatMsg {
|
||||
id: msg.id,
|
||||
display_name: Some(msg.nickname.unwrap_or(msg.username)),
|
||||
user_id: msg.user_id,
|
||||
text: msg.content,
|
||||
|
||||
+108
-46
@@ -1,12 +1,15 @@
|
||||
use crate::api::chat::ChatMsg;
|
||||
use crate::model::event::{ChatEvent, ChatMsg};
|
||||
use crate::error::{ApiResult, AppError};
|
||||
use crate::repo::message_repo::MessageRepository;
|
||||
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
|
||||
use chrono::{DateTime, Utc};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::broadcast::Sender;
|
||||
use tokio::sync::{broadcast, Mutex};
|
||||
use uuid::Uuid;
|
||||
use crate::model::space::SpaceDto;
|
||||
use crate::svc::llm_service::LlmService;
|
||||
|
||||
@@ -15,15 +18,13 @@ use crate::svc::llm_service::LlmService;
|
||||
#[derive(Clone)]
|
||||
pub struct ChatService {
|
||||
users: Arc<dyn UserRepo>,
|
||||
channels: Arc<dyn ChannelRepo>,
|
||||
channel_repo: Arc<dyn ChannelRepo>,
|
||||
spaces: Arc<dyn SpaceRepo>,
|
||||
messages: MessageRepository,
|
||||
|
||||
llm: LlmService,
|
||||
buffer_size: usize,
|
||||
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>,
|
||||
|
||||
|
||||
channels: Arc<Mutex<HashMap<i64, ChannelState>>>,
|
||||
}
|
||||
|
||||
impl ChatService {
|
||||
@@ -33,13 +34,13 @@ impl ChatService {
|
||||
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channels,
|
||||
channel_repo: channels,
|
||||
spaces,
|
||||
llm,
|
||||
users,
|
||||
messages,
|
||||
buffer_size,
|
||||
senders: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
channels: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +51,7 @@ impl ChatService {
|
||||
|
||||
let mut result = Vec::new();
|
||||
for space in spaces {
|
||||
let channels = self.channels.get_by_space_id(space.id).await?;
|
||||
let channels = self.channel_repo.get_by_space_id(space.id).await?;
|
||||
result.push(SpaceDto {
|
||||
channels,
|
||||
id: space.id,
|
||||
@@ -65,8 +66,10 @@ impl ChatService {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
|
||||
let messages = self.messages.get_by_channel(channel_id, limit).await?;
|
||||
pub async fn fetch_latest_messages_desc(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
|
||||
const PAGE: usize = 0;
|
||||
|
||||
let messages = self.messages.get_latest_by_channel_desc(channel_id, limit, PAGE).await?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
@@ -113,6 +116,7 @@ impl ChatService {
|
||||
.ok_or(AppError::NotFound)?;
|
||||
|
||||
let message = ChatMsg {
|
||||
id: Uuid::new_v4(),
|
||||
display_name: Some(user
|
||||
.nickname.clone()
|
||||
.unwrap_or_else(|| user.username.clone())),
|
||||
@@ -122,66 +126,124 @@ impl ChatService {
|
||||
};
|
||||
|
||||
self.publish(channel_id, message.clone()).await;
|
||||
|
||||
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
|
||||
self.messages.create_new(message.clone(), channel_id).await?;
|
||||
// TODO: caching w redis at repository layer
|
||||
|
||||
let svc_instance = self.clone();
|
||||
|
||||
let Some(text) = text.strip_prefix("/ask ") else {
|
||||
return Ok(())
|
||||
};
|
||||
|
||||
if !svc_instance.llm.enabled() {
|
||||
return Ok(())
|
||||
if !message.text.starts_with("/ask ") {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
let svc_instance = self.clone();
|
||||
if !svc_instance.llm.enabled() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let context = self.get_history(channel_id, 25).await?;
|
||||
tokio::spawn(async move {
|
||||
let sender = match svc_instance.channels.lock().await.get(&channel_id) {
|
||||
Some(s) => s.get_sender(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let response = svc_instance.llm
|
||||
.query(&message)
|
||||
.query(&message, &context, sender)
|
||||
.await;
|
||||
|
||||
if let Ok(reply) = response {
|
||||
|
||||
tracing::info!("LLM reply: {}", reply.text);
|
||||
|
||||
svc_instance.publish(channel_id, reply.clone()).await;
|
||||
// TODO: cache response (or do with redis!)
|
||||
if let Err(e) = svc_instance.messages
|
||||
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
|
||||
tracing::error!("Failed to persist LLM reply: {}", e);
|
||||
}
|
||||
|
||||
tracing::info!("LLM reply persisted");
|
||||
|
||||
} else {
|
||||
let Ok(reply) = response else {
|
||||
tracing::warn!("Error contacting LLM: {:?}", response);
|
||||
return;
|
||||
};
|
||||
|
||||
// TODO: cache response (or do with redis!)
|
||||
if let Err(e) = svc_instance.messages.create_new(reply, channel_id).await {
|
||||
tracing::error!("Failed to persist LLM reply: {}", e);
|
||||
}
|
||||
|
||||
tracing::debug!("Full LLM reply persisted");
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_history(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
|
||||
let mut map = self.channels.lock().await;
|
||||
if let Some(channel) = map.get(&channel_id) && !channel.cache.is_empty() {
|
||||
Ok(channel.history().clone().into_iter().take(limit).collect())
|
||||
} else {
|
||||
let messages: Vec<_> = self.messages.get_latest_by_channel_desc(channel_id, limit, 0).await?.into_iter().rev().collect();
|
||||
map.insert(channel_id,
|
||||
ChannelState::new(self.buffer_size, Some(messages.clone().into()))
|
||||
);
|
||||
Ok(messages)
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to the specified channel.
|
||||
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> {
|
||||
let mut map = self.senders.lock().await;
|
||||
let sender = map
|
||||
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatEvent> {
|
||||
let mut map = self.channels.lock().await;
|
||||
let channel = map
|
||||
.entry(channel_id)
|
||||
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
||||
sender.subscribe()
|
||||
.or_insert_with(|| ChannelState::new(self.buffer_size, None));
|
||||
channel.subscribe()
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
/// Publish a message to the specified channel.
|
||||
async fn publish(&self, channel_id: i64, msg: ChatMsg) {
|
||||
let mut map = self.senders.lock().await;
|
||||
let sender = map
|
||||
let mut map = self.channels.lock().await;
|
||||
let channel = map
|
||||
.entry(channel_id)
|
||||
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
|
||||
let _ = sender.send(msg);
|
||||
.or_insert_with(|| ChannelState::new(self.buffer_size, None));
|
||||
channel.send(msg);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelState {
|
||||
sender: Sender<ChatEvent>,
|
||||
cache: VecDeque<ChatMsg>,
|
||||
last_updated: Instant,
|
||||
}
|
||||
|
||||
impl ChannelState {
|
||||
|
||||
const MAX_HISTORY_SIZE: usize = 100;
|
||||
|
||||
#[must_use]
|
||||
pub fn new(buffer_size: usize, history: Option<VecDeque<ChatMsg>>) -> Self {
|
||||
Self {
|
||||
sender: broadcast::channel(buffer_size).0,
|
||||
cache: history.unwrap_or_default(),
|
||||
last_updated: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn history(&self) -> &VecDeque<ChatMsg> {
|
||||
&self.cache
|
||||
}
|
||||
|
||||
pub fn get_sender(&self) -> Sender<ChatEvent> {
|
||||
self.sender.clone()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn send(&mut self, msg: ChatMsg) {
|
||||
while self.cache.len() >= Self::MAX_HISTORY_SIZE {
|
||||
self.cache.pop_front();
|
||||
}
|
||||
|
||||
self.cache.push_back(msg.clone());
|
||||
|
||||
if self.sender.send(ChatEvent::SendMessage(msg)).is_err() {
|
||||
tracing::warn!("Sent message to channel with no subscribers");
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<ChatEvent> {
|
||||
self.sender.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
+119
-41
@@ -1,3 +1,13 @@
|
||||
use std::env;
|
||||
use std::sync::LazyLock;
|
||||
use futures_util::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::broadcast::Sender;
|
||||
use uuid::Uuid;
|
||||
use crate::model::event::{ChatEvent, ChatMsg};
|
||||
use crate::error::{ApiResult, AppError};
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LlmService;
|
||||
|
||||
@@ -15,71 +25,139 @@ impl LlmService {
|
||||
LMSTUDIO_URL.is_some()
|
||||
}
|
||||
|
||||
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> {
|
||||
pub async fn query(&self, message: &ChatMsg, context: &[ChatMsg], sender: Sender<ChatEvent>) -> ApiResult<ChatMsg> {
|
||||
let Some(url) = LMSTUDIO_URL.clone() else {
|
||||
return Err(AppError::internal("AI not enabled!"))
|
||||
};
|
||||
|
||||
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into());
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Build the request body
|
||||
let payload = LlmRequest {
|
||||
model, // whatever model you run locally
|
||||
messages: vec![Message {
|
||||
role: "user".into(),
|
||||
content: message.text.clone(),
|
||||
}],
|
||||
let reply_id = Uuid::new_v4();
|
||||
let timestamp = chrono::Utc::now();
|
||||
let mut reply = ChatMsg {
|
||||
id: reply_id,
|
||||
display_name: Some("llm".into()),
|
||||
user_id: 0,
|
||||
text: String::new(),
|
||||
timestamp,
|
||||
};
|
||||
let _ = sender.send(ChatEvent::SendMessage(reply.clone()));
|
||||
|
||||
// POST to lm‑studio (default 127.0.0.1:1234)
|
||||
let resp = client
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
let system_prompt = format!(
|
||||
"You are a helpful assistant in a group chat. \
|
||||
You are talking to '{}'. \
|
||||
Keep responses concise and conversational.",
|
||||
message.display_name.as_deref().unwrap_or("unknown"),
|
||||
);
|
||||
messages.push(Message { role: "system".into(), content: system_prompt });
|
||||
|
||||
for msg in context {
|
||||
let role = if msg.user_id == 0 {
|
||||
"assistant" // your LLM user_id convention
|
||||
} else {
|
||||
"user"
|
||||
};
|
||||
messages.push(Message {
|
||||
role: role.into(),
|
||||
content: format!(
|
||||
"{}: {}",
|
||||
msg.display_name.as_deref().unwrap_or("unknown"),
|
||||
msg.text
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".into(),
|
||||
content: format!(
|
||||
"{}: {}",
|
||||
message.display_name.as_deref().unwrap_or("unknown"),
|
||||
message.text.trim_start_matches("/ask ") // strip the command prefix
|
||||
),
|
||||
});
|
||||
|
||||
let Ok(resp) = reqwest::Client::new()
|
||||
.post(url)
|
||||
.json(&payload)
|
||||
.json(&LlmRequest {
|
||||
think: false,
|
||||
model: LMSTUDIO_MODEL
|
||||
.clone()
|
||||
.unwrap_or_else(|| "gpt-oss-20b".into()),
|
||||
messages,
|
||||
stream: true,
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?;
|
||||
else {
|
||||
tracing::warn!("Failed to reach LLM");
|
||||
let _ = sender.send(ChatEvent::MessageAppendContent {
|
||||
id: reply_id,
|
||||
content: String::from("I'm not available right now. Please try again later.")
|
||||
});
|
||||
return Err(AppError::internal("Failed to reach LLM"));
|
||||
};
|
||||
|
||||
// The API returns a JSON with `choices[].message.content`
|
||||
#[derive(Deserialize)]
|
||||
struct LlmResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
|
||||
let mut full_text = String::new();
|
||||
let mut buffer = String::new();
|
||||
let mut stream = resp.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|_| AppError::internal("Stream error"))?;
|
||||
buffer.push_str(&String::from_utf8_lossy(&chunk));
|
||||
|
||||
while let Some(pos) = buffer.find('\n') {
|
||||
let line = buffer[..pos].trim().to_string();
|
||||
buffer = buffer[pos + 1..].to_string();
|
||||
|
||||
if line == "data: [DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(json) = line.strip_prefix("data: ")
|
||||
&& let Ok(parsed) = serde_json::from_str::<StreamingResponse>(json)
|
||||
&& let Some(content) = parsed.choices[0].delta.content.as_ref() {
|
||||
full_text.push_str(content);
|
||||
let _ = sender.send(ChatEvent::MessageAppendContent {
|
||||
id: reply_id,
|
||||
content: content.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let llm_resp: LlmResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|_| AppError::internal("Failed to parse LLM response"))?;
|
||||
reply.text = full_text;
|
||||
|
||||
Ok(ChatMsg {
|
||||
display_name: Some(String::from("llm")),
|
||||
user_id: 0,
|
||||
text: llm_resp.choices[0].message.content.clone(),
|
||||
timestamp: chrono::Utc::now(),
|
||||
})
|
||||
Ok(reply)
|
||||
}
|
||||
}
|
||||
|
||||
use std::env;
|
||||
use std::sync::LazyLock;
|
||||
// src/llm.rs
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::chat::ChatMsg;
|
||||
use crate::error::{ApiResult, AppError};
|
||||
use crate::svc::chat_svc::ChatService;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct LlmRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
think: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingResponse {
|
||||
choices: Vec<StreamingChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingChoice {
|
||||
delta: Delta,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Delta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String, // "user" or "assistant"
|
||||
|
||||
@@ -3,4 +3,5 @@ pub mod chat_svc;
|
||||
pub mod settings_svc;
|
||||
pub mod user_svc;
|
||||
pub mod access_token_svc;
|
||||
pub mod llm_service;
|
||||
pub mod llm_service;
|
||||
pub mod relationship_svc;
|
||||
@@ -0,0 +1,6 @@
|
||||
pub struct RelationshipService {}
|
||||
|
||||
impl RelationshipService {
|
||||
pub fn new() -> Self { Self {} }
|
||||
}
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
import argparse
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from getpass import getpass
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
token: str
|
||||
|
||||
|
||||
def signup(session: requests.Session, email: str, username: str, password: str, access_token: str) -> None:
|
||||
url = f"{BASE_URL}/api/signup"
|
||||
payload = {
|
||||
"email": email,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"access_token": access_token,
|
||||
}
|
||||
resp = session.post(url, json=payload, timeout=10)
|
||||
resp.raise_for_status()
|
||||
|
||||
|
||||
def login(session: requests.Session, username: str, password: str) -> AuthResult:
|
||||
url = f"{BASE_URL}/api/login"
|
||||
payload = {
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
resp = session.post(url, json=payload, timeout=10)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
raise RuntimeError(f"Login response did not contain token: {data}")
|
||||
return AuthResult(token=token)
|
||||
|
||||
|
||||
def post_message(session: requests.Session, channel_id: int, token: str, text: str, display_name: str, user_id: int) -> None:
|
||||
url = f"{BASE_URL}/api/chat/{channel_id}"
|
||||
payload = {
|
||||
"display_name": display_name,
|
||||
"user_id": user_id,
|
||||
"text": text,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
resp = session.post(url, json=payload, headers=headers, timeout=10)
|
||||
resp.raise_for_status()
|
||||
|
||||
|
||||
def read_sse_messages(
|
||||
session: requests.Session,
|
||||
channel_id: int,
|
||||
token: str,
|
||||
expected_count: int,
|
||||
timeout_s: int,
|
||||
capture_live_messages: threading.Event,
|
||||
) -> List[dict]:
|
||||
url = f"{BASE_URL}/api/events/{channel_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
|
||||
received: List[dict] = []
|
||||
deadline = time.monotonic() + timeout_s
|
||||
|
||||
try:
|
||||
with session.get(url, headers=headers, stream=True, timeout=(5, timeout_s)) as resp:
|
||||
resp.raise_for_status()
|
||||
|
||||
event_data_lines: List[str] = []
|
||||
|
||||
for raw_line in resp.iter_lines(decode_unicode=True):
|
||||
if time.monotonic() > deadline:
|
||||
break
|
||||
|
||||
if raw_line is None:
|
||||
continue
|
||||
|
||||
line = raw_line.strip()
|
||||
|
||||
if not line:
|
||||
if event_data_lines:
|
||||
joined = "\n".join(event_data_lines)
|
||||
event_data_lines.clear()
|
||||
try:
|
||||
obj = json.loads(joined)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if capture_live_messages.is_set():
|
||||
received.append(obj)
|
||||
if expected_count > 0 and len(received) >= expected_count:
|
||||
break
|
||||
else:
|
||||
print(f"Discarding message: {obj}")
|
||||
continue
|
||||
|
||||
if line.startswith("data:"):
|
||||
event_data_lines.append(line[len("data:"):].strip())
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("Timeout while reading SSE.")
|
||||
except requests.exceptions.RequestException as exc:
|
||||
print(f"Error reading SSE: {exc}")
|
||||
|
||||
return received
|
||||
|
||||
|
||||
def prompt_nonempty(label: str, secret: bool = False) -> str:
|
||||
while True:
|
||||
value = getpass(label) if secret else input(label)
|
||||
value = value.strip()
|
||||
if value:
|
||||
return value
|
||||
print("Please enter a value.")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Chat integration test against localhost:8000")
|
||||
parser.add_argument("--existing-account", action="store_true",
|
||||
help="Skip signup and only log in with an existing account")
|
||||
parser.add_argument("--email", default=None)
|
||||
parser.add_argument("--username", default=None)
|
||||
parser.add_argument("--password", default=None)
|
||||
parser.add_argument("--access-token", default=None,
|
||||
help="Required only for signup mode")
|
||||
parser.add_argument("--channel-id", type=int, default=1)
|
||||
parser.add_argument("--message-count", type=int, default=5)
|
||||
parser.add_argument("--timeout", type=int, default=15)
|
||||
args = parser.parse_args()
|
||||
|
||||
session = requests.Session()
|
||||
|
||||
if args.existing_account:
|
||||
username = args.username or prompt_nonempty("Username: ")
|
||||
password = args.password or prompt_nonempty("Password: ", secret=True)
|
||||
else:
|
||||
email = args.email or prompt_nonempty("Email: ")
|
||||
username = args.username or prompt_nonempty("Username: ")
|
||||
password = args.password or prompt_nonempty("Password: ", secret=True)
|
||||
access_token = args.access_token or prompt_nonempty("Access token: ")
|
||||
|
||||
print("[1/5] Signing up...")
|
||||
try:
|
||||
signup(session, email, username, password, access_token)
|
||||
print(" signup ok")
|
||||
except requests.HTTPError as e:
|
||||
print(f" signup returned HTTP error: {e}")
|
||||
print(" continuing to login...")
|
||||
|
||||
print("[2/5] Logging in...")
|
||||
auth = login(session, username, password)
|
||||
print(" login ok")
|
||||
print(f" token: {auth.token[:12]}...")
|
||||
|
||||
print("[3/5] Opening event stream...")
|
||||
received_messages: List[dict] = []
|
||||
capture_live_messages = threading.Event()
|
||||
stream_done = threading.Event()
|
||||
|
||||
def stream_reader() -> None:
|
||||
nonlocal received_messages
|
||||
try:
|
||||
received_messages = read_sse_messages(
|
||||
session=session,
|
||||
channel_id=args.channel_id,
|
||||
token=auth.token,
|
||||
expected_count=args.message_count,
|
||||
timeout_s=args.timeout,
|
||||
capture_live_messages=capture_live_messages,
|
||||
)
|
||||
finally:
|
||||
stream_done.set()
|
||||
|
||||
t = threading.Thread(target=stream_reader, daemon=True)
|
||||
t.start()
|
||||
|
||||
# Give the server time to flush backlog on this same stream connection.
|
||||
time.sleep(1.0)
|
||||
|
||||
print("[4/5] Starting to capture live messages and sending messages...")
|
||||
capture_live_messages.set()
|
||||
|
||||
sent_texts = [f"Message {i}" for i in range(args.message_count)]
|
||||
for i, text in enumerate(sent_texts):
|
||||
post_message(
|
||||
session=session,
|
||||
channel_id=args.channel_id,
|
||||
token=auth.token,
|
||||
text=text,
|
||||
display_name=username,
|
||||
user_id=1,
|
||||
)
|
||||
print(f" sent {i + 1}/{args.message_count}: {text}")
|
||||
time.sleep(0.1)
|
||||
|
||||
stream_done.wait(timeout=args.timeout)
|
||||
t.join(timeout=1)
|
||||
|
||||
print("\nReceived messages:")
|
||||
for i, msg in enumerate(received_messages, start=1):
|
||||
print(f" {i}. {msg}")
|
||||
|
||||
received_texts = [m.get("text") for m in received_messages if isinstance(m, dict)]
|
||||
|
||||
for text in sent_texts:
|
||||
if text not in received_texts:
|
||||
print(f"\nFAIL: missing message: {text}")
|
||||
return 1
|
||||
|
||||
print("\nPASS: login and message delivery test succeeded.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user